mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
npu适配 (#1250)
* npu适配 * Add suport for Ascend NPU * format --------- Co-authored-by: baymax591 <14428251+baymax591@user.noreply.gitee.com> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
fc2c9dea8c
commit
28eba6fe34
@ -12,6 +12,7 @@ from typing import Any, Dict, List, Tuple
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.device import is_npu_available
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.registry import RUNNERS, TASKS
|
||||
@ -22,7 +23,10 @@ from .base import BaseRunner
|
||||
|
||||
def get_command_template(gpu_ids: List[int]) -> str:
|
||||
"""Format command template given available gpu ids."""
|
||||
if sys.platform == 'win32': # Always return win32 for Windows
|
||||
if is_npu_available():
|
||||
tmpl = 'ASCEND_RT_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
||||
tmpl += ' & {task_cmd}'
|
||||
elif sys.platform == 'win32': # Always return win32 for Windows
|
||||
# use command in Windows format
|
||||
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
||||
tmpl += ' & {task_cmd}'
|
||||
@ -74,13 +78,19 @@ class LocalRunner(BaseRunner):
|
||||
status = []
|
||||
import torch
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
if is_npu_available():
|
||||
visible_devices = 'ASCEND_RT_VISIBLE_DEVICES'
|
||||
device_nums = torch.npu.device_count()
|
||||
else:
|
||||
visible_devices = 'CUDA_VISIBLE_DEVICES'
|
||||
device_nums = torch.cuda.device_count()
|
||||
if visible_devices in os.environ:
|
||||
all_gpu_ids = [
|
||||
int(i) for i in re.findall(r'(?<!-)\d+',
|
||||
os.getenv('CUDA_VISIBLE_DEVICES'))
|
||||
int(i)
|
||||
for i in re.findall(r'(?<!-)\d+', os.getenv(visible_devices))
|
||||
]
|
||||
else:
|
||||
all_gpu_ids = list(range(torch.cuda.device_count()))
|
||||
all_gpu_ids = list(range(device_nums))
|
||||
|
||||
if self.debug:
|
||||
for task in tasks:
|
||||
|
Loading…
Reference in New Issue
Block a user