mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[FIX] fix interntrain get_loglikelihood (#1584)
This commit is contained in:
parent
89abcba486
commit
4d6349dfe1
@ -288,7 +288,7 @@ class InternTrain(BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown model dtype {model_dtype}')
|
raise NotImplementedError(f'Unknown model dtype {model_dtype}')
|
||||||
|
|
||||||
def get_token_len(self, prompt: str) -> int:
|
def get_token_len(self, prompt: str, use_bos=None, use_eos=None) -> int:
|
||||||
"""Get lengths of the tokenized strings.
|
"""Get lengths of the tokenized strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -297,7 +297,7 @@ class InternTrain(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
int: Length of the input tokens
|
int: Length of the input tokens
|
||||||
"""
|
"""
|
||||||
tokens = self.tokenizer(prompt, use_bos=True, use_eos=True)
|
tokens = self.tokenizer(prompt, use_bos=use_bos, use_eos=use_eos)
|
||||||
return len(tokens)
|
return len(tokens)
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
@ -391,7 +391,7 @@ class InternTrain(BaseModel):
|
|||||||
for input_text, cont in zip(input_texts, conts)
|
for input_text, cont in zip(input_texts, conts)
|
||||||
]
|
]
|
||||||
replaced_lens = [
|
replaced_lens = [
|
||||||
len(self.encode(input_text)[0]) for input_text in replaced_texts
|
self.get_token_len(input_text) for input_text in replaced_texts
|
||||||
]
|
]
|
||||||
loglikelihoods = []
|
loglikelihoods = []
|
||||||
for nloss, nlen, rlen in zip(loss, lens, replaced_lens):
|
for nloss, nlen, rlen in zip(loss, lens, replaced_lens):
|
||||||
|
Loading…
Reference in New Issue
Block a user