OpenCompass/opencompass/runners/slurm.py
Yuan Liu 191a3f6f9d
[Feature]: Use multimodal (#73)
* [Feature]: Add minigpt-4

* [Feature]: Add mm local runner

* [Feature]: Add instructblip

* [Feature]: Delete redundant file

* [Feature]: Delete redundant file

* [Feature]: Add README to InstructBLIP

* [Feature]: Update MiniGPT-4

* [Fix]: Fix lint

* [Feature]add omnibenchmark readme (#49)

* add omnibenchmark readme

* fix

* Update OmniMMBench.md

* Update OmniMMBench.md

* Update OmniMMBench.md

* [Fix]: Refine name (#54)

* [Feature]: Unify out and err

* [Fix]: Fix lint

* [Feature]: Rename to mmbench and change weight path

* [Feature]: Delete Omni in instructblip

* [Feature]: Check the avaliablity of lavis

* [Fix]: Fix lint

* [Feature]: Refactor MM

* [Refactor]: Refactor path

* [Feature]: Delete redundant files

* [Refactor]: Delete redundant files

---------

Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com>
2023-08-03 11:07:50 +08:00

153 lines
5.3 KiB
Python

import os
import os.path as osp
import random
import subprocess
import time
from functools import partial
from typing import Any, Dict, List, Tuple
import mmengine
from mmengine.config import ConfigDict
from mmengine.utils import track_parallel_progress
from opencompass.registry import RUNNERS, TASKS
from opencompass.utils import get_logger
from .base import BaseRunner
@RUNNERS.register_module()
class SlurmRunner(BaseRunner):
"""Distributed runner based on Slurm. It will launch tasks in parallel
using `srun` command.
Args:
task (ConfigDict): Task type config.
max_num_workers (int): Max number of workers to run in parallel.
Defaults to 32.
retry (int): Number of retries if the job failed. Defaults to 2.
partition (str): Slurm partition name. Defaults to None.
quotatype (str): Slurm quota type. Defaults to None.
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
"""
def __init__(self,
task: ConfigDict,
max_num_workers: int = 32,
retry: int = 2,
partition: str = None,
quotatype: str = None,
qos: str = None,
debug: bool = False,
lark_bot_url: str = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.retry = retry
self.partition = partition
self.quotatype = quotatype
self.qos = qos
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks.
Args:
tasks (list[dict]): A list of task configs, usually generated by
Partitioner.
Returns:
list[tuple[str, int]]: A list of (task name, exit code).
"""
if not self.debug:
status = track_parallel_progress(self._launch,
tasks,
nproc=self.max_num_workers,
keep_order=False)
else:
status = [self._launch(task, random_sleep=False) for task in tasks]
return status
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
"""Launch a single task.
Args:
task_cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching
multiple tasks at the same time. Default: True.
Returns:
tuple[str, int]: Task name and exit code.
"""
task_type = self.task_cfg.type
if isinstance(self.task_cfg.type, str):
task_type = TASKS.get(task_type)
task = task_type(task_cfg)
num_gpus = task.num_gpus
task_name = task.name
# Dump task config to file
mmengine.mkdir_or_exist('tmp/')
param_file = f'tmp/{os.getpid()}_params.py'
task_cfg.dump(param_file)
# Build up slurm command
tmpl = 'srun'
if self.partition:
tmpl += f' -p {self.partition}'
if self.quotatype:
tmpl += f' --quotatype={self.quotatype}'
if self.qos:
tmpl += f' --qos={self.qos}'
if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command, cfg_path=param_file, template=tmpl)
cmd = get_cmd()
logger = get_logger()
logger.debug(f'Running command: {cmd}')
# Run command with retry
if self.debug:
stdout = None
else:
out_path = task.get_log_path(file_extension='out')
mmengine.mkdir_or_exist(osp.split(out_path)[0])
stdout = open(out_path, 'w', encoding='utf-8')
if random_sleep:
time.sleep(random.randint(0, 10))
result = subprocess.run(cmd,
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
retry = self.retry
output_paths = task.get_output_paths()
while self._job_failed(result.returncode, output_paths) and retry > 0:
retry -= 1
if random_sleep:
time.sleep(random.randint(0, 10))
# Re-generate command to refresh ports.
cmd = get_cmd()
result = subprocess.run(cmd,
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
if result.returncode != 0 and not self.debug:
logger.warning(f'task {task_name} fail, see\n{out_path}')
# Clean up
os.remove(param_file)
return task_name, result.returncode
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
return return_code != 0 or not all(
osp.exists(output_path) for output_path in output_paths)