Support GSM8k evaluation with tools by Lagent and LangChain (#277)

* Support GSM8k evaluation with tools by Lagent and LangChain

* Avoid to use MMEngine new feature

* update document

---------

Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Ma Zerun 2023-09-22 15:28:22 +08:00 committed by GitHub
parent 681d3013de
commit 0f2c388280
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 740 additions and 16 deletions

View File

@ -0,0 +1,148 @@
from mmengine.config import read_base
from opencompass.partitioners import SizePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.openicl import AgentInferencer
with read_base():
from .summarizers.medium import summarizer
from .datasets.gsm8k.gsm8k_gen import gsm8k_datasets as datasets
from opencompass.models.lagent import LagentAgent
from lagent.llms import GPTAPI
from lagent.agents.react import ReAct, ReActProtocol
from lagent.actions import PythonInterpreter
FORCE_STOP_PROMPT_EN = """You should directly give results based on history information."""
FEWSHOT_INSTRUCTION = """\
You are a assistant who can utilize external tools.
{tool_description}
To use a tool, please use the following format:
```
{thought} Think what you need to solve, do you need to use tools?
{action} the tool name, should be one of [{action_names}]
{action_input} the input to the action
```
I will give you response after utilizing tools should using the following format:
```
{response} the results after call the tool.
``
If you already know the answer, or you do not need to use tools,
please using the following format to reply:
```
{thought} the thought process to get the final answer
{finish} final answer
```
Examples:
<HUMAN>A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?
<ASSISTANT>{thought} We need to calculate the total number of fruits. The total number of fruits in the first three baskets is given, while for the fourth basket, we need to subtract 2 from each fruit category. We can solve this problem using simple arithmetic.
{action} PythonInterpreter
{action_input}
```python
def solution():
# Fruits in the first three baskets
apples_first_three = 9
oranges_first_three = 15
bananas_first_three = 14
# Fruits in the fourth basket
apples_fourth = apples_first_three - 2
oranges_fourth = oranges_first_three - 2
bananas_fourth = bananas_first_three - 2
# Total fruits
total_fruits = ((apples_first_three + oranges_first_three + bananas_first_three) * 3 +
apples_fourth + oranges_fourth + bananas_fourth)
return {{"total_fruits": total_fruits}}
```
<SYSTEM>{response}{{'total_fruits': 146}}
<ASSISTANT> {thought} By adding the given numbers of apples, oranges, and bananas in the first three baskets, then subtracting 2 from each category for the fourth basket, we have found the total number of fruits.
{finish} 146
<HUMAN>Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?
<ASSISTANT>{thought} This is a problem that requires solving equations. We know the relationship between the number of marbles, frisbees, and deck cards. Bella has twice as many marbles as frisbees, and 20 more frisbees than deck cards. Finally, we are told Bella buys 2/5 times more of each item. This purchasing will increase the number of each type of item.
{action} PythonInterpreter
{action_input}
```python
def solution():
# Given number of marbles
marbles_now = 60
# Calculate number of frisbees and deck cards now
frisbees_now = marbles_now / 2
cards_now = frisbees_now - 20
# Calculate number of each item after buying more
marbles_then = marbles_now + (2/5) * marbles_now
frisbees_then = frisbees_now + (2/5) * frisbees_now
cards_then = cards_now + (2/5)*cards_now
# Total number of items then
total_items = marbles_then + frisbees_then + cards_then
return {{"total_items": total_items}}
```
<SYSTEM>{response}{{'total_items': 140.0}}
<ASSISTANT>{thought} By establishing the relationships between the numbers of marbles, frisbees, and deck cards that Bella currently has, we can calculate how many of each item she will have after buying 2/5 more of each. Adding these quantities together gives us the total number of items.
{finish} 140
Begin!
"""
PYTHON_INTERPRETER_DESCRIPTION = '''\
It can run a Python code. The code must be a valid code that contains only python method, and the method' name must be 'solution' and returns a dict, which key is variable name. The libraries I recommend are sympy and scipy. the format is:
```python
# import packages
import xxx
def solution():
# initialize some variables
variable_names_with_real_meaning = xxx
# middle steps
mid_variable = func(mid_variable)
# final answer
final_answer = func(mid_variable)
return final_answer
```'''
models = [
dict(abbr='gpt-3.5-react',
type=LagentAgent,
agent_type=ReAct,
max_turn=3,
llm=dict(
type=GPTAPI,
model_type='gpt-3.5-turbo',
key='ENV',
query_per_second=1,
max_seq_len=4096,
),
actions=[
dict(type=PythonInterpreter,
description=PYTHON_INTERPRETER_DESCRIPTION),
],
protocol=dict(
type=ReActProtocol,
call_protocol=FEWSHOT_INSTRUCTION,
force_stop=FORCE_STOP_PROMPT_EN,
finish=dict(role='FINISH', begin='Final Answer:', end='\n'),
),
batch_size=8),
]
for dataset in datasets:
# Use AgentInferencer instead of GenInferencer
dataset['infer_cfg']['inferencer'] = dict(type=AgentInferencer)
# Use the question as agent input directly.
dataset['infer_cfg']['prompt_template']['template'] = "{question}"
infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=1000),
runner=dict(
type=LocalRunner,
max_num_workers=16,
task=dict(type=OpenICLInferTask)),
)

