[Sync] update evaluator (#1175)

This commit is contained in:
Fengzhe Zhou 2024-05-21 14:22:46 +08:00 committed by GitHub
parent 296ea59931
commit 2b3d4150f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 184 additions and 66 deletions

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import CMMLUDataset from opencompass.datasets import CMMLUDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_capital_postprocess
@ -101,7 +101,7 @@ for _name in cmmlu_all_sets:
) )
cmmlu_eval_cfg = dict( cmmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccwithDetailsEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess)) pred_postprocessor=dict(type=first_capital_postprocess))
cmmlu_datasets.append( cmmlu_datasets.append(

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import CMMLUDataset from opencompass.datasets import CMMLUDataset
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_capital_postprocess
@ -97,7 +97,7 @@ for _name in cmmlu_all_sets:
inferencer=dict(type=PPLInferencer), inferencer=dict(type=PPLInferencer),
) )
cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) cmmlu_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator))
cmmlu_datasets.append( cmmlu_datasets.append(
dict( dict(

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import hellaswagDatasetwithICE from opencompass.datasets import hellaswagDatasetwithICE
from opencompass.utils.text_postprocessors import first_option_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
@ -41,7 +41,7 @@ hellaswag_infer_cfg = dict(
) )
hellaswag_eval_cfg = dict( hellaswag_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccwithDetailsEvaluator),
pred_role='BOT', pred_role='BOT',
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
) )

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import hellaswagDatasetwithICE from opencompass.datasets import hellaswagDatasetwithICE
from opencompass.utils.text_postprocessors import first_capital_postprocess from opencompass.utils.text_postprocessors import first_capital_postprocess
@ -29,7 +29,7 @@ hellaswag_infer_cfg = dict(
) )
hellaswag_eval_cfg = dict( hellaswag_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccwithDetailsEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess), pred_postprocessor=dict(type=first_capital_postprocess),
) )

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import MMLUDataset from opencompass.datasets import MMLUDataset
from opencompass.utils.text_postprocessors import first_option_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
@ -106,7 +106,7 @@ for _name in mmlu_all_sets:
) )
mmlu_eval_cfg = dict( mmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccwithDetailsEvaluator),
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD')) pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'))
mmlu_datasets.append( mmlu_datasets.append(

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import MMLUDataset from opencompass.datasets import MMLUDataset
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader # None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
@ -90,7 +90,7 @@ for _name in mmlu_all_sets:
inferencer=dict(type=PPLInferencer), inferencer=dict(type=PPLInferencer),
) )
mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) mmlu_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator), )
mmlu_datasets.append( mmlu_datasets.append(
dict( dict(

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import RaceDataset from opencompass.datasets import RaceDataset
from opencompass.utils.text_postprocessors import first_option_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
@ -26,7 +26,7 @@ race_infer_cfg = dict(
inferencer=dict(type=GenInferencer)) inferencer=dict(type=GenInferencer))
race_eval_cfg = dict( race_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccwithDetailsEvaluator),
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'), pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
pred_role='BOT') pred_role='BOT')

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import RaceDataset from opencompass.datasets import RaceDataset
race_reader_cfg = dict( race_reader_cfg = dict(
@ -20,7 +20,7 @@ race_infer_cfg = dict(
retriever=dict(type=ZeroRetriever), retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer)) inferencer=dict(type=PPLInferencer))
race_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) race_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator))
race_datasets = [ race_datasets = [
dict( dict(

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import GenInferencer from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import winograndeDataset_V3 from opencompass.datasets import winograndeDataset_V3
from opencompass.utils.text_postprocessors import first_option_postprocess from opencompass.utils.text_postprocessors import first_option_postprocess
@ -29,7 +29,7 @@ winogrande_infer_cfg = dict(
) )
winogrande_eval_cfg = dict( winogrande_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccwithDetailsEvaluator),
pred_role='BOT', pred_role='BOT',
pred_postprocessor=dict(type=first_option_postprocess, options='AB'), pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
) )

View File

@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import LLInferencer from opencompass.openicl.icl_inferencer import LLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
from opencompass.datasets import winograndeDataset_V3 from opencompass.datasets import winograndeDataset_V3
winogrande_reader_cfg = dict( winogrande_reader_cfg = dict(
@ -25,7 +25,7 @@ winogrande_infer_cfg = dict(
retriever=dict(type=FixKRetriever, fix_id_list=[0, 2, 4, 6, 8]), retriever=dict(type=FixKRetriever, fix_id_list=[0, 2, 4, 6, 8]),
inferencer=dict(type=LLInferencer), inferencer=dict(type=LLInferencer),
) )
winogrande_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) winogrande_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator))
winogrande_datasets = [ winogrande_datasets = [
dict( dict(

View File

@ -7,6 +7,6 @@ models = [
path='Qwen/Qwen1.5-110B', path='Qwen/Qwen1.5-110B',
max_out_len=1024, max_out_len=1024,
batch_size=8, batch_size=8,
run_cfg=dict(num_gpus=4), run_cfg=dict(num_gpus=8),
) )
] ]

View File

@ -7,6 +7,6 @@ models = [
path='Qwen/Qwen1.5-110B-Chat', path='Qwen/Qwen1.5-110B-Chat',
max_out_len=1024, max_out_len=1024,
batch_size=8, batch_size=8,
run_cfg=dict(num_gpus=4), run_cfg=dict(num_gpus=8),
) )
] ]

View File

@ -7,6 +7,6 @@ models = [
path='Qwen/Qwen1.5-14B', path='Qwen/Qwen1.5-14B',
max_out_len=1024, max_out_len=1024,
batch_size=8, batch_size=8,
run_cfg=dict(num_gpus=1), run_cfg=dict(num_gpus=2),
) )
] ]

View File

@ -7,6 +7,6 @@ models = [
path='Qwen/Qwen1.5-14B-Chat', path='Qwen/Qwen1.5-14B-Chat',
max_out_len=1024, max_out_len=1024,
batch_size=8, batch_size=8,
run_cfg=dict(num_gpus=1), run_cfg=dict(num_gpus=2),
) )
] ]

View File

@ -7,6 +7,6 @@ models = [
path='Qwen/Qwen1.5-72B', path='Qwen/Qwen1.5-72B',
max_out_len=1024, max_out_len=1024,
batch_size=8, batch_size=8,
run_cfg=dict(num_gpus=4), run_cfg=dict(num_gpus=8),
) )
] ]

View File

@ -7,6 +7,6 @@ models = [
path='Qwen/Qwen1.5-72B-Chat', path='Qwen/Qwen1.5-72B-Chat',
max_out_len=1024, max_out_len=1024,
batch_size=8, batch_size=8,
run_cfg=dict(num_gpus=4), run_cfg=dict(num_gpus=8),
) )
] ]

View File

@ -8,6 +8,7 @@ settings = [
('qwen1.5-14b-pytorch', 'Qwen/Qwen1.5-14B', 1), ('qwen1.5-14b-pytorch', 'Qwen/Qwen1.5-14B', 1),
('qwen1.5-32b-pytorch', 'Qwen/Qwen1.5-32B', 2), ('qwen1.5-32b-pytorch', 'Qwen/Qwen1.5-32B', 2),
('qwen1.5-72b-pytorch', 'Qwen/Qwen1.5-72B', 4), ('qwen1.5-72b-pytorch', 'Qwen/Qwen1.5-72B', 4),
('qwen1.5-110b-pytorch', 'Qwen/Qwen1.5-110B', 4),
('qwen1.5-moe-a2.7b-pytorch', 'Qwen/Qwen1.5-MoE-A2.7B', 1), ('qwen1.5-moe-a2.7b-pytorch', 'Qwen/Qwen1.5-MoE-A2.7B', 1),
] ]

View File

@ -115,6 +115,14 @@ sanitized_mbpp_dataset_abbrs = [
['sanitized_mbpp', 'timeout'], ['sanitized_mbpp', 'timeout'],
] ]
IFEval_dataset_abbrs = [
['IFEval', 'Prompt-level-strict-accuracy'],
['IFEval', 'Inst-level-strict-accuracy'],
['IFEval', 'Prompt-level-loose-accuracy'],
['IFEval', 'Inst-level-loose-accuracy'],
]
summarizer = dict( summarizer = dict(
type=MultiFacetedSummarizer, type=MultiFacetedSummarizer,
dataset_abbrs_list=[ dataset_abbrs_list=[
@ -124,6 +132,17 @@ summarizer = dict(
{'name': 'bbh', 'dataset_abbrs': bbh_dataset_abbrs}, {'name': 'bbh', 'dataset_abbrs': bbh_dataset_abbrs},
{'name': 'GaokaoBench', 'dataset_abbrs': GaokaoBench_dataset_abbrs}, {'name': 'GaokaoBench', 'dataset_abbrs': GaokaoBench_dataset_abbrs},
{'name': 'sanitized_mbpp', 'dataset_abbrs': sanitized_mbpp_dataset_abbrs}, {'name': 'sanitized_mbpp', 'dataset_abbrs': sanitized_mbpp_dataset_abbrs},
{'name': 'triviaqa', 'dataset_abbrs': [['triviaqa_wiki_1shot', 'score']]},
{'name': 'nq', 'dataset_abbrs': [['nq_open_1shot', 'score']]},
{'name': 'race', 'dataset_abbrs': [['race-high', 'accuracy']]},
{'name': 'winogrande', 'dataset_abbrs': [['winogrande', 'accuracy']]},
{'name': 'hellaswag', 'dataset_abbrs': [['hellaswag', 'accuracy']]},
{'name': 'gsm8k', 'dataset_abbrs': [['gsm8k', 'accuracy']]},
{'name': 'math', 'dataset_abbrs': [['math', 'accuracy']]},
{'name': 'TheoremQA', 'dataset_abbrs': [['TheoremQA', 'score']]},
{'name': 'humaneval', 'dataset_abbrs': [['openai_humaneval', 'humaneval_pass@1']]},
{'name': 'GPQA', 'dataset_abbrs': [['GPQA_diamond', 'accuracy']]},
{'name': 'IFEval', 'dataset_abbrs': IFEval_dataset_abbrs},
{'name': 'overall', 'dataset_abbrs': overall_dataset_abbrs}, {'name': 'overall', 'dataset_abbrs': overall_dataset_abbrs},
], ],
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []), summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),

View File

@ -91,34 +91,51 @@ class GaokaoBenchEvaluator(BaseEvaluator):
]: ]:
return {'score': 0} return {'score': 0}
elif self.question_type == 'multi_choice': elif self.question_type == 'multi_choice':
details = {}
correct_score, total_score = 0, 0 correct_score, total_score = 0, 0
for pred, refr in zip(predictions, references): for index, (pred, refr) in enumerate(zip(predictions, references)):
pred = self.do_predictions_postprocess(pred) pred = self.do_predictions_postprocess(pred)
pred = self.ensure_same_length(pred, refr) pred = self.ensure_same_length(pred, refr)
is_corrects = []
for p, r in zip(pred, refr): for p, r in zip(pred, refr):
if p == r: if p == r:
correct_score += 2 correct_score += 2
is_corrects.append(True)
else: else:
for i in p: for i in p:
if i not in r: if i not in r:
break break
else: else:
correct_score += 1 correct_score += 1
is_corrects.append(False)
total_score += 2 total_score += 2
return {'score': correct_score / total_score * 100} details[str(index)] = {
'pred': pred,
'refr': refr,
'is_correct': all(is_corrects),
}
else: else:
details = {}
correct_score, total_score = 0, 0 correct_score, total_score = 0, 0
for pred, refr in zip(predictions, references): for index, (pred, refr) in enumerate(zip(predictions, references)):
if self.question_type == 'multi_question_choice': if self.question_type == 'multi_question_choice':
pred = self.do_predictions_postprocess(pred, len(refr)) pred = self.do_predictions_postprocess(pred, len(refr))
else: else:
pred = self.do_predictions_postprocess(pred) pred = self.do_predictions_postprocess(pred)
pred = self.ensure_same_length(pred, refr) pred = self.ensure_same_length(pred, refr)
is_corrects = []
for p, r in zip(pred, refr): for p, r in zip(pred, refr):
if p == r: is_correct = p == r
correct_score += 1 correct_score += is_correct
total_score += 1 total_score += 1
return {'score': correct_score / total_score * 100} is_corrects.append(is_correct)
details[str(index)] = {
'pred': pred,
'refr': refr,
'is_correct': all(is_corrects),
}
return {'score': correct_score / total_score * 100, 'details': details}
for question_type in valid_gaokao_bench_question_types: for question_type in valid_gaokao_bench_question_types:

View File

