[Fix] Fix local debug mode not restrict the resources (#522)

* [Fix] fix local debug mode not restrict the resources

* minor fix
This commit is contained in:
Hubert 2023-10-30 18:13:43 +08:00 committed by GitHub
parent 229a65f305
commit b9270c3a60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@ import os
import os.path as osp import os.path as osp
import re import re
import subprocess import subprocess
import sys
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
@ -19,6 +20,18 @@ from opencompass.utils import get_logger
from .base import BaseRunner from .base import BaseRunner
def get_command_template(gpu_ids: List[int]) -> str:
"""Format command template given available gpu ids."""
if sys.platform == 'win32': # Always return win32 for Windows
# use command in Windows format
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
tmpl += ' & {task_cmd}'
else:
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
tmpl += ' {task_cmd}'
return tmpl
@RUNNERS.register_module() @RUNNERS.register_module()
class LocalRunner(BaseRunner): class LocalRunner(BaseRunner):
"""Local runner. Start tasks by local python. """Local runner. Start tasks by local python.
@ -55,17 +68,36 @@ class LocalRunner(BaseRunner):
""" """
status = [] status = []
import torch
if 'CUDA_VISIBLE_DEVICES' in os.environ:
all_gpu_ids = [
int(i) for i in re.findall(r'(?<!-)\d+',
os.getenv('CUDA_VISIBLE_DEVICES'))
]
else:
all_gpu_ids = list(range(torch.cuda.device_count()))
if self.debug: if self.debug:
for task in tasks: for task in tasks:
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type'])) task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
task_name = task.name task_name = task.name
num_gpus = task.num_gpus
assert len(all_gpu_ids) >= num_gpus
# get cmd # get cmd
mmengine.mkdir_or_exist('tmp/') mmengine.mkdir_or_exist('tmp/')
param_file = f'tmp/{os.getpid()}_params.py' param_file = f'tmp/{os.getpid()}_params.py'
try: try:
task.cfg.dump(param_file) task.cfg.dump(param_file)
cmd = task.get_command(cfg_path=param_file, # if use torchrun, restrict it behaves the same as non
template='{task_cmd}') # debug mode, otherwise, the torchrun will use all the
# available resources which might cause inconsistent
# behavior.
if len(all_gpu_ids) > num_gpus and num_gpus > 0:
get_logger().warning(f'Only use {num_gpus} GPUs for '
f'total {len(all_gpu_ids)} '
'available GPUs in debug mode.')
tmpl = get_command_template(all_gpu_ids[:num_gpus])
cmd = task.get_command(cfg_path=param_file, template=tmpl)
# run in subprocess if starts with torchrun etc. # run in subprocess if starts with torchrun etc.
if cmd.startswith('python'): if cmd.startswith('python'):
task.run() task.run()
@ -75,15 +107,6 @@ class LocalRunner(BaseRunner):
os.remove(param_file) os.remove(param_file)
status.append((task_name, 0)) status.append((task_name, 0))
else: else:
import torch
if 'CUDA_VISIBLE_DEVICES' in os.environ:
all_gpu_ids = [
int(i) for i in re.findall(
r'(?<!-)\d+', os.getenv('CUDA_VISIBLE_DEVICES'))
]
else:
all_gpu_ids = list(range(torch.cuda.device_count()))
if len(all_gpu_ids) > 0: if len(all_gpu_ids) > 0:
gpus = np.zeros(max(all_gpu_ids) + 1, dtype=np.uint) gpus = np.zeros(max(all_gpu_ids) + 1, dtype=np.uint)
gpus[all_gpu_ids] = self.max_workers_per_gpu gpus[all_gpu_ids] = self.max_workers_per_gpu
@ -145,18 +168,7 @@ class LocalRunner(BaseRunner):
param_file = f'tmp/{os.getpid()}_{index}_params.py' param_file = f'tmp/{os.getpid()}_{index}_params.py'
try: try:
task.cfg.dump(param_file) task.cfg.dump(param_file)
tmpl = get_command_template(gpu_ids)
# Build up local command
import sys
if sys.platform == 'win32': # Always return win32 for Windows
# use command in Windows format
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(
str(i) for i in gpu_ids)
tmpl += ' & {task_cmd}'
else:
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(
str(i) for i in gpu_ids)
tmpl += ' {task_cmd}'
get_cmd = partial(task.get_command, get_cmd = partial(task.get_command,
cfg_path=param_file, cfg_path=param_file,
template=tmpl) template=tmpl)