mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* update medbench * medbench update * format medbench * format * Update * update * update * update suffix --------- Co-authored-by: 施晓明 <PJLAB\shixiaoming@pjnl104220118l.pjlab.org> Co-authored-by: Leymore <zfz-960727@163.com>
201 lines
6.9 KiB
Python
201 lines
6.9 KiB
Python
# flake8: noqa
|
||
import json
|
||
import re
|
||
|
||
from . import dataset_loader
|
||
|
||
|
||
def extract_last_line(string):
|
||
lines = string.split('\n')
|
||
for item in lines[::-1]:
|
||
if item.strip() != '':
|
||
string = item
|
||
break
|
||
return string
|
||
|
||
|
||
def remove_few_shot_prefix(string: str):
|
||
prefix_list = ['The answer is therefore', '答案是']
|
||
for prefix in prefix_list:
|
||
if string.startswith(prefix):
|
||
string = string[len(prefix):].strip()
|
||
elif prefix in string:
|
||
index = string.rfind(prefix)
|
||
if index >= 0:
|
||
string = string[index + len(prefix):].strip()
|
||
return string
|
||
|
||
|
||
def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'):
|
||
if setting_name == 'few-shot-CoT':
|
||
string = extract_last_line(string)
|
||
if language == 'en':
|
||
pattern = 'answer is .*?([A-G])'
|
||
match = re.search(pattern, string)
|
||
elif language == 'zh':
|
||
pattern = '答案是.*?([A-G])'
|
||
match = re.search(pattern, string)
|
||
else:
|
||
raise ValueError('Unknown language {0}'.format(language))
|
||
if match:
|
||
return match.group(1)
|
||
else:
|
||
return None
|
||
|
||
|
||
def try_parse_few_shot_pattern(string: str, dataset_name, setting_name):
|
||
if setting_name == 'few-shot-CoT':
|
||
string = extract_last_line(string)
|
||
if dataset_name in dataset_loader.chinese_cloze_datasets:
|
||
return string.startswith('答案是')
|
||
elif dataset_name in dataset_loader.english_cloze_datasets:
|
||
return string.startswith('The answer is therefore')
|
||
elif dataset_name in dataset_loader.chinese_qa_datasets:
|
||
pattern = '答案是.*?([A-G])'
|
||
match = re.search(pattern, string)
|
||
return match is not None
|
||
elif dataset_name in dataset_loader.english_qa_datasets:
|
||
pattern = 'answer is .*?([A-G])'
|
||
match = re.search(pattern, string)
|
||
return match is not None
|
||
return False
|
||
|
||
|
||
def parse_few_shot_qa_single_answer(string, setting_name, language='en'):
|
||
answer = try_parse_few_shot_qa_single_answer(string, setting_name,
|
||
language)
|
||
if answer is None:
|
||
return find_first_capital_letter(string)
|
||
else:
|
||
return answer
|
||
|
||
|
||
def find_first_capital_letter(answer):
|
||
letter_set = {'A', 'B', 'C', 'D', 'E', 'F'}
|
||
for c in answer:
|
||
if c in letter_set:
|
||
return c
|
||
# print("Can't find capital letter in:", answer)
|
||
return ''
|
||
|
||
|
||
def extract_answer_in_bracket(answer, prefix='【', suffix='】'):
|
||
if prefix not in answer and suffix not in answer:
|
||
# print("doesn't found special tokens in:", answer)
|
||
return ''
|
||
s = answer.index(prefix) + len(prefix)
|
||
t = answer.index(suffix)
|
||
ret = answer[s:t]
|
||
return ret
|
||
|
||
|
||
def parse_math_answer(setting_name, raw_string):
|
||
if setting_name == 'few-shot-CoT':
|
||
raw_string = extract_last_line(raw_string)
|
||
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
|
||
raw_string = remove_few_shot_prefix(raw_string)
|
||
return raw_string
|
||
|
||
def remove_boxed(s):
|
||
left = '\\boxed{'
|
||
try:
|
||
assert s[:len(left)] == left
|
||
assert s[-1] == '}'
|
||
answer = s[len(left):-1]
|
||
if '=' in answer:
|
||
answer = answer.split('=')[-1].lstrip(' ')
|
||
return answer
|
||
except:
|
||
return None
|
||
|
||
def last_boxed_only_string(string):
|
||
idx = string.rfind('\\boxed')
|
||
if idx < 0:
|
||
idx = string.rfind('\\fbox')
|
||
if idx < 0:
|
||
return None
|
||
i = idx
|
||
right_brace_idx = None
|
||
num_left_braces_open = 0
|
||
while i < len(string):
|
||
if string[i] == '{':
|
||
num_left_braces_open += 1
|
||
if string[i] == '}':
|
||
num_left_braces_open -= 1
|
||
if num_left_braces_open == 0:
|
||
right_brace_idx = i
|
||
break
|
||
i += 1
|
||
|
||
if right_brace_idx == None:
|
||
retval = None
|
||
else:
|
||
retval = string[idx:right_brace_idx + 1]
|
||
|
||
return retval
|
||
|
||
def get_answer_with_dollar_sign(s):
|
||
first_pattern = '\$(.*)\$'
|
||
last_match = None
|
||
matches = re.findall(first_pattern, s)
|
||
if matches:
|
||
last_match = matches[-1]
|
||
if '=' in last_match:
|
||
last_match = last_match.split('=')[-1].lstrip(' ')
|
||
return last_match
|
||
|
||
def get_answer_without_dollar_sign(s):
|
||
last_match = None
|
||
if '=' in s:
|
||
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
|
||
if '\n' in last_match:
|
||
last_match = last_match.split('\n')[0]
|
||
else:
|
||
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
|
||
matches = re.findall(pattern, s)
|
||
if matches:
|
||
last_match = matches[-1]
|
||
return last_match
|
||
|
||
raw_string = remove_few_shot_prefix(raw_string)
|
||
if '\\boxed' in raw_string:
|
||
answer = remove_boxed(last_boxed_only_string(raw_string))
|
||
else:
|
||
answer = get_answer_with_dollar_sign(raw_string)
|
||
if not answer:
|
||
answer = get_answer_without_dollar_sign(raw_string)
|
||
return answer
|
||
|
||
|
||
def parse_qa_multiple_answer(string):
|
||
# if setting_name == 'few-shot-CoT':
|
||
# string = extract_last_line(string)
|
||
for x in ['CC', 'CA', 'AC', 'POMES', 'AI', 'MIBG', 'CF', 'CTE', 'AD', 'CB', 'BG', 'BD', 'BE', 'BH', 'CTB', 'BI', 'CE', 'Pugh', 'Child', 'CTI', 'CTA', 'TACE', 'PPD', 'Castleman', 'BA', 'CH', 'AB', 'CTC', 'CT', 'CTH', 'CD', 'AH', 'AE', 'AA', 'AF', 'BC', 'CG', 'BB', 'CI', 'BF', 'CTF', 'CTG', 'AG', 'CTD', '分级C', '分级A', 'I131', '分级B', '分级D', '131I‐MIBG', 'NYHA', 'IPF', 'DIP', 'Lambert-Eaton', 'Graves', 'IIA期', 'CKD', 'FDA', 'A级', 'B级', 'C级', 'D级', '维生素D']:
|
||
string = string.replace(x, '')
|
||
pattern = '\(*([A-Z])\)*'
|
||
match = re.findall(pattern, string)
|
||
if match:
|
||
return match
|
||
return []
|
||
|
||
|
||
def post_process(dataset_name, setting_name, prediction):
|
||
if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets:
|
||
return parse_math_answer(setting_name, prediction)
|
||
|
||
if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']:
|
||
return parse_qa_multiple_answer(prediction, setting_name)
|
||
|
||
# all other datasets are QA problems with single answer
|
||
if 'zero-shot' in setting_name:
|
||
answer = find_first_capital_letter(prediction)
|
||
return answer
|
||
|
||
# all other datasets are QA problems with single answer and setting_name are few-shot
|
||
language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh'
|
||
if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets:
|
||
return parse_few_shot_qa_single_answer(prediction, setting_name,
|
||
language)
|
||
else:
|
||
raise ValueError(f'Unsupported dataset name {dataset_name}')
|