mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] update evaluator (#1175)
This commit is contained in:
parent
296ea59931
commit
2b3d4150f3
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import CMMLUDataset
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
@ -101,7 +101,7 @@ for _name in cmmlu_all_sets:
|
||||
)
|
||||
|
||||
cmmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_postprocessor=dict(type=first_capital_postprocess))
|
||||
|
||||
cmmlu_datasets.append(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import CMMLUDataset
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
@ -97,7 +97,7 @@ for _name in cmmlu_all_sets:
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
cmmlu_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator))
|
||||
|
||||
cmmlu_datasets.append(
|
||||
dict(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import hellaswagDatasetwithICE
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
|
||||
@ -41,7 +41,7 @@ hellaswag_infer_cfg = dict(
|
||||
)
|
||||
|
||||
hellaswag_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_role='BOT',
|
||||
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import hellaswagDatasetwithICE
|
||||
from opencompass.utils.text_postprocessors import first_capital_postprocess
|
||||
|
||||
@ -29,7 +29,7 @@ hellaswag_infer_cfg = dict(
|
||||
)
|
||||
|
||||
hellaswag_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_postprocessor=dict(type=first_capital_postprocess),
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
|
||||
@ -106,7 +106,7 @@ for _name in mmlu_all_sets:
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'))
|
||||
|
||||
mmlu_datasets.append(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import MMLUDataset
|
||||
|
||||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
|
||||
@ -90,7 +90,7 @@ for _name in mmlu_all_sets:
|
||||
inferencer=dict(type=PPLInferencer),
|
||||
)
|
||||
|
||||
mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
|
||||
mmlu_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator), )
|
||||
|
||||
mmlu_datasets.append(
|
||||
dict(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import RaceDataset
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
|
||||
@ -26,7 +26,7 @@ race_infer_cfg = dict(
|
||||
inferencer=dict(type=GenInferencer))
|
||||
|
||||
race_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
|
||||
pred_role='BOT')
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import PPLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import RaceDataset
|
||||
|
||||
race_reader_cfg = dict(
|
||||
@ -20,7 +20,7 @@ race_infer_cfg = dict(
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=PPLInferencer))
|
||||
|
||||
race_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
race_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator))
|
||||
|
||||
race_datasets = [
|
||||
dict(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import winograndeDataset_V3
|
||||
from opencompass.utils.text_postprocessors import first_option_postprocess
|
||||
|
||||
@ -29,7 +29,7 @@ winogrande_infer_cfg = dict(
|
||||
)
|
||||
|
||||
winogrande_eval_cfg = dict(
|
||||
evaluator=dict(type=AccEvaluator),
|
||||
evaluator=dict(type=AccwithDetailsEvaluator),
|
||||
pred_role='BOT',
|
||||
pred_postprocessor=dict(type=first_option_postprocess, options='AB'),
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import FixKRetriever
|
||||
from opencompass.openicl.icl_inferencer import LLInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
|
||||
from opencompass.datasets import winograndeDataset_V3
|
||||
|
||||
winogrande_reader_cfg = dict(
|
||||
@ -25,7 +25,7 @@ winogrande_infer_cfg = dict(
|
||||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 2, 4, 6, 8]),
|
||||
inferencer=dict(type=LLInferencer),
|
||||
)
|
||||
winogrande_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
|
||||
winogrande_eval_cfg = dict(evaluator=dict(type=AccwithDetailsEvaluator))
|
||||
|
||||
winogrande_datasets = [
|
||||
dict(
|
||||
|
@ -7,6 +7,6 @@ models = [
|
||||
path='Qwen/Qwen1.5-110B',
|
||||
max_out_len=1024,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4),
|
||||
run_cfg=dict(num_gpus=8),
|
||||
)
|
||||
]
|
||||
|
@ -7,6 +7,6 @@ models = [
|
||||
path='Qwen/Qwen1.5-110B-Chat',
|
||||
max_out_len=1024,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4),
|
||||
run_cfg=dict(num_gpus=8),
|
||||
)
|
||||
]
|
||||
|
@ -7,6 +7,6 @@ models = [
|
||||
path='Qwen/Qwen1.5-14B',
|
||||
max_out_len=1024,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
run_cfg=dict(num_gpus=2),
|
||||
)
|
||||
]
|
||||
|
@ -7,6 +7,6 @@ models = [
|
||||
path='Qwen/Qwen1.5-14B-Chat',
|
||||
max_out_len=1024,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
run_cfg=dict(num_gpus=2),
|
||||
)
|
||||
]
|
||||
|
@ -7,6 +7,6 @@ models = [
|
||||
path='Qwen/Qwen1.5-72B',
|
||||
max_out_len=1024,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4),
|
||||
run_cfg=dict(num_gpus=8),
|
||||
)
|
||||
]
|
||||
|
@ -7,6 +7,6 @@ models = [
|
||||
path='Qwen/Qwen1.5-72B-Chat',
|
||||
max_out_len=1024,
|
||||
batch_size=8,
|
||||
run_cfg=dict(num_gpus=4),
|
||||
run_cfg=dict(num_gpus=8),
|
||||
)
|
||||
]
|
||||
|
@ -8,6 +8,7 @@ settings = [
|
||||
('qwen1.5-14b-pytorch', 'Qwen/Qwen1.5-14B', 1),
|
||||
('qwen1.5-32b-pytorch', 'Qwen/Qwen1.5-32B', 2),
|
||||
('qwen1.5-72b-pytorch', 'Qwen/Qwen1.5-72B', 4),
|
||||
('qwen1.5-110b-pytorch', 'Qwen/Qwen1.5-110B', 4),
|
||||
('qwen1.5-moe-a2.7b-pytorch', 'Qwen/Qwen1.5-MoE-A2.7B', 1),
|
||||
]
|
||||
|
||||
|
@ -115,6 +115,14 @@ sanitized_mbpp_dataset_abbrs = [
|
||||
['sanitized_mbpp', 'timeout'],
|
||||
]
|
||||
|
||||
IFEval_dataset_abbrs = [
|
||||
['IFEval', 'Prompt-level-strict-accuracy'],
|
||||
['IFEval', 'Inst-level-strict-accuracy'],
|
||||
['IFEval', 'Prompt-level-loose-accuracy'],
|
||||
['IFEval', 'Inst-level-loose-accuracy'],
|
||||
]
|
||||
|
||||
|
||||
summarizer = dict(
|
||||
type=MultiFacetedSummarizer,
|
||||
dataset_abbrs_list=[
|
||||
@ -124,6 +132,17 @@ summarizer = dict(
|
||||
{'name': 'bbh', 'dataset_abbrs': bbh_dataset_abbrs},
|
||||
{'name': 'GaokaoBench', 'dataset_abbrs': GaokaoBench_dataset_abbrs},
|
||||
{'name': 'sanitized_mbpp', 'dataset_abbrs': sanitized_mbpp_dataset_abbrs},
|
||||
{'name': 'triviaqa', 'dataset_abbrs': [['triviaqa_wiki_1shot', 'score']]},
|
||||
{'name': 'nq', 'dataset_abbrs': [['nq_open_1shot', 'score']]},
|
||||
{'name': 'race', 'dataset_abbrs': [['race-high', 'accuracy']]},
|
||||
{'name': 'winogrande', 'dataset_abbrs': [['winogrande', 'accuracy']]},
|
||||
{'name': 'hellaswag', 'dataset_abbrs': [['hellaswag', 'accuracy']]},
|
||||
{'name': 'gsm8k', 'dataset_abbrs': [['gsm8k', 'accuracy']]},
|
||||
{'name': 'math', 'dataset_abbrs': [['math', 'accuracy']]},
|
||||
{'name': 'TheoremQA', 'dataset_abbrs': [['TheoremQA', 'score']]},
|
||||
{'name': 'humaneval', 'dataset_abbrs': [['openai_humaneval', 'humaneval_pass@1']]},
|
||||
{'name': 'GPQA', 'dataset_abbrs': [['GPQA_diamond', 'accuracy']]},
|
||||
{'name': 'IFEval', 'dataset_abbrs': IFEval_dataset_abbrs},
|
||||
{'name': 'overall', 'dataset_abbrs': overall_dataset_abbrs},
|
||||
],
|
||||
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
|
||||
|
@ -91,34 +91,51 @@ class GaokaoBenchEvaluator(BaseEvaluator):
|
||||
]:
|
||||
return {'score': 0}
|
||||
elif self.question_type == 'multi_choice':
|
||||
details = {}
|
||||
correct_score, total_score = 0, 0
|
||||
for pred, refr in zip(predictions, references):
|
||||
for index, (pred, refr) in enumerate(zip(predictions, references)):
|
||||
pred = self.do_predictions_postprocess(pred)
|
||||
pred = self.ensure_same_length(pred, refr)
|
||||
is_corrects = []
|
||||
for p, r in zip(pred, refr):
|
||||
if p == r:
|
||||
correct_score += 2
|
||||
is_corrects.append(True)
|
||||
else:
|
||||
for i in p:
|
||||
if i not in r:
|
||||
break
|
||||
else:
|
||||
correct_score += 1
|
||||
is_corrects.append(False)
|
||||
total_score += 2
|
||||
return {'score': correct_score / total_score * 100}
|
||||
details[str(index)] = {
|
||||
'pred': pred,
|
||||
'refr': refr,
|
||||
'is_correct': all(is_corrects),
|
||||
}
|
||||
|
||||
else:
|
||||
details = {}
|
||||
correct_score, total_score = 0, 0
|
||||
for pred, refr in zip(predictions, references):
|
||||
for index, (pred, refr) in enumerate(zip(predictions, references)):
|
||||
if self.question_type == 'multi_question_choice':
|
||||
pred = self.do_predictions_postprocess(pred, len(refr))
|
||||
else:
|
||||
pred = self.do_predictions_postprocess(pred)
|
||||
pred = self.ensure_same_length(pred, refr)
|
||||
is_corrects = []
|
||||
for p, r in zip(pred, refr):
|
||||
if p == r:
|
||||
correct_score += 1
|
||||
is_correct = p == r
|
||||
correct_score += is_correct
|
||||
total_score += 1
|
||||
return {'score': correct_score / total_score * 100}
|
||||
is_corrects.append(is_correct)
|
||||
details[str(index)] = {
|
||||
'pred': pred,
|
||||
'refr': refr,
|
||||
'is_correct': all(is_corrects),
|
||||
}
|
||||
return {'score': correct_score / total_score * 100, 'details': details}
|
||||
|
||||
|
||||
for question_type in valid_gaokao_bench_question_types:
|
||||
|
@ -26,11 +26,13 @@ class IFEvalDataset(BaseDataset):
|
||||
|
||||
class IFEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
results = dict()
|
||||
for metric in ('strict', 'loose'):
|
||||
results[metric] = []
|
||||
for pred, refer in zip(predictions, references):
|
||||
def score(self, predictions, references, origin_prompt):
|
||||
prompt_strict_correct, prompt_strict_total = 0, 0
|
||||
inst_strict_correct, inst_strict_total = 0, 0
|
||||
prompt_loose_correct, prompt_loose_total = 0, 0
|
||||
inst_loose_correct, inst_loose_total = 0, 0
|
||||
details = {}
|
||||
for index, (pred, refer) in enumerate(zip(predictions, references)):
|
||||
input = InputExample(
|
||||
key=refer['key'],
|
||||
instruction_id_list=refer['instruction_id_list'],
|
||||
@ -40,29 +42,54 @@ class IFEvaluator(BaseEvaluator):
|
||||
for k in list(kwarg.keys()):
|
||||
if kwarg[k] is None:
|
||||
kwarg.pop(k, None)
|
||||
results['strict'].append(
|
||||
test_instruction_following_strict(input, pred))
|
||||
results['loose'].append(
|
||||
test_instruction_following_loose(input, pred))
|
||||
final_scores = dict()
|
||||
for metric in ('strict', 'loose'):
|
||||
prompt_total = 0
|
||||
prompt_correct = 0
|
||||
inst_total = 0
|
||||
inst_correct = 0
|
||||
|
||||
for example in results[metric]:
|
||||
# strict
|
||||
example = test_instruction_following_strict(input, pred)
|
||||
follow_instruction_list = example.follow_instruction_list
|
||||
instruction_id_list = example.instruction_id_list
|
||||
prompt_strict_total += 1
|
||||
is_strict_correct = all(follow_instruction_list)
|
||||
prompt_strict_correct += is_strict_correct
|
||||
inst_strict_total += len(instruction_id_list)
|
||||
inst_strict_correct += sum(follow_instruction_list)
|
||||
|
||||
prompt_total += 1
|
||||
if all(follow_instruction_list):
|
||||
prompt_correct += 1
|
||||
# loose
|
||||
example = test_instruction_following_loose(input, pred)
|
||||
follow_instruction_list = example.follow_instruction_list
|
||||
instruction_id_list = example.instruction_id_list
|
||||
prompt_loose_total += 1
|
||||
is_loose_correct = all(follow_instruction_list)
|
||||
prompt_loose_correct += is_loose_correct
|
||||
inst_loose_total += len(instruction_id_list)
|
||||
inst_loose_correct += sum(follow_instruction_list)
|
||||
|
||||
inst_total += len(instruction_id_list)
|
||||
inst_correct += sum(follow_instruction_list)
|
||||
prompt_score = f'Prompt-level-{metric}-accuracy'
|
||||
inst_score = f'Inst-level-{metric}-accuracy'
|
||||
final_scores[prompt_score] = prompt_correct / prompt_total * 100
|
||||
final_scores[inst_score] = inst_correct / inst_total * 100
|
||||
return final_scores
|
||||
if is_strict_correct:
|
||||
grade = 'strict'
|
||||
elif is_loose_correct:
|
||||
grade = 'loose'
|
||||
else:
|
||||
grade = 'none'
|
||||
|
||||
details[str(index)] = {
|
||||
'prompt': origin_prompt[index],
|
||||
'pred': pred,
|
||||
'refer': refer,
|
||||
'is_strict_correct': is_strict_correct,
|
||||
'is_loose_correct': is_loose_correct,
|
||||
'is_correct': is_strict_correct,
|
||||
'grade': grade
|
||||
}
|
||||
|
||||
results = {
|
||||
'Prompt-level-strict-accuracy':
|
||||
prompt_strict_correct / prompt_strict_total * 100,
|
||||
'Inst-level-strict-accuracy':
|
||||
inst_strict_correct / inst_strict_total * 100,
|
||||
'Prompt-level-loose-accuracy':
|
||||
prompt_loose_correct / prompt_loose_total * 100,
|
||||
'Inst-level-loose-accuracy':
|
||||
inst_loose_correct / inst_loose_total * 100,
|
||||
'details':
|
||||
details
|
||||
}
|
||||
return results
|
||||
|
@ -227,9 +227,10 @@ class MBPPEvaluator(BaseEvaluator):
|
||||
|
||||
from tqdm import tqdm
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
index, key = future.result()
|
||||
result[key] += 1
|
||||
details[str(index)]['result'] = key
|
||||
index, ret = future.result()
|
||||
result[ret] += 1
|
||||
details[str(index)]['result'] = ret
|
||||
details[str(index)]['is_correct'] = (ret == 'pass')
|
||||
|
||||
result['score'] = result['pass'] / len(predictions) * 100
|
||||
result['details'] = details
|
||||
|
@ -59,7 +59,7 @@ def _get_possible_max_seq_len(max_seq_len, path):
|
||||
raise ValueError('max_seq_len is not provided and cannot be inferred from the model config.')
|
||||
|
||||
|
||||
def _convert_chat_messages(inputs):
|
||||
def _convert_chat_messages(inputs, merge_role=True):
|
||||
outputs = []
|
||||
for _input in inputs:
|
||||
messages = []
|
||||
@ -73,7 +73,18 @@ def _convert_chat_messages(inputs):
|
||||
'SYSTEM': 'system',
|
||||
}[item['role']]
|
||||
messages.append({'role': role, 'content': item['prompt']})
|
||||
|
||||
if merge_role:
|
||||
merged_messages = []
|
||||
for item in messages:
|
||||
if merged_messages and merged_messages[-1]['role'] == item['role']:
|
||||
merged_messages[-1]['content'] += '\n' + item['content']
|
||||
else:
|
||||
merged_messages.append(item)
|
||||
messages = merged_messages
|
||||
|
||||
outputs.append(messages)
|
||||
print(messages)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -104,6 +115,8 @@ def _get_meta_template(meta_template):
|
||||
default_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', api_role='HUMAN'),
|
||||
# XXX: all system roles are mapped to human in purpose
|
||||
dict(role='SYSTEM', api_role='HUMAN'),
|
||||
dict(role='BOT', api_role='BOT', generate=True),
|
||||
]
|
||||
)
|
||||
|
@ -37,6 +37,9 @@ class TurboMindModel(BaseModel):
|
||||
arguments like session_len, max_batch_size for TurboMind.
|
||||
gen_config (Dict, optional): Generation config to set
|
||||
arguments like top_k, top_p, temperature.
|
||||
end_str (str, optional): Whether to trim generated strings with end_str
|
||||
if the model has special ending strings that are not handled well.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -45,7 +48,8 @@ class TurboMindModel(BaseModel):
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
engine_config: Dict = {},
|
||||
gen_config: Dict = {}):
|
||||
gen_config: Dict = {},
|
||||
end_str: Optional[str] = None):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template)
|
||||
@ -64,6 +68,7 @@ class TurboMindModel(BaseModel):
|
||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||
self.gen_config = gen_config
|
||||
self.major_version, self.minor_version, _ = version_info
|
||||
self.end_str = end_str
|
||||
|
||||
def generate(self,
|
||||
inputs: List[str],
|
||||
@ -119,6 +124,7 @@ class TurboMindModel(BaseModel):
|
||||
batch_input,
|
||||
[max_out_len] * len(batch_input),
|
||||
[gen_config] * len(batch_input),
|
||||
[self.end_str] * len(batch_input),
|
||||
))
|
||||
results += _results
|
||||
if stopping_criteria:
|
||||
@ -142,7 +148,8 @@ class TurboMindModel(BaseModel):
|
||||
session_id,
|
||||
prompt: PromptType,
|
||||
max_out_len: int,
|
||||
gen_config=None) -> str:
|
||||
gen_config=None,
|
||||
end_str: Optional[str] = None) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
@ -152,6 +159,10 @@ class TurboMindModel(BaseModel):
|
||||
max_out_len (int): The maximum length of the output.
|
||||
gen_config (EngineGenerationConfig, optional): Generation
|
||||
config to set arguments like top_k, top_p, temperature.
|
||||
end_str (str, optional): Whether to trim generated strings
|
||||
with end_str if the model has special ending strings
|
||||
that are not handled well.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
@ -174,6 +185,9 @@ class TurboMindModel(BaseModel):
|
||||
_, output_ids, _ = outputs
|
||||
response = self.tokenizer.decode(output_ids)
|
||||
response = valid_str(response)
|
||||
# used to trim
|
||||
if end_str:
|
||||
response = response.split(end_str)[0]
|
||||
return response
|
||||
|
||||
def get_ppl(self,
|
||||
|
@ -342,3 +342,29 @@ class EDAccEvaluator(AccEvaluator):
|
||||
'predictions': preds,
|
||||
'references': golds,
|
||||
}
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class AccwithDetailsEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references, origin_prompt) -> dict:
|
||||
|
||||
if len(predictions) != len(references):
|
||||
return {'error': 'preds and refrs have different length.'}
|
||||
|
||||
details = {}
|
||||
correct, total = 0, 0
|
||||
for index, (pred, ref) in enumerate(zip(predictions, references)):
|
||||
is_correct = pred == ref
|
||||
correct += is_correct
|
||||
details[str(index)] = {
|
||||
'prompt': origin_prompt[index],
|
||||
'pred': pred,
|
||||
'refr': ref,
|
||||
'is_correct': is_correct,
|
||||
}
|
||||
total += 1
|
||||
|
||||
results = {'accuracy': correct / total * 100, 'details': details}
|
||||
|
||||
return results
|
||||
|
Loading…
Reference in New Issue
Block a user