View File

@ -0,0 +1,48 @@
from typing import List, Tuple
from mmengine.registry import Registry
REGISTRY = Registry('helper')
class LagentAgent:
"""Agent wrapper for Lagent.
https://github.com/InternLM/lagent.
"""
def __init__(self, agent_type, llm, actions=None, protocol=None, **kwargs):
llm = REGISTRY.build(llm)
agent_cfg = {'type': agent_type, 'llm': llm, **kwargs}
if actions is not None:
from lagent.actions import ActionExecutor
executor = ActionExecutor(
[REGISTRY.build(action) for action in actions])
agent_cfg['action_executor'] = executor
if protocol is not None:
protocol = REGISTRY.build(protocol)
agent_cfg['protocol'] = protocol
self.agent = REGISTRY.build(agent_cfg)
def chat(self, user_input, ice=None) -> Tuple[str, List[dict]]:
from lagent.schema import ActionReturn, AgentReturn
generation: AgentReturn = self.agent.chat(user_input)
self.agent._session_history = [] # clear agent history
answer = generation.response
steps = []
for step in generation.actions:
step: ActionReturn
steps.append(
dict(
type=step.type,
args=step.args,
result=step.result,
thought=step.thought,
state=int(step.state),
errmsg=step.errmsg,
valid=int(step.valid),
))
return answer, steps

View File

@ -0,0 +1,53 @@
from typing import List, Tuple
from mmengine.registry import Registry
REGISTRY = Registry('helper')
class LangchainAgent:
"""Agent wrapper for Langchain.
https://github.com/langchain-ai/langchain.
"""
def __init__(self, agent_type, llm, tools) -> None:
from langchain.agents import initialize_agent, load_tools
llm = REGISTRY.build(llm)
tools = load_tools(tools, llm=llm)
self.agent = initialize_agent(tools,
llm,
agent=agent_type,
return_intermediate_steps=True)
def chat(self, user_input, ice=None) -> Tuple[str, List[dict]]:
from langchain.schema import AgentAction
try:
generation = self.agent(user_input)
answer = generation['output']
steps = []
for step in generation['intermediate_steps']:
action: AgentAction = step[0]
steps.append(
dict(
type=action.tool,
args=action.tool_input,
result=step[1],
thought=action.log,
state=0,
errmsg=None,
))
except Exception as e:
answer = None
steps = [
dict(
type='InvalidAction',
args={},
result=None,
thought=None,
state=-1002,
errmsg=str(e),
)
]
return answer, steps

View File

@ -1,3 +1,4 @@
from .icl_agent_evaluator import * # noqa
from .icl_aucroc_evaluator import AUCROCEvaluator # noqa from .icl_aucroc_evaluator import AUCROCEvaluator # noqa
from .icl_base_evaluator import BaseEvaluator # noqa from .icl_base_evaluator import BaseEvaluator # noqa
from .icl_em_evaluator import EMEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa

