mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] *_batch_generate* function, add the MultiTokenEOSCriteria (#772)
* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference. * jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference. --------- Co-authored-by: jiangjin08 <jiangjin08@MBP-2F32S5MD6P-0029.local> Co-authored-by: jiangjin08 <jiangjin08@a.sh.vip.dianping.com>
This commit is contained in:
parent
f78fcf6eeb
commit
8194199d79
@ -241,6 +241,7 @@ class HuggingFace(BaseModel):
|
||||
if self.batch_padding and len(inputs) > 1:
|
||||
return self._batch_generate(inputs=inputs,
|
||||
max_out_len=max_out_len,
|
||||
stopping_criteria=stopping_criteria,
|
||||
**generation_kwargs)
|
||||
else:
|
||||
return sum(
|
||||
@ -250,7 +251,9 @@ class HuggingFace(BaseModel):
|
||||
**generation_kwargs)
|
||||
for input_ in inputs), [])
|
||||
|
||||
def _batch_generate(self, inputs: List[str], max_out_len: int,
|
||||
def _batch_generate(self, inputs: List[str],
|
||||
max_out_len: int,
|
||||
stopping_criteria: List[str] = [],
|
||||
**kwargs) -> List[str]:
|
||||
"""Support for batch prompts inference.
|
||||
|
||||
@ -289,6 +292,19 @@ class HuggingFace(BaseModel):
|
||||
for k in tokens if k in ['input_ids', 'attention_mask']
|
||||
}
|
||||
|
||||
if stopping_criteria:
|
||||
# Construct huggingface stopping criteria
|
||||
if self.tokenizer.eos_token is not None:
|
||||
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
|
||||
stopping_criteria = transformers.StoppingCriteriaList([
|
||||
*[
|
||||
MultiTokenEOSCriteria(sequence, self.tokenizer,
|
||||
tokens['input_ids'].shape[0])
|
||||
for sequence in stopping_criteria
|
||||
],
|
||||
])
|
||||
kwargs['stopping_criteria'] = stopping_criteria
|
||||
|
||||
# step-2: conduct model forward to generate output
|
||||
outputs = self.model.generate(**tokens,
|
||||
max_new_tokens=max_out_len,
|
||||
@ -359,7 +375,8 @@ class HuggingFace(BaseModel):
|
||||
|
||||
if stopping_criteria:
|
||||
# Construct huggingface stopping criteria
|
||||
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
|
||||
if self.tokenizer.eos_token is not None:
|
||||
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
|
||||
stopping_criteria = transformers.StoppingCriteriaList([
|
||||
*[
|
||||
MultiTokenEOSCriteria(sequence, self.tokenizer,
|
||||
|
Loading…
Reference in New Issue
Block a user