mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
![]() |
# flake8: noqa: E501
|
||
|
import json
|
||
|
import re
|
||
|
|
||
|
from .common_answers import (add_quotes_to_unquoted, change_quotation,
|
||
|
is_numeric)
|
||
|
|
||
|
|
||
|
def get_gt_label(item):
|
||
|
return item['gt_answer'].strip().lower()
|
||
|
|
||
|
|
||
|
def extract_answer(model_response, item, prompt_style, type):
|
||
|
model_response += '}'
|
||
|
if 'CoT' in prompt_style and any(
|
||
|
match in type for match in ['NIE', 'NDE', 'ETT', 'CDE', 'ATE']):
|
||
|
matches = re.findall(r'\{\"answer\":.*?\}', model_response, re.DOTALL)
|
||
|
else:
|
||
|
matches = re.findall(r'\{+.*?\}+', model_response,
|
||
|
re.DOTALL | re.IGNORECASE)
|
||
|
matched_str = None
|
||
|
for match in matches:
|
||
|
if match:
|
||
|
matched_str = match.lower()
|
||
|
if matched_str.startswith('{{') and matched_str.endswith('}}}'):
|
||
|
matched_str = matched_str[1:-2]
|
||
|
elif matched_str.startswith('{{') and matched_str.endswith('}}'):
|
||
|
matched_str = matched_str[1:-1]
|
||
|
elif matched_str.startswith('{{') and matched_str.endswith('}'):
|
||
|
matched_str = matched_str[1:]
|
||
|
elif matched_str.startswith('{') and matched_str.endswith('}}'):
|
||
|
matched_str = matched_str[:-1]
|
||
|
else:
|
||
|
matched_str = None
|
||
|
|
||
|
if matched_str:
|
||
|
try:
|
||
|
inner_json_obj = json.loads(matched_str)
|
||
|
except json.JSONDecodeError:
|
||
|
# If parsing fails, try adding quotes to unquoted words and parse again
|
||
|
fixed_json_str = add_quotes_to_unquoted(matched_str)
|
||
|
fixed_json_str = change_quotation(fixed_json_str)
|
||
|
try:
|
||
|
inner_json_obj = json.loads(fixed_json_str)
|
||
|
except:
|
||
|
inner_json_obj = {}
|
||
|
|
||
|
prob_str_value = inner_json_obj.get('answer', None)
|
||
|
if prob_str_value is not None:
|
||
|
break
|
||
|
if matched_str is None:
|
||
|
prob_str_value = None
|
||
|
|
||
|
return prob_str_value
|
||
|
|
||
|
|
||
|
def get_pred_label(model_response, item, prompt_style, type):
|
||
|
model_response = model_response.strip().lower()
|
||
|
pred = extract_answer(model_response, item, prompt_style, type)
|
||
|
return pred
|