mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support CUDA_VISIBLE_DEVICES and multiple tasks on one GPU (#148)
* [Feature] Support CUDA_VISIBLE_DEVICES and multiple tasks on one GPU * Fix UT * Update according to comments
This commit is contained in:
parent
312095de9d
commit
59bf56349c
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -26,6 +27,8 @@ class LocalRunner(BaseRunner):
|
||||
task (ConfigDict): Task type config.
|
||||
max_num_workers (int): Max number of workers to run in parallel.
|
||||
Defaults to 16.
|
||||
max_workers_per_gpu (int): Max number of workers to run for one GPU.
|
||||
Defaults to 1.
|
||||
debug (bool): Whether to run in debug mode.
|
||||
lark_bot_url (str): Lark bot url.
|
||||
"""
|
||||
@ -34,9 +37,11 @@ class LocalRunner(BaseRunner):
|
||||
task: ConfigDict,
|
||||
max_num_workers: int = 16,
|
||||
debug: bool = False,
|
||||
max_workers_per_gpu: int = 1,
|
||||
lark_bot_url: str = None):
|
||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||
self.max_num_workers = max_num_workers
|
||||
self.max_workers_per_gpu = max_workers_per_gpu
|
||||
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
"""Launch multiple tasks.
|
||||
@ -58,7 +63,20 @@ class LocalRunner(BaseRunner):
|
||||
status.append((task_name, 0))
|
||||
else:
|
||||
import torch
|
||||
gpus = np.ones(torch.cuda.device_count(), dtype=np.bool_)
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
all_gpu_ids = [
|
||||
int(i) for i in re.findall(
|
||||
r'(?<!-)\d+', os.getenv('CUDA_VISIBLE_DEVICES'))
|
||||
]
|
||||
else:
|
||||
all_gpu_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
if len(all_gpu_ids) > 0:
|
||||
gpus = np.zeros(max(all_gpu_ids) + 1, dtype=np.uint)
|
||||
gpus[all_gpu_ids] = self.max_workers_per_gpu
|
||||
else:
|
||||
gpus = np.array([], dtype=np.uint)
|
||||
|
||||
pbar = tqdm(total=len(tasks))
|
||||
lock = Lock()
|
||||
|
||||
@ -69,9 +87,9 @@ class LocalRunner(BaseRunner):
|
||||
|
||||
while True:
|
||||
lock.acquire()
|
||||
if sum(gpus) >= num_gpus:
|
||||
if sum(gpus > 0) >= num_gpus:
|
||||
gpu_ids = np.where(gpus)[0][:num_gpus]
|
||||
gpus[gpu_ids] = False
|
||||
gpus[gpu_ids] -= 1
|
||||
lock.release()
|
||||
break
|
||||
lock.release()
|
||||
@ -87,7 +105,7 @@ class LocalRunner(BaseRunner):
|
||||
pbar.update()
|
||||
|
||||
with lock:
|
||||
gpus[gpu_ids] = True
|
||||
gpus[gpu_ids] += 1
|
||||
|
||||
return res
|
||||
|
||||
|
6
run.py
6
run.py
@ -98,6 +98,11 @@ def parse_args():
|
||||
'in the config.',
|
||||
type=int,
|
||||
default=32)
|
||||
parser.add_argument('--max-workers-per-gpu',
|
||||
help='Max task to run in parallel on one GPU. '
|
||||
'It will only be used in the local runner.',
|
||||
type=int,
|
||||
default=32)
|
||||
parser.add_argument(
|
||||
'--retry',
|
||||
help='Number of retries if the job failed when using slurm or dlc. '
|
||||
@ -337,6 +342,7 @@ def exec_infer_runner(tasks, args, cfg):
|
||||
else:
|
||||
runner = LocalRunner(task=dict(type='OpenICLInferTask'),
|
||||
max_num_workers=args.max_num_workers,
|
||||
max_workers_per_gpu=args.max_workers_per_gpu,
|
||||
debug=args.debug,
|
||||
lark_bot_url=cfg['lark_bot_url'])
|
||||
runner(tasks)
|
||||
|
Loading…
Reference in New Issue
Block a user