2024-04-22 15:22:04 +08:00
|
|
|
import re
|
|
|
|
from .number_utils import clean_units, compare_two_numbers, compare_two_list, number_it
|
|
|
|
import contextlib
|
|
|
|
import signal
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def time_limit(seconds: float):
|
|
|
|
def signal_handler(signum, frame):
|
|
|
|
raise ValueError
|
|
|
|
|
|
|
|
signal.setitimer(signal.ITIMER_REAL, seconds)
|
|
|
|
signal.signal(signal.SIGALRM, signal_handler)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
signal.setitimer(signal.ITIMER_REAL, 0)
|
|
|
|
|
|
|
|
|
|
|
|
def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
|
2025-04-02 12:03:45 +08:00
|
|
|
from latex2sympy2_extended import latex2sympy
|
2024-04-22 15:22:04 +08:00
|
|
|
|
|
|
|
if any([option in pred.lower() for option in ['yes', 'true']]):
|
|
|
|
pred = 'True'
|
|
|
|
elif any([option in pred.lower() for option in ['no', 'false']]):
|
|
|
|
pred = 'False'
|
|
|
|
elif any([option in pred.lower() for option in ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']]):
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
if answer_flag:
|
|
|
|
# Extract the numbers out of the string
|
|
|
|
pred = pred.split('=')[-1].strip()
|
|
|
|
pred = clean_units(pred)
|
|
|
|
try:
|
|
|
|
with time_limit(1):
|
|
|
|
tmp = str(latex2sympy(pred))
|
2025-04-21 18:55:06 +08:00
|
|
|
pred = eval(tmp)
|
|
|
|
if isinstance(pred, tuple):
|
|
|
|
pred = str(list(pred))
|
|
|
|
else:
|
|
|
|
pred = str(pred)
|
|
|
|
|
2024-04-22 15:22:04 +08:00
|
|
|
except Exception:
|
|
|
|
if re.match(r'-?[\d\.]+\s\D+$', pred):
|
|
|
|
pred = pred.split(' ')[0]
|
|
|
|
elif re.match(r'-?[\d\.]+\s[^\s]+$', pred):
|
|
|
|
pred = pred.split(' ')[0]
|
|
|
|
else:
|
|
|
|
# desparate search over the last number
|
|
|
|
preds = re.findall(r'-?\d*\.?\d+', pred)
|
|
|
|
if(len(preds) >= 1):
|
|
|
|
pred = preds[-1]
|
|
|
|
else:
|
|
|
|
pred = ''
|
|
|
|
return pred
|
|
|
|
|
|
|
|
def answer_clean(direct_answer_trigger_for_fewshot: tuple, pred: str):
|
|
|
|
pred = pred.strip('\n')
|
|
|
|
|
|
|
|
# Determine if this is ICL, if so, use \n\n to split the first chunk.
|
|
|
|
ICL = False
|
|
|
|
for trigger in direct_answer_trigger_for_fewshot:
|
|
|
|
if pred.count(trigger) > 1:
|
|
|
|
ICL = True
|
|
|
|
if ICL:
|
|
|
|
pred = pred.split('\n\n')[0]
|
|
|
|
|
|
|
|
# Split the trigger to find the answer.
|
|
|
|
preds = re.split('|'.join(direct_answer_trigger_for_fewshot), pred)
|
|
|
|
if len(preds) > 1:
|
|
|
|
answer_flag = True
|
|
|
|
pred = preds[-1]
|
|
|
|
else:
|
|
|
|
answer_flag = False
|
|
|
|
|
|
|
|
pred = pred.strip('\n').rstrip('.').rstrip('/').strip(' ')
|
|
|
|
|
|
|
|
pred = [extract_theoremqa_answer(pred, answer_flag)]
|
|
|
|
|
|
|
|
# If there is no candidate in list, null is set.
|
|
|
|
if len(pred) == 0:
|
|
|
|
pred = ""
|
|
|
|
else:
|
|
|
|
if answer_flag:
|
|
|
|
# choose the first element in list ...
|
|
|
|
pred = pred[0]
|
|
|
|
else:
|
|
|
|
# choose the last e
|
|
|
|
pred = pred[-1]
|
|
|
|
|
|
|
|
# Remove the period at the end, again!
|
|
|
|
pred = pred.rstrip('.').rstrip('/')
|
|
|
|
return pred
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compare_answer_with_groundtruth(answer: str, groundtruth_str: str, groundtruth_num = None):
|
|
|
|
if groundtruth_str.lower() in ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']:
|
|
|
|
return groundtruth_str.lower() in answer.lower()
|
|
|
|
elif answer.lower() == groundtruth_str.lower():
|
|
|
|
return True
|
|
|
|
elif groundtruth_num is not None:
|
|
|
|
if isinstance(groundtruth_num, (int, float)):
|
|
|
|
return compare_two_numbers(number_it(answer), groundtruth_num)
|
|
|
|
else:
|
|
|
|
if answer.startswith('(') and answer.endswith(')'):
|
|
|
|
try:
|
|
|
|
answer = list(eval(answer))
|
|
|
|
answer = [number_it(a) for a in answer]
|
|
|
|
except Exception as e:
|
|
|
|
return False
|
|
|
|
return compare_two_list(answer, groundtruth_num)
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return False
|