diff --git a/opencompass/models/llama2.py b/opencompass/models/llama2.py index b452e684..9971ece3 100644 --- a/opencompass/models/llama2.py +++ b/opencompass/models/llama2.py @@ -59,12 +59,18 @@ class Llama2(BaseModel): self.tokenizer = Tokenizer(tokenizer_path) def generate(self, inputs: List[str], max_out_len: int) -> List[str]: - out = self.generator.text_completion( - inputs, - temperature=0, + prompt_tokens = [] + for input in inputs: + tokens = self.tokenizer.encode(input, True, False) + num_token = min(self.model.params.max_seq_len, len(tokens)) + prompt_tokens.append(tokens[-num_token:]) + generation_tokens, _ = self.generator.generate( + prompt_tokens=prompt_tokens, max_gen_len=max_out_len, + temperature=0, ) - return [i['generation'] for i in out] + results = [self.tokenizer.decode(t) for t in generation_tokens] + return results def get_ppl(self, inputs: List[str], @@ -183,8 +189,8 @@ class Llama2Chat(BaseModel): ) return [r['generation']['content'] for r in results] except AssertionError: - self.warning('Batched data max token limit exceeded, ' - 'try to run one by one...') + self.logger.warning('Batched data max token limit exceeded, ' + 'try to run one by one...') results = [] for dialog in dialogs: