OpenCompass/opencompass/utils/text_postprocessors.py

386 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from typing import Callable, Optional, Union
from opencompass.registry import TEXT_POSTPROCESSORS
@TEXT_POSTPROCESSORS.register_module('general')
def general_postprocess(text: str) -> str:
# Cut off the first newline, period, or comma
truncated_text = re.split(r'[\n.,]', text, 1)[0]
# Remove punctuation
no_punctuation = re.sub(r'[^\w\s]', '', truncated_text)
# Remove article
no_articles = re.sub(r'\b(a|an|the)\b',
'',
no_punctuation,
flags=re.IGNORECASE)
# Remove duplicated blank spaces
cleaned_text = re.sub(r'\s+', ' ', no_articles).strip()
return cleaned_text
@TEXT_POSTPROCESSORS.register_module('general_cn')
def general_cn_postprocess(text: str) -> str:
truncated_text = re.split(r'[\n.,]', text, 1)[0]
no_punctuation = re.sub(r'[^\w\s]', '', truncated_text)
no_articles = re.sub(r'\b(a|an|the)\b',
'',
no_punctuation,
flags=re.IGNORECASE)
cleaned_text = re.sub(r'\s+', ' ', no_articles).strip()
import jieba
cleaned_text = ' '.join(jieba.cut(text))
return cleaned_text
@TEXT_POSTPROCESSORS.register_module('first-capital')
def first_capital_postprocess(text: str) -> str:
for t in text:
if t.isupper():
return t
return ''
@TEXT_POSTPROCESSORS.register_module('last-capital')
def last_capital_postprocess(text: str) -> str:
for t in text[::-1]:
if t.isupper():
return t
return ''
def first_option_postprocess(text: str, options: str, cushion=True) -> str:
"""Find first valid option for text."""
# yapf: disable
# flake8: noqa: W605
patterns = [
f'答案是?\s*([{options}])',
f'答案是?\s*\s*([{options}])',
f'答案是?\s*:\s*([{options}])',
f'答案选项应?该?是\s*([{options}])',
f'答案选项应?该?为\s*([{options}])',
f'答案应该?是\s*([{options}])',
f'答案应该?选\s*([{options}])',
f'答案选项为?\s*\s*([{options}])',
f'答案选项为?\s+\(?\*?\*?([{options}])\*?\*?\)?',
f'答案选项是?\s*:\s*([{options}])',
f'答案为\s*([{options}])',
f'答案选\s*([{options}])',
f'选择?\s*([{options}])',
f'故选?\s*([{options}])'
f'只有选?项?\s?([{options}])\s?是?对',
f'只有选?项?\s?([{options}])\s?是?错',
f'只有选?项?\s?([{options}])\s?不?正确',
f'只有选?项?\s?([{options}])\s?错误',
f'说法不?对选?项?的?是\s?([{options}])',
f'说法不?正确选?项?的?是\s?([{options}])',
f'说法错误选?项?的?是\s?([{options}])',
f'([{options}])\s?是正确的',
f'([{options}])\s?是正确答案',
f'选项\s?([{options}])\s?正确',
f'所以答\s?([{options}])',
f'所以\s?([{options}][.。$]?$)',
f'所有\s?([{options}][.。$]?$)',
f'[\s:,]([{options}])[。,,\.]?$',
f'[\s,:][故即]([{options}])[。\.]?$',
f'[\s,:]因此([{options}])[。\.]?$',
f'[是为。]\s?([{options}])[。\.]?$',
f'因此\s?([{options}])[。\.]?$',
f'显然\s?([{options}])[。\.]?$',
f'答案是\s?(\S+)(?:。|$)',
f'答案应该是\s?(\S+)(?:。|$)',
f'答案为\s?(\S+)(?:。|$)',
f'(?i)ANSWER\s*:\s*([{options}])',
f'[Tt]he answer is:?\s+\(?([{options}])\)?',
f'[Tt]he answer is:?\s+\(?\*?\*?([{options}])\*?\*?\)?',
f'[Tt]he answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is:?.*?boxed{{([{options}])}}',
f'[Tt]he correct option is:?.*?boxed{{([{options}])}}',
f'[Tt]he correct answer option is:?.*?boxed{{([{options}])}}',
f'[Tt]he answer to the question is:?\s+\(?([{options}])\)?',
f'^选项\s?([{options}])',
f'^([{options}])\s?选?项',
f'(\s|^)[{options}][\s。,:\.$]',
f'1.\s?(.*?)$',
f'1.\s?([{options}])[.。$]?$',
]
cushion_patterns = [
f'([{options}]):',
f'([{options}])',
]
# flake8: noqa
# yapf: enable
if cushion:
patterns.extend(cushion_patterns)
for pattern in patterns:
text = text.strip()
match = re.search(pattern, text, re.DOTALL)
if match:
if match.group(1) is not None and match.group(1) != '':
outputs = match.group(1)
else:
outputs = match.group(0)
for i in options:
if i in outputs:
return i
return ''
@TEXT_POSTPROCESSORS.register_module('first-capital-multi')
def first_capital_postprocess_multi(text: str) -> str:
match = re.search(r'([A-D]+)', text)
if match:
return match.group(1)
return ''
def last_option_postprocess(text: str, options: str) -> str:
match = re.findall(rf'([{options}])', text)
if match:
return match[-1]
return ''
def first_number_postprocess(text: str) -> float:
"""Return the first number in a string."""
# regex pattern to match numbers (both integers and decimals)
pattern = r'(-?\d*\.?\d+)'
# search the string for the pattern
match = re.search(pattern, text)
# if a match is found, return it. Otherwise, return None.
return float(match.group(1)) if match else None
@TEXT_POSTPROCESSORS.register_module('multiple-select')
def multiple_select_postprocess(text: str) -> str:
ret = set([t for t in text if t.isupper()])
return ''.join(sorted(ret))
@TEXT_POSTPROCESSORS.register_module('specific-xml-tag')
def xml_tag_postprocessor(text, tag):
"""Extracts content enclosed within a specified XML-style tag from a
string.
Args:
texts: The input string containing XML-style tags.
tag: The XML-style tag to extract content from (e.g., "<conclude>"). Must include the angle brackets.
Returns:
The content enclosed within the specified tag, or None if the tag is not found.
"""
# Use a regular expression to find the content within the specified tag. This handles cases where the tag might appear multiple times.
matches = re.findall(
rf'{tag}(.*?)</{tag[1:-1]}>', text,
re.DOTALL) # re.DOTALL allows . to match newline characters
if matches:
# Only keep the last one
output = matches[-1].strip(
) # Extract the content and remove leading/trailing whitespace
else:
output = 'NO ANSWER FOUND'
return output
def general_eval_wrapper_postprocess(text: str,
postprocess: Optional[Union[
str, Callable]] = None,
**kwargs) -> str:
"""Wrapper for eval text repr. Especially for chatglmpro.
Args:
text(str): Text to be postprocessed.
postprocess(Callable, optional): Original post processing function.
Defaults to None.
**kwargs: Other necessary kwargs for post processing function.
"""
try:
text = eval(text)
except Exception:
# in case empty input or other error, skip eval
pass
if postprocess:
if isinstance(postprocess, str):
postprocess = TEXT_POSTPROCESSORS.get(postprocess)
return postprocess(text, **kwargs)
else:
return text
def match_answer_pattern(response_text: str, answer_pattern: str):
match = re.search(answer_pattern, response_text)
extracted_answer = match.group(1) if match else ''
return extracted_answer
@TEXT_POSTPROCESSORS.register_module('rm_<think>_before_eval')
def remove_reasoning_part_before_evaluation(text: str):
if text.startswith('<think>'):
reasoning_end = text.rfind('</think>')
if reasoning_end == -1:
return text
else:
return text[reasoning_end + 8:]
else:
return text
@TEXT_POSTPROCESSORS.register_module(
'extract_qwq_answer_before_eval_for_huproverbrea')
def extract_answer_before_evaluation(text: str):
"""Overall, there are three situations in responses of QWQ:
1. There is a **Final Answer** title in the whole context.
2. There is only one sentence in the context.
3. There are more than one sentences in the context, \
and the last one is the answer.
"""
text = text.strip('\n')
if '**Final Answer**' in text:
answer = text.split('\n\n**Final Answer**\n\n')[-1]
else:
if '\n\nHuman' in text:
text = text.split('\n\nHuman')[0]
text_split = text.split('.')
if text_split[-1] == '':
answer = text_split[-2] + '.'
else:
answer = text_split[-1] + '.'
return answer
@TEXT_POSTPROCESSORS.register_module(
'extract_qwq_answer_before_eval_for_husimpleqa')
def extract_answer_before_evaluation(text: str):
"""The format of the answer from QwQ when inferring HuSimpleQA is \
different with others models due to the special prompt."""
max_sentence_len = 6
text_split = text.split('\n\n')
last_try_idx = max(len(text_split) - max_sentence_len, 0)
ans_start_idx = last_try_idx
has_answer = False
has_score = False
score_flags = [
'score', 'Score', 'confidence', 'Confidence', 'szcore', 'Szcore',
'pontszám', 'Pontszám', 'Biztonság', 'biztonság', 'Biztoságskor',
'biztoságskor', 'Biztoság', '信心', '分数'
]
answer_flags = ['answer', 'Answer', 'Válasz', 'válasz', '答案', '回答']
for idx, s in enumerate(reversed(text_split)):
sen_idx = len(text_split) - 1 - idx
if sen_idx < last_try_idx:
break
for sf in score_flags:
if sf in s:
has_score = True
break
for af in answer_flags:
if af in s:
has_answer = True
break
if has_answer and has_score:
ans_start_idx = sen_idx
break
answer = '\n\n'.join(text_split[max(ans_start_idx - 1, 0):])
return answer
@TEXT_POSTPROCESSORS.register_module(
'extract_qwq_answer_before_eval_for_humatchingfib')
def extract_answer_before_evaluation(text: str):
"""The format of the answer from QwQ when inferring HuSimpleQA is \
different with others models due to the special prompt."""
max_sentence_len = 30
if len(re.findall(r'\n\n', text)) > 2:
split_mark = '\n\n'
else:
split_mark = '\n'
text_split = text.split(split_mark)
last_try_idx = max(len(text_split) - max_sentence_len, 0)
ans_start_idx = last_try_idx
has_answer = False
answer_flags = [
'answer', 'Answer', 'Válasz', 'válasz', '答案', '回答', 'Summar', 'summar',
'Summar', 'summar'
]
for idx, s in enumerate(reversed(text_split)):
sen_idx = len(text_split) - 1 - idx
if sen_idx < last_try_idx:
break
for af in answer_flags:
if af in s:
has_answer = True
break
if has_answer:
ans_start_idx = sen_idx
break
answer = '\n\n'.join(text_split[max(ans_start_idx - 1, 0):])
return answer
@TEXT_POSTPROCESSORS.register_module(
'extract_qwq_answer_before_eval_for_hustandardfib')
def extract_answer_before_evaluation(text: str):
"""The format of the answer from QwQ when inferring HuSimpleQA is \
different with others models due to the special prompt."""
max_sentence_len = 70
if len(re.findall(r'\n\n', text)) > 2:
split_mark = '\n\n'
else:
split_mark = '\n'
text_split = text.split(split_mark)
last_try_idx = max(len(text_split) - max_sentence_len, 0)
ans_start_idx = last_try_idx
has_answer = False
answer_flags = [
'#0#', '#0', 'summar', 'Final Answer', 'final answer', 'Final\nAnswer'
]
for idx, s in enumerate(reversed(text_split)):
sen_idx = len(text_split) - 1 - idx
if sen_idx < last_try_idx:
break
for af in answer_flags:
if af in s:
has_answer = True
break
if has_answer:
ans_start_idx = sen_idx
break
answer = '\n\n'.join(text_split[max(ans_start_idx - 1, 0):])
return answer