[Feature] Support chat style inferencer. (#643)

* [Feature] Support chat style inferencer.

* [Fix] use new prompt

* [Fix] use new prompt

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
Ma Zerun 2023-11-30 14:00:06 +08:00 committed by GitHub
parent 5933c04fda
commit 6aaf3b91ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 780 additions and 466 deletions

View File

@ -0,0 +1,55 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import AgentInferencer
from opencompass.datasets import (
GSM8KDataset,
gsm8k_postprocess,
gsm8k_dataset_postprocess,
Gsm8kAgentEvaluator,
)
gsm8k_reader_cfg = dict(input_columns=["question"], output_column="answer")
gsm8k_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
# # ################################### NEW SHOT ###################################
dict(role='HUMAN', prompt='Mark\'s basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What\'s the total number of points scored by both teams added together?'),
dict(role='BOT', prompt='Tool:PythonInterpreter\nTool Input:def solution():\n mark_pointers_2 = 25 * 2\n mark_pointers_3 = 8 * 3\n mark_free_throws = 10 * 1\n mark_points_scored = mark_pointers_2 + mark_pointers_3 + mark_free_throws\n opponents_pointers_2 = mark_pointers_2 * 2\n opponents_pointers_3 = mark_pointers_3 / 2\n opponents_free_throws = mark_free_throws / 2\n opponents_points_scored = opponents_pointers_2 + opponents_pointers_3 + opponents_free_throws\n total_points_scored = mark_points_scored + opponents_points_scored\n result = total_points_scored\n return result'),
dict(role='SYSTEM', prompt='Response:210'),
dict(role='BOT', prompt='Thought: According to the response, I got the answer\nFinalAnswer: 210'),
dict(role='HUMAN', prompt='Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?'),
dict(role='BOT', prompt='Tool:PythonInterpreter\nTool Input:def solution():\n marbles = 60\n num_increased_marbles = marbles * 2 / 5\n num_total_marbles = marbles + num_increased_marbles\n frisbees = marbles / 2\n num_increased_frisbees = frisbees * 2 / 5\n num_total_frisbees = frisbees + num_increased_frisbees\n deck_cards = frisbees - 20\n num_increased_deck_cards = deck_cards * 2 / 5\n num_total_deck_cards = deck_cards + num_increased_deck_cards\n num_total = num_total_marbles + num_total_frisbees + num_total_deck_cards\n result = num_total\n return result'),
dict(role='SYSTEM', prompt='Response:140'),
dict(role='BOT', prompt='Thought: According to the response, I got the answer\nFinalAnswer: 140'),
dict(role='HUMAN', prompt='A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?'),
dict(role='BOT', prompt="""Tool:PythonInterpreter\nTool Input:def solution():\n num_fruits_per_first_three_basket = 9 + 15 + 14\n num_fruits_first_three_basket = num_fruits_per_first_three_basket * 3\n num_apple_fourth_basket = 9 - 2\n num_orange_fourth_basket = 15 - 2\n num_banana_fourth_basket = 14 - 2\n num_fruits_fourth_basket = num_apple_fourth_basket + num_orange_fourth_basket + num_banana_fourth_basket\n num_fruits_total = num_fruits_first_three_basket + num_fruits_fourth_basket\n result = num_fruits_total\n return result"""),
dict(role='SYSTEM', prompt='Response:146'),
dict(role='BOT', prompt='Thought: According to the response, I got the answer\nFinalAnswer: 146'),
dict(role='HUMAN', prompt='{question}'),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=AgentInferencer),
)
gsm8k_eval_cfg = dict(
evaluator=dict(type=Gsm8kAgentEvaluator),
pred_postprocessor=dict(type=gsm8k_postprocess),
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
)
gsm8k_datasets = [
dict(
abbr='gsm8k',
type=GSM8KDataset,
path='./data/gsm8k',
reader_cfg=gsm8k_reader_cfg,
infer_cfg=gsm8k_infer_cfg,
eval_cfg=gsm8k_eval_cfg,
)
]

View File

