[Fix] Update loglikehood compatibility (#1659)

This commit is contained in:
Lyu Han 2024-11-02 17:19:11 +08:00 committed by GitHub
parent f7d899823c
commit 888f1f3bef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -189,6 +189,7 @@ class TurboMindModel(BaseModel):
assert isinstance(
inputs, List), f'List(str) is expected, but got {type(inputs)}'
results = []
if self.version_info <= (0, 6, 0):
for text, cont in zip(inputs, conts):
input_ids = self.tokenizer.encode(text)
res = self.pipe.get_ppl(input_ids)
@ -198,6 +199,16 @@ class TurboMindModel(BaseModel):
logit_part = res * len(input_ids)
results.append(-(logit_sum - logit_part))
results = np.concatenate(results)
else:
for text, cont in zip(inputs, conts):
input_ids = self.tokenizer.encode(text)
res = self.pipe.get_ppl(input_ids)
logit_sum = res * len(input_ids)
input_ids = self.tokenizer.encode(text.replace(cont, ''))
res = self.pipe.get_ppl(input_ids)
logit_part = res * len(input_ids)
results.append(-(logit_sum[0] - logit_part[0]))
results = np.array(results)
return results
def _build_pipe(self, model_path, backend, engine_config):