mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* Add ToT method * Update ToT * Update ToT * Update ToT * Update ToT * Update ToT * Update ToT * Update ToT * Update chain_of_thought.md * Update icl_tot_inferencer.py --------- Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
257 lines
6.3 KiB
Python
257 lines
6.3 KiB
Python
# Prompt and utils of Game24 dataset, edited from:
|
|
# https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py
|
|
import re
|
|
from typing import List
|
|
|
|
import pandas as pd
|
|
import sympy
|
|
from datasets import Dataset
|
|
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
|
|
from .base import BaseDataset
|
|
|
|
# 5-shot
|
|
standard_prompt = '''Use numbers and basic arithmetic operations \
|
|
(+ - * /) to obtain 24.
|
|
Input: 4 4 6 8
|
|
Answer: (4 + 8) * (6 - 4) = 24
|
|
Input: 2 9 10 12
|
|
Answer: 2 * 12 * (10 - 9) = 24
|
|
Input: 4 9 10 13
|
|
Answer: (13 - 9) * (10 - 4) = 24
|
|
Input: 1 4 8 8
|
|
Answer: (8 / 4 + 1) * 8 = 24
|
|
Input: 5 5 5 9
|
|
Answer: 5 + 5 + 5 + 9 = 24
|
|
Input: {input}
|
|
'''
|
|
|
|
# 5-shot
|
|
cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to \
|
|
obtain 24. Each step, you are only allowed to choose two of the remaining \
|
|
numbers to obtain a new number.
|
|
Input: 4 4 6 8
|
|
Steps:
|
|
4 + 8 = 12 (left: 4 6 12)
|
|
6 - 4 = 2 (left: 2 12)
|
|
2 * 12 = 24 (left: 24)
|
|
Answer: (6 - 4) * (4 + 8) = 24
|
|
Input: 2 9 10 12
|
|
Steps:
|
|
12 * 2 = 24 (left: 9 10 24)
|
|
10 - 9 = 1 (left: 1 24)
|
|
24 * 1 = 24 (left: 24)
|
|
Answer: (12 * 2) * (10 - 9) = 24
|
|
Input: 4 9 10 13
|
|
Steps:
|
|
13 - 10 = 3 (left: 3 4 9)
|
|
9 - 3 = 6 (left: 4 6)
|
|
4 * 6 = 24 (left: 24)
|
|
Answer: 4 * (9 - (13 - 10)) = 24
|
|
Input: 1 4 8 8
|
|
Steps:
|
|
8 / 4 = 2 (left: 1 2 8)
|
|
1 + 2 = 3 (left: 3 8)
|
|
3 * 8 = 24 (left: 24)
|
|
Answer: (1 + 8 / 4) * 8 = 24
|
|
Input: 5 5 5 9
|
|
Steps:
|
|
5 + 5 = 10 (left: 5 9 10)
|
|
10 + 5 = 15 (left: 9 15)
|
|
15 + 9 = 24 (left: 24)
|
|
Answer: ((5 + 5) + 5) + 9 = 24
|
|
Input: {input}
|
|
'''
|
|
|
|
# 1-shot
|
|
propose_prompt = '''Input: 2 8 8 14
|
|
Possible next steps:
|
|
2 + 8 = 10 (left: 8 10 14)
|
|
8 / 2 = 4 (left: 4 8 14)
|
|
14 + 2 = 16 (left: 8 8 16)
|
|
2 * 8 = 16 (left: 8 14 16)
|
|
8 - 2 = 6 (left: 6 8 14)
|
|
14 - 8 = 6 (left: 2 6 8)
|
|
14 / 2 = 7 (left: 7 8 8)
|
|
14 - 2 = 12 (left: 8 8 12)
|
|
Input: {input}
|
|
Possible next steps:
|
|
'''
|
|
|
|
value_prompt = '''Evaluate if given numbers can reach 24 \
|
|
(sure/likely/impossible)
|
|
10 14
|
|
10 + 14 = 24
|
|
sure
|
|
11 12
|
|
11 + 12 = 23
|
|
12 - 11 = 1
|
|
11 * 12 = 132
|
|
11 / 12 = 0.91
|
|
impossible
|
|
4 4 10
|
|
4 + 4 + 10 = 8 + 10 = 18
|
|
4 * 10 - 4 = 40 - 4 = 36
|
|
(10 - 4) * 4 = 6 * 4 = 24
|
|
sure
|
|
4 9 11
|
|
9 + 11 + 4 = 20 + 4 = 24
|
|
sure
|
|
5 7 8
|
|
5 + 7 + 8 = 12 + 8 = 20
|
|
(8 - 5) * 7 = 3 * 7 = 21
|
|
I cannot obtain 24 now, but numbers are within a reasonable range
|
|
likely
|
|
5 6 6
|
|
5 + 6 + 6 = 17
|
|
(6 - 5) * 6 = 1 * 6 = 6
|
|
I cannot obtain 24 now, but numbers are within a reasonable range
|
|
likely
|
|
10 10 11
|
|
10 + 10 + 11 = 31
|
|
(11 - 10) * 10 = 10
|
|
10 10 10 are all too big
|
|
impossible
|
|
1 3 3
|
|
1 * 3 * 3 = 9
|
|
(1 + 3) * 3 = 12
|
|
1 3 3 are all too small
|
|
impossible
|
|
{input}
|
|
'''
|
|
|
|
value_last_step_prompt = '''Use numbers and basic arithmetic operations \
|
|
(+ - * /) to obtain 24. Given an input and an answer, give a judgement \
|
|
(sure/impossible) if the answer is correct, i.e. it uses each input exactly \
|
|
once and no other numbers, and reach 24.
|
|
Input: 4 4 6 8
|
|
Answer: (4 + 8) * (6 - 4) = 24
|
|
Judge:
|
|
sure
|
|
Input: 2 9 10 12
|
|
Answer: 2 * 12 * (10 - 9) = 24
|
|
Judge:
|
|
sure
|
|
Input: 4 9 10 13
|
|
Answer: (13 - 9) * (10 - 4) = 24
|
|
Judge:
|
|
sure
|
|
Input: 4 4 6 8
|
|
Answer: (4 + 8) * (6 - 4) + 1 = 25
|
|
Judge:
|
|
impossible
|
|
Input: 2 9 10 12
|
|
Answer: 2 * (12 - 10) = 24
|
|
Judge:
|
|
impossible
|
|
Input: 4 9 10 13
|
|
Answer: (13 - 4) * (10 - 9) = 24
|
|
Judge:
|
|
impossible
|
|
Input: {input}
|
|
Answer: {answer}
|
|
Judge:'''
|
|
|
|
|
|
def get_current_numbers(y: str) -> str:
|
|
last_line = y.strip().split('\n')[-1]
|
|
return last_line.split('left: ')[-1].split(')')[0]
|
|
|
|
|
|
class Game24Dataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str):
|
|
data = list(pd.read_csv(path)['Puzzles'])
|
|
data = [{'input': i, 'output': i} for i in data]
|
|
return Dataset.from_list(data[900:905])
|
|
|
|
|
|
class Game24PromptWrapper:
|
|
"""Wrapper for Game24 prompts and outputs.
|
|
|
|
standard_prompt_wrap、cot_prompt_wrap、propose_prompt_wrap:
|
|
Get prompts for different sample method.
|
|
value_prompt_wrap:
|
|
Get prompts for value-based evaluation method.
|
|
value_outputs_unwrap:
|
|
Calculate total value score for value outputs.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.value_cache = {}
|
|
self.steps = 4
|
|
self.stops = ['\n'] * 4
|
|
|
|
@staticmethod
|
|
def standard_prompt_wrap(x: str, y: str = '') -> str:
|
|
return standard_prompt.format(input=x) + y
|
|
|
|
@staticmethod
|
|
def cot_prompt_wrap(x: str, y: str = '') -> str:
|
|
return cot_prompt.format(input=x) + y
|
|
|
|
@staticmethod
|
|
def propose_prompt_wrap(x: str, y: str = '') -> str:
|
|
current_numbers = get_current_numbers(y if y else x)
|
|
if current_numbers == '24':
|
|
prompt = cot_prompt.format(input=x) + 'Steps:' + y
|
|
else:
|
|
prompt = propose_prompt.format(input=current_numbers)
|
|
return prompt
|
|
|
|
@staticmethod
|
|
def value_prompt_wrap(x: str, y: str) -> str:
|
|
last_line = y.strip().split('\n')[-1]
|
|
if 'left: ' not in last_line: # last step
|
|
ans = last_line.lower().replace('answer: ', '')
|
|
return value_last_step_prompt.format(input=x, answer=ans)
|
|
current_numbers = get_current_numbers(y)
|
|
return value_prompt.format(input=current_numbers)
|
|
|
|
@staticmethod
|
|
def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
|
|
if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
|
|
return 0
|
|
value_names = [_.split('\n')[-1] for _ in value_outputs]
|
|
value_map = {
|
|
'impossible': 0.001,
|
|
'likely': 1,
|
|
'sure': 20
|
|
} # TODO: ad hoc
|
|
value = sum(value * value_names.count(name)
|
|
for name, value in value_map.items())
|
|
return value
|
|
|
|
|
|
def game24_postprocess(output: str):
|
|
expression = output.strip().split('\n')[-1].lower().replace(
|
|
'answer: ', '').split('=')[0]
|
|
return expression
|
|
|
|
|
|
class Game24Evaluator(BaseEvaluator):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def check_nums(self, prediction, reference):
|
|
numbers = re.findall(r'\d+', prediction)
|
|
problem_numbers = re.findall(r'\d+', reference)
|
|
if sorted(numbers) != sorted(problem_numbers):
|
|
return 0
|
|
try:
|
|
return int(sympy.simplify(prediction) == 24)
|
|
except Exception:
|
|
return 0
|
|
|
|
def score(self, predictions: List, references: List) -> dict:
|
|
n = len(predictions)
|
|
acc = 0
|
|
for prediction, reference in zip(predictions, references):
|
|
if self.check_nums(prediction, reference):
|
|
acc += 1
|
|
|
|
return {'acc score': acc / n}
|