@ -26,11 +26,13 @@ class IFEvalDataset(BaseDataset):
class IFEvaluator(BaseEvaluator): class IFEvaluator(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references, origin_prompt):
results = dict() prompt_strict_correct, prompt_strict_total = 0, 0
for metric in ('strict', 'loose'): inst_strict_correct, inst_strict_total = 0, 0
results[metric] = [] prompt_loose_correct, prompt_loose_total = 0, 0
for pred, refer in zip(predictions, references): inst_loose_correct, inst_loose_total = 0, 0
details = {}
for index, (pred, refer) in enumerate(zip(predictions, references)):
input = InputExample( input = InputExample(
key=refer['key'], key=refer['key'],
instruction_id_list=refer['instruction_id_list'], instruction_id_list=refer['instruction_id_list'],
@ -40,29 +42,54 @@ class IFEvaluator(BaseEvaluator):
for k in list(kwarg.keys()): for k in list(kwarg.keys()):
if kwarg[k] is None: if kwarg[k] is None:
kwarg.pop(k, None) kwarg.pop(k, None)
results['strict'].append(
test_instruction_following_strict(input, pred))
results['loose'].append(
test_instruction_following_loose(input, pred))
final_scores = dict()
for metric in ('strict', 'loose'):
prompt_total = 0
prompt_correct = 0
inst_total = 0
inst_correct = 0
for example in results[metric]: # strict
follow_instruction_list = example.follow_instruction_list example = test_instruction_following_strict(input, pred)
instruction_id_list = example.instruction_id_list follow_instruction_list = example.follow_instruction_list
instruction_id_list = example.instruction_id_list
prompt_strict_total += 1
is_strict_correct = all(follow_instruction_list)
prompt_strict_correct += is_strict_correct
inst_strict_total += len(instruction_id_list)
inst_strict_correct += sum(follow_instruction_list)
prompt_total += 1 # loose
if all(follow_instruction_list): example = test_instruction_following_loose(input, pred)
prompt_correct += 1 follow_instruction_list = example.follow_instruction_list
instruction_id_list = example.instruction_id_list
prompt_loose_total += 1
is_loose_correct = all(follow_instruction_list)
prompt_loose_correct += is_loose_correct
inst_loose_total += len(instruction_id_list)
inst_loose_correct += sum(follow_instruction_list)
inst_total += len(instruction_id_list) if is_strict_correct:
inst_correct += sum(follow_instruction_list) grade = 'strict'
prompt_score = f'Prompt-level-{metric}-accuracy' elif is_loose_correct:
inst_score = f'Inst-level-{metric}-accuracy' grade = 'loose'
final_scores[prompt_score] = prompt_correct / prompt_total * 100 else:
final_scores[inst_score] = inst_correct / inst_total * 100 grade = 'none'
return final_scores
details[str(index)] = {
'prompt': origin_prompt[index],
'pred': pred,
'refer': refer,
'is_strict_correct': is_strict_correct,
'is_loose_correct': is_loose_correct,
'is_correct': is_strict_correct,
'grade': grade
}
results = {
'Prompt-level-strict-accuracy':
prompt_strict_correct / prompt_strict_total * 100,
'Inst-level-strict-accuracy':
inst_strict_correct / inst_strict_total * 100,
'Prompt-level-loose-accuracy':
prompt_loose_correct / prompt_loose_total * 100,
'Inst-level-loose-accuracy':
inst_loose_correct / inst_loose_total * 100,
'details':
details
}
return results

View File

@ -227,9 +227,10 @@ class MBPPEvaluator(BaseEvaluator):
from tqdm import tqdm from tqdm import tqdm
for future in tqdm(as_completed(futures), total=len(futures)): for future in tqdm(as_completed(futures), total=len(futures)):
index, key = future.result() index, ret = future.result()
result[key] += 1 result[ret] += 1
details[str(index)]['result'] = key details[str(index)]['result'] = ret
details[str(index)]['is_correct'] = (ret == 'pass')
result['score'] = result['pass'] / len(predictions) * 100 result['score'] = result['pass'] / len(predictions) * 100
result['details'] = details result['details'] = details

View File

@ -59,7 +59,7 @@ def _get_possible_max_seq_len(max_seq_len, path):
raise ValueError('max_seq_len is not provided and cannot be inferred from the model config.') raise ValueError('max_seq_len is not provided and cannot be inferred from the model config.')
def _convert_chat_messages(inputs): def _convert_chat_messages(inputs, merge_role=True):
outputs = [] outputs = []
for _input in inputs: for _input in inputs:
messages = [] messages = []
@ -73,7 +73,18 @@ def _convert_chat_messages(inputs):
'SYSTEM': 'system', 'SYSTEM': 'system',
}[item['role']] }[item['role']]
messages.append({'role': role, 'content': item['prompt']}) messages.append({'role': role, 'content': item['prompt']})
if merge_role:
merged_messages = []
for item in messages:
if merged_messages and merged_messages[-1]['role'] == item['role']:
merged_messages[-1]['content'] += '\n' + item['content']
else:
merged_messages.append(item)
messages = merged_messages
outputs.append(messages) outputs.append(messages)
print(messages)
return outputs return outputs
@ -104,6 +115,8 @@ def _get_meta_template(meta_template):
default_meta_template = dict( default_meta_template = dict(
round=[ round=[
dict(role='HUMAN', api_role='HUMAN'), dict(role='HUMAN', api_role='HUMAN'),
# XXX: all system roles are mapped to human in purpose
dict(role='SYSTEM', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True), dict(role='BOT', api_role='BOT', generate=True),
] ]
) )

