From 49c467458fffff0f7016954cf684865ad32b725a Mon Sep 17 00:00:00 2001 From: Leymore Date: Fri, 8 Sep 2023 12:47:56 +0800 Subject: [PATCH] [Feature] Update llama2 (#372) --- opencompass/models/llama2.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/opencompass/models/llama2.py b/opencompass/models/llama2.py index b452e684..9971ece3 100644 --- a/opencompass/models/llama2.py +++ b/opencompass/models/llama2.py @@ -59,12 +59,18 @@ class Llama2(BaseModel): self.tokenizer = Tokenizer(tokenizer_path) def generate(self, inputs: List[str], max_out_len: int) -> List[str]: - out = self.generator.text_completion( - inputs, - temperature=0, + prompt_tokens = [] + for input in inputs: + tokens = self.tokenizer.encode(input, True, False) + num_token = min(self.model.params.max_seq_len, len(tokens)) + prompt_tokens.append(tokens[-num_token:]) + generation_tokens, _ = self.generator.generate( + prompt_tokens=prompt_tokens, max_gen_len=max_out_len, + temperature=0, ) - return [i['generation'] for i in out] + results = [self.tokenizer.decode(t) for t in generation_tokens] + return results def get_ppl(self, inputs: List[str], @@ -183,8 +189,8 @@ class Llama2Chat(BaseModel): ) return [r['generation']['content'] for r in results] except AssertionError: - self.warning('Batched data max token limit exceeded, ' - 'try to run one by one...') + self.logger.warning('Batched data max token limit exceeded, ' + 'try to run one by one...') results = [] for dialog in dialogs: