OpenCompass/opencompass/utils/text_postprocessors.py

204 lines
6.7 KiB
Python
Raw Normal View History

2023-07-04 21:34:55 +08:00
import re
from typing import Callable, Optional, Union
2023-07-04 21:34:55 +08:00
from opencompass.registry import TEXT_POSTPROCESSORS
@TEXT_POSTPROCESSORS.register_module('general')
def general_postprocess(text: str) -> str:
# Cut off the first newline, period, or comma
truncated_text = re.split(r'[\n.,]', text, 1)[0]
# Remove punctuation
no_punctuation = re.sub(r'[^\w\s]', '', truncated_text)
# Remove article
no_articles = re.sub(r'\b(a|an|the)\b',
'',
no_punctuation,
flags=re.IGNORECASE)
# Remove duplicated blank spaces
cleaned_text = re.sub(r'\s+', ' ', no_articles).strip()
return cleaned_text
@TEXT_POSTPROCESSORS.register_module('general_cn')
def general_cn_postprocess(text: str) -> str:
truncated_text = re.split(r'[\n.,]', text, 1)[0]
no_punctuation = re.sub(r'[^\w\s]', '', truncated_text)
no_articles = re.sub(r'\b(a|an|the)\b',
'',
no_punctuation,
flags=re.IGNORECASE)
cleaned_text = re.sub(r'\s+', ' ', no_articles).strip()
import jieba
cleaned_text = ' '.join(jieba.cut(text))
return cleaned_text
@TEXT_POSTPROCESSORS.register_module('first-capital')
def first_capital_postprocess(text: str) -> str:
for t in text:
if t.isupper():
return t
return ''
2023-12-11 17:42:53 +08:00
@TEXT_POSTPROCESSORS.register_module('last-capital')
def last_capital_postprocess(text: str) -> str:
for t in text[::-1]:
if t.isupper():
return t
return ''
def first_option_postprocess(text: str, options: str, cushion=True) -> str:
"""Find first valid option for text."""
2023-11-27 16:06:49 +08:00
# yapf: disable
# flake8: noqa: W605
patterns = [
2024-05-28 23:09:59 +08:00
f'答案是?\s*([{options}])',
f'答案是?\s*\s*([{options}])',
f'答案是?\s*:\s*([{options}])',
f'答案选项应?该?是\s*([{options}])',
f'答案选项应?该?为\s*([{options}])',
2024-05-28 23:09:59 +08:00
f'答案应该?是\s*([{options}])',
f'答案应该?选\s*([{options}])',
f'答案选项为?\s*\s*([{options}])',
f'答案选项为?\s+\(?\*?\*?([{options}])\*?\*?\)?',
f'答案选项是?\s*:\s*([{options}])',
2024-05-28 23:09:59 +08:00
f'答案为\s*([{options}])',
f'答案选\s*([{options}])',
f'选择?\s*([{options}])',
f'故选?\s*([{options}])'
2023-11-27 16:06:49 +08:00
f'只有选?项?\s?([{options}])\s?是?对',
f'只有选?项?\s?([{options}])\s?是?错',
f'只有选?项?\s?([{options}])\s?不?正确',
f'只有选?项?\s?([{options}])\s?错误',
f'说法不?对选?项?的?是\s?([{options}])',
f'说法不?正确选?项?的?是\s?([{options}])',
f'说法错误选?项?的?是\s?([{options}])',
f'([{options}])\s?是正确的',
f'([{options}])\s?是正确答案',
f'选项\s?([{options}])\s?正确',
f'所以答\s?([{options}])',
f'所以\s?([{options}][.。$]?$)',
f'所有\s?([{options}][.。$]?$)',
f'[\s:,]([{options}])[。,,\.]?$',
f'[\s,:][故即]([{options}])[。\.]?$',
f'[\s,:]因此([{options}])[。\.]?$',
f'[是为。]\s?([{options}])[。\.]?$',
f'因此\s?([{options}])[。\.]?$',
f'显然\s?([{options}])[。\.]?$',
f'答案是\s?(\S+)(?:。|$)',
f'答案应该是\s?(\S+)(?:。|$)',
f'答案为\s?(\S+)(?:。|$)',
f'(?i)ANSWER\s*:\s*([{options}])',
2024-05-14 22:42:23 +08:00
f'[Tt]he answer is:?\s+\(?([{options}])\)?',
f'[Tt]he answer is:?\s+\(?\*?\*?([{options}])\*?\*?\)?',
2024-05-14 22:42:23 +08:00
f'[Tt]he answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is:?.*?boxed{{([{options}])}}',
f'[Tt]he correct option is:?.*?boxed{{([{options}])}}',
f'[Tt]he correct answer option is:?.*?boxed{{([{options}])}}',
2024-05-14 22:42:23 +08:00
f'[Tt]he answer to the question is:?\s+\(?([{options}])\)?',
f'^选项\s?([{options}])',
f'^([{options}])\s?选?项',
f'(\s|^)[{options}][\s。,:\.$]',
f'1.\s?(.*?)$',
f'1.\s?([{options}])[.。$]?$',
]
cushion_patterns = [
2023-11-27 16:06:49 +08:00
f'([{options}]):',
f'([{options}])',
]
2023-11-27 16:06:49 +08:00
# flake8: noqa
# yapf: enable
if cushion:
patterns.extend(cushion_patterns)
for pattern in patterns:
text = text.strip()
2024-05-14 22:42:23 +08:00
match = re.search(pattern, text, re.DOTALL)
if match:
if match.group(1) is not None and match.group(1) != '':
outputs = match.group(1)
else:
outputs = match.group(0)
for i in options:
if i in outputs:
return i
return ''
2023-07-04 21:34:55 +08:00
@TEXT_POSTPROCESSORS.register_module('first-capital-multi')
def first_capital_postprocess_multi(text: str) -> str:
match = re.search(r'([A-D]+)', text)
if match:
return match.group(1)
return ''
def last_option_postprocess(text: str, options: str) -> str:
match = re.findall(rf'([{options}])', text)
if match:
return match[-1]
return ''
def first_number_postprocess(text: str) -> float:
"""Return the first number in a string."""
# regex pattern to match numbers (both integers and decimals)
pattern = r'(-?\d*\.?\d+)'
# search the string for the pattern
match = re.search(pattern, text)
# if a match is found, return it. Otherwise, return None.
return float(match.group(1)) if match else None
2023-11-13 00:09:05 +08:00
@TEXT_POSTPROCESSORS.register_module('multiple-select')
def multiple_select_postprocess(text: str) -> str:
ret = set([t for t in text if t.isupper()])
return ''.join(sorted(ret))
def general_eval_wrapper_postprocess(text: str,
postprocess: Optional[Union[
str, Callable]] = None,
**kwargs) -> str:
"""Wrapper for eval text repr. Especially for chatglmpro.
Args:
text(str): Text to be postprocessed.
postprocess(Callable, optional): Original post processing function.
Defaults to None.
**kwargs: Other necessary kwargs for post processing function.
"""
try:
text = eval(text)
except Exception:
# in case empty input or other error, skip eval
pass
if postprocess:
if isinstance(postprocess, str):
postprocess = TEXT_POSTPROCESSORS.get(postprocess)
return postprocess(text, **kwargs)
else:
return text
def match_answer_pattern(response_text: str, answer_pattern: str):
match = re.search(answer_pattern, response_text)
extracted_answer = match.group(1) if match else ''
return extracted_answer