mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add vllm get_ppl (#1003)
* add vllm get_ppl * add vllm get_ppl * format --------- Co-authored-by: xingjin.wang <xingjin.wang@mihoyo.com> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
3a232db471
commit
048d41a1c4
@ -1,5 +1,7 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from opencompass.models.base import BaseModel
|
from opencompass.models.base import BaseModel
|
||||||
from opencompass.utils import get_logger
|
from opencompass.utils import get_logger
|
||||||
|
|
||||||
@ -103,6 +105,29 @@ class VLLM(BaseModel):
|
|||||||
|
|
||||||
return output_strs
|
return output_strs
|
||||||
|
|
||||||
|
def get_ppl(self,
|
||||||
|
inputs: List[str],
|
||||||
|
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||||
|
batch_size = len(inputs)
|
||||||
|
sampling_kwargs = SamplingParams(prompt_logprobs=0,
|
||||||
|
**self.generation_kwargs)
|
||||||
|
# forward
|
||||||
|
outputs = self.model.generate(inputs, sampling_kwargs)
|
||||||
|
# compute ppl
|
||||||
|
ce_loss = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
prompt_logprobs = outputs[i].prompt_logprobs[1:]
|
||||||
|
prompt_token_ids = outputs[i].prompt_token_ids[1:]
|
||||||
|
prompt_logprobs_list = [
|
||||||
|
prompt_logprobs[i][prompt_token_ids[i]]
|
||||||
|
for i in range(len(prompt_logprobs))
|
||||||
|
]
|
||||||
|
prompt_logprobs_list = [i.logprob for i in prompt_logprobs_list]
|
||||||
|
prompt_logprobs_list = np.array(prompt_logprobs_list)
|
||||||
|
loss = -prompt_logprobs_list.sum(axis=-1) / len(prompt_token_ids)
|
||||||
|
ce_loss.append(loss)
|
||||||
|
return np.array(ce_loss)
|
||||||
|
|
||||||
def prompts_preproccess(self, inputs: List[str]):
|
def prompts_preproccess(self, inputs: List[str]):
|
||||||
if self.use_fastchat_template:
|
if self.use_fastchat_template:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user