View File

@ -0,0 +1,332 @@
import json
import math
import random
import re
import time
from typing import List
import numpy as np
import requests
from opencompass.models import OpenAI
from .icl_base_evaluator import BaseEvaluator
DEFAULT_FAIL_WORDS = ('sorry', 'apologize', 'apology', 'unfortunately',
"couldn't")
CHECK_SOLVE_QUERY_PROMPT = '''\
Please check whether the answer solve the query or not.
Query:
{query}
Answer:
{answer}
Now give your judgment of JSON to `{func_name}`, remember do not be too strict.
'''
SELECT_BEST_ANSWER_PROMPT = '''\
For query {query}, you have the following answers in JSON format:
{answers}
I want you to select the best answer from the above answers and give the index of the answer of JSON to `{func_name}`. Now select the best answer.''' # noqa: E501
def extract_answer(result: dict):
"""Extract answer from toolbench format."""
final_answer = result['final_answer']
try:
final_answer = json.loads(final_answer)['final_answer']
except Exception:
pass
next_step = result['answer_details']
steps = []
while len(next_step) > 0:
step = next_step[-1]
next_step = step['next']
if step['role'] == 'tool':
tool_type = re.findall(r"'name': '(.*?)'", step['message'])
error = re.findall(r"{\"error\": \"([^\"]+)", step['message'])
if len(tool_type) > 0:
tool_type = tool_type[0]
valid = 0
else:
tool_type = None
valid = -2
if tool_type == 'Finish':
valid = 1
if len(error) > 0:
valid = -2
elif step['role'] == 'assistant':
tool_type = None
valid = -2
else:
continue
steps.append(
dict(
type=tool_type,
args=None,
result=None,
thought=None,
state=0,
valid=valid,
))
return final_answer, steps
class PassRateEvaluator(BaseEvaluator):
"""This Evaluator can determine whether pred refuses to execute the
task."""
def __init__(self, fail_words=DEFAULT_FAIL_WORDS) -> None:
super().__init__()
self.fail_words = fail_words
def score(self, predictions: List, references: List = None) -> dict:
results = []
for pred in predictions:
if pred and self.check_real_valid(pred):
results.append(1)
else:
results.append(0)
pass_rate = sum(results) / len(results) * 100
return dict(pass_rate=pass_rate)
def check_real_valid(self, answer):
"""Exclude response without real answer."""
return not any(word in answer.lower() for word in self.fail_words)
class WinRateEvaluator(BaseEvaluator):
# https://github.com/OpenBMB/ToolBench/blob/e18a30ed8f9afc131a7e313d0522c4371f030f31/toolbench/tooleval/evaluators/registered_cls/tooleval.py#L50
"""Follow `OpenAINormalizedEvaluator` in the `ToolBench`.
The Evaluator will compare which call-tool process between `pred` and
`reference` is better.
1. Compare whether an answer can be extracted. The one that can extract an
answer wins.
2. If both can, then compare whether the answer is correct. The correct one
wins.
3. If both answers are correct, then compare the number of tool calls; the
one with fewer calls wins. If the number of steps is the same, the one
with the better-looking final answer wins.
4. If both answers are incorrect, then consider factors such as whether the
tool was successfully called and the variety of tools used.
"""
def __init__(self,
model='gpt-3.5-turbo-16k',
temperature=0,
**kwargs) -> None:
super().__init__()
self.openai = OpenAI(path=model, temperature=temperature, **kwargs)
def score(self, predictions: List, references: List, origin_prompt: List,
steps: List):
compare_list = []
for query, ref, pred_answer, pred_steps in zip(origin_prompt,
references, predictions,
steps):
ref_answer, ref_steps = extract_answer(ref)
if bool(pred_answer) ^ bool(ref_answer):
# Empty vs non-empty
win = int(bool(pred_answer))
else:
pred_valid = bool(pred_answer) and self.check_solve_query(
query, pred_answer)
ref_valid = bool(ref_answer) and self.check_solve_query(
query, ref_answer)
if pred_valid and ref_valid:
# both answer success
if len(pred_steps) != len(ref_steps):
win = 1 if len(pred_steps) < len(ref_steps) else 0
else:
win = self.select_best_final_answer(
query, [ref_answer, pred_answer])
elif not pred_valid and not ref_valid:
# both answer failed
win = self.compare_steps([ref_steps, pred_steps])
else:
win = int(pred_valid)
compare_list.append(win)
pred_answer = pred_answer.replace('\n', '')
ref_answer = ref_answer.replace('\n', '')
return {'win_rate': sum(compare_list) / len(compare_list) * 100.}
def check_solve_query(self, query: str, answer: str) -> bool:
"""Check whether the answer solved the query."""
func_name = 'check_solve_query'
return_key = 'is_solved'
prompt = CHECK_SOLVE_QUERY_PROMPT.format(query=query,
answer=answer,
func_name=func_name)
function = dict(
name=func_name,
description=('Check whether the given answer solve the given '
'query, return true or false'),
parameters={
'type': 'object',
'properties': {
return_key: {
'type': 'boolean',
'description': 'true if solved and false if not'
}
},
'required': [return_key]
})
result = self._openai_function(
prompt,
max_out_len=100,
functions=[function],
function_call={'name': function['name']},
)
return bool(result[return_key])
def select_best_final_answer(self, query: str, answers: list) -> int:
"""Select the best final answer from candidates."""
func_name = 'select_best_final_answer'
return_key = 'best_answer_index'
is_reversed = random.random() > 0.5
if is_reversed:
answers = list(reversed(answers))
prompt = SELECT_BEST_ANSWER_PROMPT.format(query=query,
answers=answers,
func_name=func_name)
function = dict(
name=func_name,
description=('For given query, select the best answer in answers '
'list and return the index of the best answer'),
parameters={
'type': 'object',
'properties': {
return_key: {
'type':
'number',
'description': ('The index of the best answer in the '
'answer list, start from 0')
}
},
'required': [return_key]
})
result = self._openai_function(
prompt,
max_out_len=100,
functions=[function],
function_call={'name': function['name']},
)
if not is_reversed:
return int(result[return_key])
else:
return len(answers) - int(result[return_key]) - 1
def compare_steps(self, steps_list: list) -> int:
"""Compare results according to score when both answers are failed."""
# calculate socre and return one with highest score
scores = []
for steps in steps_list:
succeed_tool_calling = sum(step['valid'] == 0 for step in steps)
used_tool_types = len(set(step['type'] for step in steps))
score = succeed_tool_calling * 10 + used_tool_types * 5
if len(steps) <= 0:
score -= int(1e5)
else:
score += -5 * math.log(len(steps))
scores.append(score)
# return index of highest score
scores = np.array(scores)
highest_idx = np.where(scores == scores.max())[0].tolist()
return random.choice(highest_idx)
def _openai_function(self, msg: str, max_out_len: int, functions: dict,
function_call: dict, **kwargs) -> dict:
openai = self.openai
messages = [{'role': 'user', 'content': msg}]
max_num_retries = 0
while max_num_retries < openai.retry:
openai.wait()
if len(openai.invalid_keys) == len(openai.keys):
raise RuntimeError('All keys have insufficient quota.')
# find the next valid key
while True:
openai.key_ctr += 1
if openai.key_ctr == len(openai.keys):
openai.key_ctr = 0
if openai.keys[openai.key_ctr] not in openai.invalid_keys:
break
key = openai.keys[openai.key_ctr]
header = {
'Authorization': f'Bearer {key}',
'content-type': 'application/json',
}
if openai.orgs:
openai.org_ctr += 1
if openai.org_ctr == len(openai.orgs):
openai.org_ctr = 0
header['OpenAI-Organization'] = openai.orgs[openai.org_ctr]
try:
data = dict(model=openai.path,
messages=messages,
max_tokens=max_out_len,
n=1,
stop=None,
temperature=openai.temperature,
functions=functions,
function_call=function_call,
**kwargs)
raw_response = requests.post(openai.url,
headers=header,
data=json.dumps(data))
except requests.ConnectionError:
openai.logger.error('Got connection error, retrying...')
continue
try:
response = raw_response.json()
except requests.JSONDecodeError:
openai.logger.error('JsonDecode error, got',
str(raw_response.content))
continue
try:
result = response['choices'][0]['message']['function_call'][
'arguments']
return json.loads(result)
except KeyError:
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1)
continue
elif response['error']['code'] == 'insufficient_quota':
openai.invalid_keys.add(key)
openai.logger.warn(f'insufficient_quota key: {key}')
continue
openai.logger.error('Find error message in response: ',
str(response['error']))
max_num_retries += 1
raise RuntimeError('Calling OpenAI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')

