mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] *_batch_generate* function, add the MultiTokenEOSCriteria (#772)
* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference. * jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference. --------- Co-authored-by: jiangjin08 <jiangjin08@MBP-2F32S5MD6P-0029.local> Co-authored-by: jiangjin08 <jiangjin08@a.sh.vip.dianping.com>
This commit is contained in:
parent
f78fcf6eeb
commit
8194199d79
@ -241,6 +241,7 @@ class HuggingFace(BaseModel):
|
|||||||
if self.batch_padding and len(inputs) > 1:
|
if self.batch_padding and len(inputs) > 1:
|
||||||
return self._batch_generate(inputs=inputs,
|
return self._batch_generate(inputs=inputs,
|
||||||
max_out_len=max_out_len,
|
max_out_len=max_out_len,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
**generation_kwargs)
|
**generation_kwargs)
|
||||||
else:
|
else:
|
||||||
return sum(
|
return sum(
|
||||||
@ -250,7 +251,9 @@ class HuggingFace(BaseModel):
|
|||||||
**generation_kwargs)
|
**generation_kwargs)
|
||||||
for input_ in inputs), [])
|
for input_ in inputs), [])
|
||||||
|
|
||||||
def _batch_generate(self, inputs: List[str], max_out_len: int,
|
def _batch_generate(self, inputs: List[str],
|
||||||
|
max_out_len: int,
|
||||||
|
stopping_criteria: List[str] = [],
|
||||||
**kwargs) -> List[str]:
|
**kwargs) -> List[str]:
|
||||||
"""Support for batch prompts inference.
|
"""Support for batch prompts inference.
|
||||||
|
|
||||||
@ -289,6 +292,19 @@ class HuggingFace(BaseModel):
|
|||||||
for k in tokens if k in ['input_ids', 'attention_mask']
|
for k in tokens if k in ['input_ids', 'attention_mask']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stopping_criteria:
|
||||||
|
# Construct huggingface stopping criteria
|
||||||
|
if self.tokenizer.eos_token is not None:
|
||||||
|
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
|
||||||
|
stopping_criteria = transformers.StoppingCriteriaList([
|
||||||
|
*[
|
||||||
|
MultiTokenEOSCriteria(sequence, self.tokenizer,
|
||||||
|
tokens['input_ids'].shape[0])
|
||||||
|
for sequence in stopping_criteria
|
||||||
|
],
|
||||||
|
])
|
||||||
|
kwargs['stopping_criteria'] = stopping_criteria
|
||||||
|
|
||||||
# step-2: conduct model forward to generate output
|
# step-2: conduct model forward to generate output
|
||||||
outputs = self.model.generate(**tokens,
|
outputs = self.model.generate(**tokens,
|
||||||
max_new_tokens=max_out_len,
|
max_new_tokens=max_out_len,
|
||||||
@ -359,7 +375,8 @@ class HuggingFace(BaseModel):
|
|||||||
|
|
||||||
if stopping_criteria:
|
if stopping_criteria:
|
||||||
# Construct huggingface stopping criteria
|
# Construct huggingface stopping criteria
|
||||||
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
|
if self.tokenizer.eos_token is not None:
|
||||||
|
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
|
||||||
stopping_criteria = transformers.StoppingCriteriaList([
|
stopping_criteria = transformers.StoppingCriteriaList([
|
||||||
*[
|
*[
|
||||||
MultiTokenEOSCriteria(sequence, self.tokenizer,
|
MultiTokenEOSCriteria(sequence, self.tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user