mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support chat style inferencer. (#643)
* [Feature] Support chat style inferencer. * [Fix] use new prompt * [Fix] use new prompt --------- Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
parent
5933c04fda
commit
6aaf3b91ec
55
configs/datasets/gsm8k/gsm8k_agent_gen_3ac57d.py
Normal file
55
configs/datasets/gsm8k/gsm8k_agent_gen_3ac57d.py
Normal file
@ -0,0 +1,55 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import AgentInferencer
|
||||
from opencompass.datasets import (
|
||||
GSM8KDataset,
|
||||
gsm8k_postprocess,
|
||||
gsm8k_dataset_postprocess,
|
||||
Gsm8kAgentEvaluator,
|
||||
)
|
||||
|
||||
gsm8k_reader_cfg = dict(input_columns=["question"], output_column="answer")
|
||||
|
||||
gsm8k_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
# # ################################### NEW SHOT ###################################
|
||||
dict(role='HUMAN', prompt='Mark\'s basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What\'s the total number of points scored by both teams added together?'),
|
||||
dict(role='BOT', prompt='Tool:PythonInterpreter\nTool Input:def solution():\n mark_pointers_2 = 25 * 2\n mark_pointers_3 = 8 * 3\n mark_free_throws = 10 * 1\n mark_points_scored = mark_pointers_2 + mark_pointers_3 + mark_free_throws\n opponents_pointers_2 = mark_pointers_2 * 2\n opponents_pointers_3 = mark_pointers_3 / 2\n opponents_free_throws = mark_free_throws / 2\n opponents_points_scored = opponents_pointers_2 + opponents_pointers_3 + opponents_free_throws\n total_points_scored = mark_points_scored + opponents_points_scored\n result = total_points_scored\n return result'),
|
||||
dict(role='SYSTEM', prompt='Response:210'),
|
||||
dict(role='BOT', prompt='Thought: According to the response, I got the answer\nFinalAnswer: 210'),
|
||||
|
||||
dict(role='HUMAN', prompt='Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?'),
|
||||
dict(role='BOT', prompt='Tool:PythonInterpreter\nTool Input:def solution():\n marbles = 60\n num_increased_marbles = marbles * 2 / 5\n num_total_marbles = marbles + num_increased_marbles\n frisbees = marbles / 2\n num_increased_frisbees = frisbees * 2 / 5\n num_total_frisbees = frisbees + num_increased_frisbees\n deck_cards = frisbees - 20\n num_increased_deck_cards = deck_cards * 2 / 5\n num_total_deck_cards = deck_cards + num_increased_deck_cards\n num_total = num_total_marbles + num_total_frisbees + num_total_deck_cards\n result = num_total\n return result'),
|
||||
dict(role='SYSTEM', prompt='Response:140'),
|
||||
dict(role='BOT', prompt='Thought: According to the response, I got the answer\nFinalAnswer: 140'),
|
||||
|
||||
dict(role='HUMAN', prompt='A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?'),
|
||||
dict(role='BOT', prompt="""Tool:PythonInterpreter\nTool Input:def solution():\n num_fruits_per_first_three_basket = 9 + 15 + 14\n num_fruits_first_three_basket = num_fruits_per_first_three_basket * 3\n num_apple_fourth_basket = 9 - 2\n num_orange_fourth_basket = 15 - 2\n num_banana_fourth_basket = 14 - 2\n num_fruits_fourth_basket = num_apple_fourth_basket + num_orange_fourth_basket + num_banana_fourth_basket\n num_fruits_total = num_fruits_first_three_basket + num_fruits_fourth_basket\n result = num_fruits_total\n return result"""),
|
||||
dict(role='SYSTEM', prompt='Response:146'),
|
||||
dict(role='BOT', prompt='Thought: According to the response, I got the answer\nFinalAnswer: 146'),
|
||||
|
||||
dict(role='HUMAN', prompt='{question}'),
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=AgentInferencer),
|
||||
)
|
||||
|
||||
gsm8k_eval_cfg = dict(
|
||||
evaluator=dict(type=Gsm8kAgentEvaluator),
|
||||
pred_postprocessor=dict(type=gsm8k_postprocess),
|
||||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
|
||||
)
|
||||
|
||||
gsm8k_datasets = [
|
||||
dict(
|
||||
abbr='gsm8k',
|
||||
type=GSM8KDataset,
|
||||
path='./data/gsm8k',
|
||||
reader_cfg=gsm8k_reader_cfg,
|
||||
infer_cfg=gsm8k_infer_cfg,
|
||||
eval_cfg=gsm8k_eval_cfg,
|
||||
)
|
||||
]
|
55
configs/eval_chat_agent.py
Normal file
55
configs/eval_chat_agent.py
Normal file
@ -0,0 +1,55 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models.openai_api import OpenAI
|
||||
from opencompass.partitioners import SizePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
from opencompass.models.lagent import LagentAgent
|
||||
from lagent import PythonInterpreter, ReAct
|
||||
from lagent.agents.react import ReActProtocol
|
||||
|
||||
with read_base():
|
||||
from .datasets.gsm8k.gsm8k_agent_gen_3ac57d import gsm8k_datasets as datasets
|
||||
|
||||
system_prompt = """You are a helpful assistant which use tools to solve mathematical reasoning questions. The code must be a function, and the function name must be 'solution'. For mathematics, please use code tool to calculate. The example format is as follows:
|
||||
```
|
||||
def solution():
|
||||
variable_names_with_real_meaning = func(variable)
|
||||
return variable_names_with_real_meaning
|
||||
```"""
|
||||
|
||||
protocol = dict(
|
||||
type=ReActProtocol,
|
||||
action=dict(role="ACTION", begin="Tool:", end="\n"),
|
||||
action_input=dict(role="ARGS", begin="Tool Input:", end="\n"),
|
||||
finish=dict(role="FINISH", begin="FinalAnswer:", end="\n"),
|
||||
call_protocol=system_prompt,
|
||||
)
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='gpt-3.5-react',
|
||||
type=LagentAgent,
|
||||
agent_type=ReAct,
|
||||
max_turn=3,
|
||||
llm=dict(
|
||||
type=OpenAI,
|
||||
path='gpt-3.5-turbo',
|
||||
key='ENV',
|
||||
query_per_second=1,
|
||||
max_seq_len=4096,
|
||||
),
|
||||
actions=[
|
||||
dict(type=PythonInterpreter),
|
||||
],
|
||||
protocol=protocol,
|
||||
batch_size=1,
|
||||
),
|
||||
]
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=SizePartitioner, max_task_size=1000),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=16,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
82
configs/eval_chat_cibench.py
Normal file
82
configs/eval_chat_cibench.py
Normal file
@ -0,0 +1,82 @@
|
||||
from lagent.agents.react import ReActProtocol
|
||||
from mmengine.config import read_base
|
||||
|
||||
from opencompass.lagent.actions.ipython_interpreter import IPythonInterpreter
|
||||
from opencompass.lagent.agents.react import CIReAct
|
||||
from opencompass.models.lagent import CodeAgent
|
||||
from opencompass.models.openai_api import OpenAI
|
||||
from opencompass.partitioners import SizePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .datasets.CIBench.CIBench_gen_eb42f9 import \
|
||||
cibench_datasets as datasets
|
||||
|
||||
FORCE_STOP_PROMPT_EN = """You should directly give results based on history information."""
|
||||
|
||||
FEWSHOT_INSTRUCTION = """\
|
||||
You are an assistant who can utilize external tools.
|
||||
{tool_description}
|
||||
To use a tool, please response with the following format:
|
||||
```
|
||||
{thought} Think what you need to solve, do you need to use tools?
|
||||
{action} The tool name, should be one of [{action_names}].
|
||||
{action_input} The input to the tool that you want to use.
|
||||
```
|
||||
The tool will give you response after your response using the following format:
|
||||
```
|
||||
{response} the results after call the tool.
|
||||
```
|
||||
Therefore DO NOT generate tool response by yourself.
|
||||
|
||||
Also please follow the guidelines:
|
||||
1. Always use code interpreter to solve the problem.
|
||||
2. The generated codes should always in a markdown code block format.
|
||||
3. The generated codes will be executed in an ipython manner and the results will be cached.
|
||||
4. Your responded code should always be simple and only solves the problem in current step.
|
||||
|
||||
Begin!
|
||||
"""
|
||||
|
||||
IPYTHON_INTERPRETER_DESCRIPTION = '''\
|
||||
It can run Python code in a manner as jupyter notebook. The code must be a valid code that contains only python method.'''
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='gpt-3.5-code',
|
||||
type=CodeAgent,
|
||||
agent_type=CIReAct,
|
||||
max_turn=3,
|
||||
llm=dict(
|
||||
type=OpenAI,
|
||||
path='gpt-3.5-turbo',
|
||||
key='ENV',
|
||||
query_per_second=1,
|
||||
max_seq_len=4096,
|
||||
),
|
||||
actions=[
|
||||
dict(type=IPythonInterpreter,
|
||||
description=IPYTHON_INTERPRETER_DESCRIPTION)
|
||||
],
|
||||
protocol=dict(
|
||||
type=ReActProtocol,
|
||||
call_protocol=FEWSHOT_INSTRUCTION,
|
||||
force_stop=FORCE_STOP_PROMPT_EN,
|
||||
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
|
||||
),
|
||||
batch_size=1,
|
||||
),
|
||||
]
|
||||
|
||||
for dataset in datasets:
|
||||
# Evaluate on every assistant response
|
||||
dataset['infer_cfg']['inferencer']['infer_mode'] = 'every'
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=SizePartitioner, max_task_size=1000),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=16,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
35
configs/eval_chat_last.py
Normal file
35
configs/eval_chat_last.py
Normal file
@ -0,0 +1,35 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
from opencompass.models.openai_api import OpenAI
|
||||
from opencompass.openicl import ChatInferencer
|
||||
from opencompass.partitioners import SizePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets as datasets
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='gpt-3.5',
|
||||
type=OpenAI,
|
||||
path='gpt-3.5-turbo',
|
||||
key='ENV',
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=16,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
)
|
||||
]
|
||||
|
||||
for dataset in datasets:
|
||||
# Use ChatInferencer instead of GenInferencer
|
||||
dataset['infer_cfg']['inferencer'] = dict(type=ChatInferencer)
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=SizePartitioner, max_task_size=1000),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=16,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
@ -1,148 +0,0 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.partitioners import SizePartitioner
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
from opencompass.openicl import AgentInferencer
|
||||
|
||||
with read_base():
|
||||
from .summarizers.medium import summarizer
|
||||
from .datasets.gsm8k.gsm8k_gen import gsm8k_datasets as datasets
|
||||
|
||||
from opencompass.models.lagent import LagentAgent
|
||||
from lagent.llms import GPTAPI
|
||||
from lagent.agents.react import ReAct, ReActProtocol
|
||||
from lagent.actions import PythonInterpreter
|
||||
|
||||
FORCE_STOP_PROMPT_EN = """You should directly give results based on history information."""
|
||||
|
||||
FEWSHOT_INSTRUCTION = """\
|
||||
You are a assistant who can utilize external tools.
|
||||
{tool_description}
|
||||
To use a tool, please use the following format:
|
||||
```
|
||||
{thought} Think what you need to solve, do you need to use tools?
|
||||
{action} the tool name, should be one of [{action_names}]
|
||||
{action_input} the input to the action
|
||||
```
|
||||
I will give you response after utilizing tools should using the following format:
|
||||
```
|
||||
{response} the results after call the tool.
|
||||
``
|
||||
If you already know the answer, or you do not need to use tools,
|
||||
please using the following format to reply:
|
||||
```
|
||||
{thought} the thought process to get the final answer
|
||||
{finish} final answer
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
<HUMAN>A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?
|
||||
<ASSISTANT>{thought} We need to calculate the total number of fruits. The total number of fruits in the first three baskets is given, while for the fourth basket, we need to subtract 2 from each fruit category. We can solve this problem using simple arithmetic.
|
||||
{action} PythonInterpreter
|
||||
{action_input}
|
||||
```python
|
||||
def solution():
|
||||
# Fruits in the first three baskets
|
||||
apples_first_three = 9
|
||||
oranges_first_three = 15
|
||||
bananas_first_three = 14
|
||||
|
||||
# Fruits in the fourth basket
|
||||
apples_fourth = apples_first_three - 2
|
||||
oranges_fourth = oranges_first_three - 2
|
||||
bananas_fourth = bananas_first_three - 2
|
||||
|
||||
# Total fruits
|
||||
total_fruits = ((apples_first_three + oranges_first_three + bananas_first_three) * 3 +
|
||||
apples_fourth + oranges_fourth + bananas_fourth)
|
||||
|
||||
return {{"total_fruits": total_fruits}}
|
||||
```
|
||||
<SYSTEM>{response}{{'total_fruits': 146}}
|
||||
<ASSISTANT> {thought} By adding the given numbers of apples, oranges, and bananas in the first three baskets, then subtracting 2 from each category for the fourth basket, we have found the total number of fruits.
|
||||
{finish} 146
|
||||
|
||||
<HUMAN>Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?
|
||||
<ASSISTANT>{thought} This is a problem that requires solving equations. We know the relationship between the number of marbles, frisbees, and deck cards. Bella has twice as many marbles as frisbees, and 20 more frisbees than deck cards. Finally, we are told Bella buys 2/5 times more of each item. This purchasing will increase the number of each type of item.
|
||||
{action} PythonInterpreter
|
||||
{action_input}
|
||||
```python
|
||||
def solution():
|
||||
# Given number of marbles
|
||||
marbles_now = 60
|
||||
|
||||
# Calculate number of frisbees and deck cards now
|
||||
frisbees_now = marbles_now / 2
|
||||
cards_now = frisbees_now - 20
|
||||
|
||||
# Calculate number of each item after buying more
|
||||
marbles_then = marbles_now + (2/5) * marbles_now
|
||||
frisbees_then = frisbees_now + (2/5) * frisbees_now
|
||||
cards_then = cards_now + (2/5)*cards_now
|
||||
|
||||
# Total number of items then
|
||||
total_items = marbles_then + frisbees_then + cards_then
|
||||
|
||||
return {{"total_items": total_items}}
|
||||
```
|
||||
<SYSTEM>{response}{{'total_items': 140.0}}
|
||||
<ASSISTANT>{thought} By establishing the relationships between the numbers of marbles, frisbees, and deck cards that Bella currently has, we can calculate how many of each item she will have after buying 2/5 more of each. Adding these quantities together gives us the total number of items.
|
||||
{finish} 140
|
||||
|
||||
Begin!
|
||||
"""
|
||||
|
||||
PYTHON_INTERPRETER_DESCRIPTION = '''\
|
||||
It can run a Python code. The code must be a valid code that contains only python method, and the method' name must be 'solution' and returns a dict, which key is variable name. The libraries I recommend are sympy and scipy. the format is:
|
||||
```python
|
||||
# import packages
|
||||
import xxx
|
||||
def solution():
|
||||
# initialize some variables
|
||||
variable_names_with_real_meaning = xxx
|
||||
# middle steps
|
||||
mid_variable = func(mid_variable)
|
||||
# final answer
|
||||
final_answer = func(mid_variable)
|
||||
return final_answer
|
||||
```'''
|
||||
|
||||
models = [
|
||||
dict(abbr='gpt-3.5-react',
|
||||
type=LagentAgent,
|
||||
agent_type=ReAct,
|
||||
max_turn=3,
|
||||
llm=dict(
|
||||
type=GPTAPI,
|
||||
model_type='gpt-3.5-turbo',
|
||||
key='ENV',
|
||||
query_per_second=1,
|
||||
max_seq_len=4096,
|
||||
),
|
||||
actions=[
|
||||
dict(type=PythonInterpreter,
|
||||
description=PYTHON_INTERPRETER_DESCRIPTION),
|
||||
],
|
||||
protocol=dict(
|
||||
type=ReActProtocol,
|
||||
call_protocol=FEWSHOT_INSTRUCTION,
|
||||
force_stop=FORCE_STOP_PROMPT_EN,
|
||||
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
|
||||
),
|
||||
batch_size=8),
|
||||
]
|
||||
|
||||
for dataset in datasets:
|
||||
# Use AgentInferencer instead of GenInferencer
|
||||
dataset['infer_cfg']['inferencer'] = dict(type=AgentInferencer)
|
||||
# Use the question as agent input directly.
|
||||
dataset['infer_cfg']['prompt_template']['template'] = "{question}"
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=SizePartitioner, max_task_size=1000),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=16,
|
||||
task=dict(type=OpenICLInferTask)),
|
||||
)
|
@ -43,7 +43,10 @@ def load_experiment(file: str) -> dict:
|
||||
outputs.append(None)
|
||||
return dict(
|
||||
experiment=file,
|
||||
questions=questions,
|
||||
questions=sum(([
|
||||
dict(role='user', content=question),
|
||||
dict(role='assistant', content=output)
|
||||
] for question, output in zip(questions, outputs)), []),
|
||||
references=dict(outputs=outputs, tags=tags, experiment=file),
|
||||
)
|
||||
|
||||
|
@ -72,6 +72,8 @@ class IPythonInterpreter(BaseAction):
|
||||
user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
|
||||
self.user_data_dir = user_data_dir
|
||||
self._initialized = False
|
||||
if not os.path.exists(WORK_DIR):
|
||||
os.mkdir(WORK_DIR)
|
||||
|
||||
@staticmethod
|
||||
def start_kernel():
|
||||
|
@ -1,133 +1,7 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from lagent.actions import ActionExecutor
|
||||
from lagent.agents.base_agent import BaseAgent
|
||||
from lagent.agents.react import ReActProtocol
|
||||
from lagent.llms.base_api import BaseAPIModel
|
||||
from lagent.llms.base_llm import BaseModel
|
||||
from lagent.agents.react import ReAct
|
||||
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
|
||||
|
||||
|
||||
class ReAct(BaseAgent):
|
||||
"""An implementation of ReAct (https://arxiv.org/abs/2210.03629)
|
||||
|
||||
Args:
|
||||
llm (BaseModel or BaseAPIModel): a LLM service which can chat
|
||||
and act as backend.
|
||||
action_executor (ActionExecutor): an action executor to manage
|
||||
all actions and their response.
|
||||
protocol (ReActProtocol): a wrapper to generate prompt and
|
||||
parse the response from LLM / actions.
|
||||
max_turn (int): the maximum number of trails for LLM to generate
|
||||
plans that can be successfully parsed by ReWOO protocol.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
llm: Union[BaseModel, BaseAPIModel],
|
||||
action_executor: ActionExecutor,
|
||||
protocol: ReActProtocol = ReActProtocol(),
|
||||
max_turn: int = 2) -> None:
|
||||
self.max_turn = max_turn
|
||||
super().__init__(llm=llm,
|
||||
action_executor=action_executor,
|
||||
protocol=protocol)
|
||||
|
||||
def reset(self):
|
||||
"""Reset history."""
|
||||
self._session_history = []
|
||||
|
||||
def opencompass_adapter(self, prompt):
|
||||
# adapter for prompt parsing
|
||||
if isinstance(prompt, list):
|
||||
system_prompt = []
|
||||
merged_prompt = []
|
||||
for p in prompt:
|
||||
tmp_p = p.copy()
|
||||
if 'content' in tmp_p:
|
||||
tmp_p['prompt'] = tmp_p.pop('content')
|
||||
if 'role' in tmp_p:
|
||||
if tmp_p['role'] == 'system':
|
||||
# skip system prompt
|
||||
system_prompt.append(tmp_p['prompt'])
|
||||
continue
|
||||
# no system for meta template temperaily
|
||||
if tmp_p['role'] == 'assistant':
|
||||
tmp_p['role'] = 'BOT'
|
||||
if tmp_p['role'] == 'user':
|
||||
# merge previous system prompt to user
|
||||
system_str = ''.join(system_prompt)
|
||||
tmp_p['prompt'] = system_str + tmp_p['prompt']
|
||||
tmp_p['role'] = 'HUMAN'
|
||||
system_prompt = []
|
||||
merged_prompt.append(tmp_p)
|
||||
|
||||
# merge if system still exists
|
||||
if system_prompt:
|
||||
if 'role' in merged_prompt[-1]:
|
||||
if merged_prompt[-1]['role'] == 'HUMAN':
|
||||
# append to the final human prompt
|
||||
merged_prompt[-1]['prompt'] += ''.join(system_prompt)
|
||||
else:
|
||||
# create a human prompt behind
|
||||
merged_prompt.append(
|
||||
dict(role='HUMAN', prompt=''.join(system_prompt)))
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
new_prompt = PromptList()
|
||||
# adapter for meta template
|
||||
new_prompt.append(dict(section='round', pos='begin'))
|
||||
new_prompt.extend(merged_prompt)
|
||||
new_prompt.append(dict(section='round', pos='end'))
|
||||
|
||||
return new_prompt
|
||||
|
||||
def chat(self, message: str) -> AgentReturn:
|
||||
self._inner_history = []
|
||||
self._inner_history.append(dict(role='user', content=message))
|
||||
agent_return = AgentReturn()
|
||||
force_stop = False
|
||||
default_response = '对不起,我无法回答你的问题'
|
||||
for turn in range(self.max_turn):
|
||||
prompt = self._protocol.format(
|
||||
chat_history=self.session_history,
|
||||
inner_step=self._inner_history,
|
||||
action_executor=self._action_executor,
|
||||
force_stop=force_stop)
|
||||
prompt = self.opencompass_adapter(prompt)
|
||||
# allow single generation
|
||||
response = self._llm.generate_from_template([prompt], 512)[0]
|
||||
self._inner_history.append(dict(role='assistant',
|
||||
content=response))
|
||||
thought, action, action_input = self._protocol.parse(
|
||||
response, self._action_executor)
|
||||
|
||||
# TODO: hard code here
|
||||
action_input = re.sub('<eoa>', '', action_input)
|
||||
|
||||
if 'tensorflow' in action_input:
|
||||
# skip tensorflow currently
|
||||
break
|
||||
action_return: ActionReturn = self._action_executor(
|
||||
action, action_input)
|
||||
action_return.thought = thought
|
||||
agent_return.actions.append(action_return)
|
||||
if action_return.type == self._action_executor.finish_action.name:
|
||||
agent_return.response = action_return.result['text']
|
||||
return agent_return
|
||||
self._inner_history.append(
|
||||
dict(role='system',
|
||||
content=self._protocol.format_response(action_return)))
|
||||
if turn == self.max_turn - 1:
|
||||
force_stop = True
|
||||
agent_return.response = default_response
|
||||
# only append the user and final response
|
||||
self._session_history.append(dict(role='user', content=message))
|
||||
self._session_history.append(
|
||||
dict(role='assistant', content=agent_return.response))
|
||||
return agent_return
|
||||
|
||||
|
||||
class CIReAct(ReAct):
|
||||
"""Code Interpreter version of ReAct. The success state is different from
|
||||
ReAct.
|
||||
@ -165,9 +39,7 @@ class CIReAct(ReAct):
|
||||
inner_step=self._inner_history,
|
||||
action_executor=self._action_executor,
|
||||
force_stop=force_stop)
|
||||
prompt = self.opencompass_adapter(prompt)
|
||||
# allow single generation
|
||||
response = self._llm.generate_from_template([prompt], 512)[0]
|
||||
response = self._llm.generate_from_template(prompt, 512)
|
||||
self._inner_history.append(dict(role='assistant',
|
||||
content=response))
|
||||
thought, action, action_input = self._protocol.parse(
|
||||
@ -179,7 +51,7 @@ class CIReAct(ReAct):
|
||||
if action_return.state == ActionStatusCode.SUCCESS:
|
||||
# if success, stash model response and system response
|
||||
self._session_history.append(
|
||||
dict(role='assistant', content=action_return.args['text']))
|
||||
dict(role='assistant', content=response))
|
||||
self._session_history.append(
|
||||
dict(
|
||||
role='system',
|
||||
|
@ -1,10 +1,8 @@
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
|
||||
from mmengine.registry import Registry
|
||||
|
||||
from opencompass.lagent.agents.react import ReAct
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
REGISTRY = Registry('helper')
|
||||
|
||||
|
||||
@ -13,45 +11,55 @@ class LagentAgent:
|
||||
|
||||
https://github.com/InternLM/lagent.
|
||||
"""
|
||||
is_api = True
|
||||
|
||||
def __init__(self,
|
||||
agent_type,
|
||||
llm,
|
||||
actions=None,
|
||||
protocol=None,
|
||||
mutli_rounds=False,
|
||||
**kwargs):
|
||||
def __init__(self, agent_type, llm, actions=None, protocol=None, **kwargs):
|
||||
llm = REGISTRY.build(llm)
|
||||
agent_cfg = {'type': agent_type, 'llm': llm, **kwargs}
|
||||
|
||||
if actions is not None:
|
||||
from lagent.actions import ActionExecutor
|
||||
|
||||
executor = ActionExecutor(
|
||||
[REGISTRY.build(action) for action in actions])
|
||||
executor = ActionExecutor([])
|
||||
for action in actions:
|
||||
action = REGISTRY.build(action)
|
||||
if 'agentlego' in type(action).__module__:
|
||||
action = action.to_lagent()
|
||||
executor.add_action(action)
|
||||
agent_cfg['action_executor'] = executor
|
||||
if protocol is not None:
|
||||
protocol = REGISTRY.build(protocol)
|
||||
agent_cfg['protocol'] = protocol
|
||||
|
||||
self.agent = REGISTRY.build(agent_cfg)
|
||||
self.mutli_rounds = mutli_rounds
|
||||
from lagent import BaseAgent
|
||||
self.agent: BaseAgent = REGISTRY.build(agent_cfg)
|
||||
|
||||
def add_example(self, example):
|
||||
# format example in protocol if needed
|
||||
call_protocol = self.agent._protocol.call_protocol
|
||||
if '{example}' in call_protocol:
|
||||
self.agent._protocol.call_protocol = call_protocol.format(
|
||||
example=example)
|
||||
else:
|
||||
get_logger().warning('Protocal template does not have example'
|
||||
' placeholder, please check your template.')
|
||||
def reset(self):
|
||||
self.agent._session_history = []
|
||||
for action in self.agent._action_executor.actions:
|
||||
if hasattr(action, 'reset'):
|
||||
action.reset()
|
||||
|
||||
def set_history(self, history):
|
||||
self.agent._session_history = deepcopy(history)
|
||||
|
||||
@property
|
||||
def template_parser(self):
|
||||
return self.agent._llm.template_parser
|
||||
|
||||
@template_parser.setter
|
||||
def template_parser(self, value):
|
||||
self.agent._llm.template_parser = value
|
||||
|
||||
def chat(self,
|
||||
user_input: str,
|
||||
history: List[dict] = None) -> Tuple[str, List[dict]]:
|
||||
"""Chat with agent."""
|
||||
if history:
|
||||
self.agent._session_history = history
|
||||
|
||||
def one_round_chat(self, user_input, ice=None) -> Tuple[str, List[dict]]:
|
||||
"""One round chat with agent."""
|
||||
from lagent.schema import ActionReturn, AgentReturn
|
||||
|
||||
generation: AgentReturn = self.agent.chat(user_input)
|
||||
|
||||
answer = generation.response
|
||||
steps = []
|
||||
|
||||
@ -67,19 +75,7 @@ class LagentAgent:
|
||||
errmsg=step.errmsg,
|
||||
valid=int(step.valid),
|
||||
))
|
||||
return answer, steps
|
||||
|
||||
def chat(self, user_input, ice=None) -> Tuple[str, List[dict]]:
|
||||
"""Chat with agent."""
|
||||
if self.mutli_rounds:
|
||||
steps = []
|
||||
for single_input in user_input:
|
||||
answer, one_round_steps = self.one_round_chat(single_input)
|
||||
steps.append(one_round_steps)
|
||||
else:
|
||||
answer, steps = self.one_round_chat(user_input)
|
||||
|
||||
self.agent.reset() # clear agent history
|
||||
return answer, steps
|
||||
|
||||
|
||||
@ -88,25 +84,25 @@ FORCE_STOP_PROMPT_EN = (
|
||||
)
|
||||
|
||||
FEWSHOT_INSTRUCTION = """\
|
||||
You are an assistant who can utilize external tools.
|
||||
{{tool_description}}
|
||||
To use a tool, please use the following format:
|
||||
You are a assistant who can utilize external tools.
|
||||
{tool_description}
|
||||
To use a tool, please response with the following format:
|
||||
```
|
||||
{{thought}} Think what you need to solve, do you need to use tools?
|
||||
{{action}} the tool name, should be one of [{{action_names}}]
|
||||
{{action_input}} the input to the action
|
||||
{thought} Think what you need to solve, do you need to use tools?
|
||||
{action} The tool name, should be one of [{action_names}].
|
||||
{action_input} The input to the tool that you want to use.
|
||||
```
|
||||
I will give you response after utilizing tools should using the following format:
|
||||
The tool will give you response after your response using the following format:
|
||||
```
|
||||
{{response}} the results after call the tool.
|
||||
``
|
||||
If you already know the answer, or you do not need to use tools,
|
||||
please using the following format to reply:
|
||||
{response} the results after call the tool.
|
||||
```
|
||||
{{thought}} the thought process to get the final answer
|
||||
{{finish}} final answer
|
||||
```
|
||||
{example}
|
||||
Therefore DO NOT generate tool response by yourself.
|
||||
|
||||
Also please follow the guidelines:
|
||||
1. Always use code interpreter to solve the problem.
|
||||
2. The generated codes should always in a markdown code block format.
|
||||
3. The generated codes will be executed in an ipython manner and the results will be cached.
|
||||
4. Your responded code should always be simple and only solves the problem in current step.
|
||||
|
||||
Begin!
|
||||
""" # noqa
|
||||
@ -127,16 +123,13 @@ def solution():
|
||||
```""" # noqa
|
||||
|
||||
|
||||
class CodeAgent:
|
||||
class CodeAgent(LagentAgent):
|
||||
"""Code Agent wrapper for Lagent."""
|
||||
|
||||
def __new__(self, llm, **kwargs):
|
||||
def __init__(self, llm, **kwargs):
|
||||
from lagent import PythonInterpreter, ReAct
|
||||
from lagent.agents.react import ReActProtocol
|
||||
|
||||
from opencompass.lagent.actions.python_interpreter import \
|
||||
PythonInterpreter
|
||||
|
||||
mutli_rounds = kwargs.pop('mutli_rounds', False)
|
||||
agent_type = kwargs.pop('agent_type', ReAct)
|
||||
max_turn = kwargs.pop('max_turn', 3)
|
||||
actions = kwargs.pop(
|
||||
@ -155,10 +148,9 @@ class CodeAgent:
|
||||
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
|
||||
),
|
||||
)
|
||||
return LagentAgent(agent_type=agent_type,
|
||||
llm=llm,
|
||||
max_turn=max_turn,
|
||||
actions=actions,
|
||||
protocol=protocol,
|
||||
mutli_rounds=mutli_rounds,
|
||||
**kwargs)
|
||||
super().__init__(agent_type=agent_type,
|
||||
llm=llm,
|
||||
actions=actions,
|
||||
protocol=protocol,
|
||||
max_turn=max_turn,
|
||||
**kwargs)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from .icl_agent_inferencer import AgentInferencer # noqa
|
||||
from .icl_attack_inferencer import AttackInferencer # noqa
|
||||
from .icl_base_inferencer import BaseInferencer # noqa
|
||||
from .icl_chat_inferencer import ChatInferencer # noqa
|
||||
from .icl_clp_inferencer import CLPInferencer # noqa
|
||||
from .icl_gen_inferencer import GenInferencer # noqa
|
||||
from .icl_ppl_inferencer import PPLInferencer # noqa
|
||||
|
@ -1,124 +1,16 @@
|
||||
"""Agent Inferencer."""
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine
|
||||
from mmengine.registry import Registry
|
||||
from tqdm import tqdm
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
from opencompass.models.lagent import LagentAgent
|
||||
from opencompass.registry import ICL_INFERENCERS
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils.logging import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, dump_results_dict
|
||||
from .icl_base_inferencer import dump_results_dict
|
||||
from .icl_chat_inferencer import ChatInferencer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
REGISTRY = Registry('helper')
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class AgentInferencer(BaseInferencer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = 1,
|
||||
example: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
**kwargs,
|
||||
)
|
||||
self.save_every = save_every
|
||||
# example in agent usage for protocol illustration
|
||||
self.example = example
|
||||
if example:
|
||||
self.agent.add_example(example)
|
||||
|
||||
@property
|
||||
def agent(self):
|
||||
return self.model
|
||||
|
||||
def inference(self,
|
||||
retriever: BaseRetriever,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
output_json_filepath: Optional[str] = None,
|
||||
output_json_filename: Optional[str] = None) -> List:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = AgentInferencerOutputHandler()
|
||||
|
||||
if output_json_filepath is None:
|
||||
output_json_filepath = self.output_json_filepath
|
||||
if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
start = 0
|
||||
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
# TODO: move resume to output handler
|
||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||
output_handler.results_dict = tmp_result_dict
|
||||
start = len(tmp_result_dict)
|
||||
|
||||
# 3. Inference sample by sample
|
||||
logger.info('Starting inference process...')
|
||||
for idx, ice_indices in tqdm(enumerate(ice_idx_list[start:], start),
|
||||
disable=not self.is_main_process):
|
||||
# TODO: This will break the Prompt template
|
||||
# get user input directly without formatting prompt
|
||||
#
|
||||
# user_input = retriever.generate_prompt_for_generate_task(
|
||||
# idx, ice='', prompt_template=prompt_template)
|
||||
user_input = retriever.dataset_reader.dataset['test'][
|
||||
retriever.dataset_reader.input_columns[0]][idx]
|
||||
gold = retriever.dataset_reader.dataset['test'][
|
||||
retriever.dataset_reader.output_column][idx]
|
||||
|
||||
if len(ice_indices) > 0:
|
||||
assert ice_template is not None
|
||||
ice = [
|
||||
ice_template.generate_ice_item(ice_idx)
|
||||
for ice_idx in ice_indices
|
||||
]
|
||||
else:
|
||||
ice = None
|
||||
|
||||
answer, steps = self.agent.chat(user_input=user_input, ice=ice)
|
||||
|
||||
# Save current output
|
||||
output_handler.save_results(user_input, answer, steps, idx, gold)
|
||||
|
||||
# Save intermediate results
|
||||
if (self.save_every is not None and start % self.save_every == 0
|
||||
and self.is_main_process):
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
|
||||
# 4. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
os.remove(tmp_json_filepath)
|
||||
|
||||
return [
|
||||
sample['prediction']
|
||||
for sample in output_handler.results_dict.values()
|
||||
]
|
||||
|
||||
|
||||
class AgentInferencerOutputHandler:
|
||||
@ -130,10 +22,115 @@ class AgentInferencerOutputHandler:
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
|
||||
|
||||
def save_results(self, user_input, answer, steps, idx, gold):
|
||||
self.results_dict[str(idx)] = {
|
||||
'origin_prompt': user_input,
|
||||
'prediction': answer,
|
||||
def save_results(self,
|
||||
origin_prompt: list,
|
||||
prediction: str,
|
||||
steps: list,
|
||||
idx: int,
|
||||
gold: str = None):
|
||||
result_dict = {}
|
||||
if gold:
|
||||
result_dict['gold'] = gold
|
||||
result_dict.update({
|
||||
'prediction': prediction,
|
||||
'origin_prompt': origin_prompt,
|
||||
'steps': steps,
|
||||
'gold': gold,
|
||||
}
|
||||
})
|
||||
self.results_dict[str(idx)] = result_dict
|
||||
|
||||
def save_multiround_results(self,
|
||||
origin_prompt: list,
|
||||
prediction: str,
|
||||
steps: list,
|
||||
idx: int,
|
||||
gold: str = None):
|
||||
result_dict = self.results_dict.get(str(idx), {
|
||||
'gold': [],
|
||||
'prediction': [],
|
||||
'origin_prompt': [],
|
||||
'steps': [],
|
||||
})
|
||||
result_dict['gold'].append(gold)
|
||||
result_dict['prediction'].append(prediction)
|
||||
result_dict['origin_prompt'].append(origin_prompt)
|
||||
result_dict['steps'].append(steps)
|
||||
self.results_dict[str(idx)] = result_dict
|
||||
|
||||
|
||||
def model_adapter(model):
|
||||
"""Modify the generate method to accept and return single item."""
|
||||
if getattr(model, '_generate_is_wrapped', False):
|
||||
# Avoid wrap twice.
|
||||
return model
|
||||
|
||||
origin_generate = model.generate
|
||||
|
||||
def generate(self, inputs, *args, **kwargs):
|
||||
return origin_generate([inputs], *args, **kwargs)[0]
|
||||
|
||||
model.generate = types.MethodType(generate, model)
|
||||
setattr(model, '_generate_is_wrapped', True)
|
||||
return model
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class AgentInferencer(ChatInferencer):
|
||||
HandlerType = AgentInferencerOutputHandler
|
||||
|
||||
def __init__(self, model, **kwargs) -> None:
|
||||
model.agent._llm = model_adapter(model.agent._llm)
|
||||
super().__init__(model, **kwargs)
|
||||
self.model: LagentAgent
|
||||
|
||||
def infer_last(self, chat: List[dict], index: int, output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
user_idx = assistant_indices[-1] - 1
|
||||
self.model.set_history(chat[:user_idx])
|
||||
answer, steps = self.model.chat(chat[user_idx]['content'])
|
||||
output_handler.save_results(
|
||||
origin_prompt=chat[user_idx]['content'],
|
||||
prediction=answer,
|
||||
steps=steps,
|
||||
idx=index,
|
||||
gold=chat[assistant_indices[-1]]['content'],
|
||||
)
|
||||
self.model.reset()
|
||||
|
||||
def infer_every(self, chat: List[dict], index: int, output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
self.model.set_history(chat[:assistant_indices[0] - 1])
|
||||
|
||||
for i in assistant_indices:
|
||||
answer, steps = self.model.chat(chat[i - 1]['content'])
|
||||
output_handler.save_multiround_results(
|
||||
origin_prompt=chat[i - 1]['content'],
|
||||
prediction=answer,
|
||||
steps=steps,
|
||||
idx=index,
|
||||
gold=chat[i]['content'],
|
||||
)
|
||||
self.model.reset()
|
||||
|
||||
def infer_every_with_gt(self, chat: List[dict], index: int,
|
||||
output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
for i in assistant_indices:
|
||||
self.model.set_history(chat[:i - 1])
|
||||
answer, steps = self.model.chat(chat[i - 1]['content'])
|
||||
output_handler.save_multiround_results(
|
||||
origin_prompt=chat[i - 1]['content'],
|
||||
prediction=answer,
|
||||
steps=steps,
|
||||
idx=index,
|
||||
gold=chat[i]['content'],
|
||||
)
|
||||
self.model.reset()
|
||||
|
368
opencompass/openicl/icl_inferencer/icl_chat_inferencer.py
Normal file
368
opencompass/openicl/icl_inferencer/icl_chat_inferencer.py
Normal file
@ -0,0 +1,368 @@
|
||||
"""Chat Inferencer."""
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mmengine
|
||||
from mmengine import is_list_of
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.models import APITemplateParser as _APITemplateParser
|
||||
from opencompass.models import BaseModel
|
||||
from opencompass.models import LMTemplateParser as _LMTemplateParser
|
||||
from opencompass.registry import ICL_INFERENCERS
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils.logging import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, dump_results_dict
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def promptlist_to_openai(prompt: Union[str, PromptList]):
|
||||
output = []
|
||||
if isinstance(prompt, str):
|
||||
return [dict(role='user', content=prompt)]
|
||||
|
||||
for item in prompt:
|
||||
if 'section' in item:
|
||||
continue
|
||||
if isinstance(item, str) and item:
|
||||
output.append(dict(role='user', content=item))
|
||||
elif item['role'] == 'SYSTEM':
|
||||
output.append(dict(role='system', content=item['prompt']))
|
||||
elif item['role'] == 'HUMAN':
|
||||
output.append(dict(role='user', content=item['prompt']))
|
||||
elif item['role'] == 'BOT':
|
||||
output.append(dict(role='assistant', content=item['prompt']))
|
||||
return output
|
||||
|
||||
|
||||
class LMTemplateParser:
|
||||
"""LMTemplateParser accepts OpenAI format dialog inputs."""
|
||||
|
||||
def __init__(self, meta_template: Optional[dict] = None):
|
||||
self.meta_template = meta_template
|
||||
self.roles = {}
|
||||
role_mapping = {
|
||||
'SYSTEM': 'system',
|
||||
'HUMAN': 'user',
|
||||
'BOT': 'assistant',
|
||||
}
|
||||
if meta_template:
|
||||
for item in meta_template.get('round', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
for item in meta_template.get('reserved_roles', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
|
||||
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
||||
if is_list_of(chat, list):
|
||||
# Handle batch inputs
|
||||
return [self.parse_template(item) for item in chat]
|
||||
|
||||
assert is_list_of(chat, dict)
|
||||
prompt = ''
|
||||
if self.roles:
|
||||
for dialog in chat:
|
||||
role_cfg = self.roles.get(dialog['role'])
|
||||
prompt += role_cfg['begin']
|
||||
prompt += (dialog.get('content') or '')
|
||||
prompt += role_cfg['end']
|
||||
prompt += self.roles['assistant']['begin']
|
||||
else:
|
||||
# in case the model does not have any meta template
|
||||
last_sep = ''
|
||||
for item in chat:
|
||||
prompt += last_sep + (item.get('content') or '')
|
||||
last_sep = '\n'
|
||||
return prompt
|
||||
|
||||
|
||||
class APITemplateParser:
|
||||
"""APITemplateParser accepts OpenAI format dialog inputs."""
|
||||
|
||||
def __init__(self, meta_template: Optional[dict] = None):
|
||||
self.meta_template = meta_template
|
||||
self.roles = {}
|
||||
role_mapping = {
|
||||
'SYSTEM': 'system',
|
||||
'HUMAN': 'user',
|
||||
'BOT': 'assistant',
|
||||
}
|
||||
if meta_template:
|
||||
for item in meta_template.get('round', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
for item in meta_template.get('reserved_roles', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
else:
|
||||
self.roles = dict(
|
||||
system=dict(api_role='SYSTEM'),
|
||||
user=dict(api_role='HUMAN'),
|
||||
assistant=dict(api_role='BOT', generate=True),
|
||||
)
|
||||
|
||||
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
||||
if is_list_of(chat, list):
|
||||
# Handle batch inputs
|
||||
return [self.parse_template(item) for item in chat]
|
||||
|
||||
assert is_list_of(chat, dict)
|
||||
prompt = []
|
||||
for dialog in chat:
|
||||
if dialog['role'] in self.roles:
|
||||
role = self.roles[dialog['role']]['api_role']
|
||||
else:
|
||||
role = dialog['role']
|
||||
prompt.append(dict(role=role, prompt=dialog.get('content') or ''))
|
||||
return PromptList(prompt)
|
||||
|
||||
|
||||
class ChatOutputHandler:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.results_dict = {}
|
||||
|
||||
def write_to_json(self, save_dir: str, filename: str):
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
|
||||
|
||||
def save_results(self,
|
||||
origin_prompt: list,
|
||||
prediction: str,
|
||||
idx: int,
|
||||
gold: str = None):
|
||||
result_dict = {}
|
||||
if gold:
|
||||
result_dict['gold'] = gold
|
||||
result_dict.update({
|
||||
'prediction': prediction,
|
||||
'origin_prompt': origin_prompt,
|
||||
})
|
||||
self.results_dict[str(idx)] = result_dict
|
||||
|
||||
def save_multiround_results(self,
|
||||
origin_prompt: list,
|
||||
prediction: str,
|
||||
idx: int,
|
||||
gold: str = None):
|
||||
result_dict = self.results_dict.get(str(idx), {
|
||||
'gold': [],
|
||||
'prediction': [],
|
||||
'origin_prompt': [],
|
||||
})
|
||||
result_dict['gold'].append(gold)
|
||||
result_dict['prediction'].append(prediction)
|
||||
result_dict['origin_prompt'].append(origin_prompt)
|
||||
self.results_dict[str(idx)] = result_dict
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class ChatInferencer(BaseInferencer):
|
||||
HandlerType = ChatOutputHandler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = 1,
|
||||
infer_mode: str = 'last',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
**kwargs,
|
||||
)
|
||||
assert infer_mode in ['last', 'every', 'every_with_gt']
|
||||
self.infer_mode = infer_mode
|
||||
self.model: BaseModel
|
||||
self._set_meta_template(self.model)
|
||||
|
||||
if self.model.is_api and save_every is None:
|
||||
save_every = 1
|
||||
self.save_every = save_every
|
||||
|
||||
def _set_meta_template(self, model):
|
||||
origin = model.template_parser
|
||||
if isinstance(origin, _APITemplateParser):
|
||||
model.template_parser = APITemplateParser(origin.meta_template)
|
||||
if isinstance(origin, _LMTemplateParser):
|
||||
model.template_parser = LMTemplateParser(origin.meta_template)
|
||||
|
||||
def inference(self,
|
||||
retriever: BaseRetriever,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
output_json_filepath: Optional[str] = None,
|
||||
output_json_filename: Optional[str] = None) -> dict:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = self.HandlerType()
|
||||
|
||||
if output_json_filepath is None:
|
||||
output_json_filepath = self.output_json_filepath
|
||||
if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
chat_list = self.get_chat_list(
|
||||
ice_idx_list,
|
||||
retriever,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
# TODO: move resume to output handler
|
||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||
output_handler.results_dict = tmp_result_dict
|
||||
index = len(tmp_result_dict)
|
||||
|
||||
# 4. Wrap prompts with Dataloader
|
||||
dataloader = self.get_dataloader(chat_list[index:], batch_size=1)
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.info('Starting inference process...')
|
||||
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||
chat = datum[0]
|
||||
|
||||
if self.infer_mode == 'last':
|
||||
self.infer_last(chat, index, output_handler)
|
||||
elif self.infer_mode == 'every':
|
||||
self.infer_every(chat, index, output_handler)
|
||||
elif self.infer_mode == 'every_with_gt':
|
||||
self.infer_every_with_gt(chat, index, output_handler)
|
||||
index += 1
|
||||
|
||||
# Save intermediate results
|
||||
if (self.save_every is not None and index % self.save_every == 0
|
||||
and self.is_main_process):
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
|
||||
# 4. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
os.remove(tmp_json_filepath)
|
||||
|
||||
return output_handler.results_dict
|
||||
|
||||
def get_chat_list(self,
|
||||
ice_idx_list: List[List[int]],
|
||||
retriever: BaseRetriever,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
prompt_list = []
|
||||
input_columns = retriever.dataset_reader.input_columns
|
||||
output_column = retriever.dataset_reader.output_column
|
||||
|
||||
def chat_from_entry(entry):
|
||||
if prompt_template is None and len(input_columns) == 1:
|
||||
# Directly use the input column as the user input
|
||||
user = entry.get(input_columns[0])
|
||||
assistant = entry.get(output_column, '')
|
||||
return [
|
||||
dict(role='user', content=user),
|
||||
dict(role='assistant', content=assistant),
|
||||
]
|
||||
elif prompt_template is not None:
|
||||
# Use prompt template to generate chat history
|
||||
chat = promptlist_to_openai(
|
||||
prompt_template.generate_item(entry))
|
||||
gold = entry.get(output_column, '')
|
||||
if chat[-1]['role'] != 'assistant':
|
||||
chat.append(dict(role='assistant', content=gold))
|
||||
return chat
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
for idx, ice_idx in enumerate(ice_idx_list):
|
||||
# NOTE: The in-context examples won't be used by now.
|
||||
|
||||
item = {
|
||||
k: v
|
||||
for k, v in retriever.test_ds[idx].items()
|
||||
if k in input_columns or k == output_column
|
||||
}
|
||||
if all(isinstance(value, str) for value in item.values()):
|
||||
# Every column is a single string
|
||||
chat = chat_from_entry(item)
|
||||
elif all(is_list_of(value, str) for value in item.values()):
|
||||
# Every column is a list of string for multi-round chat
|
||||
entries = [dict(zip(item, v)) for v in zip(*item.values())]
|
||||
chat = sum((chat_from_entry(entry) for entry in entries), [])
|
||||
elif len(input_columns) == 1 and is_list_of(
|
||||
item[input_columns[0]], dict):
|
||||
# Single input column and it's already a chat.
|
||||
chat = item[input_columns[0]]
|
||||
else:
|
||||
raise ValueError('Cannot construct chat from the dataset.')
|
||||
|
||||
prompt_list.append(chat)
|
||||
return prompt_list
|
||||
|
||||
def infer_last(self, chat: List[dict], index: int, output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
history = chat[:assistant_indices[-1]]
|
||||
output = self.model.generate_from_template([history],
|
||||
max_out_len=512)[0]
|
||||
output_handler.save_results(
|
||||
origin_prompt=history,
|
||||
prediction=output,
|
||||
idx=index,
|
||||
gold=chat[assistant_indices[-1]]['content'],
|
||||
)
|
||||
|
||||
def infer_every(self, chat: List[dict], index: int, output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
for i in assistant_indices:
|
||||
history = chat[:i]
|
||||
output = self.model.generate_from_template([history],
|
||||
max_out_len=512)[0]
|
||||
output_handler.save_multiround_results(
|
||||
origin_prompt=history[-1]['content'],
|
||||
prediction=output,
|
||||
idx=index,
|
||||
gold=chat[i]['content'],
|
||||
)
|
||||
chat[i]['content'] = output
|
||||
index += 1
|
||||
|
||||
def infer_every_with_gt(self, chat: List[dict], index: int,
|
||||
output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
for i in assistant_indices:
|
||||
history = chat[:i]
|
||||
output = self.model.generate_from_template([history],
|
||||
max_out_len=512)[0]
|
||||
output_handler.save_multiround_results(
|
||||
origin_prompt=history[-1]['content'],
|
||||
prediction=output,
|
||||
idx=index,
|
||||
gold=chat[i]['content'],
|
||||
)
|
||||
index += 1
|
Loading…
Reference in New Issue
Block a user