View File

@ -74,7 +74,7 @@ class HuggingfaceEvaluator(BaseEvaluator):
f'len(references): {len(references)}' f'len(references): {len(references)}'
} }
# use codes pre-downloaded to opencompass repo, avoid downloading # use codes pre-downloaded to opencompass repo, avoid downloading
local_path = os.path.join(os.dirname(os.path.abspath(__file__)), local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'hf_metrics', self.metric + '.py') 'hf_metrics', self.metric + '.py')
if os.path.exists(local_path): if os.path.exists(local_path):
metric = evaluate.load(local_path) metric = evaluate.load(local_path)

View File

@ -1,3 +1,4 @@
from .icl_agent_inferencer import AgentInferencer # noqa
from .icl_attack_inferencer import AttackInferencer # noqa from .icl_attack_inferencer import AttackInferencer # noqa
from .icl_base_inferencer import BaseInferencer # noqa from .icl_base_inferencer import BaseInferencer # noqa
from .icl_clp_inferencer import CLPInferencer # noqa from .icl_clp_inferencer import CLPInferencer # noqa

View File

@ -0,0 +1,132 @@
"""Agent Inferencer."""
import os
import os.path as osp
from typing import List, Optional
import mmengine
from mmengine.registry import Registry
from tqdm import tqdm
from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
logger = get_logger(__name__)
REGISTRY = Registry('helper')
@ICL_INFERENCERS.register_module()
class AgentInferencer(BaseInferencer):
def __init__(
self,
model,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
**kwargs) -> None:
super().__init__(
model=model,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.save_every = save_every
@property
def agent(self):
return self.model
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = AgentInferencerOutputHandler()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# Create tmp json file for saving intermediate results and future
# resuming
start = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
tmp_result_dict = mmengine.load(tmp_json_filepath)
output_handler.results_dict = tmp_result_dict
start = len(tmp_result_dict)
# 3. Inference sample by sample
logger.info('Starting inference process...')
for idx, ice_indices in tqdm(enumerate(ice_idx_list[start:], start),
disable=not self.is_main_process):
user_input = retriever.generate_prompt_for_generate_task(
idx, ice='', prompt_template=prompt_template)
gold = retriever.dataset_reader.dataset['test'][
retriever.dataset_reader.output_column][idx]
if len(ice_indices) > 0:
assert ice_template is not None
ice = [
ice_template.generate_ice_item(ice_idx)
for ice_idx in ice_indices
]
else:
ice = None
answer, steps = self.agent.chat(user_input=user_input, ice=ice)
# Save current output
output_handler.save_results(user_input, answer, steps, idx, gold)
# Save intermediate results
if (self.save_every is not None and start % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 4. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return [
sample['prediction']
for sample in output_handler.results_dict.values()
]
class AgentInferencerOutputHandler:
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
def save_results(self, user_input, answer, steps, idx, gold):
self.results_dict[str(idx)] = {
'origin_prompt': user_input,
'prediction': answer,
'steps': steps,
'gold': gold,
}

View File

@ -3,6 +3,7 @@ import fnmatch
import os.path as osp import os.path as osp
import time import time
from collections import Counter from collections import Counter
from inspect import signature
from typing import Optional from typing import Optional
import mmengine import mmengine
@ -71,8 +72,9 @@ class OpenICLEvalTask(BaseTask):
test_set = build_dataset_from_cfg(self.dataset_cfg).test test_set = build_dataset_from_cfg(self.dataset_cfg).test
# Postprocess dataset if necessary # Postprocess dataset if necessary
if 'dataset_postprocessor' in self.eval_cfg: if 'dataset_postprocessor' in self.eval_cfg:
proc = TEXT_POSTPROCESSORS.get( proc = self.eval_cfg['dataset_postprocessor']['type']
self.eval_cfg['dataset_postprocessor']['type']) if isinstance(proc, str):
proc = TEXT_POSTPROCESSORS.get(proc)
def postprocess(sample): def postprocess(sample):
s = sample[self.output_column] s = sample[self.output_column]
@ -98,20 +100,21 @@ class OpenICLEvalTask(BaseTask):
else: else:
if osp.exists(osp.realpath(filename)): if osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename) preds = mmengine.load(filename)
pred_strs = [ preds = [preds[str(i)] for i in range(len(preds))]
preds[str(i)]['prediction'] for i in range(len(preds))
]
else: else:
filename = partial_filename filename = partial_filename
pred_strs = [] preds = []
i = 1 i = 1
while osp.exists(osp.realpath(filename)): while osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename) sub_preds = mmengine.load(filename)
preds.extend(
[sub_preds[str(i)] for i in range(len(sub_preds))])
filename = root + f'_{i}' + ext filename = root + f'_{i}' + ext
i += 1 i += 1
pred_strs += [
preds[str(i)]['prediction'] for i in range(len(preds)) preds = {k: [pred[k] for pred in preds] for k in preds[0]}
]
pred_strs = preds.pop('prediction')
if ('pred_role' in self.eval_cfg if ('pred_role' in self.eval_cfg
and 'meta_template' in self.model_cfg and 'meta_template' in self.model_cfg
@ -142,7 +145,9 @@ class OpenICLEvalTask(BaseTask):
# Postprocess predictions if necessary # Postprocess predictions if necessary
if 'pred_postprocessor' in self.eval_cfg: if 'pred_postprocessor' in self.eval_cfg:
kwargs = self.eval_cfg['pred_postprocessor'] kwargs = self.eval_cfg['pred_postprocessor']
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type')) proc = kwargs.pop('type')
if isinstance(proc, str):
proc = TEXT_POSTPROCESSORS.get(proc)
if sc_size is not None: if sc_size is not None:
pred_strs = [[proc(s, **kwargs) for s in preds] pred_strs = [[proc(s, **kwargs) for s in preds]
for preds in pred_strs] for preds in pred_strs]
@ -156,8 +161,13 @@ class OpenICLEvalTask(BaseTask):
] ]
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator']) icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
result = icl_evaluator.score( preds['predictions'] = pred_strs
predictions=pred_strs, references=test_set[self.output_column]) preds['references'] = test_set[self.output_column]
preds = {
k: preds[k]
for k in signature(icl_evaluator.score).parameters
}
result = icl_evaluator.score(**preds)
if 'error' in result: if 'error' in result:
self.logger.error( self.logger.error(

View File

@ -99,7 +99,7 @@ class OpenICLInferTask(BaseTask):
self._set_default_value(inferencer_cfg, 'max_out_len', self._set_default_value(inferencer_cfg, 'max_out_len',
self.max_out_len) self.max_out_len)
self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size) self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size)
inferencer_cfg['max_seq_len'] = self.model_cfg['max_seq_len'] inferencer_cfg['max_seq_len'] = self.model_cfg.get('max_seq_len')
inferencer = ICL_INFERENCERS.build(inferencer_cfg) inferencer = ICL_INFERENCERS.build(inferencer_cfg)
out_path = get_infer_output_path( out_path = get_infer_output_path(
@ -128,7 +128,6 @@ class OpenICLInferTask(BaseTask):
def _set_default_value(self, cfg: ConfigDict, key: str, value: Any): def _set_default_value(self, cfg: ConfigDict, key: str, value: Any):
if key not in cfg: if key not in cfg:
assert value, (f'{key} must be specified!')
cfg[key] = value cfg[key] = value