From fcab30f82ea4daec22dfd1db0dd55e7e06440629 Mon Sep 17 00:00:00 2001 From: Hubert <42952108+yingfhu@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:00:25 +0800 Subject: [PATCH] [Fix] change save_every defaults to 1 (#592) --- opencompass/openicl/icl_inferencer/icl_attack_inferencer.py | 4 ++-- opencompass/openicl/icl_inferencer/icl_gen_inferencer.py | 4 ++-- opencompass/openicl/icl_inferencer/icl_sc_inferencer.py | 4 ++-- opencompass/openicl/icl_inferencer/icl_tot_inferencer.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py b/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py index f8d8ea04..39b84560 100644 --- a/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py @@ -42,7 +42,7 @@ class AttackInferencer(BaseInferencer): gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts. save_every (:obj:`int`, optional): Save intermediate results every - `save_every` epochs. + `save_every` iters. Defaults to 1. generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. """ @@ -58,7 +58,7 @@ class AttackInferencer(BaseInferencer): gen_field_replace_token: Optional[str] = '', output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', - save_every: Optional[int] = None, + save_every: Optional[int] = 1, dataset_cfg: Optional[List[int]] = None, **kwargs) -> None: super().__init__( diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index 0398e2c6..5cbd637f 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -36,7 +36,7 @@ class GenInferencer(BaseInferencer): gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts. save_every (:obj:`int`, optional): Save intermediate results every - `save_every` epochs. + `save_every` iters. Defaults to 1. generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. """ @@ -50,7 +50,7 @@ class GenInferencer(BaseInferencer): gen_field_replace_token: Optional[str] = '', output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', - save_every: Optional[int] = None, + save_every: Optional[int] = 1, **kwargs) -> None: super().__init__( model=model, diff --git a/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py b/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py index dbbd41c9..0544c9b1 100644 --- a/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py @@ -34,7 +34,7 @@ class SCInferencer(BaseInferencer): gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts. save_every (:obj:`int`, optional): Save intermediate results every - `save_every` epochs. + `save_every` iters. Defaults to 1. generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. sc_size (:obj:`int`, optional): Sample size for Self-Consistency @@ -51,7 +51,7 @@ class SCInferencer(BaseInferencer): gen_field_replace_token: Optional[str] = '', output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', - save_every: Optional[int] = None, + save_every: Optional[int] = 1, sc_size: Optional[int] = 1, infer_type: Optional[str] = '', generation_kwargs: dict = {}, diff --git a/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py b/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py index 22a2298e..939a2066 100644 --- a/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py @@ -43,7 +43,7 @@ class ToTInferencer(GenInferencer): gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts. save_every (:obj:`int`, optional): Save intermediate results every - `save_every` epochs. + `save_every` iters. Defaults to 1. generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of @@ -74,7 +74,7 @@ class ToTInferencer(GenInferencer): gen_field_replace_token: Optional[str] = '', output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', - save_every: Optional[int] = None, + save_every: Optional[int] = 1, naive_run: bool = False, prompt_wrapper: dict = {}, prompt_sample: str = 'standard',