[Feature] *_batch_generate* function, add the MultiTokenEOSCriteria (#772)

* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference.

* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference.

---------

Co-authored-by: jiangjin08 <jiangjin08@MBP-2F32S5MD6P-0029.local>
Co-authored-by: jiangjin08 <jiangjin08@a.sh.vip.dianping.com>
This commit is contained in:
jiangjin1999 2024-01-08 16:40:02 +08:00 committed by GitHub
parent f78fcf6eeb
commit 8194199d79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -241,6 +241,7 @@ class HuggingFace(BaseModel):
if self.batch_padding and len(inputs) > 1:
return self._batch_generate(inputs=inputs,
max_out_len=max_out_len,
stopping_criteria=stopping_criteria,
**generation_kwargs)
else:
return sum(
@ -250,7 +251,9 @@ class HuggingFace(BaseModel):
**generation_kwargs)
for input_ in inputs), [])
def _batch_generate(self, inputs: List[str], max_out_len: int,
def _batch_generate(self, inputs: List[str],
max_out_len: int,
stopping_criteria: List[str] = [],
**kwargs) -> List[str]:
"""Support for batch prompts inference.
@ -289,6 +292,19 @@ class HuggingFace(BaseModel):
for k in tokens if k in ['input_ids', 'attention_mask']
}
if stopping_criteria:
# Construct huggingface stopping criteria
if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
stopping_criteria = transformers.StoppingCriteriaList([
*[
MultiTokenEOSCriteria(sequence, self.tokenizer,
tokens['input_ids'].shape[0])
for sequence in stopping_criteria
],
])
kwargs['stopping_criteria'] = stopping_criteria
# step-2: conduct model forward to generate output
outputs = self.model.generate(**tokens,
max_new_tokens=max_out_len,
@ -359,7 +375,8 @@ class HuggingFace(BaseModel):
if stopping_criteria:
# Construct huggingface stopping criteria
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
stopping_criteria = transformers.StoppingCriteriaList([
*[
MultiTokenEOSCriteria(sequence, self.tokenizer,