From 0f2c3882800404f9ffa3ec7d4c0212eec98074ac Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Fri, 22 Sep 2023 15:28:22 +0800 Subject: [PATCH] Support GSM8k evaluation with tools by Lagent and LangChain (#277) * Support GSM8k evaluation with tools by Lagent and LangChain * Avoid to use MMEngine new feature * update document --------- Co-authored-by: Leymore --- configs/eval_openai_agent.py | 148 ++++++++ opencompass/models/lagent.py | 48 +++ opencompass/models/langchain.py | 53 +++ opencompass/openicl/icl_evaluator/__init__.py | 1 + .../icl_evaluator/icl_agent_evaluator.py | 332 ++++++++++++++++++ .../openicl/icl_evaluator/icl_hf_evaluator.py | 2 +- .../openicl/icl_inferencer/__init__.py | 1 + .../icl_inferencer/icl_agent_inferencer.py | 132 +++++++ opencompass/tasks/openicl_eval.py | 36 +- opencompass/tasks/openicl_infer.py | 3 +- 10 files changed, 740 insertions(+), 16 deletions(-) create mode 100644 configs/eval_openai_agent.py create mode 100644 opencompass/models/lagent.py create mode 100644 opencompass/models/langchain.py create mode 100644 opencompass/openicl/icl_evaluator/icl_agent_evaluator.py create mode 100644 opencompass/openicl/icl_inferencer/icl_agent_inferencer.py diff --git a/configs/eval_openai_agent.py b/configs/eval_openai_agent.py new file mode 100644 index 00000000..a66c0b81 --- /dev/null +++ b/configs/eval_openai_agent.py @@ -0,0 +1,148 @@ +from mmengine.config import read_base +from opencompass.partitioners import SizePartitioner +from opencompass.runners import LocalRunner +from opencompass.tasks import OpenICLInferTask +from opencompass.openicl import AgentInferencer + +with read_base(): + from .summarizers.medium import summarizer + from .datasets.gsm8k.gsm8k_gen import gsm8k_datasets as datasets + +from opencompass.models.lagent import LagentAgent +from lagent.llms import GPTAPI +from lagent.agents.react import ReAct, ReActProtocol +from lagent.actions import PythonInterpreter + +FORCE_STOP_PROMPT_EN = """You should directly give results based on history information.""" + +FEWSHOT_INSTRUCTION = """\ +You are a assistant who can utilize external tools. +{tool_description} +To use a tool, please use the following format: +``` +{thought} Think what you need to solve, do you need to use tools? +{action} the tool name, should be one of [{action_names}] +{action_input} the input to the action +``` +I will give you response after utilizing tools should using the following format: +``` +{response} the results after call the tool. +`` +If you already know the answer, or you do not need to use tools, +please using the following format to reply: +``` +{thought} the thought process to get the final answer +{finish} final answer +``` + +Examples: + +A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there? +{thought} We need to calculate the total number of fruits. The total number of fruits in the first three baskets is given, while for the fourth basket, we need to subtract 2 from each fruit category. We can solve this problem using simple arithmetic. +{action} PythonInterpreter +{action_input} +```python +def solution(): + # Fruits in the first three baskets + apples_first_three = 9 + oranges_first_three = 15 + bananas_first_three = 14 + + # Fruits in the fourth basket + apples_fourth = apples_first_three - 2 + oranges_fourth = oranges_first_three - 2 + bananas_fourth = bananas_first_three - 2 + + # Total fruits + total_fruits = ((apples_first_three + oranges_first_three + bananas_first_three) * 3 + + apples_fourth + oranges_fourth + bananas_fourth) + + return {{"total_fruits": total_fruits}} +``` +{response}{{'total_fruits': 146}} + {thought} By adding the given numbers of apples, oranges, and bananas in the first three baskets, then subtracting 2 from each category for the fourth basket, we have found the total number of fruits. +{finish} 146 + +Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles? +{thought} This is a problem that requires solving equations. We know the relationship between the number of marbles, frisbees, and deck cards. Bella has twice as many marbles as frisbees, and 20 more frisbees than deck cards. Finally, we are told Bella buys 2/5 times more of each item. This purchasing will increase the number of each type of item. +{action} PythonInterpreter +{action_input} +```python +def solution(): + # Given number of marbles + marbles_now = 60 + + # Calculate number of frisbees and deck cards now + frisbees_now = marbles_now / 2 + cards_now = frisbees_now - 20 + + # Calculate number of each item after buying more + marbles_then = marbles_now + (2/5) * marbles_now + frisbees_then = frisbees_now + (2/5) * frisbees_now + cards_then = cards_now + (2/5)*cards_now + + # Total number of items then + total_items = marbles_then + frisbees_then + cards_then + + return {{"total_items": total_items}} +``` +{response}{{'total_items': 140.0}} +{thought} By establishing the relationships between the numbers of marbles, frisbees, and deck cards that Bella currently has, we can calculate how many of each item she will have after buying 2/5 more of each. Adding these quantities together gives us the total number of items. +{finish} 140 + +Begin! +""" + +PYTHON_INTERPRETER_DESCRIPTION = '''\ +It can run a Python code. The code must be a valid code that contains only python method, and the method' name must be 'solution' and returns a dict, which key is variable name. The libraries I recommend are sympy and scipy. the format is: +```python +# import packages +import xxx +def solution(): + # initialize some variables + variable_names_with_real_meaning = xxx + # middle steps + mid_variable = func(mid_variable) + # final answer + final_answer = func(mid_variable) + return final_answer +```''' + +models = [ + dict(abbr='gpt-3.5-react', + type=LagentAgent, + agent_type=ReAct, + max_turn=3, + llm=dict( + type=GPTAPI, + model_type='gpt-3.5-turbo', + key='ENV', + query_per_second=1, + max_seq_len=4096, + ), + actions=[ + dict(type=PythonInterpreter, + description=PYTHON_INTERPRETER_DESCRIPTION), + ], + protocol=dict( + type=ReActProtocol, + call_protocol=FEWSHOT_INSTRUCTION, + force_stop=FORCE_STOP_PROMPT_EN, + finish=dict(role='FINISH', begin='Final Answer:', end='\n'), + ), + batch_size=8), +] + +for dataset in datasets: + # Use AgentInferencer instead of GenInferencer + dataset['infer_cfg']['inferencer'] = dict(type=AgentInferencer) + # Use the question as agent input directly. + dataset['infer_cfg']['prompt_template']['template'] = "{question}" + +infer = dict( + partitioner=dict(type=SizePartitioner, max_task_size=1000), + runner=dict( + type=LocalRunner, + max_num_workers=16, + task=dict(type=OpenICLInferTask)), +) diff --git a/opencompass/models/lagent.py b/opencompass/models/lagent.py new file mode 100644 index 00000000..646c6451 --- /dev/null +++ b/opencompass/models/lagent.py @@ -0,0 +1,48 @@ +from typing import List, Tuple + +from mmengine.registry import Registry + +REGISTRY = Registry('helper') + + +class LagentAgent: + """Agent wrapper for Lagent. + + https://github.com/InternLM/lagent. + """ + + def __init__(self, agent_type, llm, actions=None, protocol=None, **kwargs): + llm = REGISTRY.build(llm) + agent_cfg = {'type': agent_type, 'llm': llm, **kwargs} + + if actions is not None: + from lagent.actions import ActionExecutor + executor = ActionExecutor( + [REGISTRY.build(action) for action in actions]) + agent_cfg['action_executor'] = executor + if protocol is not None: + protocol = REGISTRY.build(protocol) + agent_cfg['protocol'] = protocol + + self.agent = REGISTRY.build(agent_cfg) + + def chat(self, user_input, ice=None) -> Tuple[str, List[dict]]: + from lagent.schema import ActionReturn, AgentReturn + generation: AgentReturn = self.agent.chat(user_input) + self.agent._session_history = [] # clear agent history + answer = generation.response + steps = [] + + for step in generation.actions: + step: ActionReturn + steps.append( + dict( + type=step.type, + args=step.args, + result=step.result, + thought=step.thought, + state=int(step.state), + errmsg=step.errmsg, + valid=int(step.valid), + )) + return answer, steps diff --git a/opencompass/models/langchain.py b/opencompass/models/langchain.py new file mode 100644 index 00000000..ab7d7fd2 --- /dev/null +++ b/opencompass/models/langchain.py @@ -0,0 +1,53 @@ +from typing import List, Tuple + +from mmengine.registry import Registry + +REGISTRY = Registry('helper') + + +class LangchainAgent: + """Agent wrapper for Langchain. + + https://github.com/langchain-ai/langchain. + """ + + def __init__(self, agent_type, llm, tools) -> None: + from langchain.agents import initialize_agent, load_tools + + llm = REGISTRY.build(llm) + tools = load_tools(tools, llm=llm) + self.agent = initialize_agent(tools, + llm, + agent=agent_type, + return_intermediate_steps=True) + + def chat(self, user_input, ice=None) -> Tuple[str, List[dict]]: + from langchain.schema import AgentAction + try: + generation = self.agent(user_input) + answer = generation['output'] + steps = [] + for step in generation['intermediate_steps']: + action: AgentAction = step[0] + steps.append( + dict( + type=action.tool, + args=action.tool_input, + result=step[1], + thought=action.log, + state=0, + errmsg=None, + )) + except Exception as e: + answer = None + steps = [ + dict( + type='InvalidAction', + args={}, + result=None, + thought=None, + state=-1002, + errmsg=str(e), + ) + ] + return answer, steps diff --git a/opencompass/openicl/icl_evaluator/__init__.py b/opencompass/openicl/icl_evaluator/__init__.py index b81dbc15..e24f8eca 100644 --- a/opencompass/openicl/icl_evaluator/__init__.py +++ b/opencompass/openicl/icl_evaluator/__init__.py @@ -1,3 +1,4 @@ +from .icl_agent_evaluator import * # noqa from .icl_aucroc_evaluator import AUCROCEvaluator # noqa from .icl_base_evaluator import BaseEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa diff --git a/opencompass/openicl/icl_evaluator/icl_agent_evaluator.py b/opencompass/openicl/icl_evaluator/icl_agent_evaluator.py new file mode 100644 index 00000000..7b2ffb05 --- /dev/null +++ b/opencompass/openicl/icl_evaluator/icl_agent_evaluator.py @@ -0,0 +1,332 @@ +import json +import math +import random +import re +import time +from typing import List + +import numpy as np +import requests + +from opencompass.models import OpenAI + +from .icl_base_evaluator import BaseEvaluator + +DEFAULT_FAIL_WORDS = ('sorry', 'apologize', 'apology', 'unfortunately', + "couldn't") + +CHECK_SOLVE_QUERY_PROMPT = '''\ +Please check whether the answer solve the query or not. +Query: +{query} + +Answer: +{answer} + +Now give your judgment of JSON to `{func_name}`, remember do not be too strict. +''' + +SELECT_BEST_ANSWER_PROMPT = '''\ +For query {query}, you have the following answers in JSON format: +{answers} + +I want you to select the best answer from the above answers and give the index of the answer of JSON to `{func_name}`. Now select the best answer.''' # noqa: E501 + + +def extract_answer(result: dict): + """Extract answer from toolbench format.""" + final_answer = result['final_answer'] + try: + final_answer = json.loads(final_answer)['final_answer'] + except Exception: + pass + + next_step = result['answer_details'] + steps = [] + + while len(next_step) > 0: + step = next_step[-1] + next_step = step['next'] + if step['role'] == 'tool': + tool_type = re.findall(r"'name': '(.*?)'", step['message']) + error = re.findall(r"{\"error\": \"([^\"]+)", step['message']) + if len(tool_type) > 0: + tool_type = tool_type[0] + valid = 0 + else: + tool_type = None + valid = -2 + if tool_type == 'Finish': + valid = 1 + if len(error) > 0: + valid = -2 + elif step['role'] == 'assistant': + tool_type = None + valid = -2 + else: + continue + steps.append( + dict( + type=tool_type, + args=None, + result=None, + thought=None, + state=0, + valid=valid, + )) + return final_answer, steps + + +class PassRateEvaluator(BaseEvaluator): + """This Evaluator can determine whether pred refuses to execute the + task.""" + + def __init__(self, fail_words=DEFAULT_FAIL_WORDS) -> None: + super().__init__() + self.fail_words = fail_words + + def score(self, predictions: List, references: List = None) -> dict: + results = [] + for pred in predictions: + if pred and self.check_real_valid(pred): + results.append(1) + else: + results.append(0) + pass_rate = sum(results) / len(results) * 100 + return dict(pass_rate=pass_rate) + + def check_real_valid(self, answer): + """Exclude response without real answer.""" + return not any(word in answer.lower() for word in self.fail_words) + + +class WinRateEvaluator(BaseEvaluator): + # https://github.com/OpenBMB/ToolBench/blob/e18a30ed8f9afc131a7e313d0522c4371f030f31/toolbench/tooleval/evaluators/registered_cls/tooleval.py#L50 + """Follow `OpenAINormalizedEvaluator` in the `ToolBench`. + + The Evaluator will compare which call-tool process between `pred` and + `reference` is better. + + 1. Compare whether an answer can be extracted. The one that can extract an + answer wins. + 2. If both can, then compare whether the answer is correct. The correct one + wins. + 3. If both answers are correct, then compare the number of tool calls; the + one with fewer calls wins. If the number of steps is the same, the one + with the better-looking final answer wins. + 4. If both answers are incorrect, then consider factors such as whether the + tool was successfully called and the variety of tools used. + """ + + def __init__(self, + model='gpt-3.5-turbo-16k', + temperature=0, + **kwargs) -> None: + super().__init__() + self.openai = OpenAI(path=model, temperature=temperature, **kwargs) + + def score(self, predictions: List, references: List, origin_prompt: List, + steps: List): + compare_list = [] + for query, ref, pred_answer, pred_steps in zip(origin_prompt, + references, predictions, + steps): + ref_answer, ref_steps = extract_answer(ref) + + if bool(pred_answer) ^ bool(ref_answer): + # Empty vs non-empty + win = int(bool(pred_answer)) + else: + pred_valid = bool(pred_answer) and self.check_solve_query( + query, pred_answer) + ref_valid = bool(ref_answer) and self.check_solve_query( + query, ref_answer) + + if pred_valid and ref_valid: + # both answer success + if len(pred_steps) != len(ref_steps): + win = 1 if len(pred_steps) < len(ref_steps) else 0 + else: + win = self.select_best_final_answer( + query, [ref_answer, pred_answer]) + elif not pred_valid and not ref_valid: + # both answer failed + win = self.compare_steps([ref_steps, pred_steps]) + else: + win = int(pred_valid) + + compare_list.append(win) + + pred_answer = pred_answer.replace('\n', '') + ref_answer = ref_answer.replace('\n', '') + return {'win_rate': sum(compare_list) / len(compare_list) * 100.} + + def check_solve_query(self, query: str, answer: str) -> bool: + """Check whether the answer solved the query.""" + func_name = 'check_solve_query' + return_key = 'is_solved' + + prompt = CHECK_SOLVE_QUERY_PROMPT.format(query=query, + answer=answer, + func_name=func_name) + + function = dict( + name=func_name, + description=('Check whether the given answer solve the given ' + 'query, return true or false'), + parameters={ + 'type': 'object', + 'properties': { + return_key: { + 'type': 'boolean', + 'description': 'true if solved and false if not' + } + }, + 'required': [return_key] + }) + + result = self._openai_function( + prompt, + max_out_len=100, + functions=[function], + function_call={'name': function['name']}, + ) + return bool(result[return_key]) + + def select_best_final_answer(self, query: str, answers: list) -> int: + """Select the best final answer from candidates.""" + func_name = 'select_best_final_answer' + return_key = 'best_answer_index' + + is_reversed = random.random() > 0.5 + if is_reversed: + answers = list(reversed(answers)) + prompt = SELECT_BEST_ANSWER_PROMPT.format(query=query, + answers=answers, + func_name=func_name) + + function = dict( + name=func_name, + description=('For given query, select the best answer in answers ' + 'list and return the index of the best answer'), + parameters={ + 'type': 'object', + 'properties': { + return_key: { + 'type': + 'number', + 'description': ('The index of the best answer in the ' + 'answer list, start from 0') + } + }, + 'required': [return_key] + }) + + result = self._openai_function( + prompt, + max_out_len=100, + functions=[function], + function_call={'name': function['name']}, + ) + if not is_reversed: + return int(result[return_key]) + else: + return len(answers) - int(result[return_key]) - 1 + + def compare_steps(self, steps_list: list) -> int: + """Compare results according to score when both answers are failed.""" + # calculate socre and return one with highest score + scores = [] + for steps in steps_list: + succeed_tool_calling = sum(step['valid'] == 0 for step in steps) + used_tool_types = len(set(step['type'] for step in steps)) + score = succeed_tool_calling * 10 + used_tool_types * 5 + if len(steps) <= 0: + score -= int(1e5) + else: + score += -5 * math.log(len(steps)) + scores.append(score) + + # return index of highest score + scores = np.array(scores) + highest_idx = np.where(scores == scores.max())[0].tolist() + return random.choice(highest_idx) + + def _openai_function(self, msg: str, max_out_len: int, functions: dict, + function_call: dict, **kwargs) -> dict: + openai = self.openai + + messages = [{'role': 'user', 'content': msg}] + + max_num_retries = 0 + while max_num_retries < openai.retry: + openai.wait() + + if len(openai.invalid_keys) == len(openai.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + openai.key_ctr += 1 + if openai.key_ctr == len(openai.keys): + openai.key_ctr = 0 + + if openai.keys[openai.key_ctr] not in openai.invalid_keys: + break + + key = openai.keys[openai.key_ctr] + + header = { + 'Authorization': f'Bearer {key}', + 'content-type': 'application/json', + } + + if openai.orgs: + openai.org_ctr += 1 + if openai.org_ctr == len(openai.orgs): + openai.org_ctr = 0 + header['OpenAI-Organization'] = openai.orgs[openai.org_ctr] + + try: + data = dict(model=openai.path, + messages=messages, + max_tokens=max_out_len, + n=1, + stop=None, + temperature=openai.temperature, + functions=functions, + function_call=function_call, + **kwargs) + raw_response = requests.post(openai.url, + headers=header, + data=json.dumps(data)) + except requests.ConnectionError: + openai.logger.error('Got connection error, retrying...') + continue + try: + response = raw_response.json() + except requests.JSONDecodeError: + openai.logger.error('JsonDecode error, got', + str(raw_response.content)) + continue + try: + result = response['choices'][0]['message']['function_call'][ + 'arguments'] + return json.loads(result) + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + openai.invalid_keys.add(key) + openai.logger.warn(f'insufficient_quota key: {key}') + continue + + openai.logger.error('Find error message in response: ', + str(response['error'])) + max_num_retries += 1 + + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + 'details.') diff --git a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py index 92e6797f..004480f7 100644 --- a/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_hf_evaluator.py @@ -74,7 +74,7 @@ class HuggingfaceEvaluator(BaseEvaluator): f'len(references): {len(references)}' } # use codes pre-downloaded to opencompass repo, avoid downloading - local_path = os.path.join(os.dirname(os.path.abspath(__file__)), + local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_metrics', self.metric + '.py') if os.path.exists(local_path): metric = evaluate.load(local_path) diff --git a/opencompass/openicl/icl_inferencer/__init__.py b/opencompass/openicl/icl_inferencer/__init__.py index 89e2e443..29171d60 100644 --- a/opencompass/openicl/icl_inferencer/__init__.py +++ b/opencompass/openicl/icl_inferencer/__init__.py @@ -1,3 +1,4 @@ +from .icl_agent_inferencer import AgentInferencer # noqa from .icl_attack_inferencer import AttackInferencer # noqa from .icl_base_inferencer import BaseInferencer # noqa from .icl_clp_inferencer import CLPInferencer # noqa diff --git a/opencompass/openicl/icl_inferencer/icl_agent_inferencer.py b/opencompass/openicl/icl_inferencer/icl_agent_inferencer.py new file mode 100644 index 00000000..781ce9dc --- /dev/null +++ b/opencompass/openicl/icl_inferencer/icl_agent_inferencer.py @@ -0,0 +1,132 @@ +"""Agent Inferencer.""" +import os +import os.path as osp +from typing import List, Optional + +import mmengine +from mmengine.registry import Registry +from tqdm import tqdm + +from opencompass.registry import ICL_INFERENCERS + +from ..icl_prompt_template import PromptTemplate +from ..icl_retriever import BaseRetriever +from ..utils.logging import get_logger +from .icl_base_inferencer import BaseInferencer, dump_results_dict + +logger = get_logger(__name__) +REGISTRY = Registry('helper') + + +@ICL_INFERENCERS.register_module() +class AgentInferencer(BaseInferencer): + + def __init__( + self, + model, + output_json_filepath: Optional[str] = './icl_inference_output', + output_json_filename: Optional[str] = 'predictions', + save_every: Optional[int] = 1, + **kwargs) -> None: + super().__init__( + model=model, + output_json_filename=output_json_filename, + output_json_filepath=output_json_filepath, + **kwargs, + ) + self.save_every = save_every + + @property + def agent(self): + return self.model + + def inference(self, + retriever: BaseRetriever, + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None, + output_json_filepath: Optional[str] = None, + output_json_filename: Optional[str] = None) -> List: + # 1. Preparation for output logs + output_handler = AgentInferencerOutputHandler() + + if output_json_filepath is None: + output_json_filepath = self.output_json_filepath + if output_json_filename is None: + output_json_filename = self.output_json_filename + + # 2. Get results of retrieval process + if 'Fix' in retriever.__class__.__name__: + ice_idx_list = retriever.retrieve(self.fix_id_list) + else: + ice_idx_list = retriever.retrieve() + + # Create tmp json file for saving intermediate results and future + # resuming + start = 0 + tmp_json_filepath = os.path.join(output_json_filepath, + 'tmp_' + output_json_filename) + if osp.exists(tmp_json_filepath): + # TODO: move resume to output handler + tmp_result_dict = mmengine.load(tmp_json_filepath) + output_handler.results_dict = tmp_result_dict + start = len(tmp_result_dict) + + # 3. Inference sample by sample + logger.info('Starting inference process...') + for idx, ice_indices in tqdm(enumerate(ice_idx_list[start:], start), + disable=not self.is_main_process): + user_input = retriever.generate_prompt_for_generate_task( + idx, ice='', prompt_template=prompt_template) + gold = retriever.dataset_reader.dataset['test'][ + retriever.dataset_reader.output_column][idx] + + if len(ice_indices) > 0: + assert ice_template is not None + ice = [ + ice_template.generate_ice_item(ice_idx) + for ice_idx in ice_indices + ] + else: + ice = None + + answer, steps = self.agent.chat(user_input=user_input, ice=ice) + + # Save current output + output_handler.save_results(user_input, answer, steps, idx, gold) + + # Save intermediate results + if (self.save_every is not None and start % self.save_every == 0 + and self.is_main_process): + output_handler.write_to_json(output_json_filepath, + 'tmp_' + output_json_filename) + + # 4. Output + if self.is_main_process: + os.makedirs(output_json_filepath, exist_ok=True) + output_handler.write_to_json(output_json_filepath, + output_json_filename) + if osp.exists(tmp_json_filepath): + os.remove(tmp_json_filepath) + + return [ + sample['prediction'] + for sample in output_handler.results_dict.values() + ] + + +class AgentInferencerOutputHandler: + + def __init__(self) -> None: + self.results_dict = {} + + def write_to_json(self, save_dir: str, filename: str): + """Dump the result to a json file.""" + dump_results_dict(self.results_dict, osp.join(save_dir, filename)) + + def save_results(self, user_input, answer, steps, idx, gold): + self.results_dict[str(idx)] = { + 'origin_prompt': user_input, + 'prediction': answer, + 'steps': steps, + 'gold': gold, + } diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 1db2d84b..807a75d0 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -3,6 +3,7 @@ import fnmatch import os.path as osp import time from collections import Counter +from inspect import signature from typing import Optional import mmengine @@ -71,8 +72,9 @@ class OpenICLEvalTask(BaseTask): test_set = build_dataset_from_cfg(self.dataset_cfg).test # Postprocess dataset if necessary if 'dataset_postprocessor' in self.eval_cfg: - proc = TEXT_POSTPROCESSORS.get( - self.eval_cfg['dataset_postprocessor']['type']) + proc = self.eval_cfg['dataset_postprocessor']['type'] + if isinstance(proc, str): + proc = TEXT_POSTPROCESSORS.get(proc) def postprocess(sample): s = sample[self.output_column] @@ -98,20 +100,21 @@ class OpenICLEvalTask(BaseTask): else: if osp.exists(osp.realpath(filename)): preds = mmengine.load(filename) - pred_strs = [ - preds[str(i)]['prediction'] for i in range(len(preds)) - ] + preds = [preds[str(i)] for i in range(len(preds))] else: filename = partial_filename - pred_strs = [] + preds = [] i = 1 while osp.exists(osp.realpath(filename)): - preds = mmengine.load(filename) + sub_preds = mmengine.load(filename) + preds.extend( + [sub_preds[str(i)] for i in range(len(sub_preds))]) filename = root + f'_{i}' + ext i += 1 - pred_strs += [ - preds[str(i)]['prediction'] for i in range(len(preds)) - ] + + preds = {k: [pred[k] for pred in preds] for k in preds[0]} + + pred_strs = preds.pop('prediction') if ('pred_role' in self.eval_cfg and 'meta_template' in self.model_cfg @@ -142,7 +145,9 @@ class OpenICLEvalTask(BaseTask): # Postprocess predictions if necessary if 'pred_postprocessor' in self.eval_cfg: kwargs = self.eval_cfg['pred_postprocessor'] - proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type')) + proc = kwargs.pop('type') + if isinstance(proc, str): + proc = TEXT_POSTPROCESSORS.get(proc) if sc_size is not None: pred_strs = [[proc(s, **kwargs) for s in preds] for preds in pred_strs] @@ -156,8 +161,13 @@ class OpenICLEvalTask(BaseTask): ] icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator']) - result = icl_evaluator.score( - predictions=pred_strs, references=test_set[self.output_column]) + preds['predictions'] = pred_strs + preds['references'] = test_set[self.output_column] + preds = { + k: preds[k] + for k in signature(icl_evaluator.score).parameters + } + result = icl_evaluator.score(**preds) if 'error' in result: self.logger.error( diff --git a/opencompass/tasks/openicl_infer.py b/opencompass/tasks/openicl_infer.py index 3e98d939..195a0bd9 100644 --- a/opencompass/tasks/openicl_infer.py +++ b/opencompass/tasks/openicl_infer.py @@ -99,7 +99,7 @@ class OpenICLInferTask(BaseTask): self._set_default_value(inferencer_cfg, 'max_out_len', self.max_out_len) self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size) - inferencer_cfg['max_seq_len'] = self.model_cfg['max_seq_len'] + inferencer_cfg['max_seq_len'] = self.model_cfg.get('max_seq_len') inferencer = ICL_INFERENCERS.build(inferencer_cfg) out_path = get_infer_output_path( @@ -128,7 +128,6 @@ class OpenICLInferTask(BaseTask): def _set_default_value(self, cfg: ConfigDict, key: str, value: Any): if key not in cfg: - assert value, (f'{key} must be specified!') cfg[key] = value