View File

@ -37,6 +37,9 @@ class TurboMindModel(BaseModel):
arguments like session_len, max_batch_size for TurboMind. arguments like session_len, max_batch_size for TurboMind.
gen_config (Dict, optional): Generation config to set gen_config (Dict, optional): Generation config to set
arguments like top_k, top_p, temperature. arguments like top_k, top_p, temperature.
end_str (str, optional): Whether to trim generated strings with end_str
if the model has special ending strings that are not handled well.
Defaults to None.
""" """
def __init__(self, def __init__(self,
@ -45,7 +48,8 @@ class TurboMindModel(BaseModel):
max_seq_len: int = 2048, max_seq_len: int = 2048,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
engine_config: Dict = {}, engine_config: Dict = {},
gen_config: Dict = {}): gen_config: Dict = {},
end_str: Optional[str] = None):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template) meta_template=meta_template)
@ -64,6 +68,7 @@ class TurboMindModel(BaseModel):
self.generator_ids = [i + 1 for i in range(concurrency)] self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config self.gen_config = gen_config
self.major_version, self.minor_version, _ = version_info self.major_version, self.minor_version, _ = version_info
self.end_str = end_str
def generate(self, def generate(self,
inputs: List[str], inputs: List[str],
@ -119,6 +124,7 @@ class TurboMindModel(BaseModel):
batch_input, batch_input,
[max_out_len] * len(batch_input), [max_out_len] * len(batch_input),
[gen_config] * len(batch_input), [gen_config] * len(batch_input),
[self.end_str] * len(batch_input),
)) ))
results += _results results += _results
if stopping_criteria: if stopping_criteria:
@ -142,7 +148,8 @@ class TurboMindModel(BaseModel):
session_id, session_id,
prompt: PromptType, prompt: PromptType,
max_out_len: int, max_out_len: int,
gen_config=None) -> str: gen_config=None,
end_str: Optional[str] = None) -> str:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
@ -152,6 +159,10 @@ class TurboMindModel(BaseModel):
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
gen_config (EngineGenerationConfig, optional): Generation gen_config (EngineGenerationConfig, optional): Generation
config to set arguments like top_k, top_p, temperature. config to set arguments like top_k, top_p, temperature.
end_str (str, optional): Whether to trim generated strings
with end_str if the model has special ending strings
that are not handled well.
Defaults to None.
Returns: Returns:
str: The generated string. str: The generated string.
""" """
@ -174,6 +185,9 @@ class TurboMindModel(BaseModel):
_, output_ids, _ = outputs _, output_ids, _ = outputs
response = self.tokenizer.decode(output_ids) response = self.tokenizer.decode(output_ids)
response = valid_str(response) response = valid_str(response)
# used to trim
if end_str:
response = response.split(end_str)[0]
return response return response
def get_ppl(self, def get_ppl(self,

View File

@ -342,3 +342,29 @@ class EDAccEvaluator(AccEvaluator):
'predictions': preds, 'predictions': preds,
'references': golds, 'references': golds,
} }
@ICL_EVALUATORS.register_module()
class AccwithDetailsEvaluator(BaseEvaluator):
def score(self, predictions, references, origin_prompt) -> dict:
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length.'}
details = {}
correct, total = 0, 0
for index, (pred, ref) in enumerate(zip(predictions, references)):
is_correct = pred == ref
correct += is_correct
details[str(index)] = {
'prompt': origin_prompt[index],
'pred': pred,
'refr': ref,
'is_correct': is_correct,
}
total += 1
results = {'accuracy': correct / total * 100, 'details': details}
return results