mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge branch 'open-compass:main' into main
This commit is contained in:
commit
ceffb11a87
@ -123,7 +123,7 @@ class TurboMindModelwithChatTemplate(BaseModel):
|
|||||||
|
|
||||||
gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG)
|
gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG)
|
||||||
gen_config.update(self.gen_config)
|
gen_config.update(self.gen_config)
|
||||||
if do_sample:
|
if do_sample or self.gen_config['do_sample']:
|
||||||
gen_config['top_k'] = 40
|
gen_config['top_k'] = 40
|
||||||
gen_config['temperature'] = temperature
|
gen_config['temperature'] = temperature
|
||||||
else:
|
else:
|
||||||
|
@ -7,6 +7,7 @@ from opencompass.utils import get_logger
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LLM, SamplingParams = None, None
|
LLM, SamplingParams = None, None
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ class VLLM(BaseModel):
|
|||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
mode: str = 'none',
|
mode: str = 'none',
|
||||||
use_fastchat_template: bool = False,
|
use_fastchat_template: bool = False,
|
||||||
|
lora_path: str = None,
|
||||||
stop_words: List[str] = [],
|
stop_words: List[str] = [],
|
||||||
):
|
):
|
||||||
super().__init__(path=path,
|
super().__init__(path=path,
|
||||||
@ -38,7 +40,7 @@ class VLLM(BaseModel):
|
|||||||
self.tokenizer = self.model.get_tokenizer()
|
self.tokenizer = self.model.get_tokenizer()
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.generation_kwargs.pop('do_sample', None)
|
self.generation_kwargs.pop('do_sample', None)
|
||||||
|
self.lora_path = lora_path
|
||||||
assert mode in ['none', 'mid']
|
assert mode in ['none', 'mid']
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.use_fastchat_template = use_fastchat_template
|
self.use_fastchat_template = use_fastchat_template
|
||||||
@ -96,7 +98,10 @@ class VLLM(BaseModel):
|
|||||||
_stop = list(set(self.stop_words + stopping_criteria))
|
_stop = list(set(self.stop_words + stopping_criteria))
|
||||||
generation_kwargs.update({'stop': _stop})
|
generation_kwargs.update({'stop': _stop})
|
||||||
sampling_kwargs = SamplingParams(**generation_kwargs)
|
sampling_kwargs = SamplingParams(**generation_kwargs)
|
||||||
outputs = self.model.generate(inputs, sampling_kwargs)
|
if not self.lora_path:
|
||||||
|
outputs = self.model.generate(inputs, sampling_kwargs)
|
||||||
|
else:
|
||||||
|
outputs = self.model.generate(inputs, sampling_kwargs, lora_request=LoRARequest("sql_adapter", 1, self.lora_path))
|
||||||
|
|
||||||
prompt_list, output_strs = [], []
|
prompt_list, output_strs = [], []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
|
Loading…
Reference in New Issue
Block a user