[Fix] Update loglikehood compatibility (#1659)

This commit is contained in:
Lyu Han 2024-11-02 17:19:11 +08:00 committed by GitHub
parent f7d899823c
commit 888f1f3bef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -189,15 +189,26 @@ class TurboMindModel(BaseModel):
assert isinstance( assert isinstance(
inputs, List), f'List(str) is expected, but got {type(inputs)}' inputs, List), f'List(str) is expected, but got {type(inputs)}'
results = [] results = []
for text, cont in zip(inputs, conts): if self.version_info <= (0, 6, 0):
input_ids = self.tokenizer.encode(text) for text, cont in zip(inputs, conts):
res = self.pipe.get_ppl(input_ids) input_ids = self.tokenizer.encode(text)
logit_sum = res * len(input_ids) res = self.pipe.get_ppl(input_ids)
input_ids = self.tokenizer.encode(text.replace(cont, '')) logit_sum = res * len(input_ids)
res = self.pipe.get_ppl(input_ids) input_ids = self.tokenizer.encode(text.replace(cont, ''))
logit_part = res * len(input_ids) res = self.pipe.get_ppl(input_ids)
results.append(-(logit_sum - logit_part)) logit_part = res * len(input_ids)
results = np.concatenate(results) results.append(-(logit_sum - logit_part))
results = np.concatenate(results)
else:
for text, cont in zip(inputs, conts):
input_ids = self.tokenizer.encode(text)
res = self.pipe.get_ppl(input_ids)
logit_sum = res * len(input_ids)
input_ids = self.tokenizer.encode(text.replace(cont, ''))
res = self.pipe.get_ppl(input_ids)
logit_part = res * len(input_ids)
results.append(-(logit_sum[0] - logit_part[0]))
results = np.array(results)
return results return results
def _build_pipe(self, model_path, backend, engine_config): def _build_pipe(self, model_path, backend, engine_config):