mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] enable HuggingFacewithChatTemplate with --accelerator via cli (#1163)
* enable HuggingFacewithChatTemplate with --accelerator via cli * rm vllm_internlm2_chat_7b
This commit is contained in:
parent
e3c0448bbc
commit
8ea2c404d7
@ -1,36 +1,23 @@
|
||||
from opencompass.models.turbomind import TurboMindModel
|
||||
|
||||
|
||||
_meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
|
||||
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
|
||||
],
|
||||
)
|
||||
from opencompass.models import TurboMindModelwithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=TurboMindModel,
|
||||
type=TurboMindModelwithChatTemplate,
|
||||
abbr='internlm2-chat-7b-turbomind',
|
||||
path='internlm/internlm2-chat-7b',
|
||||
meta_template=_meta_template,
|
||||
engine_config=dict(
|
||||
session_len=32768,
|
||||
max_batch_size=32,
|
||||
model_name='internlm2-chat-7b',
|
||||
max_batch_size=16,
|
||||
tp=1,
|
||||
stop_words=[2, 92542],
|
||||
),
|
||||
gen_config=dict(
|
||||
top_k=1,
|
||||
top_p=0.8,
|
||||
temperature=1.0,
|
||||
max_new_tokens=2000,
|
||||
temperature=1e-6,
|
||||
top_p=0.9,
|
||||
),
|
||||
max_out_len=2000,
|
||||
max_seq_len=32768,
|
||||
batch_size=32,
|
||||
concurrency=8,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
max_seq_len=2048,
|
||||
max_out_len=1024,
|
||||
batch_size=32768,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
stop_words=['</s>', '<|im_end|>'],
|
||||
)
|
||||
]
|
||||
|
13
configs/models/hf_internlm/vllm_internlm2_chat_7b.py
Normal file
13
configs/models/hf_internlm/vllm_internlm2_chat_7b.py
Normal file
@ -0,0 +1,13 @@
|
||||
from opencompass.models import VLLMwithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=VLLMwithChatTemplate,
|
||||
abbr='internlm2-chat-7b-vllm',
|
||||
path='internlm/internlm2-chat-7b',
|
||||
model_kwargs=dict(tensor_parallel_size=1),
|
||||
max_out_len=1024,
|
||||
batch_size=32768,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
13
configs/models/qwen/vllm_qwen1_5_7b_chat.py
Normal file
13
configs/models/qwen/vllm_qwen1_5_7b_chat.py
Normal file
@ -0,0 +1,13 @@
|
||||
from opencompass.models import VLLMwithChatTemplate
|
||||
|
||||
models = [
|
||||
dict(
|
||||
type=VLLMwithChatTemplate,
|
||||
abbr='qwen1.5-7b-chat-vllm',
|
||||
path='Qwen/Qwen1.5-7B-Chat',
|
||||
model_kwargs=dict(tensor_parallel_size=1),
|
||||
max_out_len=1024,
|
||||
batch_size=32768,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
@ -55,8 +55,8 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
'-a', '--accelerator',
|
||||
help='Infer accelerator, support vllm and lmdeploy now.',
|
||||
choices=['vllm', 'lmdeploy', 'hf'],
|
||||
default='hf',
|
||||
choices=['vllm', 'lmdeploy', None],
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-m',
|
||||
'--mode',
|
||||
|
@ -35,8 +35,11 @@ from .sensetime_api import SenseTime # noqa: F401
|
||||
from .stepfun_api import StepFun # noqa: F401
|
||||
from .turbomind import TurboMindModel # noqa: F401
|
||||
from .turbomind_tis import TurboMindTisModel # noqa: F401
|
||||
from .turbomind_with_tf_above_v4_33 import \
|
||||
TurboMindModelwithChatTemplate # noqa: F401
|
||||
from .unigpt_api import UniGPT # noqa: F401
|
||||
from .vllm import VLLM # noqa: F401
|
||||
from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401
|
||||
from .xunfei_api import XunFei, XunFeiSpark # noqa: F401
|
||||
from .yayi_api import Yayi # noqa: F401
|
||||
from .zhipuai_api import ZhiPuAI # noqa: F401
|
||||
|
@ -64,7 +64,7 @@ def _convert_chat_messages(inputs):
|
||||
for _input in inputs:
|
||||
messages = []
|
||||
if isinstance(_input, str):
|
||||
messages.append({'role': 'HUMAN', 'prompt': _input})
|
||||
messages.append({'role': 'HUMAN', 'content': _input})
|
||||
else:
|
||||
for item in _input:
|
||||
role = {
|
||||
|
195
opencompass/models/turbomind_with_tf_above_v4_33.py
Normal file
195
opencompass/models/turbomind_with_tf_above_v4_33.py
Normal file
@ -0,0 +1,195 @@
|
||||
# flake8: noqa
|
||||
# yapf: disable
|
||||
import copy
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.utils.logging import get_logger
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .huggingface_above_v4_33 import (_convert_chat_messages,
|
||||
_format_with_fast_chat_template,
|
||||
_get_meta_template,
|
||||
_get_possible_max_seq_len)
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
def valid_str(string, coding='utf-8'):
|
||||
"""decode text according to its encoding type."""
|
||||
invalid_chars = [b'\xef\xbf\xbd']
|
||||
bstr = bytes(string, coding)
|
||||
for invalid_char in invalid_chars:
|
||||
bstr = bstr.replace(invalid_char, b'')
|
||||
ret = bstr.decode(encoding=coding, errors='ignore')
|
||||
return ret
|
||||
|
||||
|
||||
class TurboMindModelwithChatTemplate(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
tokenizer_only: bool = False,
|
||||
engine_config: Dict = {},
|
||||
gen_config: Dict = {},
|
||||
concurrency: int = 8,
|
||||
max_seq_len: int = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
fastchat_template: Optional[str] = None,
|
||||
stop_words: List[str] = [],
|
||||
):
|
||||
from lmdeploy.messages import TurbomindEngineConfig
|
||||
from lmdeploy.turbomind import TurboMind
|
||||
from lmdeploy.version import version_info
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.logger = get_logger()
|
||||
self.path = path
|
||||
self.tokenizer_only = tokenizer_only
|
||||
self.template_parser = _get_meta_template(meta_template)
|
||||
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
||||
|
||||
self.origin_tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||
if not tokenizer_only:
|
||||
DEFAULT_ENGING_CONFIG = {'session_len': self.max_seq_len}
|
||||
_engine_config = DEFAULT_ENGING_CONFIG.copy()
|
||||
_engine_config.update(engine_config)
|
||||
engine_config = TurbomindEngineConfig(**_engine_config)
|
||||
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
|
||||
self.tokenizer = tm_model.tokenizer
|
||||
self.generators = [tm_model.create_instance() for i in range(concurrency)]
|
||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||
self.concurrency = concurrency
|
||||
self.gen_config = gen_config
|
||||
self.version_info = version_info
|
||||
self.fastchat_template = fastchat_template
|
||||
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
|
||||
self.logger.info(f'using stop words: {self.stop_words}')
|
||||
|
||||
def _get_potential_stop_words(self, path: Optional[str]):
|
||||
from transformers import GenerationConfig
|
||||
potential_stop_words = []
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(path)
|
||||
for token_id in generation_config.eos_token_id:
|
||||
potential_stop_words.append(self.origin_tokenizer.decode(token_id))
|
||||
except:
|
||||
pass
|
||||
potential_stop_words.append(self.origin_tokenizer.eos_token)
|
||||
potential_stop_words = list(set(potential_stop_words))
|
||||
return potential_stop_words
|
||||
|
||||
def generate(self,
|
||||
inputs: List[str],
|
||||
max_out_len: int = 512,
|
||||
stopping_criteria: List[str] = [],
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: int = 1,
|
||||
**kwargs) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of prompts
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
assert isinstance(inputs, List), f'List(str) is expected, but got {type(inputs)}'
|
||||
|
||||
messages = _convert_chat_messages(inputs)
|
||||
if self.fastchat_template:
|
||||
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
|
||||
else:
|
||||
messages = [self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
|
||||
|
||||
# split messages into batches
|
||||
batch_messages = [messages[i:i + self.concurrency] for i in range(0, len(messages), self.concurrency)]
|
||||
|
||||
stop_words = list(set(self.stop_words + stopping_criteria))
|
||||
DEFAULT_GEN_CONFIG = {
|
||||
'max_new_tokens': max_out_len,
|
||||
'min_new_tokens': 1,
|
||||
'top_k': 1,
|
||||
'stop_words': stop_words,
|
||||
}
|
||||
gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG)
|
||||
gen_config.update(self.gen_config)
|
||||
if do_sample:
|
||||
gen_config['top_k'] = 1000
|
||||
gen_config['temperature'] = temperature
|
||||
# if stopping_criteria:
|
||||
# stop_words = gen_config.get('stop_words', [])
|
||||
# for t in stopping_criteria:
|
||||
# t = self.tokenizer.encode(t, add_bos=False)
|
||||
# stop_words.append(t[0])
|
||||
# gen_config['stop_words'] = list(set(stop_words))
|
||||
from lmdeploy.messages import EngineGenerationConfig, GenerationConfig
|
||||
gen_config = GenerationConfig(**gen_config)
|
||||
gen_config = EngineGenerationConfig.From(gen_config, self.tokenizer)
|
||||
|
||||
results = []
|
||||
for batch_message in batch_messages:
|
||||
n = len(batch_message)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
_results = list(
|
||||
executor.map(
|
||||
self._generate,
|
||||
self.generators[:n],
|
||||
self.generator_ids[:n],
|
||||
batch_message,
|
||||
[gen_config] * n,
|
||||
))
|
||||
results += _results
|
||||
|
||||
for s in stop_words:
|
||||
results = [r.split(s)[0] for r in results]
|
||||
return results
|
||||
|
||||
def _generate(self,
|
||||
generator,
|
||||
session_id,
|
||||
prompt: PromptType,
|
||||
gen_config=None) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
prompt (PromptType): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
gen_config (EngineGenerationConfig, optional): Generation
|
||||
config to set arguments like top_k, top_p, temperature.
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert type(prompt) is str, 'We only support string for TurboMind Python API'
|
||||
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
for outputs in generator.stream_infer(session_id=session_id,
|
||||
input_ids=[input_ids],
|
||||
gen_config=gen_config,
|
||||
sequence_start=True,
|
||||
sequence_end=True,
|
||||
step=0,
|
||||
stream_output=False):
|
||||
if self.version_info >= (0, 4, 0):
|
||||
output_ids = outputs.token_ids
|
||||
else:
|
||||
_, output_ids, _ = outputs
|
||||
response = self.tokenizer.decode(output_ids)
|
||||
response = valid_str(response)
|
||||
return response
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized strings.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
m = _convert_chat_messages([prompt])[0]
|
||||
t = self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True)
|
||||
return len(t['input_ids'])
|
@ -130,21 +130,6 @@ class VLLM(BaseModel):
|
||||
ce_loss.append(loss)
|
||||
return np.array(ce_loss)
|
||||
|
||||
def prompts_preproccess(self, inputs: List[str]):
|
||||
if self.use_fastchat_template:
|
||||
try:
|
||||
from fastchat.model import get_conversation_template
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
'Fastchat is not implemented. You can use '
|
||||
"'pip install \"fschat[model_worker,webui]\"' "
|
||||
'to implement fastchat.')
|
||||
conv = get_conversation_template('vicuna')
|
||||
conv.append_message(conv.roles[0], inputs[0])
|
||||
conv.append_message(conv.roles[1], None)
|
||||
inputs = [conv.get_prompt()]
|
||||
return inputs
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized strings.
|
||||
|
||||
|
127
opencompass/models/vllm_with_tf_above_v4_33.py
Normal file
127
opencompass/models/vllm_with_tf_above_v4_33.py
Normal file
@ -0,0 +1,127 @@
|
||||
# flake8: noqa
|
||||
# yapf: disable
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.utils import get_logger
|
||||
|
||||
from .huggingface_above_v4_33 import (_convert_chat_messages,
|
||||
_format_with_fast_chat_template,
|
||||
_get_meta_template,
|
||||
_get_possible_max_seq_len)
|
||||
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
except ImportError:
|
||||
LLM, SamplingParams = None, None
|
||||
|
||||
|
||||
class VLLMwithChatTemplate(BaseModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_kwargs: dict = dict(),
|
||||
tokenizer_only: bool = False,
|
||||
generation_kwargs: dict = dict(),
|
||||
max_seq_len: int = None,
|
||||
meta_template: Optional[Dict] = None,
|
||||
fastchat_template: Optional[str] = None,
|
||||
stop_words: List[str] = [],
|
||||
):
|
||||
assert LLM, ('Please install VLLM with `pip install vllm`. note: torch==2.1.2 is required.')
|
||||
|
||||
self.logger = get_logger()
|
||||
self.path = path
|
||||
self.tokenizer_only = tokenizer_only
|
||||
self.template_parser = _get_meta_template(meta_template)
|
||||
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
||||
if tokenizer_only:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||
else:
|
||||
self._load_model(path, model_kwargs)
|
||||
self.tokenizer = self.model.get_tokenizer()
|
||||
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.generation_kwargs.pop('do_sample', None)
|
||||
self.fastchat_template = fastchat_template
|
||||
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
|
||||
|
||||
def _load_model(self, path: str, added_model_kwargs: dict = dict()):
|
||||
import ray
|
||||
|
||||
if ray.is_initialized():
|
||||
self.logger.info('shutdown ray instance to avoid "Calling ray.init() again" error.')
|
||||
ray.shutdown()
|
||||
|
||||
DEFAULT_MODEL_KWARGS = dict(trust_remote_code=True)
|
||||
model_kwargs = DEFAULT_MODEL_KWARGS.copy()
|
||||
model_kwargs.update(added_model_kwargs)
|
||||
self.model = LLM(path, **model_kwargs)
|
||||
|
||||
def _get_potential_stop_words(self, path: Optional[str]):
|
||||
from transformers import GenerationConfig
|
||||
potential_stop_words = []
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(path)
|
||||
for token_id in generation_config.eos_token_id:
|
||||
potential_stop_words.append(self.tokenizer.decode(token_id))
|
||||
except:
|
||||
pass
|
||||
potential_stop_words.append(self.tokenizer.eos_token)
|
||||
potential_stop_words = list(set(potential_stop_words))
|
||||
return potential_stop_words
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int, stopping_criteria: List[str] = [], **kwargs) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of strings.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
messages = _convert_chat_messages(inputs)
|
||||
if self.fastchat_template:
|
||||
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
|
||||
else:
|
||||
messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
|
||||
|
||||
DEFAULT_GENERATION_KWARGS = {
|
||||
'temperature': 0,
|
||||
'max_tokens': max_out_len,
|
||||
'stop': list(set(self.stop_words + stopping_criteria))
|
||||
}
|
||||
sampling_kwargs = DEFAULT_GENERATION_KWARGS.copy()
|
||||
sampling_kwargs.update(self.generation_kwargs)
|
||||
sampling_kwargs.update(kwargs)
|
||||
sampling_kwargs = SamplingParams(**sampling_kwargs)
|
||||
|
||||
outputs = self.model.generate(messages, sampling_kwargs)
|
||||
|
||||
prompt_list, output_strs = [], []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
prompt_list.append(prompt)
|
||||
output_strs.append(generated_text)
|
||||
|
||||
return output_strs
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
"""Get lengths of the tokenized strings.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
|
||||
Returns:
|
||||
int: Length of the input tokens
|
||||
"""
|
||||
m = _convert_chat_messages([prompt])[0]
|
||||
t = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True)
|
||||
return len(t['input_ids'])
|
@ -1,3 +1,5 @@
|
||||
# flake8: noqa
|
||||
# yapf: disable
|
||||
import os
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
@ -7,7 +9,9 @@ from mmengine.config import Config
|
||||
from opencompass.datasets.custom import make_custom_dataset_config
|
||||
from opencompass.models import (VLLM, HuggingFace, HuggingFaceBaseModel,
|
||||
HuggingFaceCausalLM, HuggingFaceChatGLM3,
|
||||
HuggingFacewithChatTemplate, TurboMindModel)
|
||||
HuggingFacewithChatTemplate, TurboMindModel,
|
||||
TurboMindModelwithChatTemplate,
|
||||
VLLMwithChatTemplate)
|
||||
from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner
|
||||
from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner
|
||||
from opencompass.tasks import OpenICLEvalTask, OpenICLInferTask
|
||||
@ -79,25 +83,16 @@ def get_config_from_arg(args) -> Config:
|
||||
config = try_fill_in_custom_cfgs(config)
|
||||
# set infer accelerator if needed
|
||||
if args.accelerator in ['vllm', 'lmdeploy']:
|
||||
config['models'] = change_accelerator(config['models'],
|
||||
args.accelerator)
|
||||
config['models'] = change_accelerator(config['models'], args.accelerator)
|
||||
if 'eval' in config and 'partitioner' in config['eval']:
|
||||
if 'models' in config['eval']['partitioner']:
|
||||
config['eval']['partitioner'][
|
||||
'models'] = change_accelerator(
|
||||
config['eval']['partitioner']['models'],
|
||||
args.accelerator)
|
||||
if 'judge_models' in config['eval']['partitioner']:
|
||||
config['eval']['partitioner'][
|
||||
'judge_models'] = change_accelerator(
|
||||
config['eval']['partitioner']['judge_models'],
|
||||
args.accelerator)
|
||||
config['eval']['partitioner']['models'] = change_accelerator(config['eval']['partitioner']['models'], args.accelerator)
|
||||
if config.get('eval', {}).get('partitioner', {}).get('judge_models') is not None:
|
||||
config['eval']['partitioner']['judge_models'] = change_accelerator(config['eval']['partitioner']['judge_models'], args.accelerator)
|
||||
return config
|
||||
|
||||
# parse dataset args
|
||||
if not args.datasets and not args.custom_dataset_path:
|
||||
raise ValueError('You must specify "--datasets" or '
|
||||
'"--custom-dataset-path" if you do not specify a '
|
||||
'config file path.')
|
||||
raise ValueError('You must specify "--datasets" or "--custom-dataset-path" if you do not specify a config file path.')
|
||||
datasets = []
|
||||
if args.datasets:
|
||||
datasets_dir = os.path.join(args.config_dir, 'datasets')
|
||||
@ -110,7 +105,7 @@ def get_config_from_arg(args) -> Config:
|
||||
dataset_key_suffix = '_datasets'
|
||||
|
||||
for dataset in match_cfg_file(datasets_dir, [dataset_name]):
|
||||
get_logger().info(f'Loading {dataset[0]}: {dataset[1]}')
|
||||
logger.info(f'Loading {dataset[0]}: {dataset[1]}')
|
||||
cfg = Config.fromfile(dataset[1])
|
||||
for k in cfg.keys():
|
||||
if k.endswith(dataset_key_suffix):
|
||||
@ -128,19 +123,15 @@ def get_config_from_arg(args) -> Config:
|
||||
|
||||
# parse model args
|
||||
if not args.models and not args.hf_path:
|
||||
raise ValueError('You must specify a config file path, '
|
||||
'or specify --models and --datasets, or '
|
||||
'specify HuggingFace model parameters and '
|
||||
'--datasets.')
|
||||
raise ValueError('You must specify a config file path, or specify --models and --datasets, or specify HuggingFace model parameters and --datasets.')
|
||||
models = []
|
||||
if args.models:
|
||||
model_dir = os.path.join(args.config_dir, 'models')
|
||||
for model in match_cfg_file(model_dir, args.models):
|
||||
get_logger().info(f'Loading {model[0]}: {model[1]}')
|
||||
logger.info(f'Loading {model[0]}: {model[1]}')
|
||||
cfg = Config.fromfile(model[1])
|
||||
if 'models' not in cfg:
|
||||
raise ValueError(
|
||||
f'Config file {model[1]} does not contain "models" field')
|
||||
raise ValueError(f'Config file {model[1]} does not contain "models" field')
|
||||
models += cfg['models']
|
||||
else:
|
||||
if args.hf_type == 'chat':
|
||||
@ -167,8 +158,7 @@ def get_config_from_arg(args) -> Config:
|
||||
if args.accelerator in ['vllm', 'lmdeploy']:
|
||||
models = change_accelerator(models, args.accelerator)
|
||||
# parse summarizer args
|
||||
summarizer_arg = args.summarizer if args.summarizer is not None \
|
||||
else 'example'
|
||||
summarizer_arg = args.summarizer if args.summarizer is not None else 'example'
|
||||
summarizers_dir = os.path.join(args.config_dir, 'summarizers')
|
||||
|
||||
# Check if summarizer_arg contains '/'
|
||||
@ -188,9 +178,7 @@ def get_config_from_arg(args) -> Config:
|
||||
# from the configuration file
|
||||
summarizer = cfg[summarizer_key]
|
||||
|
||||
return Config(dict(models=models, datasets=datasets,
|
||||
summarizer=summarizer),
|
||||
format_python_code=False)
|
||||
return Config(dict(models=models, datasets=datasets, summarizer=summarizer), format_python_code=False)
|
||||
|
||||
|
||||
def change_accelerator(models, accelerator):
|
||||
@ -200,20 +188,15 @@ def change_accelerator(models, accelerator):
|
||||
for model in models:
|
||||
logger.info(f'Transforming {model["abbr"]} to {accelerator}')
|
||||
# change HuggingFace model to VLLM or TurboMindModel
|
||||
if model['type'] in [
|
||||
HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3
|
||||
]:
|
||||
if model['type'] in [HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3]:
|
||||
gen_args = dict()
|
||||
if model.get('generation_kwargs') is not None:
|
||||
generation_kwargs = model['generation_kwargs'].copy()
|
||||
gen_args['temperature'] = generation_kwargs.get(
|
||||
'temperature', 0.001)
|
||||
gen_args['temperature'] = generation_kwargs.get('temperature', 0.001)
|
||||
gen_args['top_k'] = generation_kwargs.get('top_k', 1)
|
||||
gen_args['top_p'] = generation_kwargs.get('top_p', 0.9)
|
||||
gen_args['stop_token_ids'] = generation_kwargs.get(
|
||||
'eos_token_id', None)
|
||||
generation_kwargs['stop_token_ids'] = generation_kwargs.get(
|
||||
'eos_token_id', None)
|
||||
gen_args['stop_token_ids'] = generation_kwargs.get('eos_token_id', None)
|
||||
generation_kwargs['stop_token_ids'] = generation_kwargs.get('eos_token_id', None)
|
||||
generation_kwargs.pop('eos_token_id')
|
||||
else:
|
||||
# if generation_kwargs is not provided, set default values
|
||||
@ -228,8 +211,7 @@ def change_accelerator(models, accelerator):
|
||||
mod = TurboMindModel
|
||||
acc_model = dict(
|
||||
type=f'{mod.__module__}.{mod.__name__}',
|
||||
abbr=model['abbr'].replace('hf', 'lmdeploy')
|
||||
if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy',
|
||||
abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy',
|
||||
path=model['path'],
|
||||
engine_config=dict(session_len=model['max_seq_len'],
|
||||
max_batch_size=model['batch_size'],
|
||||
@ -253,11 +235,9 @@ def change_accelerator(models, accelerator):
|
||||
|
||||
acc_model = dict(
|
||||
type=f'{VLLM.__module__}.{VLLM.__name__}',
|
||||
abbr=model['abbr'].replace('hf', 'vllm')
|
||||
if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
||||
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
|
||||
path=model['path'],
|
||||
model_kwargs=dict(
|
||||
tensor_parallel_size=model['run_cfg']['num_gpus']),
|
||||
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']),
|
||||
max_out_len=model['max_out_len'],
|
||||
max_seq_len=model['max_seq_len'],
|
||||
batch_size=model['batch_size'],
|
||||
@ -268,7 +248,36 @@ def change_accelerator(models, accelerator):
|
||||
if model.get(item) is not None:
|
||||
acc_model[item] = model[item]
|
||||
else:
|
||||
raise ValueError(f'Unsupported accelerator {accelerator}')
|
||||
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
|
||||
elif model['type'] in [HuggingFacewithChatTemplate]:
|
||||
if accelerator == 'vllm':
|
||||
mod = VLLMwithChatTemplate
|
||||
acc_model = dict(
|
||||
type=f'{mod.__module__}.{mod.__name__}',
|
||||
abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-vllm',
|
||||
path=model['path'],
|
||||
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']),
|
||||
max_out_len=model['max_out_len'],
|
||||
batch_size=32768,
|
||||
run_cfg=model['run_cfg'],
|
||||
stop_words=model.get('stop_words', []),
|
||||
)
|
||||
elif accelerator == 'lmdeploy':
|
||||
mod = TurboMindModelwithChatTemplate
|
||||
acc_model = dict(
|
||||
type=f'{mod.__module__}.{mod.__name__}',
|
||||
abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-turbomind',
|
||||
path=model['path'],
|
||||
engine_config=dict(max_batch_size=model.get('batch_size', 16), tp=model['run_cfg']['num_gpus']),
|
||||
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9),
|
||||
max_seq_len=model.get('max_seq_len', 2048),
|
||||
max_out_len=model['max_out_len'],
|
||||
batch_size=32768,
|
||||
run_cfg=model['run_cfg'],
|
||||
stop_words=model.get('stop_words', []),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
|
||||
else:
|
||||
raise ValueError(f'Unsupported model type {model["type"]}')
|
||||
model_accels.append(acc_model)
|
||||
|
Loading…
Reference in New Issue
Block a user