mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Fix local debug mode not restrict the resources (#522)
* [Fix] fix local debug mode not restrict the resources * minor fix
This commit is contained in:
parent
229a65f305
commit
b9270c3a60
@ -2,6 +2,7 @@ import os
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -19,6 +20,18 @@ from opencompass.utils import get_logger
|
|||||||
from .base import BaseRunner
|
from .base import BaseRunner
|
||||||
|
|
||||||
|
|
||||||
|
def get_command_template(gpu_ids: List[int]) -> str:
|
||||||
|
"""Format command template given available gpu ids."""
|
||||||
|
if sys.platform == 'win32': # Always return win32 for Windows
|
||||||
|
# use command in Windows format
|
||||||
|
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
||||||
|
tmpl += ' & {task_cmd}'
|
||||||
|
else:
|
||||||
|
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
||||||
|
tmpl += ' {task_cmd}'
|
||||||
|
return tmpl
|
||||||
|
|
||||||
|
|
||||||
@RUNNERS.register_module()
|
@RUNNERS.register_module()
|
||||||
class LocalRunner(BaseRunner):
|
class LocalRunner(BaseRunner):
|
||||||
"""Local runner. Start tasks by local python.
|
"""Local runner. Start tasks by local python.
|
||||||
@ -55,17 +68,36 @@ class LocalRunner(BaseRunner):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
status = []
|
status = []
|
||||||
|
import torch
|
||||||
|
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||||
|
all_gpu_ids = [
|
||||||
|
int(i) for i in re.findall(r'(?<!-)\d+',
|
||||||
|
os.getenv('CUDA_VISIBLE_DEVICES'))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
all_gpu_ids = list(range(torch.cuda.device_count()))
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
|
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
|
||||||
task_name = task.name
|
task_name = task.name
|
||||||
|
num_gpus = task.num_gpus
|
||||||
|
assert len(all_gpu_ids) >= num_gpus
|
||||||
# get cmd
|
# get cmd
|
||||||
mmengine.mkdir_or_exist('tmp/')
|
mmengine.mkdir_or_exist('tmp/')
|
||||||
param_file = f'tmp/{os.getpid()}_params.py'
|
param_file = f'tmp/{os.getpid()}_params.py'
|
||||||
try:
|
try:
|
||||||
task.cfg.dump(param_file)
|
task.cfg.dump(param_file)
|
||||||
cmd = task.get_command(cfg_path=param_file,
|
# if use torchrun, restrict it behaves the same as non
|
||||||
template='{task_cmd}')
|
# debug mode, otherwise, the torchrun will use all the
|
||||||
|
# available resources which might cause inconsistent
|
||||||
|
# behavior.
|
||||||
|
if len(all_gpu_ids) > num_gpus and num_gpus > 0:
|
||||||
|
get_logger().warning(f'Only use {num_gpus} GPUs for '
|
||||||
|
f'total {len(all_gpu_ids)} '
|
||||||
|
'available GPUs in debug mode.')
|
||||||
|
tmpl = get_command_template(all_gpu_ids[:num_gpus])
|
||||||
|
cmd = task.get_command(cfg_path=param_file, template=tmpl)
|
||||||
# run in subprocess if starts with torchrun etc.
|
# run in subprocess if starts with torchrun etc.
|
||||||
if cmd.startswith('python'):
|
if cmd.startswith('python'):
|
||||||
task.run()
|
task.run()
|
||||||
@ -75,15 +107,6 @@ class LocalRunner(BaseRunner):
|
|||||||
os.remove(param_file)
|
os.remove(param_file)
|
||||||
status.append((task_name, 0))
|
status.append((task_name, 0))
|
||||||
else:
|
else:
|
||||||
import torch
|
|
||||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
|
||||||
all_gpu_ids = [
|
|
||||||
int(i) for i in re.findall(
|
|
||||||
r'(?<!-)\d+', os.getenv('CUDA_VISIBLE_DEVICES'))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
all_gpu_ids = list(range(torch.cuda.device_count()))
|
|
||||||
|
|
||||||
if len(all_gpu_ids) > 0:
|
if len(all_gpu_ids) > 0:
|
||||||
gpus = np.zeros(max(all_gpu_ids) + 1, dtype=np.uint)
|
gpus = np.zeros(max(all_gpu_ids) + 1, dtype=np.uint)
|
||||||
gpus[all_gpu_ids] = self.max_workers_per_gpu
|
gpus[all_gpu_ids] = self.max_workers_per_gpu
|
||||||
@ -145,18 +168,7 @@ class LocalRunner(BaseRunner):
|
|||||||
param_file = f'tmp/{os.getpid()}_{index}_params.py'
|
param_file = f'tmp/{os.getpid()}_{index}_params.py'
|
||||||
try:
|
try:
|
||||||
task.cfg.dump(param_file)
|
task.cfg.dump(param_file)
|
||||||
|
tmpl = get_command_template(gpu_ids)
|
||||||
# Build up local command
|
|
||||||
import sys
|
|
||||||
if sys.platform == 'win32': # Always return win32 for Windows
|
|
||||||
# use command in Windows format
|
|
||||||
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(
|
|
||||||
str(i) for i in gpu_ids)
|
|
||||||
tmpl += ' & {task_cmd}'
|
|
||||||
else:
|
|
||||||
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(
|
|
||||||
str(i) for i in gpu_ids)
|
|
||||||
tmpl += ' {task_cmd}'
|
|
||||||
get_cmd = partial(task.get_command,
|
get_cmd = partial(task.get_command,
|
||||||
cfg_path=param_file,
|
cfg_path=param_file,
|
||||||
template=tmpl)
|
template=tmpl)
|
||||||
|
Loading…
Reference in New Issue
Block a user