mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add single lora adapter support for vLLM inference. (#1679)
This commit is contained in:
parent
17b5e52f6c
commit
3ec178f4a9
@ -7,6 +7,7 @@ from opencompass.utils import get_logger
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LLM, SamplingParams = None, None
|
LLM, SamplingParams = None, None
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ class VLLM(BaseModel):
|
|||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
mode: str = 'none',
|
mode: str = 'none',
|
||||||
use_fastchat_template: bool = False,
|
use_fastchat_template: bool = False,
|
||||||
|
lora_path: str = None,
|
||||||
stop_words: List[str] = [],
|
stop_words: List[str] = [],
|
||||||
):
|
):
|
||||||
super().__init__(path=path,
|
super().__init__(path=path,
|
||||||
@ -38,7 +40,7 @@ class VLLM(BaseModel):
|
|||||||
self.tokenizer = self.model.get_tokenizer()
|
self.tokenizer = self.model.get_tokenizer()
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.generation_kwargs.pop('do_sample', None)
|
self.generation_kwargs.pop('do_sample', None)
|
||||||
|
self.lora_path = lora_path
|
||||||
assert mode in ['none', 'mid']
|
assert mode in ['none', 'mid']
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.use_fastchat_template = use_fastchat_template
|
self.use_fastchat_template = use_fastchat_template
|
||||||
@ -96,7 +98,10 @@ class VLLM(BaseModel):
|
|||||||
_stop = list(set(self.stop_words + stopping_criteria))
|
_stop = list(set(self.stop_words + stopping_criteria))
|
||||||
generation_kwargs.update({'stop': _stop})
|
generation_kwargs.update({'stop': _stop})
|
||||||
sampling_kwargs = SamplingParams(**generation_kwargs)
|
sampling_kwargs = SamplingParams(**generation_kwargs)
|
||||||
|
if not self.lora_path:
|
||||||
outputs = self.model.generate(inputs, sampling_kwargs)
|
outputs = self.model.generate(inputs, sampling_kwargs)
|
||||||
|
else:
|
||||||
|
outputs = self.model.generate(inputs, sampling_kwargs, lora_request=LoRARequest("sql_adapter", 1, self.lora_path))
|
||||||
|
|
||||||
prompt_list, output_strs = [], []
|
prompt_list, output_strs = [], []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
|
Loading…
Reference in New Issue
Block a user