Merge branch 'open-compass:main' into main

This commit is contained in:
bittersweet1999 2024-11-12 17:34:11 +08:00 committed by GitHub
commit ceffb11a87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View File

@ -123,7 +123,7 @@ class TurboMindModelwithChatTemplate(BaseModel):
gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG)
gen_config.update(self.gen_config)
if do_sample:
if do_sample or self.gen_config['do_sample']:
gen_config['top_k'] = 40
gen_config['temperature'] = temperature
else:

View File

@ -7,6 +7,7 @@ from opencompass.utils import get_logger
try:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
except ImportError:
LLM, SamplingParams = None, None
@ -25,6 +26,7 @@ class VLLM(BaseModel):
meta_template: Optional[Dict] = None,
mode: str = 'none',
use_fastchat_template: bool = False,
lora_path: str = None,
stop_words: List[str] = [],
):
super().__init__(path=path,
@ -38,7 +40,7 @@ class VLLM(BaseModel):
self.tokenizer = self.model.get_tokenizer()
self.generation_kwargs = generation_kwargs
self.generation_kwargs.pop('do_sample', None)
self.lora_path = lora_path
assert mode in ['none', 'mid']
self.mode = mode
self.use_fastchat_template = use_fastchat_template
@ -96,7 +98,10 @@ class VLLM(BaseModel):
_stop = list(set(self.stop_words + stopping_criteria))
generation_kwargs.update({'stop': _stop})
sampling_kwargs = SamplingParams(**generation_kwargs)
outputs = self.model.generate(inputs, sampling_kwargs)
if not self.lora_path:
outputs = self.model.generate(inputs, sampling_kwargs)
else:
outputs = self.model.generate(inputs, sampling_kwargs, lora_request=LoRARequest("sql_adapter", 1, self.lora_path))
prompt_list, output_strs = [], []
for output in outputs: