[Fix] Fixed repeated loading of VLLM (#1051)

* [fix]Fixed the issue caused by the repeated loading of VLLM model during task segmentation.

* [fix] avoid TypeError: VLLM.__init__() got an unexpected keyword argument 'tokenizer_only'

* restore .pre-commit-config.yaml

* restore opencompass/tasks/openicl_infer.py

---------

Co-authored-by: IcyFeather <mengzhuo.happy@gmail.com>
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
Robin Chen 2024-04-17 20:36:08 +08:00 committed by GitHub
parent 629836146a
commit c172401323
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 34 additions and 9 deletions

View File

@ -49,6 +49,12 @@ class VLLM(BaseModel):
model_kwargs = DEFAULT_MODEL_KWARGS.copy()
if add_model_kwargs is not None:
model_kwargs.update(add_model_kwargs)
import ray
if ray.is_initialized():
self.logger.info('shutdown ray instance to avoid '
'"Calling ray.init() again" error.')
ray.shutdown()
self.model = LLM(path, **model_kwargs)
def generate(self, inputs: List[str], max_out_len: int,

View File

@ -46,12 +46,14 @@ class LocalRunner(BaseRunner):
lark_bot_url (str): Lark bot url.
"""
def __init__(self,
def __init__(
self,
task: ConfigDict,
max_num_workers: int = 16,
debug: bool = False,
max_workers_per_gpu: int = 1,
lark_bot_url: str = None):
lark_bot_url: str = None,
):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.max_workers_per_gpu = max_workers_per_gpu
@ -69,6 +71,7 @@ class LocalRunner(BaseRunner):
status = []
import torch
if 'CUDA_VISIBLE_DEVICES' in os.environ:
all_gpu_ids = [
int(i) for i in re.findall(r'(?<!-)\d+',
@ -100,6 +103,17 @@ class LocalRunner(BaseRunner):
cmd = task.get_command(cfg_path=param_file, template=tmpl)
# run in subprocess if starts with torchrun etc.
if 'python3 ' in cmd or 'python ' in cmd:
# If it is an infer type task do not reload if
# the current model has already been loaded.
if 'infer' in self.task_cfg.type.lower():
# If a model instance already exists,
# do not reload it.
if hasattr(self, 'cur_model'):
task.run(self.cur_model)
else:
task.run()
self.cur_model = task.model
else:
task.run()
else:
subprocess.run(cmd, shell=True, text=True)

View File

@ -59,13 +59,17 @@ class OpenICLInferTask(BaseTask):
return template.format(task_cmd=command)
def run(self):
def run(self, cur_model=None):
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}')
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
self.max_out_len = model_cfg.get('max_out_len', None)
self.batch_size = model_cfg.get('batch_size', None)
self.min_out_len = model_cfg.get('min_out_len', None)
if cur_model:
self.model = cur_model
else:
self.model = build_model_from_cfg(model_cfg)
cur_model = self.model
for dataset_cfg in dataset_cfgs:
self.model_cfg = model_cfg

View File

@ -22,4 +22,5 @@ def build_model_from_cfg(model_cfg: ConfigDict):
model_cfg.pop('summarizer_abbr', None)
model_cfg.pop('pred_postprocessor', None)
model_cfg.pop('min_out_len', None)
model_cfg.pop('tokenizer_only', None)
return MODELS.build(model_cfg)