diff --git a/opencompass/datasets/truthfulqa.py b/opencompass/datasets/truthfulqa.py index 35ae28de..9a4d20c5 100644 --- a/opencompass/datasets/truthfulqa.py +++ b/opencompass/datasets/truthfulqa.py @@ -2,6 +2,7 @@ import evaluate import numpy as np import torch from datasets import load_dataset +from mmengine.device import is_npu_available from transformers import AutoModelForCausalLM, AutoTokenizer from opencompass.openicl.icl_evaluator import BaseEvaluator @@ -9,7 +10,13 @@ from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET from .base import BaseDataset -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if is_npu_available(): + backend = 'npu' +elif torch.cuda.is_available(): + backend = 'cuda' +else: + backend = 'cpu' +device = torch.device(backend) @LOAD_DATASET.register_module() diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 0276ebba..d1d2b3f3 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Union import torch +from mmengine.device import is_npu_available from opencompass.models.base import BaseModel, LMTemplateParser from opencompass.models.base_api import APITemplateParser @@ -224,6 +225,8 @@ class HuggingFacewithChatTemplate(BaseModel): model_kwargs.update(kwargs) model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) self.logger.debug(f'using model_kwargs: {model_kwargs}') + if is_npu_available(): + model_kwargs['device_map'] = 'npu' try: self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)