[Fix] Fix bugs for PeftModel generate (#252)

* fix bugs

* fix typo
This commit is contained in:
LZHgrla 2023-08-24 14:07:33 +08:00 committed by GitHub
parent 2a5cef2914
commit 77745a84ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -203,7 +203,9 @@ class HuggingFace(BaseModel):
max_length=self.max_seq_len -
max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
outputs = self.model.generate(input_ids,
# To accommodate the PeftModel, parameters should be passed in
# key-value format for generate.
outputs = self.model.generate(input_ids=input_ids,
max_new_tokens=max_out_len,
**kwargs)