From 048d41a1c4a42bf489129f195f72b4739facd365 Mon Sep 17 00:00:00 2001 From: Wang Xingjin <72117044+VVVenus1212@users.noreply.github.com> Date: Fri, 26 Apr 2024 21:31:56 +0800 Subject: [PATCH] add vllm get_ppl (#1003) * add vllm get_ppl * add vllm get_ppl * format --------- Co-authored-by: xingjin.wang Co-authored-by: Leymore --- opencompass/models/vllm.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/opencompass/models/vllm.py b/opencompass/models/vllm.py index c4d836f1..63da7b3f 100644 --- a/opencompass/models/vllm.py +++ b/opencompass/models/vllm.py @@ -1,5 +1,7 @@ from typing import Dict, List, Optional +import numpy as np + from opencompass.models.base import BaseModel from opencompass.utils import get_logger @@ -103,6 +105,29 @@ class VLLM(BaseModel): return output_strs + def get_ppl(self, + inputs: List[str], + mask_length: Optional[List[int]] = None) -> List[float]: + batch_size = len(inputs) + sampling_kwargs = SamplingParams(prompt_logprobs=0, + **self.generation_kwargs) + # forward + outputs = self.model.generate(inputs, sampling_kwargs) + # compute ppl + ce_loss = [] + for i in range(batch_size): + prompt_logprobs = outputs[i].prompt_logprobs[1:] + prompt_token_ids = outputs[i].prompt_token_ids[1:] + prompt_logprobs_list = [ + prompt_logprobs[i][prompt_token_ids[i]] + for i in range(len(prompt_logprobs)) + ] + prompt_logprobs_list = [i.logprob for i in prompt_logprobs_list] + prompt_logprobs_list = np.array(prompt_logprobs_list) + loss = -prompt_logprobs_list.sum(axis=-1) / len(prompt_token_ids) + ce_loss.append(loss) + return np.array(ce_loss) + def prompts_preproccess(self, inputs: List[str]): if self.use_fastchat_template: try: