* npu适配

* Add suport for Ascend NPU

* format

---------

Co-authored-by: baymax591 <14428251+baymax591@user.noreply.gitee.com>
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
baymax591 2024-07-03 18:55:19 +08:00 committed by GitHub
parent fc2c9dea8c
commit 28eba6fe34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@ from typing import Any, Dict, List, Tuple
import mmengine import mmengine
import numpy as np import numpy as np
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from mmengine.device import is_npu_available
from tqdm import tqdm from tqdm import tqdm
from opencompass.registry import RUNNERS, TASKS from opencompass.registry import RUNNERS, TASKS
@ -22,7 +23,10 @@ from .base import BaseRunner
def get_command_template(gpu_ids: List[int]) -> str: def get_command_template(gpu_ids: List[int]) -> str:
"""Format command template given available gpu ids.""" """Format command template given available gpu ids."""
if sys.platform == 'win32': # Always return win32 for Windows if is_npu_available():
tmpl = 'ASCEND_RT_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
tmpl += ' & {task_cmd}'
elif sys.platform == 'win32': # Always return win32 for Windows
# use command in Windows format # use command in Windows format
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids) tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
tmpl += ' & {task_cmd}' tmpl += ' & {task_cmd}'
@ -74,13 +78,19 @@ class LocalRunner(BaseRunner):
status = [] status = []
import torch import torch
if 'CUDA_VISIBLE_DEVICES' in os.environ: if is_npu_available():
visible_devices = 'ASCEND_RT_VISIBLE_DEVICES'
device_nums = torch.npu.device_count()
else:
visible_devices = 'CUDA_VISIBLE_DEVICES'
device_nums = torch.cuda.device_count()
if visible_devices in os.environ:
all_gpu_ids = [ all_gpu_ids = [
int(i) for i in re.findall(r'(?<!-)\d+', int(i)
os.getenv('CUDA_VISIBLE_DEVICES')) for i in re.findall(r'(?<!-)\d+', os.getenv(visible_devices))
] ]
else: else:
all_gpu_ids = list(range(torch.cuda.device_count())) all_gpu_ids = list(range(device_nums))
if self.debug: if self.debug:
for task in tasks: for task in tasks: