mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Fix MultiRound Subjective Evaluation(#1043)
* fix multiround * fix
This commit is contained in:
parent
8c85edd1cd
commit
6f98c8d9ab
@ -30,6 +30,7 @@ for _name in subjective_all_sets:
|
|||||||
subjective_eval_cfg = dict(
|
subjective_eval_cfg = dict(
|
||||||
evaluator=dict(
|
evaluator=dict(
|
||||||
type=LMEvaluator,
|
type=LMEvaluator,
|
||||||
|
wrap_all_predictions=True,
|
||||||
prompt_template=dict(
|
prompt_template=dict(
|
||||||
type=PromptTemplate,
|
type=PromptTemplate,
|
||||||
template=dict(round=[
|
template=dict(round=[
|
||||||
|
111
configs/eval_subjective_functional_multiround.py
Normal file
111
configs/eval_subjective_functional_multiround.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from opencompass.models import HuggingFaceCausalLM
|
||||||
|
from copy import deepcopy
|
||||||
|
from opencompass.models import TurboMindModel
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3, OpenAI
|
||||||
|
from opencompass.partitioners import NaivePartitioner, SizePartitioner
|
||||||
|
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
|
||||||
|
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
|
||||||
|
from opencompass.runners import LocalRunner
|
||||||
|
from opencompass.runners import SlurmSequentialRunner
|
||||||
|
from opencompass.tasks import OpenICLInferTask
|
||||||
|
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
|
||||||
|
from opencompass.summarizers import MultiroundSummarizer
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .datasets.subjective.multiround.functionalmt_zh_judgeby_gpt4 import subjective_datasets
|
||||||
|
|
||||||
|
api_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role='HUMAN', api_role='HUMAN'),
|
||||||
|
dict(role='BOT', api_role='BOT', generate=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", begin='<|im_start|>user\n', end='<|im_end|>\n'),
|
||||||
|
dict(role="BOT", begin="<|im_start|>assistant\n", end='<|im_end|>\n', generate=True),
|
||||||
|
],
|
||||||
|
eos_token_id=151645,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=HuggingFaceCausalLM,
|
||||||
|
abbr='qwen1.5-7b-chat-hf',
|
||||||
|
path="Qwen/Qwen1.5-7B-Chat",
|
||||||
|
model_kwargs=dict(
|
||||||
|
device_map='auto',
|
||||||
|
trust_remote_code=True
|
||||||
|
),
|
||||||
|
tokenizer_kwargs=dict(
|
||||||
|
padding_side='left',
|
||||||
|
truncation_side='left',
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_fast=False,
|
||||||
|
),
|
||||||
|
generation_kwargs=dict(
|
||||||
|
do_sample=True,
|
||||||
|
),
|
||||||
|
meta_template=_meta_template,
|
||||||
|
pad_token_id=151645,
|
||||||
|
max_out_len=100,
|
||||||
|
max_seq_len=2048,
|
||||||
|
batch_size=8,
|
||||||
|
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||||
|
end_str='<|im_end|>',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
datasets = [*subjective_datasets]
|
||||||
|
|
||||||
|
work_dir = 'outputs/multiround/'
|
||||||
|
# -------------Inferen Stage ----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=SizePartitioner, max_task_size=1000),
|
||||||
|
runner=dict(
|
||||||
|
type=SlurmSequentialRunner,
|
||||||
|
partition='your part',
|
||||||
|
quotatype='auto',
|
||||||
|
max_num_workers=256,
|
||||||
|
task=dict(type=OpenICLInferTask)),
|
||||||
|
)
|
||||||
|
|
||||||
|
judge_models = [dict(
|
||||||
|
abbr='GPT4-Turbo',
|
||||||
|
type=OpenAI,
|
||||||
|
path='gpt-4-1106-preview',
|
||||||
|
key='',
|
||||||
|
meta_template=api_meta_template,
|
||||||
|
query_per_second=1,
|
||||||
|
max_out_len=1024,
|
||||||
|
max_seq_len=4096,
|
||||||
|
batch_size=10,
|
||||||
|
retry=10,
|
||||||
|
temperature = 0,
|
||||||
|
)]
|
||||||
|
|
||||||
|
## ------------- Evaluation Configuration
|
||||||
|
eval = dict(
|
||||||
|
partitioner=dict(
|
||||||
|
type=SubjectiveSizePartitioner,
|
||||||
|
max_task_size=1000,
|
||||||
|
mode='singlescore',
|
||||||
|
models = models,
|
||||||
|
judge_models=judge_models
|
||||||
|
),
|
||||||
|
runner=dict(
|
||||||
|
type=SlurmSequentialRunner,
|
||||||
|
partition='your part',
|
||||||
|
quotatype='auto',
|
||||||
|
max_num_workers=256,
|
||||||
|
task=dict(type=SubjectiveEvalTask)),
|
||||||
|
)
|
||||||
|
|
||||||
|
summarizer = dict(
|
||||||
|
type=MultiroundSummarizer
|
||||||
|
)
|
@ -75,9 +75,11 @@ class LMEvaluator:
|
|||||||
keywords, ``{prediction}`` and ``{reference}``, referring to
|
keywords, ``{prediction}`` and ``{reference}``, referring to
|
||||||
the prediction and optionally the reference answer.
|
the prediction and optionally the reference answer.
|
||||||
judge_cfg (ConfigDict): The config of language model as a judge.
|
judge_cfg (ConfigDict): The config of language model as a judge.
|
||||||
|
meta_review_prompt_template (ConfigDict, optional): Prompt template for meta judge model.
|
||||||
output_path (str): The path to prediction output.
|
output_path (str): The path to prediction output.
|
||||||
dataset_cfg (ConfigDict, optional): The config of the dataset to be
|
dataset_cfg (ConfigDict, optional): The config of the dataset to be
|
||||||
evaluated.
|
evaluated.
|
||||||
|
pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round.
|
||||||
postprocessor (ConfigDict): The model prediction's postprocessor
|
postprocessor (ConfigDict): The model prediction's postprocessor
|
||||||
config.
|
config.
|
||||||
"""
|
"""
|
||||||
@ -88,6 +90,7 @@ class LMEvaluator:
|
|||||||
judge_cfg: ConfigDict,
|
judge_cfg: ConfigDict,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
meta_review_prompt_template: Optional[ConfigDict] = None,
|
meta_review_prompt_template: Optional[ConfigDict] = None,
|
||||||
|
pack_all_predictions: Optional[bool] = False,
|
||||||
dataset_cfg: Optional[ConfigDict] = None,
|
dataset_cfg: Optional[ConfigDict] = None,
|
||||||
postprocessor: ConfigDict = dict(type=first_number_postprocess)
|
postprocessor: ConfigDict = dict(type=first_number_postprocess)
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -112,6 +115,7 @@ class LMEvaluator:
|
|||||||
self.postprocessor = get_type_from_cfg(postprocessor)
|
self.postprocessor = get_type_from_cfg(postprocessor)
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
self.dataset_cfg = dataset_cfg
|
self.dataset_cfg = dataset_cfg
|
||||||
|
self.pack_all_predictions = pack_all_predictions
|
||||||
|
|
||||||
def score(self,
|
def score(self,
|
||||||
predictions,
|
predictions,
|
||||||
@ -171,12 +175,17 @@ class LMEvaluator:
|
|||||||
elif isinstance(
|
elif isinstance(
|
||||||
predictions[0][0], list
|
predictions[0][0], list
|
||||||
): #multi round for format like [[[{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}], [{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}]]]
|
): #multi round for format like [[[{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}], [{'round':1, 'user':'', 'assistant':''}, {'round':2, 'user':'', 'assistant':''}]]]
|
||||||
for i in range(len(predictions)):
|
if self.pack_all_predictions:
|
||||||
multiround_predictions = extract_dicts(predictions[i])
|
for i in range(len(predictions)):
|
||||||
for j in range(len(multiround_predictions)):
|
key = 'prediction' if i == 0 else f'prediction{i + 1}'
|
||||||
key = 'prediction' if i == 0 else f'prediction{i}'
|
pred_dict[key] = predictions[i]
|
||||||
key += '_r' + str(j + 1)
|
else:
|
||||||
pred_dict[key] = multiround_predictions[j]
|
for i in range(len(predictions)):
|
||||||
|
multiround_predictions = extract_dicts(predictions[i])
|
||||||
|
for j in range(len(multiround_predictions)):
|
||||||
|
key = 'prediction' if i == 0 else f'prediction{i}'
|
||||||
|
key += '_r' + str(j + 1)
|
||||||
|
pred_dict[key] = multiround_predictions[j]
|
||||||
if judgements:
|
if judgements:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Not applied meta-reivew judge on multi-round dataset')
|
'Not applied meta-reivew judge on multi-round dataset')
|
||||||
|
@ -172,8 +172,6 @@ class ChatInferencer(BaseInferencer):
|
|||||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||||
output_json_filename: Optional[str] = 'predictions',
|
output_json_filename: Optional[str] = 'predictions',
|
||||||
save_every: Optional[int] = 1,
|
save_every: Optional[int] = 1,
|
||||||
temperature: Optional[float] = 0.0,
|
|
||||||
do_sample: Optional[bool] = False,
|
|
||||||
infer_mode: str = 'last',
|
infer_mode: str = 'last',
|
||||||
max_out_len: int = 512,
|
max_out_len: int = 512,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
@ -185,8 +183,6 @@ class ChatInferencer(BaseInferencer):
|
|||||||
)
|
)
|
||||||
assert infer_mode in ['last', 'every', 'every_with_gt']
|
assert infer_mode in ['last', 'every', 'every_with_gt']
|
||||||
self.infer_mode = infer_mode
|
self.infer_mode = infer_mode
|
||||||
self.temperature = temperature
|
|
||||||
self.do_sample = do_sample
|
|
||||||
self.model: BaseModel
|
self.model: BaseModel
|
||||||
self._set_meta_template(self.model)
|
self._set_meta_template(self.model)
|
||||||
|
|
||||||
@ -353,16 +349,8 @@ class ChatInferencer(BaseInferencer):
|
|||||||
|
|
||||||
for i in assistant_indices:
|
for i in assistant_indices:
|
||||||
history = chat[:i]
|
history = chat[:i]
|
||||||
if self.do_sample:
|
output = self.model.generate_from_template(
|
||||||
output = self.model.generate_from_template(
|
[history], max_out_len=self.max_out_len)[0]
|
||||||
[history],
|
|
||||||
do_sample=self.do_sample,
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_out_len=self.max_out_len)[0]
|
|
||||||
else:
|
|
||||||
output = self.model.generate_from_template(
|
|
||||||
[history], do_sample=False,
|
|
||||||
max_out_len=self.max_out_len)[0]
|
|
||||||
chat[i]['content'] = output
|
chat[i]['content'] = output
|
||||||
if not self.dialogue_mode:
|
if not self.dialogue_mode:
|
||||||
output_handler.save_multiround_results(
|
output_handler.save_multiround_results(
|
||||||
|
@ -128,7 +128,8 @@ class MultiroundSummarizer:
|
|||||||
self.eval_model_abbrs = [
|
self.eval_model_abbrs = [
|
||||||
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
|
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
|
||||||
]
|
]
|
||||||
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_model'])
|
self.judge_abbr = model_abbr_from_cfg(
|
||||||
|
self.cfg['eval']['partitioner']['judge_models'][0])
|
||||||
|
|
||||||
def summarize(self,
|
def summarize(self,
|
||||||
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
|
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
|
||||||
|
Loading…
Reference in New Issue
Block a user