diff --git a/opencompass/__init__.py b/opencompass/__init__.py index 020ed73d..d93b5b24 100644 --- a/opencompass/__init__.py +++ b/opencompass/__init__.py @@ -1 +1 @@ -__version__ = '0.2.2' +__version__ = '0.2.3' diff --git a/opencompass/openicl/icl_inferencer/icl_chat_inferencer.py b/opencompass/openicl/icl_inferencer/icl_chat_inferencer.py index ee28b140..544aaf85 100644 --- a/opencompass/openicl/icl_inferencer/icl_chat_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_chat_inferencer.py @@ -175,6 +175,7 @@ class ChatInferencer(BaseInferencer): temperature: Optional[float] = 0.0, do_sample: Optional[bool] = False, infer_mode: str = 'last', + max_out_len: int = 512, **kwargs) -> None: super().__init__( model=model, @@ -193,6 +194,7 @@ class ChatInferencer(BaseInferencer): save_every = 1 self.save_every = save_every self.dialogue_mode = False + self.max_out_len = max_out_len def _set_meta_template(self, model): origin = model.template_parser @@ -334,8 +336,8 @@ class ChatInferencer(BaseInferencer): ] history = chat[:assistant_indices[-1]] - output = self.model.generate_from_template([history], - max_out_len=512)[0] + output = self.model.generate_from_template( + [history], max_out_len=self.max_out_len)[0] output_handler.save_results( origin_prompt=history, prediction=output, @@ -356,11 +358,11 @@ class ChatInferencer(BaseInferencer): [history], do_sample=self.do_sample, temperature=self.temperature, - max_out_len=512)[0] + max_out_len=self.max_out_len)[0] else: - output = self.model.generate_from_template([history], - do_sample=False, - max_out_len=512)[0] + output = self.model.generate_from_template( + [history], do_sample=False, + max_out_len=self.max_out_len)[0] chat[i]['content'] = output if not self.dialogue_mode: output_handler.save_multiround_results( @@ -397,8 +399,8 @@ class ChatInferencer(BaseInferencer): for i in assistant_indices: history = chat[:i] - output = self.model.generate_from_template([history], - max_out_len=512)[0] + output = self.model.generate_from_template( + [history], max_out_len=self.max_out_len)[0] output_handler.save_multiround_results( origin_prompt=history[-1]['content'], prediction=output,