[Feature] Update llama2 (#372)

This commit is contained in:
Leymore 2023-09-08 12:47:56 +08:00 committed by GitHub
parent 3871188c89
commit 49c467458f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -59,12 +59,18 @@ class Llama2(BaseModel):
self.tokenizer = Tokenizer(tokenizer_path) self.tokenizer = Tokenizer(tokenizer_path)
def generate(self, inputs: List[str], max_out_len: int) -> List[str]: def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
out = self.generator.text_completion( prompt_tokens = []
inputs, for input in inputs:
temperature=0, tokens = self.tokenizer.encode(input, True, False)
num_token = min(self.model.params.max_seq_len, len(tokens))
prompt_tokens.append(tokens[-num_token:])
generation_tokens, _ = self.generator.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_out_len, max_gen_len=max_out_len,
temperature=0,
) )
return [i['generation'] for i in out] results = [self.tokenizer.decode(t) for t in generation_tokens]
return results
def get_ppl(self, def get_ppl(self,
inputs: List[str], inputs: List[str],
@ -183,8 +189,8 @@ class Llama2Chat(BaseModel):
) )
return [r['generation']['content'] for r in results] return [r['generation']['content'] for r in results]
except AssertionError: except AssertionError:
self.warning('Batched data max token limit exceeded, ' self.logger.warning('Batched data max token limit exceeded, '
'try to run one by one...') 'try to run one by one...')
results = [] results = []
for dialog in dialogs: for dialog in dialogs: