From 5868d5afa4d9574441c38988aa26e44de0bdd3b2 Mon Sep 17 00:00:00 2001 From: Chenguang Li <87689256+noemotiovon@users.noreply.github.com> Date: Mon, 21 Oct 2024 17:42:53 +0800 Subject: [PATCH] [Bug] Fix-NPU-Support (#1618) * bugfix NPU support * formatting --------- Co-authored-by: noemotiovon --- opencompass/datasets/truthfulqa.py | 9 ++++++++- opencompass/models/huggingface_above_v4_33.py | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/opencompass/datasets/truthfulqa.py b/opencompass/datasets/truthfulqa.py index 35ae28de..9a4d20c5 100644 --- a/opencompass/datasets/truthfulqa.py +++ b/opencompass/datasets/truthfulqa.py @@ -2,6 +2,7 @@ import evaluate import numpy as np import torch from datasets import load_dataset +from mmengine.device import is_npu_available from transformers import AutoModelForCausalLM, AutoTokenizer from opencompass.openicl.icl_evaluator import BaseEvaluator @@ -9,7 +10,13 @@ from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET from .base import BaseDataset -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if is_npu_available(): + backend = 'npu' +elif torch.cuda.is_available(): + backend = 'cuda' +else: + backend = 'cpu' +device = torch.device(backend) @LOAD_DATASET.register_module() diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 0276ebba..d1d2b3f3 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Union import torch +from mmengine.device import is_npu_available from opencompass.models.base import BaseModel, LMTemplateParser from opencompass.models.base_api import APITemplateParser @@ -224,6 +225,8 @@ class HuggingFacewithChatTemplate(BaseModel): model_kwargs.update(kwargs) model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) self.logger.debug(f'using model_kwargs: {model_kwargs}') + if is_npu_available(): + model_kwargs['device_map'] = 'npu' try: self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)