mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
npu适配 (#1250)
* npu适配 * Add suport for Ascend NPU * format --------- Co-authored-by: baymax591 <14428251+baymax591@user.noreply.gitee.com> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
fc2c9dea8c
commit
28eba6fe34
@ -12,6 +12,7 @@ from typing import Any, Dict, List, Tuple
|
|||||||
import mmengine
|
import mmengine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmengine.config import ConfigDict
|
from mmengine.config import ConfigDict
|
||||||
|
from mmengine.device import is_npu_available
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from opencompass.registry import RUNNERS, TASKS
|
from opencompass.registry import RUNNERS, TASKS
|
||||||
@ -22,7 +23,10 @@ from .base import BaseRunner
|
|||||||
|
|
||||||
def get_command_template(gpu_ids: List[int]) -> str:
|
def get_command_template(gpu_ids: List[int]) -> str:
|
||||||
"""Format command template given available gpu ids."""
|
"""Format command template given available gpu ids."""
|
||||||
if sys.platform == 'win32': # Always return win32 for Windows
|
if is_npu_available():
|
||||||
|
tmpl = 'ASCEND_RT_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
||||||
|
tmpl += ' & {task_cmd}'
|
||||||
|
elif sys.platform == 'win32': # Always return win32 for Windows
|
||||||
# use command in Windows format
|
# use command in Windows format
|
||||||
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
||||||
tmpl += ' & {task_cmd}'
|
tmpl += ' & {task_cmd}'
|
||||||
@ -74,13 +78,19 @@ class LocalRunner(BaseRunner):
|
|||||||
status = []
|
status = []
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
if is_npu_available():
|
||||||
|
visible_devices = 'ASCEND_RT_VISIBLE_DEVICES'
|
||||||
|
device_nums = torch.npu.device_count()
|
||||||
|
else:
|
||||||
|
visible_devices = 'CUDA_VISIBLE_DEVICES'
|
||||||
|
device_nums = torch.cuda.device_count()
|
||||||
|
if visible_devices in os.environ:
|
||||||
all_gpu_ids = [
|
all_gpu_ids = [
|
||||||
int(i) for i in re.findall(r'(?<!-)\d+',
|
int(i)
|
||||||
os.getenv('CUDA_VISIBLE_DEVICES'))
|
for i in re.findall(r'(?<!-)\d+', os.getenv(visible_devices))
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
all_gpu_ids = list(range(torch.cuda.device_count()))
|
all_gpu_ids = list(range(device_nums))
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
|
Loading…
Reference in New Issue
Block a user