[Feature] Update llama2 (#372)

This commit is contained in:
Leymore 2023-09-08 12:47:56 +08:00 committed by GitHub
parent 3871188c89
commit 49c467458f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -59,12 +59,18 @@ class Llama2(BaseModel):
self.tokenizer = Tokenizer(tokenizer_path)
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
out = self.generator.text_completion(
inputs,
temperature=0,
prompt_tokens = []
for input in inputs:
tokens = self.tokenizer.encode(input, True, False)
num_token = min(self.model.params.max_seq_len, len(tokens))
prompt_tokens.append(tokens[-num_token:])
generation_tokens, _ = self.generator.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_out_len,
temperature=0,
)
return [i['generation'] for i in out]
results = [self.tokenizer.decode(t) for t in generation_tokens]
return results
def get_ppl(self,
inputs: List[str],
@ -183,8 +189,8 @@ class Llama2Chat(BaseModel):
)
return [r['generation']['content'] for r in results]
except AssertionError:
self.warning('Batched data max token limit exceeded, '
'try to run one by one...')
self.logger.warning('Batched data max token limit exceeded, '
'try to run one by one...')
results = []
for dialog in dialogs: