add vllm get_ppl (#1003)

* add vllm get_ppl

* add vllm get_ppl

* format

---------

Co-authored-by: xingjin.wang <xingjin.wang@mihoyo.com>
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Wang Xingjin 2024-04-26 21:31:56 +08:00 committed by GitHub
parent 3a232db471
commit 048d41a1c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,7 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
from opencompass.models.base import BaseModel from opencompass.models.base import BaseModel
from opencompass.utils import get_logger from opencompass.utils import get_logger
@ -103,6 +105,29 @@ class VLLM(BaseModel):
return output_strs return output_strs
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
batch_size = len(inputs)
sampling_kwargs = SamplingParams(prompt_logprobs=0,
**self.generation_kwargs)
# forward
outputs = self.model.generate(inputs, sampling_kwargs)
# compute ppl
ce_loss = []
for i in range(batch_size):
prompt_logprobs = outputs[i].prompt_logprobs[1:]
prompt_token_ids = outputs[i].prompt_token_ids[1:]
prompt_logprobs_list = [
prompt_logprobs[i][prompt_token_ids[i]]
for i in range(len(prompt_logprobs))
]
prompt_logprobs_list = [i.logprob for i in prompt_logprobs_list]
prompt_logprobs_list = np.array(prompt_logprobs_list)
loss = -prompt_logprobs_list.sum(axis=-1) / len(prompt_token_ids)
ce_loss.append(loss)
return np.array(ce_loss)
def prompts_preproccess(self, inputs: List[str]): def prompts_preproccess(self, inputs: List[str]):
if self.use_fastchat_template: if self.use_fastchat_template:
try: try: