mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[FIX] fix interntrain get_loglikelihood (#1584)
This commit is contained in:
parent
89abcba486
commit
4d6349dfe1
@ -288,7 +288,7 @@ class InternTrain(BaseModel):
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model dtype {model_dtype}')
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
def get_token_len(self, prompt: str, use_bos=None, use_eos=None) -> int:
|
||||
"""Get lengths of the tokenized strings.
|
||||
|
||||
Args:
|
||||
@ -297,7 +297,7 @@ class InternTrain(BaseModel):
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
tokens = self.tokenizer(prompt, use_bos=True, use_eos=True)
|
||||
tokens = self.tokenizer(prompt, use_bos=use_bos, use_eos=use_eos)
|
||||
return len(tokens)
|
||||
|
||||
def generate(self,
|
||||
@ -391,7 +391,7 @@ class InternTrain(BaseModel):
|
||||
for input_text, cont in zip(input_texts, conts)
|
||||
]
|
||||
replaced_lens = [
|
||||
len(self.encode(input_text)[0]) for input_text in replaced_texts
|
||||
self.get_token_len(input_text) for input_text in replaced_texts
|
||||
]
|
||||
loglikelihoods = []
|
||||
for nloss, nlen, rlen in zip(loss, lens, replaced_lens):
|
||||
|
Loading…
Reference in New Issue
Block a user