[Feature] Add Tree-of-Thought method (#173)

* Add ToT method

* Update ToT

* Update ToT

* Update ToT

* Update ToT

* Update ToT

* Update ToT

* Update ToT

* Update chain_of_thought.md

* Update icl_tot_inferencer.py

---------

Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
This commit is contained in:
liushz 2023-08-23 12:23:05 +08:00 committed by GitHub
parent ff5ab92331
commit 02ce139bc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 783 additions and 1 deletions

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .game24_gen_8dfde3 import game24_datasets # noqa: F401, F403

View File

@ -0,0 +1,34 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import ToTInferencer
from opencompass.datasets import (Game24Dataset, game24_postprocess,
Game24Evaluator, Game24PromptWrapper)
generation_kwargs = dict(do_sample=False, temperature=0.7)
game24_reader_cfg = dict(
input_columns=['input'],
output_column='output')
game24_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{input}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ToTInferencer, generation_kwargs=generation_kwargs, method_generate='propose',
method_evaluate='value', method_select='greedy', n_evaluate_sample=3, n_select_sample=5, prompt_wrapper=dict(type=Game24PromptWrapper)))
game24_eval_cfg = dict(
evaluator=dict(type=Game24Evaluator),
pred_postprocessor=dict(type=game24_postprocess),
)
game24_datasets = [
dict(
abbr='game24',
type=Game24Dataset,
path='./data/game24/game24.csv',
reader_cfg=game24_reader_cfg,
infer_cfg=game24_infer_cfg,
eval_cfg=game24_eval_cfg)
]

View File

@ -4,6 +4,8 @@
During the process of reasoning, CoT (Chain of Thought) method is an efficient way to help LLMs deal complex questions, for example: math problem and relation inference. In OpenCompass, we support multiple types of CoT method.
![image](https://github.com/InternLM/opencompass/assets/28834990/45d60e0e-02a1-49aa-b792-40a1f95f9b9e)
## 1. Zero Shot CoT
You can change the `PromptTemplate` of the dataset config, by simply add *Let's think step by step* to realize a Zero-Shot CoT prompt for your evaluation:
@ -73,3 +75,53 @@ Where `SAMPLE_SIZE` is the number of reasoning paths in Self-Consistency, higher
![image](https://github.com/InternLM/opencompass/assets/28834990/05c7d850-7076-43ca-b165-e6251f9b3001)
From the figure, it can be seen that in different reasoning tasks, performance tends to improve as the number of reasoning paths increases. However, for some tasks, increasing the number of reasoning paths may reach a limit, and further increasing the number of paths may not bring significant performance improvement. Therefore, it is necessary to conduct experiments and adjustments on specific tasks to find the optimal number of reasoning paths that best suit the task.
## 4. Tree-of-Thoughts
In contrast to the conventional CoT approach that considers only a single reasoning path, Tree-of-Thoughts (ToT) allows the language model to explore multiple diverse reasoning paths simultaneously. The model evaluates the reasoning process through self-assessment and makes global choices by conducting lookahead or backtracking when necessary. Specifically, this process is divided into the following four stages:
**1. Thought Decomposition**
Based on the nature of the problem, break down the problem into multiple intermediate steps. Each step can be a phrase, equation, or writing plan, depending on the nature of the problem.
**2. Thought Generation**
Assuming that solving the problem requires k steps, there are two methods to generate reasoning content:
- Independent sampling: For each state, the model independently extracts k reasoning contents from the CoT prompts, without relying on other reasoning contents.
- Sequential generation: Sequentially use "prompts" to guide the generation of reasoning content, where each reasoning content may depend on the previous one.
**3. Heuristic Evaluation**
Use heuristic methods to evaluate the contribution of each generated reasoning content to problem-solving. This self-evaluation is based on the model's self-feedback and involves designing prompts to have the model score multiple generated results.
**4. Search Algorithm Selection**
Based on the methods of generating and evaluating reasoning content, select an appropriate search algorithm. For example, you can use breadth-first search (BFS) or depth-first search (DFS) algorithms to systematically explore the thought tree, conducting lookahead and backtracking.
In OpenCompass, ToT parameters need to be set according to the requirements. Below is an example configuration for the 24-Point game from the [official paper](https://arxiv.org/pdf/2305.10601.pdf). Currently, ToT inference is supported only with Huggingface models:
```python
# This ToT Game24 config can be found at: opencompass/configs/datasets/game24/game24_gen_8dfde3.py.
from opencompass.datasets import (Game24Dataset, game24_postprocess,
Game24Evaluator, Game24PromptWrapper)
generation_kwargs = dict(temperature=0.7)
game24_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{input}'), # Directly pass the input content, as the Prompt needs to be specified in steps
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ToTInferencer, # Replace GenInferencer with ToTInferencer
generation_kwargs=generation_kwargs,
method_generate='propose', # Method for generating reasoning content, can be independent sampling (sample) or sequential generation (propose)
method_evaluate='value', # Method for evaluating reasoning content, can be voting (vote) or scoring (value)
method_select='greedy', # Method for selecting reasoning content, can be greedy (greedy) or random (sample)
n_evaluate_sample=3,
n_select_sample=5,
task_wrapper=dict(type=Game24PromptWrapper) # This Wrapper class includes the prompts for each step and methods for generating and evaluating reasoning content, needs customization according to the task
))
```
If you want to use the ToT method on a custom dataset, you'll need to make additional configurations in the `opencompass.datasets.YourDataConfig.py` file to set up the `YourDataPromptWrapper` class. This is required for handling the thought generation and heuristic evaluation step within the ToT framework. For reasoning tasks similar to the game 24-Point, you can refer to the implementation in `opencompass/datasets/game24.py` for guidance.

View File

@ -4,6 +4,8 @@
CoT思维链是帮助大型语言模型解决如数学问题和关系推理问题等复杂问题的有效方式在OpenCompass中我们支持多种类型的CoT方法。
![image](https://github.com/InternLM/opencompass/assets/28834990/45d60e0e-02a1-49aa-b792-40a1f95f9b9e)
## 1. 零样本思维链
可以通过在数据集配置中简单地添加 “Let's think step by step",来更改数据集配置的 PromptTemplate从而实现 零样本 CoT prompt 以进行评估:
@ -49,7 +51,7 @@ Question: {question}\nLet's think step by step:\n{answer}
## 3. Self-Consistency
SC (Self-Consistency) 方法是在 [此文章](https://arxiv.org/abs/2203.11171) 中提出的,该方法会为问题生成多不同的推理路径,并对生成的答案进行众数投票。这种方法在复杂推理任务中表现出了显著的能力,但由于需要推理多次来采样多条推理链,所以可能会消耗很多的时间和资源。在 OpenCompass 中,您可以通过在数据集配置中将 `GenInferencer` 替换为 `SCInferencer` 并设置相应的参数参数来简单地实现 SC 方法,例如:
SC (Self-Consistency) 方法是在 [此文章](https://arxiv.org/abs/2203.11171) 中提出的,该方法会为问题生成多不同的推理路径,并对生成的答案进行众数投票。这种方法在复杂推理任务中表现出了显著的能力,但由于需要推理多次来采样多条推理链,所以可能会消耗很多的时间和资源。在 OpenCompass 中,您可以通过在数据集配置中将 `GenInferencer` 替换为 `SCInferencer` 并设置相应的参数参数来简单地实现 SC 方法,例如:
```python
# 此SC版gsm8k测试配置可以在 opencompass.configs.datasets.gsm8k.gsm8k_gen_a3e34a.py 中找到。
@ -73,3 +75,54 @@ gsm8k_eval_cfg = dict(sc_size=SAMPLE_SIZE)
![image](https://github.com/InternLM/opencompass/assets/28834990/05c7d850-7076-43ca-b165-e6251f9b3001)
从图中可以看出,在不同的推理任务中,随着推理路径数量的增加,性能呈现出增长的趋势。但是,对于某些任务,增加推理路径的数量可能达到一个极限,进一步增加推理路径的数量可能不会带来更多的性能提升。因此,需要在具体任务中进行实验和调整,找到最适合任务的推理路径数量。
## 4. Tree-of-Thoughts
相比一般的CoT方法采样一条推理路径ToT(Tree-of-Thoughts)允许语言模型同时考虑多种不同的推理路径,通过对推理过程进行自我评估,以及在必要时进行前瞻或回溯以做出全局选择。具体的,分为下面四个阶段:
**1. 问题分解 (Thought Decomposition)**
根据问题的特点,将问题分解成多个中间步骤。每个步骤可以是短语、算式或写作计划,这取决于问题的性质。
**2. 推理过程生成 (Thought Generation)**
假设解决问题需要k个步骤有两种方法生成推理内容
- 独立采样对于每个状态模型会独立地从CoT提示中完整抽取k个推理内容不依赖于其他的推理内容。
- 顺序生成:顺序地使用“提示”来逐步引导推理内容生成,每个推理内容都可能依赖于前一个推理内容。
**3. 启发式评估 (Heuristic Evaluation)**
使用启发式方法评估每个生成的推理内容对问题解决的贡献这种自我评估基于语言模型的自我反馈如设计Prompt让模型对多个生成结果进行打分。
**4. 选择搜索算法 (Search Algorithm)**
根据生成和评估推理内容的方法选择适当的搜索算法。例如可以使用广度优先搜索BFS或深度优先搜索DFS等算法来系统地探索思考树并进行前瞻和回溯。
在OpenCompass中需要根据需要设置ToT参数以下是[ToT论文](https://arxiv.org/pdf/2305.10601.pdf)中24点游戏的样例配置目前支持Huggingface模型进行ToT推理
```python
# 此 ToT Game24 配置可以在以下路径找到opencompass/configs/datasets/game24/game24_gen_8dfde3.py。
from opencompass.datasets import (Game24Dataset, game24_postprocess,
Game24Evaluator, Game24PromptWrapper)
generation_kwargs = dict(temperature=0.7)
game24_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{input}'), #直接传入input内容因为Prompt需要分段指定
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=ToTInferencer, # 替换GenInferencer为ToTInferencer
generation_kwargs=generation_kwargs,
method_generate='propose', # 生成推理内容的方法可以是独立采样sample或顺序生成propose
method_evaluate='value', # 评估推理内容的方法,可以是投票 vote或打分value
method_select='greedy', # 选择推理内容的方法可以是贪心greedy或随机sample
n_evaluate_sample=3,
n_select_sample=5,
task_wrapper=dict(type=Game24PromptWrapper) # 该Wrapper类包含每个步骤的Prompt和推理内容的生成及评估方法需要根据任务进行自定义
))
```
如果要在自定义的数据集上使用ToT方法相比普通评测方式需要在`opencompass.datasets.YourDataConfig.py`中额外设置`YourDataPromptWrapper`类以进行ToT中的推理生成和启发式评估。对于类似游戏24点的推理任务具体可以参考`opencompass/datasets/game24.py`。

View File

@ -25,6 +25,7 @@ from .drcd import * # noqa: F401, F403
from .drop import * # noqa: F401, F403
from .eprstmt import * # noqa: F401, F403
from .flores import * # noqa: F401, F403
from .game24 import * # noqa: F401, F403
from .GaokaoBench import * # noqa: F401, F403
from .govrepcrs import * # noqa: F401, F403
from .gsm8k import * # noqa: F401, F403

View File

@ -0,0 +1,256 @@
# Prompt and utils of Game24 dataset, edited from:
# https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py
import re
from typing import List
import pandas as pd
import sympy
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from .base import BaseDataset
# 5-shot
standard_prompt = '''Use numbers and basic arithmetic operations \
(+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
'''
# 5-shot
cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to \
obtain 24. Each step, you are only allowed to choose two of the remaining \
numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
'''
# 1-shot
propose_prompt = '''Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 / 2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Input: {input}
Possible next steps:
'''
value_prompt = '''Evaluate if given numbers can reach 24 \
(sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{input}
'''
value_last_step_prompt = '''Use numbers and basic arithmetic operations \
(+ - * /) to obtain 24. Given an input and an answer, give a judgement \
(sure/impossible) if the answer is correct, i.e. it uses each input exactly \
once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge:
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge:
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge:
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge:
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge:
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge:
impossible
Input: {input}
Answer: {answer}
Judge:'''
def get_current_numbers(y: str) -> str:
last_line = y.strip().split('\n')[-1]
return last_line.split('left: ')[-1].split(')')[0]
class Game24Dataset(BaseDataset):
@staticmethod
def load(path: str):
data = list(pd.read_csv(path)['Puzzles'])
data = [{'input': i, 'output': i} for i in data]
return Dataset.from_list(data[900:905])
class Game24PromptWrapper:
"""Wrapper for Game24 prompts and outputs.
standard_prompt_wrapcot_prompt_wrappropose_prompt_wrap:
Get prompts for different sample method.
value_prompt_wrap:
Get prompts for value-based evaluation method.
value_outputs_unwrap:
Calculate total value score for value outputs.
"""
def __init__(self):
self.value_cache = {}
self.steps = 4
self.stops = ['\n'] * 4
@staticmethod
def standard_prompt_wrap(x: str, y: str = '') -> str:
return standard_prompt.format(input=x) + y
@staticmethod
def cot_prompt_wrap(x: str, y: str = '') -> str:
return cot_prompt.format(input=x) + y
@staticmethod
def propose_prompt_wrap(x: str, y: str = '') -> str:
current_numbers = get_current_numbers(y if y else x)
if current_numbers == '24':
prompt = cot_prompt.format(input=x) + 'Steps:' + y
else:
prompt = propose_prompt.format(input=current_numbers)
return prompt
@staticmethod
def value_prompt_wrap(x: str, y: str) -> str:
last_line = y.strip().split('\n')[-1]
if 'left: ' not in last_line: # last step
ans = last_line.lower().replace('answer: ', '')
return value_last_step_prompt.format(input=x, answer=ans)
current_numbers = get_current_numbers(y)
return value_prompt.format(input=current_numbers)
@staticmethod
def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
if len(y.strip().split('\n')) == 4 and 'answer' not in y.lower():
return 0
value_names = [_.split('\n')[-1] for _ in value_outputs]
value_map = {
'impossible': 0.001,
'likely': 1,
'sure': 20
} # TODO: ad hoc
value = sum(value * value_names.count(name)
for name, value in value_map.items())
return value
def game24_postprocess(output: str):
expression = output.strip().split('\n')[-1].lower().replace(
'answer: ', '').split('=')[0]
return expression
class Game24Evaluator(BaseEvaluator):
def __init__(self) -> None:
super().__init__()
def check_nums(self, prediction, reference):
numbers = re.findall(r'\d+', prediction)
problem_numbers = re.findall(r'\d+', reference)
if sorted(numbers) != sorted(problem_numbers):
return 0
try:
return int(sympy.simplify(prediction) == 24)
except Exception:
return 0
def score(self, predictions: List, references: List) -> dict:
n = len(predictions)
acc = 0
for prediction, reference in zip(predictions, references):
if self.check_nums(prediction, reference):
acc += 1
return {'acc score': acc / n}

View File

@ -3,3 +3,4 @@ from .icl_clp_inferencer import CLPInferencer # noqa
from .icl_gen_inferencer import GenInferencer # noqa
from .icl_ppl_inferencer import PPLInferencer # noqa
from .icl_sc_inferencer import SCInferencer # noqa
from .icl_tot_inferencer import ToTInferencer # noqa

View File

@ -0,0 +1,380 @@
"""Tree-of-Thought Generation Inferencer."""
import itertools
import os
import os.path as osp
from typing import List, Optional
import mmengine
import numpy as np
import torch
from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import ICL_INFERENCERS, TOT_WRAPPER
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_gen_inferencer import GenInferencer, GenInferencerOutputHandler
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class ToTInferencer(GenInferencer):
"""Tree-of-Thought Inferencer class to evaluate by tree style reasoning
paths.
Doc: https://opencompass.readthedocs.io/en/latest/prompt/
chain_of_thought.html
Official tot paper: https://arxiv.org/pdf/2305.10601.pdf
Attributes:
model (:obj:`BaseModelWrapper`, optional): The module to inference.
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
allowed by the LM.
batch_size (:obj:`int`, optional): Batch size for the
:obj:`DataLoader`.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
gen_field_replace_token (:obj:`str`, optional): Used to replace the
generation field token when generating prompts.
save_every (:obj:`int`, optional): Save intermediate results every
`save_every` epochs.
generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method.
fix_id_list (:obj:`List[int]`, optional): List of indices to fix
naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of
ToT + BFS.
prompt_wrapper (:obj:`dict`): wrapper for prompts
prompt_sample (:obj:`str`): (choices=[standard, cot]) sampling prompt
method_generate (:obj:`str`): (choices=[sample, propose])
thought generator,whether to sample independent thoughts (used in
Creative Writing task) or propose sequential thoughts (used in Game
of 24)
method_evaluate (:obj:`str`): (choices=[value, vote]) state evaluator,
whether to use the value states independently (used in Game of 24)
or vote on states together (used in Creative Writing)
n_generate_sample (:obj:`int`): number of times to prompt for
thought generation
n_evaluate_sample(:obj:`int`): number of times to prompt for
state evaluation
n_select_sample (:obj:`int`): number of states to keep from each step
(i.e. b in the Tree-of-Thought paper's ToT + BFS algorithm)
"""
def __init__(
self,
model: BaseModel,
max_out_len: int,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
gen_field_replace_token: Optional[str] = '',
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
naive_run: bool = False,
prompt_wrapper: dict = {},
prompt_sample: str = 'standard',
method_generate: str = 'sample',
method_evaluate: str = 'value',
method_select: str = 'greedy',
n_generate_sample: int = 1,
n_evaluate_sample: int = 1,
n_select_sample: int = 1,
generation_kwargs: dict = {},
**kwargs) -> None:
super().__init__(
model=model,
max_out_len=max_out_len,
max_seq_len=max_seq_len,
batch_size=batch_size,
gen_field_replace_token=gen_field_replace_token,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
save_every=save_every,
fix_id_list=fix_id_list,
sc_size=n_evaluate_sample,
**kwargs,
)
self.max_out_len = max_out_len
self.prompt_wrapper = TOT_WRAPPER.build(prompt_wrapper)
self.naive_run = naive_run
self.prompt_sample = prompt_sample
self.method_generate = method_generate
self.method_evaluate = method_evaluate
self.method_select = method_select
self.n_generate_sample = n_generate_sample
self.n_evaluate_sample = n_evaluate_sample
self.n_select_sample = n_select_sample
self.generation_kwargs = generation_kwargs
def get_value(self,
x: str,
y: str,
n_evaluate_sample: int,
cache_value: bool = True) -> str:
"""Get evaluation value of a partial output.
Args:
x (str): The input text to be evaluated.
y (str): The partial output to be evaluated.
n_evaluate_sample (int): Times to evaluate each partial output.
cache_value (bool): Cache to avoid duplicate candidates.
Defaults to True.
Returns:
str: Value of evaluated partial outputs.
"""
value_prompt = self.prompt_wrapper.value_prompt_wrap(x, y)
if cache_value and value_prompt in self.prompt_wrapper.value_cache:
return self.prompt_wrapper.value_cache[value_prompt]
value_outputs = self.model.generate_from_template(
[value_prompt],
max_out_len=self.max_out_len,
num_beams=n_evaluate_sample,
num_return_sequences=n_evaluate_sample,
**self.generation_kwargs)
value = self.prompt_wrapper.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
self.prompt_wrapper.value_cache[value_prompt] = value
return value
def get_values(self,
x: str,
ys: List[str],
n_evaluate_sample: int,
cache_value: bool = True) -> List[str]:
"""Get evaluation values of partial outputs.
Args:
x (str): The input text to be solved.
ys (List[str]): The partial outputs to be evaluated.
n_evaluate_sample (int): Times to evaluate each partial output.
cache_value (bool): Cache to avoid duplicate candidates.
Defaults to True.
Returns:
List[str]: Values of evaluated partial outputs.
"""
values = []
local_value_cache = {}
for y in ys: # each partial output
if y in local_value_cache: # avoid duplicate candidates
value = 0
else:
value = self.get_value(x,
y,
n_evaluate_sample,
cache_value=cache_value)
local_value_cache[y] = value
values.append(value)
return values
def get_votes(self, x: str, ys: List[str],
n_evaluate_sample: int) -> List[str]:
"""Get votes of partial outputs.
Args:
x (str): The input text to be solved.
ys (List[str]): The partial outputs to be evaluated.
n_evaluate_sample (int): Times to evaluate each partial output.
Returns:
List[str]: Values of evaluated partial outputs.
"""
vote_prompt = self.prompt_wrapper.vote_prompt_wrap(x, ys)
vote_outputs = self.model.generate_from_template(
[vote_prompt],
max_out_len=self.max_out_len,
num_beams=n_evaluate_sample,
num_return_sequences=n_evaluate_sample,
**self.generation_kwargs)
values = self.prompt_wrapper.vote_outputs_unwrap(vote_outputs, len(ys))
return values
def get_proposals(self, x: str, y: str) -> List[str]:
"""Get proposal prompts.
Args:
x (str): The input text to be solved.
y (str): The partial output.
Returns:
List[str]: Proposal prompts.
"""
propose_prompt = self.prompt_wrapper.propose_prompt_wrap(x, y)
proposals = self.model.generate_from_template(
[propose_prompt],
max_out_len=self.max_out_len,
num_beams=1,
num_return_sequences=1,
**self.generation_kwargs)[0].split('\n')
return [y + _ + '\n' for _ in proposals]
def get_samples(self, x: str, y: str, n_generate_sample: int,
prompt_sample: str):
"""Get samples from a partial output.
Args:
x (str): The input text to be solved.
y (str): The partial output.
n_generate_sample (int): Times to generate samples.
prompt_sample (str): (choices=[standard, cot]) sampling prompt
Returns:
List[str]: Samples from a partial output.
"""
if prompt_sample == 'standard':
prompt = self.prompt_wrapper.standard_prompt_wrap(x, y)
elif prompt_sample == 'cot':
prompt = self.prompt_wrapper.cot_prompt_wrap(x, y)
else:
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
samples = self.model.generate_from_template(
[prompt],
max_out_len=self.max_out_len,
num_beams=n_generate_sample,
num_return_sequences=n_generate_sample,
**self.generation_kwargs)
return [y + _ for _ in samples]
def tot_solve(self, x: str) -> str:
"""Solve a problem using Tree-of-Thought algorithm.
Args:
x (str): The input text to be solved.
Returns:
str: Final answer of the problem.
"""
ys = [''] # current output candidates
infos = []
for step in range(self.prompt_wrapper.steps):
logger.info(f'\n-- step {str(step)} --\n')
# generation
if self.method_generate == 'sample':
new_ys = [
self.get_samples(x,
y,
self.n_generate_sample,
prompt_sample=self.prompt_sample)
for y in ys
]
elif self.method_generate == 'propose':
new_ys = [self.get_proposals(x, y) for y in ys]
new_ys = list(itertools.chain(*new_ys))
ids = list(range(len(new_ys)))
# evaluation
if self.method_evaluate == 'vote':
values = self.get_votes(x, new_ys, self.n_evaluate_sample)
elif self.method_evaluate == 'value':
values = self.get_values(x, new_ys, self.n_evaluate_sample)
# selection
if self.method_select == 'sample':
ps = np.array(values) / sum(values)
select_ids = np.random.choice(ids,
size=self.n_select_sample,
p=ps).tolist()
elif self.method_select == 'greedy':
select_ids = sorted(ids, key=lambda x: values[x],
reverse=True)[:self.n_select_sample]
select_new_ys = [new_ys[select_id] for select_id in select_ids]
# log
sorted_new_ys, sorted_values = zip(
*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
logger.info(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: '
f'{sorted_values}\n-- choices --: {select_new_ys}\n')
infos.append({
'step': step,
'x': x,
'ys': ys,
'new_ys': new_ys,
'values': values,
'select_new_ys': select_new_ys
})
ys = select_new_ys
logger.info(ys)
return ys
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = GenInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
ice_idx_list,
retriever,
self.gen_field_replace_token,
max_seq_len=self.max_seq_len,
ice_template=ice_template,
prompt_template=prompt_template)
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
tmp_result_dict = mmengine.load(tmp_json_filepath)
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
# 5. Inference for prompts in each batch
logger.info('Starting ToT inference process...')
for entries in tqdm(dataloader, disable=not self.is_main_process):
# 5-1. Inference with ToT and local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entries, mode='gen')
generated = [self.tot_solve(entry) for entry in entries]
# 5-2. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
index = index + 1
# 5-3. Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return [
sample['prediction']
for sample in output_handler.results_dict.values()
]

View File

@ -34,3 +34,4 @@ METRICS = Registry('metric',
MM_MODELS = Registry('mm_model',
parent=MMENGINE_MODELS,
locations=['opencompass.multimodal.models'])
TOT_WRAPPER = Registry('tot_wrapper', locations=['opencompass.datasets'])