@ -0,0 +1,55 @@
from mmengine.config import read_base
from opencompass.models.openai_api import OpenAI
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.models.lagent import LagentAgent
from lagent import PythonInterpreter, ReAct
from lagent.agents.react import ReActProtocol
with read_base():
from .datasets.gsm8k.gsm8k_agent_gen_3ac57d import gsm8k_datasets as datasets
system_prompt = """You are a helpful assistant which use tools to solve mathematical reasoning questions. The code must be a function, and the function name must be 'solution'. For mathematics, please use code tool to calculate. The example format is as follows:
```
def solution():
variable_names_with_real_meaning = func(variable)
return variable_names_with_real_meaning
```"""
protocol = dict(
type=ReActProtocol,
action=dict(role="ACTION", begin="Tool:", end="\n"),
action_input=dict(role="ARGS", begin="Tool Input:", end="\n"),
finish=dict(role="FINISH", begin="FinalAnswer:", end="\n"),
call_protocol=system_prompt,
)
models = [
dict(
abbr='gpt-3.5-react',
type=LagentAgent,
agent_type=ReAct,
max_turn=3,
llm=dict(
type=OpenAI,
path='gpt-3.5-turbo',
key='ENV',
query_per_second=1,
max_seq_len=4096,
),
actions=[
dict(type=PythonInterpreter),
],
protocol=protocol,
batch_size=1,
),
]
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)

View File

@ -0,0 +1,82 @@
from lagent.agents.react import ReActProtocol
from mmengine.config import read_base
from opencompass.lagent.actions.ipython_interpreter import IPythonInterpreter
from opencompass.lagent.agents.react import CIReAct
from opencompass.models.lagent import CodeAgent
from opencompass.models.openai_api import OpenAI
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
from .datasets.CIBench.CIBench_gen_eb42f9 import \
cibench_datasets as datasets
FORCE_STOP_PROMPT_EN = """You should directly give results based on history information."""
FEWSHOT_INSTRUCTION = """\
You are an assistant who can utilize external tools.
{tool_description}
To use a tool, please response with the following format:
```
{thought} Think what you need to solve, do you need to use tools?
{action} The tool name, should be one of [{action_names}].
{action_input} The input to the tool that you want to use.
```
The tool will give you response after your response using the following format:
```
{response} the results after call the tool.
```
Therefore DO NOT generate tool response by yourself.
Also please follow the guidelines:
1. Always use code interpreter to solve the problem.
2. The generated codes should always in a markdown code block format.
3. The generated codes will be executed in an ipython manner and the results will be cached.
4. Your responded code should always be simple and only solves the problem in current step.
Begin!
"""
IPYTHON_INTERPRETER_DESCRIPTION = '''\
It can run Python code in a manner as jupyter notebook. The code must be a valid code that contains only python method.'''
models = [
dict(
abbr='gpt-3.5-code',
type=CodeAgent,
agent_type=CIReAct,
max_turn=3,
llm=dict(
type=OpenAI,
path='gpt-3.5-turbo',
key='ENV',
query_per_second=1,
max_seq_len=4096,
),
actions=[
dict(type=IPythonInterpreter,
description=IPYTHON_INTERPRETER_DESCRIPTION)
],
protocol=dict(
type=ReActProtocol,
call_protocol=FEWSHOT_INSTRUCTION,
force_stop=FORCE_STOP_PROMPT_EN,
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
),
batch_size=1,
),
]
for dataset in datasets:
# Evaluate on every assistant response
dataset['infer_cfg']['inferencer']['infer_mode'] = 'every'
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)

35
configs/eval_chat_last.py Normal file
View File

@ -0,0 +1,35 @@
from mmengine.config import read_base
from opencompass.models.openai_api import OpenAI
from opencompass.openicl import ChatInferencer
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets as datasets
models = [
dict(
abbr='gpt-3.5',
type=OpenAI,
path='gpt-3.5-turbo',
key='ENV',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
run_cfg=dict(num_gpus=1, num_procs=1),
)
]
for dataset in datasets:
# Use ChatInferencer instead of GenInferencer
dataset['infer_cfg']['inferencer'] = dict(type=ChatInferencer)
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)

View File

@ -1,148 +0,0 @@
from mmengine.config import read_base
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.openicl import AgentInferencer
with read_base():
from .summarizers.medium import summarizer
from .datasets.gsm8k.gsm8k_gen import gsm8k_datasets as datasets
from opencompass.models.lagent import LagentAgent
from lagent.llms import GPTAPI
from lagent.agents.react import ReAct, ReActProtocol
from lagent.actions import PythonInterpreter
FORCE_STOP_PROMPT_EN = """You should directly give results based on history information."""
FEWSHOT_INSTRUCTION = """\
You are a assistant who can utilize external tools.
{tool_description}
To use a tool, please use the following format:
```
{thought} Think what you need to solve, do you need to use tools?
{action} the tool name, should be one of [{action_names}]
{action_input} the input to the action
```
I will give you response after utilizing tools should using the following format:
```
{response} the results after call the tool.
``
If you already know the answer, or you do not need to use tools,
please using the following format to reply:
```
{thought} the thought process to get the final answer
{finish} final answer
```
Examples:
<HUMAN>A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?
<ASSISTANT>{thought} We need to calculate the total number of fruits. The total number of fruits in the first three baskets is given, while for the fourth basket, we need to subtract 2 from each fruit category. We can solve this problem using simple arithmetic.
{action} PythonInterpreter
{action_input}
```python
def solution():
# Fruits in the first three baskets
apples_first_three = 9
oranges_first_three = 15
bananas_first_three = 14
# Fruits in the fourth basket
apples_fourth = apples_first_three - 2
oranges_fourth = oranges_first_three - 2
bananas_fourth = bananas_first_three - 2
# Total fruits
total_fruits = ((apples_first_three + oranges_first_three + bananas_first_three) * 3 +
apples_fourth + oranges_fourth + bananas_fourth)
return {{"total_fruits": total_fruits}}
```
<SYSTEM>{response}{{'total_fruits': 146}}
<ASSISTANT> {thought} By adding the given numbers of apples, oranges, and bananas in the first three baskets, then subtracting 2 from each category for the fourth basket, we have found the total number of fruits.
{finish} 146
<HUMAN>Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?
<ASSISTANT>{thought} This is a problem that requires solving equations. We know the relationship between the number of marbles, frisbees, and deck cards. Bella has twice as many marbles as frisbees, and 20 more frisbees than deck cards. Finally, we are told Bella buys 2/5 times more of each item. This purchasing will increase the number of each type of item.
{action} PythonInterpreter
{action_input}
```python
def solution():
# Given number of marbles
marbles_now = 60
# Calculate number of frisbees and deck cards now
frisbees_now = marbles_now / 2
cards_now = frisbees_now - 20
# Calculate number of each item after buying more
marbles_then = marbles_now + (2/5) * marbles_now
frisbees_then = frisbees_now + (2/5) * frisbees_now
cards_then = cards_now + (2/5)*cards_now
# Total number of items then
total_items = marbles_then + frisbees_then + cards_then
return {{"total_items": total_items}}
```
<SYSTEM>{response}{{'total_items': 140.0}}
<ASSISTANT>{thought} By establishing the relationships between the numbers of marbles, frisbees, and deck cards that Bella currently has, we can calculate how many of each item she will have after buying 2/5 more of each. Adding these quantities together gives us the total number of items.
{finish} 140
Begin!
"""
PYTHON_INTERPRETER_DESCRIPTION = '''\
It can run a Python code. The code must be a valid code that contains only python method, and the method' name must be 'solution' and returns a dict, which key is variable name. The libraries I recommend are sympy and scipy. the format is:
```python
# import packages
import xxx
def solution():
# initialize some variables
variable_names_with_real_meaning = xxx
# middle steps
mid_variable = func(mid_variable)
# final answer
final_answer = func(mid_variable)
return final_answer
```'''
models = [
dict(abbr='gpt-3.5-react',
type=LagentAgent,
agent_type=ReAct,
max_turn=3,
llm=dict(
type=GPTAPI,
model_type='gpt-3.5-turbo',
key='ENV',
query_per_second=1,
max_seq_len=4096,
),
actions=[
dict(type=PythonInterpreter,
description=PYTHON_INTERPRETER_DESCRIPTION),
],
protocol=dict(
type=ReActProtocol,
call_protocol=FEWSHOT_INSTRUCTION,
force_stop=FORCE_STOP_PROMPT_EN,
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
),
batch_size=8),
]
for dataset in datasets:
# Use AgentInferencer instead of GenInferencer
dataset['infer_cfg']['inferencer'] = dict(type=AgentInferencer)
# Use the question as agent input directly.
dataset['infer_cfg']['prompt_template']['template'] = "{question}"
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)

View File

@ -43,7 +43,10 @@ def load_experiment(file: str) -> dict:
outputs.append(None)
return dict(
experiment=file,
questions=questions,
questions=sum(([
dict(role='user', content=question),
dict(role='assistant', content=output)
] for question, output in zip(questions, outputs)), []),
references=dict(outputs=outputs, tags=tags, experiment=file),
)

View File

@ -72,6 +72,8 @@ class IPythonInterpreter(BaseAction):
user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
self.user_data_dir = user_data_dir
self._initialized = False
if not os.path.exists(WORK_DIR):
os.mkdir(WORK_DIR)
@staticmethod
def start_kernel():

View File

@ -1,133 +1,7 @@
import re
from typing import Union
from lagent.actions import ActionExecutor
from lagent.agents.base_agent import BaseAgent
from lagent.agents.react import ReActProtocol
from lagent.llms.base_api import BaseAPIModel
from lagent.llms.base_llm import BaseModel
from lagent.agents.react import ReAct
from lagent.schema import ActionReturn, ActionStatusCode, AgentReturn
class ReAct(BaseAgent):
"""An implementation of ReAct (https://arxiv.org/abs/2210.03629)
Args:
llm (BaseModel or BaseAPIModel): a LLM service which can chat
and act as backend.
action_executor (ActionExecutor): an action executor to manage
all actions and their response.
protocol (ReActProtocol): a wrapper to generate prompt and
parse the response from LLM / actions.
max_turn (int): the maximum number of trails for LLM to generate
plans that can be successfully parsed by ReWOO protocol.
"""
def __init__(self,
llm: Union[BaseModel, BaseAPIModel],
action_executor: ActionExecutor,
protocol: ReActProtocol = ReActProtocol(),
max_turn: int = 2) -> None:
self.max_turn = max_turn
super().__init__(llm=llm,
action_executor=action_executor,
protocol=protocol)
def reset(self):
"""Reset history."""
self._session_history = []
def opencompass_adapter(self, prompt):
# adapter for prompt parsing
if isinstance(prompt, list):
system_prompt = []
merged_prompt = []
for p in prompt:
tmp_p = p.copy()
if 'content' in tmp_p:
tmp_p['prompt'] = tmp_p.pop('content')
if 'role' in tmp_p:
if tmp_p['role'] == 'system':
# skip system prompt
system_prompt.append(tmp_p['prompt'])
continue
# no system for meta template temperaily
if tmp_p['role'] == 'assistant':
tmp_p['role'] = 'BOT'
if tmp_p['role'] == 'user':
# merge previous system prompt to user
system_str = ''.join(system_prompt)
tmp_p['prompt'] = system_str + tmp_p['prompt']
tmp_p['role'] = 'HUMAN'
system_prompt = []
merged_prompt.append(tmp_p)
# merge if system still exists
if system_prompt:
if 'role' in merged_prompt[-1]:
if merged_prompt[-1]['role'] == 'HUMAN':
# append to the final human prompt
merged_prompt[-1]['prompt'] += ''.join(system_prompt)
else:
# create a human prompt behind
merged_prompt.append(
dict(role='HUMAN', prompt=''.join(system_prompt)))
from opencompass.utils.prompt import PromptList
new_prompt = PromptList()
# adapter for meta template
new_prompt.append(dict(section='round', pos='begin'))
new_prompt.extend(merged_prompt)
new_prompt.append(dict(section='round', pos='end'))
return new_prompt
def chat(self, message: str) -> AgentReturn:
self._inner_history = []
self._inner_history.append(dict(role='user', content=message))
agent_return = AgentReturn()
force_stop = False
default_response = '对不起,我无法回答你的问题'
for turn in range(self.max_turn):
prompt = self._protocol.format(
chat_history=self.session_history,
inner_step=self._inner_history,
action_executor=self._action_executor,
force_stop=force_stop)
prompt = self.opencompass_adapter(prompt)
# allow single generation
response = self._llm.generate_from_template([prompt], 512)[0]
self._inner_history.append(dict(role='assistant',
content=response))
thought, action, action_input = self._protocol.parse(
response, self._action_executor)
# TODO: hard code here
action_input = re.sub('<eoa>', '', action_input)
if 'tensorflow' in action_input:
# skip tensorflow currently
break
action_return: ActionReturn = self._action_executor(
action, action_input)
action_return.thought = thought
agent_return.actions.append(action_return)
if action_return.type == self._action_executor.finish_action.name:
agent_return.response = action_return.result['text']
return agent_return
self._inner_history.append(
dict(role='system',
content=self._protocol.format_response(action_return)))
if turn == self.max_turn - 1:
force_stop = True
agent_return.response = default_response
# only append the user and final response
self._session_history.append(dict(role='user', content=message))
self._session_history.append(
dict(role='assistant', content=agent_return.response))
return agent_return
class CIReAct(ReAct):
"""Code Interpreter version of ReAct. The success state is different from
ReAct.
@ -165,9 +39,7 @@ class CIReAct(ReAct):
inner_step=self._inner_history,
action_executor=self._action_executor,
force_stop=force_stop)
prompt = self.opencompass_adapter(prompt)
# allow single generation
response = self._llm.generate_from_template([prompt], 512)[0]
response = self._llm.generate_from_template(prompt, 512)
self._inner_history.append(dict(role='assistant',
content=response))
thought, action, action_input = self._protocol.parse(
@ -179,7 +51,7 @@ class CIReAct(ReAct):
if action_return.state == ActionStatusCode.SUCCESS:
# if success, stash model response and system response
self._session_history.append(
dict(role='assistant', content=action_return.args['text']))
dict(role='assistant', content=response))
self._session_history.append(
dict(
role='system',

View File

@ -1,10 +1,8 @@
from copy import deepcopy
from typing import List, Tuple
from mmengine.registry import Registry
from opencompass.lagent.agents.react import ReAct
from opencompass.utils import get_logger
REGISTRY = Registry('helper')
@ -13,45 +11,55 @@ class LagentAgent:
https://github.com/InternLM/lagent.
"""
is_api = True
def __init__(self,
agent_type,
llm,
actions=None,
protocol=None,
mutli_rounds=False,
**kwargs):
def __init__(self, agent_type, llm, actions=None, protocol=None, **kwargs):
llm = REGISTRY.build(llm)
agent_cfg = {'type': agent_type, 'llm': llm, **kwargs}
if actions is not None:
from lagent.actions import ActionExecutor
executor = ActionExecutor(
[REGISTRY.build(action) for action in actions])
executor = ActionExecutor([])
for action in actions:
action = REGISTRY.build(action)
if 'agentlego' in type(action).__module__:
action = action.to_lagent()
executor.add_action(action)
agent_cfg['action_executor'] = executor
if protocol is not None:
protocol = REGISTRY.build(protocol)
agent_cfg['protocol'] = protocol
self.agent = REGISTRY.build(agent_cfg)
self.mutli_rounds = mutli_rounds
from lagent import BaseAgent
self.agent: BaseAgent = REGISTRY.build(agent_cfg)
def add_example(self, example):
# format example in protocol if needed
call_protocol = self.agent._protocol.call_protocol
if '{example}' in call_protocol:
self.agent._protocol.call_protocol = call_protocol.format(
example=example)
else:
get_logger().warning('Protocal template does not have example'
' placeholder, please check your template.')
def reset(self):
self.agent._session_history = []
for action in self.agent._action_executor.actions:
if hasattr(action, 'reset'):
action.reset()
def set_history(self, history):
self.agent._session_history = deepcopy(history)
@property
def template_parser(self):
return self.agent._llm.template_parser
@template_parser.setter
def template_parser(self, value):
self.agent._llm.template_parser = value
def chat(self,
user_input: str,
history: List[dict] = None) -> Tuple[str, List[dict]]:
"""Chat with agent."""
if history:
self.agent._session_history = history
def one_round_chat(self, user_input, ice=None) -> Tuple[str, List[dict]]:
"""One round chat with agent."""
from lagent.schema import ActionReturn, AgentReturn
generation: AgentReturn = self.agent.chat(user_input)
answer = generation.response
steps = []
@ -67,19 +75,7 @@ class LagentAgent:
errmsg=step.errmsg,
valid=int(step.valid),
))
return answer, steps
def chat(self, user_input, ice=None) -> Tuple[str, List[dict]]:
"""Chat with agent."""
if self.mutli_rounds:
steps = []
for single_input in user_input:
answer, one_round_steps = self.one_round_chat(single_input)
steps.append(one_round_steps)
else:
answer, steps = self.one_round_chat(user_input)
self.agent.reset() # clear agent history
return answer, steps
@ -88,25 +84,25 @@ FORCE_STOP_PROMPT_EN = (
)
FEWSHOT_INSTRUCTION = """\
You are an assistant who can utilize external tools.
{{tool_description}}
To use a tool, please use the following format:
You are a assistant who can utilize external tools.
{tool_description}
To use a tool, please response with the following format:
```
{{thought}} Think what you need to solve, do you need to use tools?
{{action}} the tool name, should be one of [{{action_names}}]
{{action_input}} the input to the action
{thought} Think what you need to solve, do you need to use tools?
{action} The tool name, should be one of [{action_names}].
{action_input} The input to the tool that you want to use.
```
I will give you response after utilizing tools should using the following format:
The tool will give you response after your response using the following format:
```
{{response}} the results after call the tool.
``
If you already know the answer, or you do not need to use tools,
please using the following format to reply:
{response} the results after call the tool.
```
{{thought}} the thought process to get the final answer
{{finish}} final answer
```
{example}
Therefore DO NOT generate tool response by yourself.
Also please follow the guidelines:
1. Always use code interpreter to solve the problem.
2. The generated codes should always in a markdown code block format.
3. The generated codes will be executed in an ipython manner and the results will be cached.
4. Your responded code should always be simple and only solves the problem in current step.
Begin!
""" # noqa
@ -127,16 +123,13 @@ def solution():
```""" # noqa
class CodeAgent:
class CodeAgent(LagentAgent):
"""Code Agent wrapper for Lagent."""
def __new__(self, llm, **kwargs):
def __init__(self, llm, **kwargs):
from lagent import PythonInterpreter, ReAct
from lagent.agents.react import ReActProtocol
from opencompass.lagent.actions.python_interpreter import \
PythonInterpreter
mutli_rounds = kwargs.pop('mutli_rounds', False)
agent_type = kwargs.pop('agent_type', ReAct)
max_turn = kwargs.pop('max_turn', 3)
actions = kwargs.pop(
@ -155,10 +148,9 @@ class CodeAgent:
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
),
)
return LagentAgent(agent_type=agent_type,
super().__init__(agent_type=agent_type,
llm=llm,
max_turn=max_turn,
actions=actions,
protocol=protocol,
mutli_rounds=mutli_rounds,
max_turn=max_turn,
**kwargs)

View File

@ -1,6 +1,7 @@
from .icl_agent_inferencer import AgentInferencer # noqa
from .icl_attack_inferencer import AttackInferencer # noqa
from .icl_base_inferencer import BaseInferencer # noqa
from .icl_chat_inferencer import ChatInferencer # noqa
from .icl_clp_inferencer import CLPInferencer # noqa
from .icl_gen_inferencer import GenInferencer # noqa
from .icl_ppl_inferencer import PPLInferencer # noqa

View File

@ -1,124 +1,16 @@
"""Agent Inferencer."""
import os
import os.path as osp
from typing import List, Optional
import mmengine
from mmengine.registry import Registry
from tqdm import tqdm
import types
from typing import List
from opencompass.models.lagent import LagentAgent
from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
from .icl_base_inferencer import dump_results_dict
from .icl_chat_inferencer import ChatInferencer
logger = get_logger(__name__)
REGISTRY = Registry('helper')
@ICL_INFERENCERS.register_module()
class AgentInferencer(BaseInferencer):
def __init__(
self,
model,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
example: Optional[str] = None,
**kwargs) -> None:
super().__init__(
model=model,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.save_every = save_every
# example in agent usage for protocol illustration
self.example = example
if example:
self.agent.add_example(example)
@property
def agent(self):
return self.model
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = AgentInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# Create tmp json file for saving intermediate results and future
# resuming
start = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
tmp_result_dict = mmengine.load(tmp_json_filepath)
output_handler.results_dict = tmp_result_dict
start = len(tmp_result_dict)
# 3. Inference sample by sample
logger.info('Starting inference process...')
for idx, ice_indices in tqdm(enumerate(ice_idx_list[start:], start),
disable=not self.is_main_process):
# TODO: This will break the Prompt template
# get user input directly without formatting prompt
#
# user_input = retriever.generate_prompt_for_generate_task(
# idx, ice='', prompt_template=prompt_template)
user_input = retriever.dataset_reader.dataset['test'][
retriever.dataset_reader.input_columns[0]][idx]
gold = retriever.dataset_reader.dataset['test'][
retriever.dataset_reader.output_column][idx]
if len(ice_indices) > 0:
assert ice_template is not None
ice = [
ice_template.generate_ice_item(ice_idx)
for ice_idx in ice_indices
]
else:
ice = None
answer, steps = self.agent.chat(user_input=user_input, ice=ice)
# Save current output
output_handler.save_results(user_input, answer, steps, idx, gold)
# Save intermediate results
if (self.save_every is not None and start % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 4. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return [
sample['prediction']
for sample in output_handler.results_dict.values()
]
class AgentInferencerOutputHandler:
@ -130,10 +22,115 @@ class AgentInferencerOutputHandler:
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
def save_results(self, user_input, answer, steps, idx, gold):
self.results_dict[str(idx)] = {
'origin_prompt': user_input,
'prediction': answer,
def save_results(self,
origin_prompt: list,
prediction: str,
steps: list,
idx: int,
gold: str = None):
result_dict = {}
if gold:
result_dict['gold'] = gold
result_dict.update({
'prediction': prediction,
'origin_prompt': origin_prompt,
'steps': steps,
'gold': gold,
}
})
self.results_dict[str(idx)] = result_dict
def save_multiround_results(self,
origin_prompt: list,
prediction: str,
steps: list,
idx: int,
gold: str = None):
result_dict = self.results_dict.get(str(idx), {
'gold': [],
'prediction': [],
'origin_prompt': [],
'steps': [],
})
result_dict['gold'].append(gold)
result_dict['prediction'].append(prediction)
result_dict['origin_prompt'].append(origin_prompt)
result_dict['steps'].append(steps)
self.results_dict[str(idx)] = result_dict
def model_adapter(model):
"""Modify the generate method to accept and return single item."""
if getattr(model, '_generate_is_wrapped', False):
# Avoid wrap twice.
return model
origin_generate = model.generate
def generate(self, inputs, *args, **kwargs):
return origin_generate([inputs], *args, **kwargs)[0]
model.generate = types.MethodType(generate, model)
setattr(model, '_generate_is_wrapped', True)
return model
@ICL_INFERENCERS.register_module()
class AgentInferencer(ChatInferencer):
HandlerType = AgentInferencerOutputHandler
def __init__(self, model, **kwargs) -> None:
model.agent._llm = model_adapter(model.agent._llm)
super().__init__(model, **kwargs)
self.model: LagentAgent
def infer_last(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
user_idx = assistant_indices[-1] - 1
self.model.set_history(chat[:user_idx])
answer, steps = self.model.chat(chat[user_idx]['content'])
output_handler.save_results(
origin_prompt=chat[user_idx]['content'],
prediction=answer,
steps=steps,
idx=index,
gold=chat[assistant_indices[-1]]['content'],
)
self.model.reset()
def infer_every(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
self.model.set_history(chat[:assistant_indices[0] - 1])
for i in assistant_indices:
answer, steps = self.model.chat(chat[i - 1]['content'])
output_handler.save_multiround_results(
origin_prompt=chat[i - 1]['content'],
prediction=answer,
steps=steps,
idx=index,
gold=chat[i]['content'],
)
self.model.reset()
def infer_every_with_gt(self, chat: List[dict], index: int,
output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
for i in assistant_indices:
self.model.set_history(chat[:i - 1])
answer, steps = self.model.chat(chat[i - 1]['content'])
output_handler.save_multiround_results(
origin_prompt=chat[i - 1]['content'],
prediction=answer,
steps=steps,
idx=index,
gold=chat[i]['content'],
)
self.model.reset()

View File

@ -0,0 +1,368 @@
"""Chat Inferencer."""
import os
import os.path as osp
from typing import List, Optional, Union
import mmengine
from mmengine import is_list_of
from tqdm import tqdm
from opencompass.models import APITemplateParser as _APITemplateParser
from opencompass.models import BaseModel
from opencompass.models import LMTemplateParser as _LMTemplateParser
from opencompass.registry import ICL_INFERENCERS
from opencompass.utils.prompt import PromptList
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
logger = get_logger(__name__)
def promptlist_to_openai(prompt: Union[str, PromptList]):
output = []
if isinstance(prompt, str):
return [dict(role='user', content=prompt)]
for item in prompt:
if 'section' in item:
continue
if isinstance(item, str) and item:
output.append(dict(role='user', content=item))
elif item['role'] == 'SYSTEM':
output.append(dict(role='system', content=item['prompt']))
elif item['role'] == 'HUMAN':
output.append(dict(role='user', content=item['prompt']))
elif item['role'] == 'BOT':
output.append(dict(role='assistant', content=item['prompt']))
return output
class LMTemplateParser:
"""LMTemplateParser accepts OpenAI format dialog inputs."""
def __init__(self, meta_template: Optional[dict] = None):
self.meta_template = meta_template
self.roles = {}
role_mapping = {
'SYSTEM': 'system',
'HUMAN': 'user',
'BOT': 'assistant',
}
if meta_template:
for item in meta_template.get('round', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
for item in meta_template.get('reserved_roles', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
def parse_template(self, chat: List[dict], mode='gen') -> str:
if is_list_of(chat, list):
# Handle batch inputs
return [self.parse_template(item) for item in chat]
assert is_list_of(chat, dict)
prompt = ''
if self.roles:
for dialog in chat:
role_cfg = self.roles.get(dialog['role'])
prompt += role_cfg['begin']
prompt += (dialog.get('content') or '')
prompt += role_cfg['end']
prompt += self.roles['assistant']['begin']
else:
# in case the model does not have any meta template
last_sep = ''
for item in chat:
prompt += last_sep + (item.get('content') or '')
last_sep = '\n'
return prompt
class APITemplateParser:
"""APITemplateParser accepts OpenAI format dialog inputs."""
def __init__(self, meta_template: Optional[dict] = None):
self.meta_template = meta_template
self.roles = {}
role_mapping = {
'SYSTEM': 'system',
'HUMAN': 'user',
'BOT': 'assistant',
}
if meta_template:
for item in meta_template.get('round', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
for item in meta_template.get('reserved_roles', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
else:
self.roles = dict(
system=dict(api_role='SYSTEM'),
user=dict(api_role='HUMAN'),
assistant=dict(api_role='BOT', generate=True),
)
def parse_template(self, chat: List[dict], mode='gen') -> str:
if is_list_of(chat, list):
# Handle batch inputs
return [self.parse_template(item) for item in chat]
assert is_list_of(chat, dict)
prompt = []
for dialog in chat:
if dialog['role'] in self.roles:
role = self.roles[dialog['role']]['api_role']
else:
role = dialog['role']
prompt.append(dict(role=role, prompt=dialog.get('content') or ''))
return PromptList(prompt)
class ChatOutputHandler:
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
def save_results(self,
origin_prompt: list,
prediction: str,
idx: int,
gold: str = None):
result_dict = {}
if gold:
result_dict['gold'] = gold
result_dict.update({
'prediction': prediction,
'origin_prompt': origin_prompt,
})
self.results_dict[str(idx)] = result_dict
def save_multiround_results(self,
origin_prompt: list,
prediction: str,
idx: int,
gold: str = None):
result_dict = self.results_dict.get(str(idx), {
'gold': [],
'prediction': [],
'origin_prompt': [],
})
result_dict['gold'].append(gold)
result_dict['prediction'].append(prediction)
result_dict['origin_prompt'].append(origin_prompt)
self.results_dict[str(idx)] = result_dict
@ICL_INFERENCERS.register_module()
class ChatInferencer(BaseInferencer):
HandlerType = ChatOutputHandler
def __init__(
self,
model,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
infer_mode: str = 'last',
**kwargs) -> None:
super().__init__(
model=model,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
assert infer_mode in ['last', 'every', 'every_with_gt']
self.infer_mode = infer_mode
self.model: BaseModel
self._set_meta_template(self.model)
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
def _set_meta_template(self, model):
origin = model.template_parser
if isinstance(origin, _APITemplateParser):
model.template_parser = APITemplateParser(origin.meta_template)
if isinstance(origin, _LMTemplateParser):
model.template_parser = LMTemplateParser(origin.meta_template)
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> dict:
# 1. Preparation for output logs
output_handler = self.HandlerType()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
chat_list = self.get_chat_list(
ice_idx_list,
retriever,
prompt_template=prompt_template,
)
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
tmp_result_dict = mmengine.load(tmp_json_filepath)
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(chat_list[index:], batch_size=1)
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for datum in tqdm(dataloader, disable=not self.is_main_process):
chat = datum[0]
if self.infer_mode == 'last':
self.infer_last(chat, index, output_handler)
elif self.infer_mode == 'every':
self.infer_every(chat, index, output_handler)
elif self.infer_mode == 'every_with_gt':
self.infer_every_with_gt(chat, index, output_handler)
index += 1
# Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 4. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return output_handler.results_dict
def get_chat_list(self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
input_columns = retriever.dataset_reader.input_columns
output_column = retriever.dataset_reader.output_column
def chat_from_entry(entry):
if prompt_template is None and len(input_columns) == 1:
# Directly use the input column as the user input
user = entry.get(input_columns[0])
assistant = entry.get(output_column, '')
return [
dict(role='user', content=user),
dict(role='assistant', content=assistant),
]
elif prompt_template is not None:
# Use prompt template to generate chat history
chat = promptlist_to_openai(
prompt_template.generate_item(entry))
gold = entry.get(output_column, '')
if chat[-1]['role'] != 'assistant':
chat.append(dict(role='assistant', content=gold))
return chat
else:
raise ValueError()
for idx, ice_idx in enumerate(ice_idx_list):
# NOTE: The in-context examples won't be used by now.
item = {
k: v
for k, v in retriever.test_ds[idx].items()
if k in input_columns or k == output_column
}
if all(isinstance(value, str) for value in item.values()):
# Every column is a single string
chat = chat_from_entry(item)
elif all(is_list_of(value, str) for value in item.values()):
# Every column is a list of string for multi-round chat
entries = [dict(zip(item, v)) for v in zip(*item.values())]
chat = sum((chat_from_entry(entry) for entry in entries), [])
elif len(input_columns) == 1 and is_list_of(
item[input_columns[0]], dict):
# Single input column and it's already a chat.
chat = item[input_columns[0]]
else:
raise ValueError('Cannot construct chat from the dataset.')
prompt_list.append(chat)
return prompt_list
def infer_last(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
history = chat[:assistant_indices[-1]]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output_handler.save_results(
origin_prompt=history,
prediction=output,
idx=index,
gold=chat[assistant_indices[-1]]['content'],
)
def infer_every(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
for i in assistant_indices:
history = chat[:i]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
chat[i]['content'] = output
index += 1
def infer_every_with_gt(self, chat: List[dict], index: int,
output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
for i in assistant_indices:
history = chat[:i]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
index += 1