mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Bug] Fix-NPU-Support (#1618)
* bugfix NPU support * formatting --------- Co-authored-by: noemotiovon <noemotiovon@gmail.com>
This commit is contained in:
parent
500b44ba2d
commit
5868d5afa4
@ -2,6 +2,7 @@ import evaluate
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from mmengine.device import is_npu_available
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
@ -9,7 +10,13 @@ from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
|||||||
|
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
if is_npu_available():
|
||||||
|
backend = 'npu'
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
backend = 'cuda'
|
||||||
|
else:
|
||||||
|
backend = 'cpu'
|
||||||
|
device = torch.device(backend)
|
||||||
|
|
||||||
|
|
||||||
@LOAD_DATASET.register_module()
|
@LOAD_DATASET.register_module()
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from mmengine.device import is_npu_available
|
||||||
|
|
||||||
from opencompass.models.base import BaseModel, LMTemplateParser
|
from opencompass.models.base import BaseModel, LMTemplateParser
|
||||||
from opencompass.models.base_api import APITemplateParser
|
from opencompass.models.base_api import APITemplateParser
|
||||||
@ -224,6 +225,8 @@ class HuggingFacewithChatTemplate(BaseModel):
|
|||||||
model_kwargs.update(kwargs)
|
model_kwargs.update(kwargs)
|
||||||
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs)
|
||||||
self.logger.debug(f'using model_kwargs: {model_kwargs}')
|
self.logger.debug(f'using model_kwargs: {model_kwargs}')
|
||||||
|
if is_npu_available():
|
||||||
|
model_kwargs['device_map'] = 'npu'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user