From 2450a22f1848554ecdc612908df70ecc0c8a357a Mon Sep 17 00:00:00 2001 From: xusong28 Date: Fri, 12 Jul 2024 15:11:34 +0800 Subject: [PATCH] add api_prompts to GenInferencerOutput --- opencompass/models/huggingface_above_v4_33.py | 3 ++- opencompass/models/vllm_with_tf_above_v4_33.py | 2 +- .../openicl/icl_inferencer/icl_base_inferencer.py | 14 ++++++++------ .../openicl/icl_inferencer/icl_gen_inferencer.py | 11 +++++++++-- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 329ea2a3..6e7265d9 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -270,6 +270,7 @@ class HuggingFacewithChatTemplate(BaseModel): tokenize_kwargs['add_special_tokens'] = False tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) + prompt_list = messages tokens = {k: v.to(self.model.device) for k, v in tokens.items()} generation_kwargs = self.generation_kwargs.copy() @@ -292,7 +293,7 @@ class HuggingFacewithChatTemplate(BaseModel): for stop in stopping_criteria: decodeds = [t.split(stop)[0] for t in decodeds] - return decodeds + return prompt_list, decodeds def get_token_len(self, prompt: str) -> int: m = _convert_chat_messages([prompt])[0] diff --git a/opencompass/models/vllm_with_tf_above_v4_33.py b/opencompass/models/vllm_with_tf_above_v4_33.py index cf79ea6f..8196a142 100644 --- a/opencompass/models/vllm_with_tf_above_v4_33.py +++ b/opencompass/models/vllm_with_tf_above_v4_33.py @@ -118,7 +118,7 @@ class VLLMwithChatTemplate(BaseModel): prompt_list.append(prompt) output_strs.append(generated_text) - return output_strs + return prompt_list, output_strs def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized strings. diff --git a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py index b08a6fab..2ac4423c 100644 --- a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py @@ -111,13 +111,15 @@ class GenInferencerOutputHandler: """Dump the result to a json file.""" dump_results_dict(self.results_dict, Path(save_dir) / filename) - def save_results(self, origin_prompt, prediction, idx, gold=None): - self.results_dict[str(idx)] = { - 'origin_prompt': origin_prompt, - 'prediction': prediction, - } + def save_results(self, origin_prompt, prediction, idx, api_prompts=None, gold=None): + results = {} + if api_prompts: + results['api_prompts'] = api_prompts + results['origin_prompt'] = origin_prompt + results['prediction'] = prediction if gold: - self.results_dict[str(idx)]['gold'] = gold + results['gold'] = gold + self.results_dict[str(idx)] = results class PPLInferencerOutputHandler: diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index 17bdf468..63e10ad8 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -152,18 +152,25 @@ class GenInferencer(BaseInferencer): results = self.model.generate_from_template( entry, max_out_len=self.max_out_len, **extra_gen_kwargs) generated = results + if isinstance(generated, tuple): + api_prompts_list = parsed_entries + prompts, generated = generated + else: + api_prompts_list = [None]*len(generated) + prompts = parsed_entries num_return_sequences = getattr(self.model, 'generation_kwargs', {}).get('num_return_sequences', 1) # 5-3. Save current output - for prompt, prediction, gold in zip( - parsed_entries, batched(generated, num_return_sequences), + for api_prompts, prompt, prediction, gold in zip( + api_prompts_list, prompts, batched(generated, num_return_sequences), golds): if num_return_sequences == 1: prediction = prediction[0] output_handler.save_results(prompt, prediction, index, + api_prompts=api_prompts, gold=gold) index = index + 1