[Fix] Fix gen inferencer (#615)

This commit is contained in:
Fengzhe Zhou 2023-11-22 12:04:31 +08:00 committed by GitHub
parent 721a45c68f
commit fb30b7c7a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -130,8 +130,8 @@ class GenInferencer(BaseInferencer):
entry, max_out_len=self.max_out_len)
generated = results
num_return_sequences = self.model.get('generation_kwargs', {}).get(
'num_return_sequences', 1)
num_return_sequences = getattr(self.model, 'generation_kwargs',
{}).get('num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences),