mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Add Tree-of-Thought method (#173)
* Add ToT method * Update ToT * Update ToT * Update ToT * Update ToT * Update ToT * Update ToT * Update ToT * Update chain_of_thought.md * Update icl_tot_inferencer.py --------- Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
This commit is contained in:
parent
ff5ab92331
commit
02ce139bc6
4
configs/datasets/game24/game24_gen.py
Normal file
4
configs/datasets/game24/game24_gen.py
Normal file
@ -0,0 +1,4 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .game24_gen_8dfde3 import game24_datasets # noqa: F401, F403
|
34
configs/datasets/game24/game24_gen_8dfde3.py
Normal file
34
configs/datasets/game24/game24_gen_8dfde3.py
Normal file
@ -0,0 +1,34 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import ToTInferencer
|
||||
from opencompass.datasets import (Game24Dataset, game24_postprocess,
|
||||
Game24Evaluator, Game24PromptWrapper)
|
||||
|
||||
generation_kwargs = dict(do_sample=False, temperature=0.7)
|
||||
|
||||
game24_reader_cfg = dict(
|
||||
input_columns=['input'],
|
||||
output_column='output')
|
||||
|
||||
game24_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='{input}'),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=ToTInferencer, generation_kwargs=generation_kwargs, method_generate='propose',
|
||||
method_evaluate='value', method_select='greedy', n_evaluate_sample=3, n_select_sample=5, prompt_wrapper=dict(type=Game24PromptWrapper)))
|
||||
|
||||
game24_eval_cfg = dict(
|
||||
evaluator=dict(type=Game24Evaluator),
|
||||
pred_postprocessor=dict(type=game24_postprocess),
|
||||
)
|
||||
|
||||
game24_datasets = [
|
||||
dict(
|
||||
abbr='game24',
|
||||
type=Game24Dataset,
|
||||
path='./data/game24/game24.csv',
|
||||
reader_cfg=game24_reader_cfg,
|
||||
infer_cfg=game24_infer_cfg,
|
||||
eval_cfg=game24_eval_cfg)
|
||||
]
|
@ -4,6 +4,8 @@
|
||||
|
||||
During the process of reasoning, CoT (Chain of Thought) method is an efficient way to help LLMs deal complex questions, for example: math problem and relation inference. In OpenCompass, we support multiple types of CoT method.
|
||||
|
||||

|
||||
|
||||
## 1. Zero Shot CoT
|
||||
|
||||
You can change the `PromptTemplate` of the dataset config, by simply add *Let's think step by step* to realize a Zero-Shot CoT prompt for your evaluation:
|
||||
@ -73,3 +75,53 @@ Where `SAMPLE_SIZE` is the number of reasoning paths in Self-Consistency, higher
|
||||

