mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add vllm get_ppl (#1003)
* add vllm get_ppl * add vllm get_ppl * format --------- Co-authored-by: xingjin.wang <xingjin.wang@mihoyo.com> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
3a232db471
commit
048d41a1c4
@ -1,5 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
@ -103,6 +105,29 @@ class VLLM(BaseModel):
|
||||
|
||||
return output_strs
|
||||
|
||||
def get_ppl(self,
|
||||
inputs: List[str],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
batch_size = len(inputs)
|
||||
sampling_kwargs = SamplingParams(prompt_logprobs=0,
|
||||
**self.generation_kwargs)
|
||||
# forward
|
||||
outputs = self.model.generate(inputs, sampling_kwargs)
|
||||
# compute ppl
|
||||
ce_loss = []
|
||||
for i in range(batch_size):
|
||||
prompt_logprobs = outputs[i].prompt_logprobs[1:]
|
||||
prompt_token_ids = outputs[i].prompt_token_ids[1:]
|
||||
prompt_logprobs_list = [
|
||||
prompt_logprobs[i][prompt_token_ids[i]]
|
||||
for i in range(len(prompt_logprobs))
|
||||
]
|
||||
prompt_logprobs_list = [i.logprob for i in prompt_logprobs_list]
|
||||
prompt_logprobs_list = np.array(prompt_logprobs_list)
|
||||
loss = -prompt_logprobs_list.sum(axis=-1) / len(prompt_token_ids)
|
||||
ce_loss.append(loss)
|
||||
return np.array(ce_loss)
|
||||
|
||||
def prompts_preproccess(self, inputs: List[str]):
|
||||
if self.use_fastchat_template:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user