mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* [Feature] Support chat style inferencer. * [Fix] use new prompt * [Fix] use new prompt --------- Co-authored-by: yingfhu <yingfhu@gmail.com>
137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
"""Agent Inferencer."""
|
|
import os.path as osp
|
|
import types
|
|
from typing import List
|
|
|
|
from opencompass.models.lagent import LagentAgent
|
|
from opencompass.registry import ICL_INFERENCERS
|
|
|
|
from ..utils.logging import get_logger
|
|
from .icl_base_inferencer import dump_results_dict
|
|
from .icl_chat_inferencer import ChatInferencer
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class AgentInferencerOutputHandler:
|
|
|
|
def __init__(self) -> None:
|
|
self.results_dict = {}
|
|
|
|
def write_to_json(self, save_dir: str, filename: str):
|
|
"""Dump the result to a json file."""
|
|
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
|
|
|
|
def save_results(self,
|
|
origin_prompt: list,
|
|
prediction: str,
|
|
steps: list,
|
|
idx: int,
|
|
gold: str = None):
|
|
result_dict = {}
|
|
if gold:
|
|
result_dict['gold'] = gold
|
|
result_dict.update({
|
|
'prediction': prediction,
|
|
'origin_prompt': origin_prompt,
|
|
'steps': steps,
|
|
})
|
|
self.results_dict[str(idx)] = result_dict
|
|
|
|
def save_multiround_results(self,
|
|
origin_prompt: list,
|
|
prediction: str,
|
|
steps: list,
|
|
idx: int,
|
|
gold: str = None):
|
|
result_dict = self.results_dict.get(str(idx), {
|
|
'gold': [],
|
|
'prediction': [],
|
|
'origin_prompt': [],
|
|
'steps': [],
|
|
})
|
|
result_dict['gold'].append(gold)
|
|
result_dict['prediction'].append(prediction)
|
|
result_dict['origin_prompt'].append(origin_prompt)
|
|
result_dict['steps'].append(steps)
|
|
self.results_dict[str(idx)] = result_dict
|
|
|
|
|
|
def model_adapter(model):
|
|
"""Modify the generate method to accept and return single item."""
|
|
if getattr(model, '_generate_is_wrapped', False):
|
|
# Avoid wrap twice.
|
|
return model
|
|
|
|
origin_generate = model.generate
|
|
|
|
def generate(self, inputs, *args, **kwargs):
|
|
return origin_generate([inputs], *args, **kwargs)[0]
|
|
|
|
model.generate = types.MethodType(generate, model)
|
|
setattr(model, '_generate_is_wrapped', True)
|
|
return model
|
|
|
|
|
|
@ICL_INFERENCERS.register_module()
|
|
class AgentInferencer(ChatInferencer):
|
|
HandlerType = AgentInferencerOutputHandler
|
|
|
|
def __init__(self, model, **kwargs) -> None:
|
|
model.agent._llm = model_adapter(model.agent._llm)
|
|
super().__init__(model, **kwargs)
|
|
self.model: LagentAgent
|
|
|
|
def infer_last(self, chat: List[dict], index: int, output_handler):
|
|
assistant_indices = [
|
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
|
]
|
|
|
|
user_idx = assistant_indices[-1] - 1
|
|
self.model.set_history(chat[:user_idx])
|
|
answer, steps = self.model.chat(chat[user_idx]['content'])
|
|
output_handler.save_results(
|
|
origin_prompt=chat[user_idx]['content'],
|
|
prediction=answer,
|
|
steps=steps,
|
|
idx=index,
|
|
gold=chat[assistant_indices[-1]]['content'],
|
|
)
|
|
self.model.reset()
|
|
|
|
def infer_every(self, chat: List[dict], index: int, output_handler):
|
|
assistant_indices = [
|
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
|
]
|
|
|
|
self.model.set_history(chat[:assistant_indices[0] - 1])
|
|
|
|
for i in assistant_indices:
|
|
answer, steps = self.model.chat(chat[i - 1]['content'])
|
|
output_handler.save_multiround_results(
|
|
origin_prompt=chat[i - 1]['content'],
|
|
prediction=answer,
|
|
steps=steps,
|
|
idx=index,
|
|
gold=chat[i]['content'],
|
|
)
|
|
self.model.reset()
|
|
|
|
def infer_every_with_gt(self, chat: List[dict], index: int,
|
|
output_handler):
|
|
assistant_indices = [
|
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
|
]
|
|
|
|
for i in assistant_indices:
|
|
self.model.set_history(chat[:i - 1])
|
|
answer, steps = self.model.chat(chat[i - 1]['content'])
|
|
output_handler.save_multiround_results(
|
|
origin_prompt=chat[i - 1]['content'],
|
|
prediction=answer,
|
|
steps=steps,
|
|
idx=index,
|
|
gold=chat[i]['content'],
|
|
)
|
|
self.model.reset()
|