From ea517e3377615aa48fb550ce73d070f799901ee0 Mon Sep 17 00:00:00 2001 From: xusong28 Date: Fri, 12 Jul 2024 16:07:24 +0800 Subject: [PATCH] fix yapf --- opencompass/openicl/icl_inferencer/icl_base_inferencer.py | 7 ++++++- opencompass/openicl/icl_inferencer/icl_gen_inferencer.py | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py index 2ac4423c..087c28a9 100644 --- a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py @@ -111,7 +111,12 @@ class GenInferencerOutputHandler: """Dump the result to a json file.""" dump_results_dict(self.results_dict, Path(save_dir) / filename) - def save_results(self, origin_prompt, prediction, idx, api_prompts=None, gold=None): + def save_results(self, + origin_prompt, + prediction, + idx, + api_prompts=None, + gold=None): results = {} if api_prompts: results['api_prompts'] = api_prompts diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index 63e10ad8..8ef6edcc 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -156,15 +156,15 @@ class GenInferencer(BaseInferencer): api_prompts_list = parsed_entries prompts, generated = generated else: - api_prompts_list = [None]*len(generated) + api_prompts_list = [None] * len(generated) prompts = parsed_entries num_return_sequences = getattr(self.model, 'generation_kwargs', {}).get('num_return_sequences', 1) # 5-3. Save current output for api_prompts, prompt, prediction, gold in zip( - api_prompts_list, prompts, batched(generated, num_return_sequences), - golds): + api_prompts_list, prompts, + batched(generated, num_return_sequences), golds): if num_return_sequences == 1: prediction = prediction[0] output_handler.save_results(prompt,