mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Bug] Fix-NPU-Support (#1618)
* bugfix NPU support * formatting --------- Co-authored-by: noemotiovon <noemotiovon@gmail.com>
This commit is contained in:
parent
500b44ba2d
commit
5868d5afa4
@ -2,6 +2,7 @@ import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from mmengine.device import is_npu_available
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
@ -9,7 +10,13 @@ from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if is_npu_available():
|
||||
backend = 'npu'
|
||||
elif torch.cuda.is_available():
|
||||
backend = 'cuda'
|
||||
else:
|
||||
backend = 'cpu'
|
||||
device = torch.device(backend)
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
|
@ -3,6 +3,7 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmengine.device import is_npu_available
|
||||
|
||||
from opencompass.models.base import BaseModel, LMTemplateParser
|
||||
from opencompass.models.base_api import APITemplateParser
|
||||
@ -224,6 +225,8 @@ class HuggingFacewithChatTemplate(BaseModel):
|
||||
model_kwargs.update(kwargs)
|
||||
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
||||
self.logger.debug(f'using model_kwargs: {model_kwargs}')
|
||||
if is_npu_available():
|
||||
model_kwargs['device_map'] = 'npu'
|
||||
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user