mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] support download from modelscope (#534)
* [Feature] download from modelscope * [Feature] download from modelscope * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
parent
048775192b
commit
c0785e53d8
30
configs/models/ms_internlm/ms_internlm_chat_7b_8k.py
Normal file
30
configs/models/ms_internlm/ms_internlm_chat_7b_8k.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from opencompass.models import ModelScopeCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'),
|
||||||
|
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=ModelScopeCausalLM,
|
||||||
|
abbr='internlm-chat-7b-8k-ms',
|
||||||
|
path='Shanghai_AI_Laboratory/internlm-chat-7b-8k',
|
||||||
|
tokenizer_path='Shanghai_AI_Laboratory/internlm-chat-7b-8k',
|
||||||
|
tokenizer_kwargs=dict(
|
||||||
|
padding_side='left',
|
||||||
|
truncation_side='left',
|
||||||
|
use_fast=False,
|
||||||
|
trust_remote_code=True,
|
||||||
|
),
|
||||||
|
max_out_len=100,
|
||||||
|
max_seq_len=2048,
|
||||||
|
batch_size=8,
|
||||||
|
meta_template=_meta_template,
|
||||||
|
model_kwargs=dict(trust_remote_code=True, device_map='auto'),
|
||||||
|
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||||
|
)
|
||||||
|
]
|
30
configs/models/qwen/ms_qwen_7b_chat.py
Normal file
30
configs/models/qwen/ms_qwen_7b_chat.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from opencompass.models import ModelScopeCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
_meta_template = dict(
|
||||||
|
round=[
|
||||||
|
dict(role="HUMAN", begin='\n<|im_start|>user\n', end='<|im_end|>'),
|
||||||
|
dict(role="BOT", begin="\n<|im_start|>assistant\n", end='<|im_end|>', generate=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
dict(
|
||||||
|
type=ModelScopeCausalLM,
|
||||||
|
abbr='qwen-7b-chat-ms',
|
||||||
|
path="qwen/Qwen-7B-Chat",
|
||||||
|
tokenizer_path='qwen/Qwen-7B-Chat',
|
||||||
|
tokenizer_kwargs=dict(
|
||||||
|
padding_side='left',
|
||||||
|
truncation_side='left',
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_fast=False,),
|
||||||
|
pad_token_id=151643,
|
||||||
|
max_out_len=100,
|
||||||
|
max_seq_len=2048,
|
||||||
|
batch_size=8,
|
||||||
|
meta_template=_meta_template,
|
||||||
|
model_kwargs=dict(device_map='auto', trust_remote_code=True),
|
||||||
|
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||||
|
)
|
||||||
|
]
|
@ -14,6 +14,7 @@ from .intern_model import InternLM # noqa: F401, F403
|
|||||||
from .lightllm_api import LightllmAPI # noqa: F401
|
from .lightllm_api import LightllmAPI # noqa: F401
|
||||||
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
||||||
from .minimax_api import MiniMax # noqa: F401
|
from .minimax_api import MiniMax # noqa: F401
|
||||||
|
from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403
|
||||||
from .openai_api import OpenAI # noqa: F401
|
from .openai_api import OpenAI # noqa: F401
|
||||||
from .pangu_api import PanGu # noqa: F401
|
from .pangu_api import PanGu # noqa: F401
|
||||||
from .sensetime_api import SenseTime # noqa: F401
|
from .sensetime_api import SenseTime # noqa: F401
|
||||||
|
215
opencompass/models/modelscope.py
Normal file
215
opencompass/models/modelscope.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
|
||||||
|
from .huggingface import HuggingFace
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelScope(HuggingFace):
|
||||||
|
"""Model wrapper around ModelScope models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The name or path to ModelScope's model.
|
||||||
|
ms_cache_dir: Set the cache dir to MS model cache dir. If None, it will
|
||||||
|
use the env variable MS_MODEL_HUB. Defaults to None.
|
||||||
|
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||||
|
to 2048.
|
||||||
|
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||||
|
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||||
|
Defaults to {}.
|
||||||
|
peft_path (str, optional): The name or path to the ModelScope's PEFT
|
||||||
|
model. If None, the original model will not be converted to PEFT.
|
||||||
|
Defaults to None.
|
||||||
|
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||||
|
Defaults to False.
|
||||||
|
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
||||||
|
Defaults to dict(device_map='auto').
|
||||||
|
meta_template (Dict, optional): The model's meta prompt
|
||||||
|
template if needed, in case the requirement of injecting or
|
||||||
|
wrapping of any meta instructions.
|
||||||
|
extract_pred_after_decode (bool): Whether to extract the prediction
|
||||||
|
string from the decoded output string, instead of extract the
|
||||||
|
prediction tokens before decoding. Defaults to False.
|
||||||
|
batch_padding (bool): If False, inference with be performed in for-loop
|
||||||
|
without batch padding.
|
||||||
|
pad_token_id (int): The id of the padding token. Defaults to None. Use
|
||||||
|
(#vocab + pad_token_id) if get negative value.
|
||||||
|
mode (str, optional): The method of input truncation when input length
|
||||||
|
exceeds max_seq_len. 'mid' represents the part of input to
|
||||||
|
truncate. Defaults to 'none'.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
About ``extract_pred_after_decode``: Commonly, we should extract the
|
||||||
|
the prediction tokens before decoding. But for some tokenizers using
|
||||||
|
``sentencepiece``, like LLaMA, this behavior may change the number of
|
||||||
|
whitespaces, which is harmful for Python programming tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
path: str,
|
||||||
|
ms_cache_dir: Optional[str] = None,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
tokenizer_kwargs: dict = dict(),
|
||||||
|
peft_path: Optional[str] = None,
|
||||||
|
tokenizer_only: bool = False,
|
||||||
|
model_kwargs: dict = dict(device_map='auto'),
|
||||||
|
meta_template: Optional[Dict] = None,
|
||||||
|
extract_pred_after_decode: bool = False,
|
||||||
|
batch_padding: bool = False,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
mode: str = 'none'):
|
||||||
|
super().__init__(
|
||||||
|
path=path,
|
||||||
|
hf_cache_dir=ms_cache_dir,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
peft_path=peft_path,
|
||||||
|
tokenizer_only=tokenizer_only,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
meta_template=meta_template,
|
||||||
|
extract_pred_after_decode=extract_pred_after_decode,
|
||||||
|
batch_padding=batch_padding,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
||||||
|
tokenizer_kwargs: dict):
|
||||||
|
from modelscope import AutoTokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_path if tokenizer_path else path, **tokenizer_kwargs)
|
||||||
|
|
||||||
|
# A patch for some models without pad_token_id
|
||||||
|
if self.pad_token_id is not None:
|
||||||
|
if self.pad_token_id < 0:
|
||||||
|
self.pad_token_id += self.tokenizer.vocab_size
|
||||||
|
if self.tokenizer.pad_token_id is None:
|
||||||
|
self.logger.debug(f'Using {self.pad_token_id} as pad_token_id')
|
||||||
|
elif self.tokenizer.pad_token_id != self.pad_token_id:
|
||||||
|
self.logger.warning(
|
||||||
|
'pad_token_id is not consistent with the tokenizer. Using '
|
||||||
|
f'{self.pad_token_id} as pad_token_id')
|
||||||
|
self.tokenizer.pad_token_id = self.pad_token_id
|
||||||
|
elif self.tokenizer.pad_token_id is None:
|
||||||
|
self.logger.warning('pad_token_id is not set for the tokenizer.')
|
||||||
|
if self.tokenizer.eos_token is not None:
|
||||||
|
self.logger.warning(
|
||||||
|
f'Using eos_token_id {self.tokenizer.eos_token} '
|
||||||
|
'as pad_token_id.')
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
else:
|
||||||
|
from modelscope import GenerationConfig
|
||||||
|
gcfg = GenerationConfig.from_pretrained(path)
|
||||||
|
|
||||||
|
if gcfg.pad_token_id is not None:
|
||||||
|
self.logger.warning(
|
||||||
|
f'Using pad_token_id {gcfg.pad_token_id} '
|
||||||
|
'as pad_token_id.')
|
||||||
|
self.tokenizer.pad_token_id = gcfg.pad_token_id
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'pad_token_id is not set for this tokenizer. Try to '
|
||||||
|
'set pad_token_id via passing '
|
||||||
|
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
||||||
|
|
||||||
|
# A patch for llama when batch_padding = True
|
||||||
|
if 'decapoda-research/llama' in path or \
|
||||||
|
(tokenizer_path and
|
||||||
|
'decapoda-research/llama' in tokenizer_path):
|
||||||
|
self.logger.warning('We set new pad_token_id for LLaMA model')
|
||||||
|
# keep consistent with official LLaMA repo
|
||||||
|
# https://github.com/google/sentencepiece/blob/master/python/sentencepiece_python_module_example.ipynb # noqa
|
||||||
|
self.tokenizer.bos_token = '<s>'
|
||||||
|
self.tokenizer.eos_token = '</s>'
|
||||||
|
self.tokenizer.pad_token_id = 0
|
||||||
|
|
||||||
|
def _set_model_kwargs_torch_dtype(self, model_kwargs):
|
||||||
|
if 'torch_dtype' not in model_kwargs:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = {
|
||||||
|
'torch.float16': torch.float16,
|
||||||
|
'torch.bfloat16': torch.bfloat16,
|
||||||
|
'torch.float': torch.float,
|
||||||
|
'auto': 'auto',
|
||||||
|
'None': None
|
||||||
|
}.get(model_kwargs['torch_dtype'])
|
||||||
|
self.logger.debug(f'MS using torch_dtype: {torch_dtype}')
|
||||||
|
if torch_dtype is not None:
|
||||||
|
model_kwargs['torch_dtype'] = torch_dtype
|
||||||
|
|
||||||
|
def _load_model(self,
|
||||||
|
path: str,
|
||||||
|
model_kwargs: dict,
|
||||||
|
peft_path: Optional[str] = None):
|
||||||
|
from modelscope import AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
|
self._set_model_kwargs_torch_dtype(model_kwargs)
|
||||||
|
try:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
path, **model_kwargs)
|
||||||
|
except ValueError:
|
||||||
|
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||||
|
|
||||||
|
if peft_path is not None:
|
||||||
|
from peft import PeftModel
|
||||||
|
self.model = PeftModel.from_pretrained(self.model,
|
||||||
|
peft_path,
|
||||||
|
is_trainable=False)
|
||||||
|
self.model.eval()
|
||||||
|
self.model.generation_config.do_sample = False
|
||||||
|
|
||||||
|
# A patch for llama when batch_padding = True
|
||||||
|
if 'decapoda-research/llama' in path:
|
||||||
|
self.model.config.bos_token_id = 1
|
||||||
|
self.model.config.eos_token_id = 2
|
||||||
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
|
||||||
|
class ModelScopeCausalLM(ModelScope):
|
||||||
|
"""Model wrapper around ModelScope CausalLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The name or path to ModelScope's model.
|
||||||
|
ms_cache_dir: Set the cache dir to MS model cache dir. If None, it will
|
||||||
|
use the env variable MS_MODEL_HUB. Defaults to None.
|
||||||
|
max_seq_len (int): The maximum length of the input sequence. Defaults
|
||||||
|
to 2048.
|
||||||
|
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
||||||
|
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
||||||
|
Defaults to {}.
|
||||||
|
peft_path (str, optional): The name or path to the ModelScope's PEFT
|
||||||
|
model. If None, the original model will not be converted to PEFT.
|
||||||
|
Defaults to None.
|
||||||
|
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
||||||
|
Defaults to False.
|
||||||
|
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
||||||
|
Defaults to dict(device_map='auto').
|
||||||
|
meta_template (Dict, optional): The model's meta prompt
|
||||||
|
template if needed, in case the requirement of injecting or
|
||||||
|
wrapping of any meta instructions.
|
||||||
|
batch_padding (bool): If False, inference with be performed in for-loop
|
||||||
|
without batch padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_model(self,
|
||||||
|
path: str,
|
||||||
|
model_kwargs: dict,
|
||||||
|
peft_path: Optional[str] = None):
|
||||||
|
from modelscope import AutoModelForCausalLM
|
||||||
|
|
||||||
|
self._set_model_kwargs_torch_dtype(model_kwargs)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||||
|
if peft_path is not None:
|
||||||
|
from peft import PeftModel
|
||||||
|
self.model = PeftModel.from_pretrained(self.model,
|
||||||
|
peft_path,
|
||||||
|
is_trainable=False)
|
||||||
|
self.model.eval()
|
||||||
|
self.model.generation_config.do_sample = False
|
Loading…
Reference in New Issue
Block a user