OpenCompass/opencompass/datasets/TheoremQA/utils.py
Linchen Xiao 455bb05d1b
[Update] Update dataset configs (#2030)
* [Update] Update dataset configs

* Fix lint
2025-04-21 18:55:06 +08:00

116 lines
3.7 KiB
Python

import re
from .number_utils import clean_units, compare_two_numbers, compare_two_list, number_it
import contextlib
import signal
@contextlib.contextmanager
def time_limit(seconds: float):
def signal_handler(signum, frame):
raise ValueError
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
from latex2sympy2_extended import latex2sympy
if any([option in pred.lower() for option in ['yes', 'true']]):
pred = 'True'
elif any([option in pred.lower() for option in ['no', 'false']]):
pred = 'False'
elif any([option in pred.lower() for option in ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']]):
pass
else:
if answer_flag:
# Extract the numbers out of the string
pred = pred.split('=')[-1].strip()
pred = clean_units(pred)
try:
with time_limit(1):
tmp = str(latex2sympy(pred))
pred = eval(tmp)
if isinstance(pred, tuple):
pred = str(list(pred))
else:
pred = str(pred)
except Exception:
if re.match(r'-?[\d\.]+\s\D+$', pred):
pred = pred.split(' ')[0]
elif re.match(r'-?[\d\.]+\s[^\s]+$', pred):
pred = pred.split(' ')[0]
else:
# desparate search over the last number
preds = re.findall(r'-?\d*\.?\d+', pred)
if(len(preds) >= 1):
pred = preds[-1]
else:
pred = ''
return pred
def answer_clean(direct_answer_trigger_for_fewshot: tuple, pred: str):
pred = pred.strip('\n')
# Determine if this is ICL, if so, use \n\n to split the first chunk.
ICL = False
for trigger in direct_answer_trigger_for_fewshot:
if pred.count(trigger) > 1:
ICL = True
if ICL:
pred = pred.split('\n\n')[0]
# Split the trigger to find the answer.
preds = re.split('|'.join(direct_answer_trigger_for_fewshot), pred)
if len(preds) > 1:
answer_flag = True
pred = preds[-1]
else:
answer_flag = False
pred = pred.strip('\n').rstrip('.').rstrip('/').strip(' ')
pred = [extract_theoremqa_answer(pred, answer_flag)]
# If there is no candidate in list, null is set.
if len(pred) == 0:
pred = ""
else:
if answer_flag:
# choose the first element in list ...
pred = pred[0]
else:
# choose the last e
pred = pred[-1]
# Remove the period at the end, again!
pred = pred.rstrip('.').rstrip('/')
return pred
def compare_answer_with_groundtruth(answer: str, groundtruth_str: str, groundtruth_num = None):
if groundtruth_str.lower() in ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']:
return groundtruth_str.lower() in answer.lower()
elif answer.lower() == groundtruth_str.lower():
return True
elif groundtruth_num is not None:
if isinstance(groundtruth_num, (int, float)):
return compare_two_numbers(number_it(answer), groundtruth_num)
else:
if answer.startswith('(') and answer.endswith(')'):
try:
answer = list(eval(answer))
answer = [number_it(a) for a in answer]
except Exception as e:
return False
return compare_two_list(answer, groundtruth_num)
else:
return False
else:
return False