mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] Bump version 0.2.3 (#957)
This commit is contained in:
parent
64fde73b15
commit
ab6cdb2be8
@ -1 +1 @@
|
|||||||
__version__ = '0.2.2'
|
__version__ = '0.2.3'
|
||||||
|
@ -175,6 +175,7 @@ class ChatInferencer(BaseInferencer):
|
|||||||
temperature: Optional[float] = 0.0,
|
temperature: Optional[float] = 0.0,
|
||||||
do_sample: Optional[bool] = False,
|
do_sample: Optional[bool] = False,
|
||||||
infer_mode: str = 'last',
|
infer_mode: str = 'last',
|
||||||
|
max_out_len: int = 512,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -193,6 +194,7 @@ class ChatInferencer(BaseInferencer):
|
|||||||
save_every = 1
|
save_every = 1
|
||||||
self.save_every = save_every
|
self.save_every = save_every
|
||||||
self.dialogue_mode = False
|
self.dialogue_mode = False
|
||||||
|
self.max_out_len = max_out_len
|
||||||
|
|
||||||
def _set_meta_template(self, model):
|
def _set_meta_template(self, model):
|
||||||
origin = model.template_parser
|
origin = model.template_parser
|
||||||
@ -334,8 +336,8 @@ class ChatInferencer(BaseInferencer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
history = chat[:assistant_indices[-1]]
|
history = chat[:assistant_indices[-1]]
|
||||||
output = self.model.generate_from_template([history],
|
output = self.model.generate_from_template(
|
||||||
max_out_len=512)[0]
|
[history], max_out_len=self.max_out_len)[0]
|
||||||
output_handler.save_results(
|
output_handler.save_results(
|
||||||
origin_prompt=history,
|
origin_prompt=history,
|
||||||
prediction=output,
|
prediction=output,
|
||||||
@ -356,11 +358,11 @@ class ChatInferencer(BaseInferencer):
|
|||||||
[history],
|
[history],
|
||||||
do_sample=self.do_sample,
|
do_sample=self.do_sample,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_out_len=512)[0]
|
max_out_len=self.max_out_len)[0]
|
||||||
else:
|
else:
|
||||||
output = self.model.generate_from_template([history],
|
output = self.model.generate_from_template(
|
||||||
do_sample=False,
|
[history], do_sample=False,
|
||||||
max_out_len=512)[0]
|
max_out_len=self.max_out_len)[0]
|
||||||
chat[i]['content'] = output
|
chat[i]['content'] = output
|
||||||
if not self.dialogue_mode:
|
if not self.dialogue_mode:
|
||||||
output_handler.save_multiround_results(
|
output_handler.save_multiround_results(
|
||||||
@ -397,8 +399,8 @@ class ChatInferencer(BaseInferencer):
|
|||||||
|
|
||||||
for i in assistant_indices:
|
for i in assistant_indices:
|
||||||
history = chat[:i]
|
history = chat[:i]
|
||||||
output = self.model.generate_from_template([history],
|
output = self.model.generate_from_template(
|
||||||
max_out_len=512)[0]
|
[history], max_out_len=self.max_out_len)[0]
|
||||||
output_handler.save_multiround_results(
|
output_handler.save_multiround_results(
|
||||||
origin_prompt=history[-1]['content'],
|
origin_prompt=history[-1]['content'],
|
||||||
prediction=output,
|
prediction=output,
|
||||||
|
Loading…
Reference in New Issue
Block a user