[FIX] fix interntrain get_loglikelihood (#1584)

This commit is contained in:
x54-729 2024-10-08 11:34:04 +08:00 committed by GitHub
parent 89abcba486
commit 4d6349dfe1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -288,7 +288,7 @@ class InternTrain(BaseModel):
else: else:
raise NotImplementedError(f'Unknown model dtype {model_dtype}') raise NotImplementedError(f'Unknown model dtype {model_dtype}')
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str, use_bos=None, use_eos=None) -> int:
"""Get lengths of the tokenized strings. """Get lengths of the tokenized strings.
Args: Args:
@ -297,7 +297,7 @@ class InternTrain(BaseModel):
Returns: Returns:
int: Length of the input tokens int: Length of the input tokens
""" """
tokens = self.tokenizer(prompt, use_bos=True, use_eos=True) tokens = self.tokenizer(prompt, use_bos=use_bos, use_eos=use_eos)
return len(tokens) return len(tokens)
def generate(self, def generate(self,
@ -391,7 +391,7 @@ class InternTrain(BaseModel):
for input_text, cont in zip(input_texts, conts) for input_text, cont in zip(input_texts, conts)
] ]
replaced_lens = [ replaced_lens = [
len(self.encode(input_text)[0]) for input_text in replaced_texts self.get_token_len(input_text) for input_text in replaced_texts
] ]
loglikelihoods = [] loglikelihoods = []
for nloss, nlen, rlen in zip(loss, lens, replaced_lens): for nloss, nlen, rlen in zip(loss, lens, replaced_lens):