OpenCompass/opencompass/utils/text_postprocessors.py
2024-10-29 18:37:15 +08:00

201 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from typing import Callable, Optional, Union
from opencompass.registry import TEXT_POSTPROCESSORS
@TEXT_POSTPROCESSORS.register_module('general')
def general_postprocess(text: str) -> str:
# Cut off the first newline, period, or comma
truncated_text = re.split(r'[\n.,]', text, 1)[0]
# Remove punctuation
no_punctuation = re.sub(r'[^\w\s]', '', truncated_text)
# Remove article
no_articles = re.sub(r'\b(a|an|the)\b',
'',
no_punctuation,
flags=re.IGNORECASE)
# Remove duplicated blank spaces
cleaned_text = re.sub(r'\s+', ' ', no_articles).strip()
return cleaned_text
@TEXT_POSTPROCESSORS.register_module('general_cn')
def general_cn_postprocess(text: str) -> str:
truncated_text = re.split(r'[\n.,]', text, 1)[0]
no_punctuation = re.sub(r'[^\w\s]', '', truncated_text)
no_articles = re.sub(r'\b(a|an|the)\b',
'',
no_punctuation,
flags=re.IGNORECASE)
cleaned_text = re.sub(r'\s+', ' ', no_articles).strip()
import jieba
cleaned_text = ' '.join(jieba.cut(text))
return cleaned_text
@TEXT_POSTPROCESSORS.register_module('first-capital')
def first_capital_postprocess(text: str) -> str:
for t in text:
if t.isupper():
return t
return ''
@TEXT_POSTPROCESSORS.register_module('last-capital')
def last_capital_postprocess(text: str) -> str:
for t in text[::-1]:
if t.isupper():
return t
return ''
def first_option_postprocess(text: str, options: str, cushion=True) -> str:
"""Find first valid option for text."""
# yapf: disable
# flake8: noqa: W605
patterns = [
f'答案是?\s*([{options}])',
f'答案是?\s*\s*([{options}])',
f'答案是?\s*:\s*([{options}])',
f'答案选项应?该?是\s*([{options}])',
f'答案选项应?该?为\s*([{options}])',
f'答案应该?是\s*([{options}])',
f'答案应该?选\s*([{options}])',
f'答案选项为?\s*\s*([{options}])',
f'答案选项为?\s+\(?\*?\*?([{options}])\*?\*?\)?',
f'答案选项是?\s*:\s*([{options}])',
f'答案为\s*([{options}])',
f'答案选\s*([{options}])',
f'选择?\s*([{options}])',
f'故选?\s*([{options}])'
f'只有选?项?\s?([{options}])\s?是?对',
f'只有选?项?\s?([{options}])\s?是?错',
f'只有选?项?\s?([{options}])\s?不?正确',
f'只有选?项?\s?([{options}])\s?错误',
f'说法不?对选?项?的?是\s?([{options}])',
f'说法不?正确选?项?的?是\s?([{options}])',
f'说法错误选?项?的?是\s?([{options}])',
f'([{options}])\s?是正确的',
f'([{options}])\s?是正确答案',
f'选项\s?([{options}])\s?正确',
f'所以答\s?([{options}])',
f'所以\s?([{options}][.。$]?$)',
f'所有\s?([{options}][.。$]?$)',
f'[\s:,]([{options}])[。,,\.]?$',
f'[\s,:][故即]([{options}])[。\.]?$',
f'[\s,:]因此([{options}])[。\.]?$',
f'[是为。]\s?([{options}])[。\.]?$',
f'因此\s?([{options}])[。\.]?$',
f'显然\s?([{options}])[。\.]?$',
f'答案是\s?(\S+)(?:。|$)',
f'答案应该是\s?(\S+)(?:。|$)',
f'答案为\s?(\S+)(?:。|$)',
f'(?i)ANSWER\s*:\s*([{options}])',
f'[Tt]he answer is:?\s+\(?([{options}])\)?',
f'[Tt]he answer is:?\s+\(?\*?\*?([{options}])\*?\*?\)?',
f'[Tt]he answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is option:?\s+\(?([{options}])\)?',
f'[Tt]he correct answer is:?.*?boxed{{([{options}])}}',
f'[Tt]he correct option is:?.*?boxed{{([{options}])}}',
f'[Tt]he correct answer option is:?.*?boxed{{([{options}])}}',
f'[Tt]he answer to the question is:?\s+\(?([{options}])\)?',
f'^选项\s?([{options}])',
f'^([{options}])\s?选?项',
f'(\s|^)[{options}][\s。,:\.$]',
f'1.\s?(.*?)$',
f'1.\s?([{options}])[.。$]?$',
]
cushion_patterns = [
f'([{options}]):',
f'([{options}])',
]
# flake8: noqa
# yapf: enable
if cushion:
patterns.extend(cushion_patterns)
for pattern in patterns:
text = text.strip()
match = re.search(pattern, text, re.DOTALL)
if match:
outputs = match.group(0)
for i in options:
if i in outputs:
return i
return ''
@TEXT_POSTPROCESSORS.register_module('first-capital-multi')
def first_capital_postprocess_multi(text: str) -> str:
match = re.search(r'([A-D]+)', text)
if match:
return match.group(1)
return ''
def last_option_postprocess(text: str, options: str) -> str:
match = re.findall(rf'([{options}])', text)
if match:
return match[-1]
return ''
def first_number_postprocess(text: str) -> float:
"""Return the first number in a string."""
# regex pattern to match numbers (both integers and decimals)
pattern = r'(-?\d*\.?\d+)'
# search the string for the pattern
match = re.search(pattern, text)
# if a match is found, return it. Otherwise, return None.
return float(match.group(1)) if match else None
@TEXT_POSTPROCESSORS.register_module('multiple-select')
def multiple_select_postprocess(text: str) -> str:
ret = set([t for t in text if t.isupper()])
return ''.join(sorted(ret))
def general_eval_wrapper_postprocess(text: str,
postprocess: Optional[Union[
str, Callable]] = None,
**kwargs) -> str:
"""Wrapper for eval text repr. Especially for chatglmpro.
Args:
text(str): Text to be postprocessed.
postprocess(Callable, optional): Original post processing function.
Defaults to None.
**kwargs: Other necessary kwargs for post processing function.
"""
try:
text = eval(text)
except Exception:
# in case empty input or other error, skip eval
pass
if postprocess:
if isinstance(postprocess, str):
postprocess = TEXT_POSTPROCESSORS.get(postprocess)
return postprocess(text, **kwargs)
else:
return text
def match_answer_pattern(response_text: str, answer_pattern: str):
match = re.search(answer_pattern, response_text)
extracted_answer = match.group(1) if match else ''
return extracted_answer