diff --git a/opencompass/models/vllm.py b/opencompass/models/vllm.py index ab042437..fbfaf66e 100644 --- a/opencompass/models/vllm.py +++ b/opencompass/models/vllm.py @@ -7,6 +7,7 @@ from opencompass.utils import get_logger try: from vllm import LLM, SamplingParams + from vllm.lora.request import LoRARequest except ImportError: LLM, SamplingParams = None, None @@ -25,6 +26,7 @@ class VLLM(BaseModel): meta_template: Optional[Dict] = None, mode: str = 'none', use_fastchat_template: bool = False, + lora_path: str = None, stop_words: List[str] = [], ): super().__init__(path=path, @@ -38,7 +40,7 @@ class VLLM(BaseModel): self.tokenizer = self.model.get_tokenizer() self.generation_kwargs = generation_kwargs self.generation_kwargs.pop('do_sample', None) - + self.lora_path = lora_path assert mode in ['none', 'mid'] self.mode = mode self.use_fastchat_template = use_fastchat_template @@ -96,7 +98,10 @@ class VLLM(BaseModel): _stop = list(set(self.stop_words + stopping_criteria)) generation_kwargs.update({'stop': _stop}) sampling_kwargs = SamplingParams(**generation_kwargs) - outputs = self.model.generate(inputs, sampling_kwargs) + if not self.lora_path: + outputs = self.model.generate(inputs, sampling_kwargs) + else: + outputs = self.model.generate(inputs, sampling_kwargs, lora_request=LoRARequest("sql_adapter", 1, self.lora_path)) prompt_list, output_strs = [], [] for output in outputs: