From 28eba6fe34fd41d906712d40f076ab0d0273f0e1 Mon Sep 17 00:00:00 2001 From: baymax591 Date: Wed, 3 Jul 2024 18:55:19 +0800 Subject: [PATCH] =?UTF-8?q?npu=E9=80=82=E9=85=8D=20(#1250)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * npu适配 * Add suport for Ascend NPU * format --------- Co-authored-by: baymax591 <14428251+baymax591@user.noreply.gitee.com> Co-authored-by: Leymore --- opencompass/runners/local.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/opencompass/runners/local.py b/opencompass/runners/local.py index 297fba25..ea336d5b 100644 --- a/opencompass/runners/local.py +++ b/opencompass/runners/local.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Tuple import mmengine import numpy as np from mmengine.config import ConfigDict +from mmengine.device import is_npu_available from tqdm import tqdm from opencompass.registry import RUNNERS, TASKS @@ -22,7 +23,10 @@ from .base import BaseRunner def get_command_template(gpu_ids: List[int]) -> str: """Format command template given available gpu ids.""" - if sys.platform == 'win32': # Always return win32 for Windows + if is_npu_available(): + tmpl = 'ASCEND_RT_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids) + tmpl += ' & {task_cmd}' + elif sys.platform == 'win32': # Always return win32 for Windows # use command in Windows format tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids) tmpl += ' & {task_cmd}' @@ -74,13 +78,19 @@ class LocalRunner(BaseRunner): status = [] import torch - if 'CUDA_VISIBLE_DEVICES' in os.environ: + if is_npu_available(): + visible_devices = 'ASCEND_RT_VISIBLE_DEVICES' + device_nums = torch.npu.device_count() + else: + visible_devices = 'CUDA_VISIBLE_DEVICES' + device_nums = torch.cuda.device_count() + if visible_devices in os.environ: all_gpu_ids = [ - int(i) for i in re.findall(r'(?