OpenCompass/opencompass/runners/slurm.py

157 lines
5.6 KiB
Python
Raw Normal View History

2023-07-05 10:33:12 +08:00
import os
import os.path as osp
import random
import subprocess
import time
from functools import partial
2023-07-05 10:33:12 +08:00
from typing import Any, Dict, List, Tuple
import mmengine
from mmengine.config import ConfigDict
from mmengine.utils import track_parallel_progress
from opencompass.registry import RUNNERS, TASKS
from opencompass.utils import get_logger
from .base import BaseRunner
@RUNNERS.register_module()
class SlurmRunner(BaseRunner):
"""Distributed runner based on Slurm. It will launch tasks in parallel
using `srun` command.
Args:
task (ConfigDict): Task type config.
max_num_workers (int): Max number of workers to run in parallel.
Defaults to 32.
retry (int): Number of retries if the job failed. Defaults to 2.
partition (str): Slurm partition name. Defaults to None.
quotatype (str): Slurm quota type. Defaults to None.
qos (str): Slurm quality of service. Defaults to None.
2023-07-05 10:33:12 +08:00
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
"""
def __init__(self,
task: ConfigDict,
max_num_workers: int = 32,
retry: int = 2,
partition: str = None,
quotatype: str = None,
qos: str = None,
2023-07-05 10:33:12 +08:00
debug: bool = False,
lark_bot_url: str = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.retry = retry
self.partition = partition
self.quotatype = quotatype
self.qos = qos
2023-07-05 10:33:12 +08:00
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks.
Args:
tasks (list[dict]): A list of task configs, usually generated by
Partitioner.
Returns:
list[tuple[str, int]]: A list of (task name, exit code).
"""
if not self.debug:
status = track_parallel_progress(self._launch,
tasks,
nproc=self.max_num_workers,
keep_order=False)
else:
status = [self._launch(task, random_sleep=False) for task in tasks]
return status
def _launch(self, task_cfg: ConfigDict, random_sleep: bool = True):
"""Launch a single task.
Args:
task_cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching
multiple tasks at the same time. Default: True.
Returns:
tuple[str, int]: Task name and exit code.
"""
task_type = self.task_cfg.type
if isinstance(self.task_cfg.type, str):
task_type = TASKS.get(task_type)
task = task_type(task_cfg)
num_gpus = task.num_gpus
task_name = task.name
# Dump task config to file
mmengine.mkdir_or_exist('tmp/')
param_file = f'tmp/{os.getpid()}_params.py'
try:
task_cfg.dump(param_file)
# Build up slurm command
tmpl = 'srun'
if self.partition:
tmpl += f' -p {self.partition}'
if self.quotatype:
tmpl += f' --quotatype={self.quotatype}'
if self.qos:
tmpl += f' --qos={self.qos}'
if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command,
cfg_path=param_file,
template=tmpl)
cmd = get_cmd()
logger = get_logger()
logger.debug(f'Running command: {cmd}')
# Run command with retry
if self.debug:
stdout = None
else:
out_path = task.get_log_path(file_extension='out')
mmengine.mkdir_or_exist(osp.split(out_path)[0])
stdout = open(out_path, 'w', encoding='utf-8')
2023-07-05 10:33:12 +08:00
if random_sleep:
time.sleep(random.randint(0, 10))
result = subprocess.run(cmd,
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
retry = self.retry
output_paths = task.get_output_paths()
while self._job_failed(result.returncode,
output_paths) and retry > 0:
retry -= 1
if random_sleep:
time.sleep(random.randint(0, 10))
# Re-generate command to refresh ports.
cmd = get_cmd()
result = subprocess.run(cmd,
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
if result.returncode != 0 and not self.debug:
logger.warning(f'task {task_name} fail, see\n{out_path}')
finally:
# Clean up
os.remove(param_file)
2023-07-05 10:33:12 +08:00
return task_name, result.returncode
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
return return_code != 0 or not all(
osp.exists(output_path) for output_path in output_paths)