mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Update llama2 (#372)
This commit is contained in:
parent
3871188c89
commit
49c467458f
@ -59,12 +59,18 @@ class Llama2(BaseModel):
|
||||
self.tokenizer = Tokenizer(tokenizer_path)
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
||||
out = self.generator.text_completion(
|
||||
inputs,
|
||||
temperature=0,
|
||||
prompt_tokens = []
|
||||
for input in inputs:
|
||||
tokens = self.tokenizer.encode(input, True, False)
|
||||
num_token = min(self.model.params.max_seq_len, len(tokens))
|
||||
prompt_tokens.append(tokens[-num_token:])
|
||||
generation_tokens, _ = self.generator.generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_gen_len=max_out_len,
|
||||
temperature=0,
|
||||
)
|
||||
return [i['generation'] for i in out]
|
||||
results = [self.tokenizer.decode(t) for t in generation_tokens]
|
||||
return results
|
||||
|
||||
def get_ppl(self,
|
||||
inputs: List[str],
|
||||
@ -183,8 +189,8 @@ class Llama2Chat(BaseModel):
|
||||
)
|
||||
return [r['generation']['content'] for r in results]
|
||||
except AssertionError:
|
||||
self.warning('Batched data max token limit exceeded, '
|
||||
'try to run one by one...')
|
||||
self.logger.warning('Batched data max token limit exceeded, '
|
||||
'try to run one by one...')
|
||||
|
||||
results = []
|
||||
for dialog in dialogs:
|
||||
|
Loading…
Reference in New Issue
Block a user