|
||||
|
||||
From the figure, it can be seen that in different reasoning tasks, performance tends to improve as the number of reasoning paths increases. However, for some tasks, increasing the number of reasoning paths may reach a limit, and further increasing the number of paths may not bring significant performance improvement. Therefore, it is necessary to conduct experiments and adjustments on specific tasks to find the optimal number of reasoning paths that best suit the task.
|
||||
|
||||
## 4. Tree-of-Thoughts
|
||||
|
||||
In contrast to the conventional CoT approach that considers only a single reasoning path, Tree-of-Thoughts (ToT) allows the language model to explore multiple diverse reasoning paths simultaneously. The model evaluates the reasoning process through self-assessment and makes global choices by conducting lookahead or backtracking when necessary. Specifically, this process is divided into the following four stages:
|
||||
|
||||
**1. Thought Decomposition**
|
||||
|
||||
Based on the nature of the problem, break down the problem into multiple intermediate steps. Each step can be a phrase, equation, or writing plan, depending on the nature of the problem.
|
||||
|
||||
**2. Thought Generation**
|
||||
|
||||
Assuming that solving the problem requires k steps, there are two methods to generate reasoning content:
|
||||
|
||||
- Independent sampling: For each state, the model independently extracts k reasoning contents from the CoT prompts, without relying on other reasoning contents.
|
||||
- Sequential generation: Sequentially use "prompts" to guide the generation of reasoning content, where each reasoning content may depend on the previous one.
|
||||
|
||||
**3. Heuristic Evaluation**
|
||||
|
||||
Use heuristic methods to evaluate the contribution of each generated reasoning content to problem-solving. This self-evaluation is based on the model's self-feedback and involves designing prompts to have the model score multiple generated results.
|
||||
|
||||
**4. Search Algorithm Selection**
|
||||
|
||||
Based on the methods of generating and evaluating reasoning content, select an appropriate search algorithm. For example, you can use breadth-first search (BFS) or depth-first search (DFS) algorithms to systematically explore the thought tree, conducting lookahead and backtracking.
|
||||
|
||||
In OpenCompass, ToT parameters need to be set according to the requirements. Below is an example configuration for the 24-Point game from the [official paper](https://arxiv.org/pdf/2305.10601.pdf). Currently, ToT inference is supported only with Huggingface models:
|
||||
|
||||
```python
|
||||
# This ToT Game24 config can be found at: opencompass/configs/datasets/game24/game24_gen_8dfde3.py.
|
||||
from opencompass.datasets import (Game24Dataset, game24_postprocess,
|
||||
Game24Evaluator, Game24PromptWrapper)
|
||||
|
||||
generation_kwargs = dict(temperature=0.7)
|
||||
|
||||
game24_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='{input}'), # Directly pass the input content, as the Prompt needs to be specified in steps
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=ToTInferencer, # Replace GenInferencer with ToTInferencer
|
||||
generation_kwargs=generation_kwargs,
|
||||
method_generate='propose', # Method for generating reasoning content, can be independent sampling (sample) or sequential generation (propose)
|
||||
method_evaluate='value', # Method for evaluating reasoning content, can be voting (vote) or scoring (value)
|
||||
method_select='greedy', # Method for selecting reasoning content, can be greedy (greedy) or random (sample)
|
||||
n_evaluate_sample=3,
|
||||
n_select_sample=5,
|
||||
task_wrapper=dict(type=Game24PromptWrapper) # This Wrapper class includes the prompts for each step and methods for generating and evaluating reasoning content, needs customization according to the task
|
||||
))
|
||||
```
|
||||
|
||||
If you want to use the ToT method on a custom dataset, you'll need to make additional configurations in the `opencompass.datasets.YourDataConfig.py` file to set up the `YourDataPromptWrapper` class. This is required for handling the thought generation and heuristic evaluation step within the ToT framework. For reasoning tasks similar to the game 24-Point, you can refer to the implementation in `opencompass/datasets/game24.py` for guidance.
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
CoT(思维链)是帮助大型语言模型解决如数学问题和关系推理问题等复杂问题的有效方式,在OpenCompass中,我们支持多种类型的CoT方法。
|
||||
|
||||

|
||||
|
||||
## 1. 零样本思维链
|
||||
|
||||
可以通过在数据集配置中简单地添加 “Let's think step by step",来更改数据集配置的 PromptTemplate,从而实现 零样本 CoT prompt 以进行评估:
|
||||
@ -49,7 +51,7 @@ Question: {question}\nLet's think step by step:\n{answer}
|
||||
|
||||
## 3. Self-Consistency
|
||||
|
||||
SC (Self-Consistency) 方法是在 [此文章](https://arxiv.org/abs/2203.11171) 中提出的,该方法会为问题生成多个不同的推理路径,并对生成的答案进行众数投票。这种方法在复杂推理任务中表现出了显著的能力,但由于需要推理多次来采样多条推理链,所以可能会消耗很多的时间和资源。在 OpenCompass 中,您可以通过在数据集配置中将 `GenInferencer` 替换为 `SCInferencer` 并设置相应的参数参数来简单地实现 SC 方法,例如:
|
||||
SC (Self-Consistency) 方法是在 [此文章](https://arxiv.org/abs/2203.11171) 中提出的,该方法会为问题生成多条不同的推理路径,并对生成的答案进行众数投票。这种方法在复杂推理任务中表现出了显著的能力,但由于需要推理多次来采样多条推理链,所以可能会消耗很多的时间和资源。在 OpenCompass 中,您可以通过在数据集配置中将 `GenInferencer` 替换为 `SCInferencer` 并设置相应的参数参数来简单地实现 SC 方法,例如:
|
||||
|
||||
```python
|
||||
# 此SC版gsm8k测试配置可以在: opencompass.configs.datasets.gsm8k.gsm8k_gen_a3e34a.py 中找到。
|
||||
@ -73,3 +75,54 @@ gsm8k_eval_cfg = dict(sc_size=SAMPLE_SIZE)
|
||||

|
||||
|
||||
从图中可以看出,在不同的推理任务中,随着推理路径数量的增加,性能呈现出增长的趋势。但是,对于某些任务,增加推理路径的数量可能达到一个极限,进一步增加推理路径的数量可能不会带来更多的性能提升。因此,需要在具体任务中进行实验和调整,找到最适合任务的推理路径数量。
|
||||
|
||||
## 4. Tree-of-Thoughts
|
||||
|
||||
相比一般的CoT方法采样一条推理路径,ToT(Tree-of-Thoughts)允许语言模型同时考虑多种不同的推理路径,通过对推理过程进行自我评估,以及在必要时进行前瞻或回溯以做出全局选择。具体的,分为下面四个阶段:
|
||||
|
||||
**1. 问题分解 (Thought Decomposition)**
|
||||
|
||||
根据问题的特点,将问题分解成多个中间步骤。每个步骤可以是短语、算式或写作计划,这取决于问题的性质。
|
||||
|
||||
**2. 推理过程生成 (Thought Generation)**
|
||||
|
||||
假设解决问题需要k个步骤,有两种方法生成推理内容:
|
||||
|
||||
- 独立采样:对于每个状态,模型会独立地从CoT提示中完整抽取k个推理内容,不依赖于其他的推理内容。
|
||||
- 顺序生成:顺序地使用“提示”来逐步引导推理内容生成,每个推理内容都可能依赖于前一个推理内容。
|
||||
|
||||
**3. 启发式评估 (Heuristic Evaluation)**
|
||||
|
||||
使用启发式方法评估每个生成的推理内容对问题解决的贡献,这种自我评估基于语言模型的自我反馈,如设计Prompt让模型对多个生成结果进行打分。
|
||||
|
||||
**4. 选择搜索算法 (Search Algorithm)**
|
||||
|
||||
根据生成和评估推理内容的方法,选择适当的搜索算法。例如,可以使用广度优先搜索(BFS)或深度优先搜索(DFS)等算法来系统地探索思考树,并进行前瞻和回溯。
|
||||
|
||||
在OpenCompass中,需要根据需要设置ToT参数,以下是[ToT论文](https://arxiv.org/pdf/2305.10601.pdf)中24点游戏的样例配置,目前支持Huggingface模型进行ToT推理:
|
||||
|
||||
```python
|
||||
# 此 ToT Game24 配置可以在以下路径找到:opencompass/configs/datasets/game24/game24_gen_8dfde3.py。
|
||||
from opencompass.datasets import (Game24Dataset, game24_postprocess,
|
||||
Game24Evaluator, Game24PromptWrapper)
|
||||
|
||||
generation_kwargs = dict(temperature=0.7)
|
||||
|
||||
game24_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template='{input}'), #直接传入input内容,因为Prompt需要分段指定
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=ToTInferencer, # 替换GenInferencer为ToTInferencer
|
||||
generation_kwargs=generation_kwargs,
|
||||
method_generate='propose', # 生成推理内容的方法,可以是独立采样(sample)或顺序生成(propose)
|
||||
method_evaluate='value', # 评估推理内容的方法,可以是投票 (vote)或打分(value)
|
||||
method_select='greedy', # 选择推理内容的方法,可以是贪心(greedy)或随机(sample)
|
||||
n_evaluate_sample=3,
|
||||
n_select_sample=5,
|
||||
task_wrapper=dict(type=Game24PromptWrapper) # 该Wrapper类包含每个步骤的Prompt和推理内容的生成及评估方法,需要根据任务进行自定义
|
||||
))
|
||||
|
||||
```
|
||||
|
||||
如果要在自定义的数据集上使用ToT方法,相比普通评测方式,需要在`opencompass.datasets.YourDataConfig.py`中额外设置`YourDataPromptWrapper`类,以进行ToT中的推理生成和启发式评估。对于类似游戏24点的推理任务,具体可以参考`opencompass/datasets/game24.py`。
|
||||
|
@ -25,6 +25,7 @@ from .drcd import * # noqa: F401, F403
|
||||
from .drop import * # noqa: F401, F403
|
||||
from .eprstmt import * # noqa: F401, F403
|
||||
from .flores import * # noqa: F401, F403
|
||||
from .game24 import * # noqa: F401, F403
|
||||
from .GaokaoBench import * # noqa: F401, F403
|
||||
from .govrepcrs import * # noqa: F401, F403
|
||||
from .gsm8k import * # noqa: F401, F403
|
||||
|
256
opencompass/datasets/game24.py
Normal file
256
opencompass/datasets/game24.py
Normal file
@ -0,0 +1,256 @@
|
||||
# Prompt and utils of Game24 dataset, edited from:
|
||||
# https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
import sympy
|
||||
from datasets import Dataset
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
# 5-shot
|
||||
standard_prompt = '''Use numbers and basic arithmetic operations \
|
||||
(+ - * /) to obtain 24.
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) = 24
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * 12 * (10 - 9) = 24
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 9) * (10 - 4) = 24
|
||||
Input: 1 4 8 8
|
||||
Answer: (8 / 4 + 1) * 8 = 24
|
||||
Input: 5 5 5 9
|
||||
Answer: 5 + 5 + 5 + 9 = 24
|
||||
Input: {input}
|
||||
'''
|
||||
|
||||
# 5-shot
|
||||
cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to \
|
||||
obtain 24. Each step, you are only allowed to choose two of the remaining \
|
||||
numbers to obtain a new number.
|
||||
Input: 4 4 6 8
|
||||
Steps:
|
||||
4 + 8 = 12 (left: 4 6 12)
|
||||
6 - 4 = 2 (left: 2 12)
|
||||
2 * 12 = 24 (left: 24)
|
||||
Answer: (6 - 4) * (4 + 8) = 24
|
||||
Input: 2 9 10 12
|
||||
Steps:
|
||||
12 * 2 = 24 (left: 9 10 24)
|
||||
10 - 9 = 1 (left: 1 24)
|
||||
24 * 1 = 24 (left: 24)
|
||||
Answer: (12 * 2) * (10 - 9) = 24
|
||||
Input: 4 9 10 13
|
||||
Steps:
|
||||
13 - 10 = 3 (left: 3 4 9)
|
||||
9 - 3 = 6 (left: 4 6)
|
||||
4 * 6 = 24 (left: 24)
|
||||
Answer: 4 * (9 - (13 - 10)) = 24
|
||||
Input: 1 4 8 8
|
||||
Steps:
|
||||
8 / 4 = 2 (left: 1 2 8)
|
||||
1 + 2 = 3 (left: 3 8)
|
||||
3 * 8 = 24 (left: 24)
|
||||
Answer: (1 + 8 / 4) * 8 = 24
|
||||
Input: 5 5 5 9
|
||||
Steps:
|
||||
5 + 5 = 10 (left: 5 9 10)
|
||||
10 + 5 = 15 (left: 9 15)
|
||||
15 + 9 = 24 (left: 24)
|
||||
Answer: ((5 + 5) + 5) + 9 = 24
|
||||
Input: {input}
|
||||
'''
|
||||
|
||||
# 1-shot
|
||||
propose_prompt = '''Input: 2 8 8 14
|
||||
Possible next steps:
|
||||
2 + 8 = 10 (left: 8 10 14)
|
||||
8 / 2 = 4 (left: 4 8 14)
|
||||
14 + 2 = 16 (left: 8 8 16)
|
||||
2 * 8 = 16 (left: 8 14 16)
|
||||
8 - 2 = 6 (left: 6 8 14)
|
||||
14 - 8 = 6 (left: 2 6 8)
|
||||
14 / 2 = 7 (left: 7 8 8)
|
||||
14 - 2 = 12 (left: 8 8 12)
|
||||
Input: {input}
|
||||
Possible next steps:
|
||||
'''
|
||||
|
||||
value_prompt = '''Evaluate if given numbers can reach 24 \
|
||||
(sure/likely/impossible)
|
||||
10 14
|
||||
10 + 14 = 24
|
||||
sure
|
||||
11 12
|
||||
11 + 12 = 23
|
||||
12 - 11 = 1
|
||||
11 * 12 = 132
|
||||
11 / 12 = 0.91
|
||||
impossible
|
||||
4 4 10
|
||||
4 + 4 + 10 = 8 + 10 = 18
|
||||
4 * 10 - 4 = 40 - 4 = 36
|
||||
(10 - 4) * 4 = 6 * 4 = 24
|
||||
sure
|
||||
4 9 11
|
||||
9 + 11 + 4 = 20 + 4 = 24
|
||||
sure
|
||||
5 7 8
|
||||
5 + 7 + 8 = 12 + 8 = 20
|
||||
(8 - 5) * 7 = 3 * 7 = 21
|
||||
I cannot obtain 24 now, but numbers are within a reasonable range
|
||||
likely
|
||||
5 6 6
|
||||
5 + 6 + 6 = 17
|
||||
(6 - 5) * 6 = 1 * 6 = 6
|
||||
I cannot obtain 24 now, but numbers are within a reasonable range
|
||||
likely
|
||||
10 10 11
|
||||
10 + 10 + 11 = 31
|
||||
(11 - 10) * 10 = 10
|
||||
10 10 10 are all too big
|
||||
impossible
|
||||
1 3 3
|
||||
1 * 3 * 3 = 9
|
||||
(1 + 3) * 3 = 12
|
||||
1 3 3 are all too small
|
||||
impossible
|
||||
{input}
|
||||
'''
|
||||
|
||||
value_last_step_prompt = '''Use numbers and basic arithmetic operations \
|
||||
(+ - * /) to obtain 24. Given an input and an answer, give a judgement \
|
||||
(sure/impossible) if the answer is correct, i.e. it uses each input exactly \
|
||||
once and no other numbers, and reach 24.
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * 12 * (10 - 9) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 9) * (10 - 4) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) + 1 = 25
|
||||
Judge:
|
||||
impossible
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * (12 - 10) = 24
|
||||
Judge:
|
||||
impossible
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 4) * (10 - 9) = 24
|
||||
Judge:
|
||||
impossible
|
||||
Input: {input}
|
||||
Answer: {answer}
|
||||
Judge:'''
|
||||
|
||||
|
||||
def get_current_numbers(y: str) -> str:
|
||||
last_line = y.strip().split('\n')[-1]
|
||||
return last_line.split('left: ')[-1].split(')')[0]
|
||||
|
||||
|
||||
class Game24Dataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
data = list(pd.read_csv(path)['Puzzles'])
|
||||
data = [{'input': i, 'output': i} for i in data]
|
||||
return Dataset.from_list(data[900:905])
|
||||
|
||||
|
||||
class Game24PromptWrapper:
|
||||
"""Wrapper for Game24 prompts and outputs.
|
||||
|
||||
standard_prompt_wrap、cot_prompt_wrap、propose_prompt_wrap:
|
||||
Get prompts for different sample method.
|
||||
value_prompt_wrap:
|
||||
Get prompts for value-based evaluation method.
|
||||
value_outputs_unwrap:
|
||||
Calculate total value score for value outputs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.value_cache = {}
|
||||
self.steps = 4
|
||||
self.stops = ['\n'] * 4
|
||||
|
||||
@staticmethod
|
||||
def standard_prompt_wrap(x: str, y: str = '') -> str:
|
||||
return standard_prompt.format(input=x) + y
|
||||
|
||||
@staticmethod
|
||||
def cot_prompt_wrap(x: str, y: str = '') -> str:
|
||||
return cot_prompt.format(input=x) + y
|
||||
|
||||
@staticmethod
|
||||
def propose_prompt_wrap(x: str, y: str = '') -> str:
|
||||
current_numbers = get_current_numbers(y if y else x)
|
||||
if current_numbers == '24':
|
||||
prompt = cot_prompt.format(input=x) + 'Steps:' + y
|
||||
else:
|
||||
prompt = propose_prompt.format(input=current_numbers)
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def value_prompt_wrap(x: str, y: str) -> str:
|
||||
last_line = y.strip().split('\n')[-1]
|
||||
if 'left: ' not in last_line: # last step
|
||||
ans = last_line.lower().replace('answer: ', '')
|
||||
return value_last_step_prompt.format(input=x, answer=ans)
|
||||
current_numbers = get_current_numbers(y)
|
||||
return value_prompt.format(input=current_numbers)
|
||||
|
||||
@staticmethod
|
||||
def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
|
||||
if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
|
||||
return 0
|
||||
value_names = [_.split('\n')[-1] for _ in value_outputs]
|
||||
value_map = {
|
||||
'impossible': 0.001,
|
||||
'likely': 1,
|
||||
'sure': 20
|
||||
} # TODO: ad hoc
|
||||
value = sum(value * value_names.count(name)
|
||||
for name, value in value_map.items())
|
||||
return value
|
||||
|
||||
|
||||
def game24_postprocess(output: str):
|
||||
expression = output.strip().split('\n')[-1].lower().replace(
|
||||
'answer: ', '').split('=')[0]
|
||||
return expression
|
||||
|
||||
|
||||
class Game24Evaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def check_nums(self, prediction, reference):
|
||||
numbers = re.findall(r'\d+', prediction)
|
||||
problem_numbers = re.findall(r'\d+', reference)
|
||||
if sorted(numbers) != sorted(problem_numbers):
|
||||
return 0
|
||||
try:
|
||||
return int(sympy.simplify(prediction) == 24)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def score(self, predictions: List, references: List) -> dict:
|
||||
n = len(predictions)
|
||||
acc = 0
|
||||
for prediction, reference in zip(predictions, references):
|
||||
if self.check_nums(prediction, reference):
|
||||
acc += 1
|
||||
|
||||
return {'acc score': acc / n}
|
@ -3,3 +3,4 @@ from .icl_clp_inferencer import CLPInferencer # noqa
|
||||
from .icl_gen_inferencer import GenInferencer # noqa
|
||||
from .icl_ppl_inferencer import PPLInferencer # noqa
|
||||
from .icl_sc_inferencer import SCInferencer # noqa
|
||||
from .icl_tot_inferencer import ToTInferencer # noqa
|
||||
|
380
opencompass/openicl/icl_inferencer/icl_tot_inferencer.py
Normal file
380
opencompass/openicl/icl_inferencer/icl_tot_inferencer.py
Normal file
@ -0,0 +1,380 @@
|
||||
"""Tree-of-Thought Generation Inferencer."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.registry import ICL_INFERENCERS, TOT_WRAPPER
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils.logging import get_logger
|
||||
from .icl_gen_inferencer import GenInferencer, GenInferencerOutputHandler
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class ToTInferencer(GenInferencer):
|
||||
"""Tree-of-Thought Inferencer class to evaluate by tree style reasoning
|
||||
paths.
|
||||
Doc: https://opencompass.readthedocs.io/en/latest/prompt/
|
||||
chain_of_thought.html
|
||||
Official tot paper: https://arxiv.org/pdf/2305.10601.pdf
|
||||
|
||||
|
||||
Attributes:
|
||||
model (:obj:`BaseModelWrapper`, optional): The module to inference.
|
||||
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
|
||||
allowed by the LM.
|
||||
batch_size (:obj:`int`, optional): Batch size for the
|
||||
:obj:`DataLoader`.
|
||||
output_json_filepath (:obj:`str`, optional): File path for output
|
||||
`JSON` file.
|
||||
output_json_filename (:obj:`str`, optional): File name for output
|
||||
`JSON` file.
|
||||
gen_field_replace_token (:obj:`str`, optional): Used to replace the
|
||||
generation field token when generating prompts.
|
||||
save_every (:obj:`int`, optional): Save intermediate results every
|
||||
`save_every` epochs.
|
||||
generation_kwargs (:obj:`Dict`, optional): Parameters for the
|
||||
:obj:`model.generate()` method.
|
||||
fix_id_list (:obj:`List[int]`, optional): List of indices to fix
|
||||
naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of
|
||||
ToT + BFS.
|
||||
prompt_wrapper (:obj:`dict`): wrapper for prompts
|
||||
prompt_sample (:obj:`str`): (choices=[standard, cot]) sampling prompt
|
||||
method_generate (:obj:`str`): (choices=[sample, propose])
|
||||
thought generator,whether to sample independent thoughts (used in
|
||||
Creative Writing task) or propose sequential thoughts (used in Game
|
||||
of 24)
|
||||
method_evaluate (:obj:`str`): (choices=[value, vote]) state evaluator,
|
||||
whether to use the value states independently (used in Game of 24)
|
||||
or vote on states together (used in Creative Writing)
|
||||
n_generate_sample (:obj:`int`): number of times to prompt for
|
||||
thought generation
|
||||
n_evaluate_sample(:obj:`int`): number of times to prompt for
|
||||
state evaluation
|
||||
n_select_sample (:obj:`int`): number of states to keep from each step
|
||||
(i.e. b in the Tree-of-Thought paper's ToT + BFS algorithm)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
max_out_len: int,
|
||||
max_seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = 1,
|
||||
gen_field_replace_token: Optional[str] = '',
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = None,
|
||||
fix_id_list: Optional[List[int]] = None,
|
||||
naive_run: bool = False,
|
||||
prompt_wrapper: dict = {},
|
||||
prompt_sample: str = 'standard',
|
||||
method_generate: str = 'sample',
|
||||
method_evaluate: str = 'value',
|
||||
method_select: str = 'greedy',
|
||||
n_generate_sample: int = 1,
|
||||
n_evaluate_sample: int = 1,
|
||||
n_select_sample: int = 1,
|
||||
generation_kwargs: dict = {},
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
max_out_len=max_out_len,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
gen_field_replace_token=gen_field_replace_token,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
save_every=save_every,
|
||||
fix_id_list=fix_id_list,
|
||||
sc_size=n_evaluate_sample,
|
||||
**kwargs,
|
||||
)
|
||||
self.max_out_len = max_out_len
|
||||
self.prompt_wrapper = TOT_WRAPPER.build(prompt_wrapper)
|
||||
self.naive_run = naive_run
|
||||
self.prompt_sample = prompt_sample
|
||||
self.method_generate = method_generate
|
||||
self.method_evaluate = method_evaluate
|
||||
self.method_select = method_select
|
||||
self.n_generate_sample = n_generate_sample
|
||||
self.n_evaluate_sample = n_evaluate_sample
|
||||
self.n_select_sample = n_select_sample
|
||||
self.generation_kwargs = generation_kwargs
|
||||
|
||||
def get_value(self,
|
||||
x: str,
|
||||
y: str,
|
||||
n_evaluate_sample: int,
|
||||
cache_value: bool = True) -> str:
|
||||
"""Get evaluation value of a partial output.
|
||||
|
||||
Args:
|
||||
x (str): The input text to be evaluated.
|
||||
y (str): The partial output to be evaluated.
|
||||
n_evaluate_sample (int): Times to evaluate each partial output.
|
||||
cache_value (bool): Cache to avoid duplicate candidates.
|
||||
Defaults to True.
|
||||
Returns:
|
||||
str: Value of evaluated partial outputs.
|
||||
"""
|
||||
value_prompt = self.prompt_wrapper.value_prompt_wrap(x, y)
|
||||
if cache_value and value_prompt in self.prompt_wrapper.value_cache:
|
||||
return self.prompt_wrapper.value_cache[value_prompt]
|
||||
value_outputs = self.model.generate_from_template(
|
||||
[value_prompt],
|
||||
max_out_len=self.max_out_len,
|
||||
num_beams=n_evaluate_sample,
|
||||
num_return_sequences=n_evaluate_sample,
|
||||
**self.generation_kwargs)
|
||||
value = self.prompt_wrapper.value_outputs_unwrap(x, y, value_outputs)
|
||||
if cache_value:
|
||||
self.prompt_wrapper.value_cache[value_prompt] = value
|
||||
return value
|
||||
|
||||
def get_values(self,
|
||||
x: str,
|
||||
ys: List[str],
|
||||
n_evaluate_sample: int,
|
||||
cache_value: bool = True) -> List[str]:
|
||||
"""Get evaluation values of partial outputs.
|
||||
|
||||
Args:
|
||||
x (str): The input text to be solved.
|
||||
ys (List[str]): The partial outputs to be evaluated.
|
||||
n_evaluate_sample (int): Times to evaluate each partial output.
|
||||
cache_value (bool): Cache to avoid duplicate candidates.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
List[str]: Values of evaluated partial outputs.
|
||||
"""
|
||||
values = []
|
||||
local_value_cache = {}
|
||||
for y in ys: # each partial output
|
||||
if y in local_value_cache: # avoid duplicate candidates
|
||||
value = 0
|
||||
else:
|
||||
value = self.get_value(x,
|
||||
y,
|
||||
n_evaluate_sample,
|
||||
cache_value=cache_value)
|
||||
local_value_cache[y] = value
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def get_votes(self, x: str, ys: List[str],
|
||||
n_evaluate_sample: int) -> List[str]:
|
||||
"""Get votes of partial outputs.
|
||||
|
||||
Args:
|
||||
x (str): The input text to be solved.
|
||||
ys (List[str]): The partial outputs to be evaluated.
|
||||
n_evaluate_sample (int): Times to evaluate each partial output.
|
||||
|
||||
Returns:
|
||||
List[str]: Values of evaluated partial outputs.
|
||||
"""
|
||||
vote_prompt = self.prompt_wrapper.vote_prompt_wrap(x, ys)
|
||||
vote_outputs = self.model.generate_from_template(
|
||||
[vote_prompt],
|
||||
max_out_len=self.max_out_len,
|
||||
num_beams=n_evaluate_sample,
|
||||
num_return_sequences=n_evaluate_sample,
|
||||
**self.generation_kwargs)
|
||||
values = self.prompt_wrapper.vote_outputs_unwrap(vote_outputs, len(ys))
|
||||
return values
|
||||
|
||||
def get_proposals(self, x: str, y: str) -> List[str]:
|
||||
"""Get proposal prompts.
|
||||
|
||||
Args:
|
||||
x (str): The input text to be solved.
|
||||
y (str): The partial output.
|
||||
|
||||
Returns:
|
||||
List[str]: Proposal prompts.
|
||||
"""
|
||||
propose_prompt = self.prompt_wrapper.propose_prompt_wrap(x, y)
|
||||
proposals = self.model.generate_from_template(
|
||||
[propose_prompt],
|
||||
max_out_len=self.max_out_len,
|
||||
num_beams=1,
|
||||
num_return_sequences=1,
|
||||
**self.generation_kwargs)[0].split('\n')
|
||||
return [y + _ + '\n' for _ in proposals]
|
||||
|
||||
def get_samples(self, x: str, y: str, n_generate_sample: int,
|
||||
prompt_sample: str):
|
||||
"""Get samples from a partial output.
|
||||
|
||||
Args:
|
||||
x (str): The input text to be solved.
|
||||
y (str): The partial output.
|
||||
n_generate_sample (int): Times to generate samples.
|
||||
prompt_sample (str): (choices=[standard, cot]) sampling prompt
|
||||
|
||||
Returns:
|
||||
List[str]: Samples from a partial output.
|
||||
"""
|
||||
if prompt_sample == 'standard':
|
||||
prompt = self.prompt_wrapper.standard_prompt_wrap(x, y)
|
||||
elif prompt_sample == 'cot':
|
||||
prompt = self.prompt_wrapper.cot_prompt_wrap(x, y)
|
||||
else:
|
||||
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
|
||||
samples = self.model.generate_from_template(
|
||||
[prompt],
|
||||
max_out_len=self.max_out_len,
|
||||
num_beams=n_generate_sample,
|
||||
num_return_sequences=n_generate_sample,
|
||||
**self.generation_kwargs)
|
||||
return [y + _ for _ in samples]
|
||||
|
||||
def tot_solve(self, x: str) -> str:
|
||||
"""Solve a problem using Tree-of-Thought algorithm.
|
||||
|
||||
Args:
|
||||
x (str): The input text to be solved.
|
||||
|
||||
Returns:
|
||||
str: Final answer of the problem.
|
||||
"""
|
||||
ys = [''] # current output candidates
|
||||
infos = []
|
||||
for step in range(self.prompt_wrapper.steps):
|
||||
logger.info(f'\n-- step {str(step)} --\n')
|
||||
# generation
|
||||
if self.method_generate == 'sample':
|
||||
new_ys = [
|
||||
self.get_samples(x,
|
||||
y,
|
||||
self.n_generate_sample,
|
||||
prompt_sample=self.prompt_sample)
|
||||
for y in ys
|
||||
]
|
||||
elif self.method_generate == 'propose':
|
||||
new_ys = [self.get_proposals(x, y) for y in ys]
|
||||
new_ys = list(itertools.chain(*new_ys))
|
||||
ids = list(range(len(new_ys)))
|
||||
# evaluation
|
||||
if self.method_evaluate == 'vote':
|
||||
values = self.get_votes(x, new_ys, self.n_evaluate_sample)
|
||||
elif self.method_evaluate == 'value':
|
||||
values = self.get_values(x, new_ys, self.n_evaluate_sample)
|
||||
|
||||
# selection
|
||||
if self.method_select == 'sample':
|
||||
ps = np.array(values) / sum(values)
|
||||
select_ids = np.random.choice(ids,
|
||||
size=self.n_select_sample,
|
||||
p=ps).tolist()
|
||||
elif self.method_select == 'greedy':
|
||||
select_ids = sorted(ids, key=lambda x: values[x],
|
||||
reverse=True)[:self.n_select_sample]
|
||||
select_new_ys = [new_ys[select_id] for select_id in select_ids]
|
||||
|
||||
# log
|
||||
sorted_new_ys, sorted_values = zip(
|
||||
*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
|
||||
logger.info(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: '
|
||||
f'{sorted_values}\n-- choices --: {select_new_ys}\n')
|
||||
|
||||
infos.append({
|
||||
'step': step,
|
||||
'x': x,
|
||||
'ys': ys,
|
||||
'new_ys': new_ys,
|
||||
'values': values,
|
||||
'select_new_ys': select_new_ys
|
||||
})
|
||||
ys = select_new_ys
|
||||
logger.info(ys)
|
||||
|
||||
return ys
|
||||
|
||||
def inference(self,
|
||||
retriever: BaseRetriever,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
output_json_filepath: Optional[str] = None,
|
||||
output_json_filename: Optional[str] = None) -> List:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = GenInferencerOutputHandler()
|
||||
|
||||
if output_json_filepath is None:
|
||||
output_json_filepath = self.output_json_filepath
|
||||
if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
if 'Fix' in retriever.__class__.__name__:
|
||||
ice_idx_list = retriever.retrieve(self.fix_id_list)
|
||||
else:
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
|
||||
ice_idx_list,
|
||||
retriever,
|
||||
self.gen_field_replace_token,
|
||||
max_seq_len=self.max_seq_len,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
# TODO: move resume to output handler
|
||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||
output_handler.results_dict = tmp_result_dict
|
||||
index = len(tmp_result_dict)
|
||||
|
||||
# 4. Wrap prompts with Dataloader
|
||||
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting ToT inference process...')
|
||||
for entries in tqdm(dataloader, disable=not self.is_main_process):
|
||||
# 5-1. Inference with ToT and local model
|
||||
with torch.no_grad():
|
||||
parsed_entries = self.model.parse_template(entries, mode='gen')
|
||||
generated = [self.tot_solve(entry) for entry in entries]
|
||||
|
||||
# 5-2. Save current output
|
||||
for prompt, prediction in zip(parsed_entries, generated):
|
||||
output_handler.save_results(prompt, prediction, index)
|
||||
index = index + 1
|
||||
|
||||
# 5-3. Save intermediate results
|
||||
if (self.save_every is not None and index % self.save_every == 0
|
||||
and self.is_main_process):
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
|
||||
# 6. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
os.remove(tmp_json_filepath)
|
||||
|
||||
return [
|
||||
sample['prediction']
|
||||
for sample in output_handler.results_dict.values()
|
||||
]
|
@ -34,3 +34,4 @@ METRICS = Registry('metric',
|
||||
MM_MODELS = Registry('mm_model',
|
||||
parent=MMENGINE_MODELS,
|
||||
locations=['opencompass.multimodal.models'])
|
||||
TOT_WRAPPER = Registry('tot_wrapper', locations=['opencompass.datasets'])
|
||||
|
Loading…
Reference in New Issue
Block a user