[Feat] enable HuggingFacewithChatTemplate with --accelerator via cli (#1163)

* enable HuggingFacewithChatTemplate with --accelerator via cli

* rm vllm_internlm2_chat_7b
This commit is contained in:
Fengzhe Zhou 2024-05-15 21:51:07 +08:00 committed by GitHub
parent e3c0448bbc
commit 8ea2c404d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 418 additions and 86 deletions

View File

@ -1,36 +1,23 @@
from opencompass.models.turbomind import TurboMindModel
_meta_template = dict(
round=[
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
],
)
from opencompass.models import TurboMindModelwithChatTemplate
models = [
dict(
type=TurboMindModel,
type=TurboMindModelwithChatTemplate,
abbr='internlm2-chat-7b-turbomind',
path='internlm/internlm2-chat-7b',
meta_template=_meta_template,
engine_config=dict(
session_len=32768,
max_batch_size=32,
model_name='internlm2-chat-7b',
max_batch_size=16,
tp=1,
stop_words=[2, 92542],
),
gen_config=dict(
top_k=1,
top_p=0.8,
temperature=1.0,
max_new_tokens=2000,
temperature=1e-6,
top_p=0.9,
),
max_out_len=2000,
max_seq_len=32768,
batch_size=32,
concurrency=8,
run_cfg=dict(num_gpus=1, num_procs=1),
max_seq_len=2048,
max_out_len=1024,
batch_size=32768,
run_cfg=dict(num_gpus=1),
stop_words=['</s>', '<|im_end|>'],
)
]

View File

@ -0,0 +1,13 @@
from opencompass.models import VLLMwithChatTemplate
models = [
dict(
type=VLLMwithChatTemplate,
abbr='internlm2-chat-7b-vllm',
path='internlm/internlm2-chat-7b',
model_kwargs=dict(tensor_parallel_size=1),
max_out_len=1024,
batch_size=32768,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -0,0 +1,13 @@
from opencompass.models import VLLMwithChatTemplate
models = [
dict(
type=VLLMwithChatTemplate,
abbr='qwen1.5-7b-chat-vllm',
path='Qwen/Qwen1.5-7B-Chat',
model_kwargs=dict(tensor_parallel_size=1),
max_out_len=1024,
batch_size=32768,
run_cfg=dict(num_gpus=1),
)
]

View File

@ -55,8 +55,8 @@ def parse_args():
parser.add_argument(
'-a', '--accelerator',
help='Infer accelerator, support vllm and lmdeploy now.',
choices=['vllm', 'lmdeploy', 'hf'],
default='hf',
choices=['vllm', 'lmdeploy', None],
default=None,
type=str)
parser.add_argument('-m',
'--mode',

View File

@ -35,8 +35,11 @@ from .sensetime_api import SenseTime # noqa: F401
from .stepfun_api import StepFun # noqa: F401
from .turbomind import TurboMindModel # noqa: F401
from .turbomind_tis import TurboMindTisModel # noqa: F401
from .turbomind_with_tf_above_v4_33 import \
TurboMindModelwithChatTemplate # noqa: F401
from .unigpt_api import UniGPT # noqa: F401
from .vllm import VLLM # noqa: F401
from .vllm_with_tf_above_v4_33 import VLLMwithChatTemplate # noqa: F401
from .xunfei_api import XunFei, XunFeiSpark # noqa: F401
from .yayi_api import Yayi # noqa: F401
from .zhipuai_api import ZhiPuAI # noqa: F401

View File

@ -64,7 +64,7 @@ def _convert_chat_messages(inputs):
for _input in inputs:
messages = []
if isinstance(_input, str):
messages.append({'role': 'HUMAN', 'prompt': _input})
messages.append({'role': 'HUMAN', 'content': _input})
else:
for item in _input:
role = {

View File

@ -0,0 +1,195 @@
# flake8: noqa
# yapf: disable
import copy
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
from opencompass.models.base import BaseModel
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList
from .huggingface_above_v4_33 import (_convert_chat_messages,
_format_with_fast_chat_template,
_get_meta_template,
_get_possible_max_seq_len)
PromptType = Union[PromptList, str]
def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding)
for invalid_char in invalid_chars:
bstr = bstr.replace(invalid_char, b'')
ret = bstr.decode(encoding=coding, errors='ignore')
return ret
class TurboMindModelwithChatTemplate(BaseModel):
def __init__(
self,
path: str,
tokenizer_only: bool = False,
engine_config: Dict = {},
gen_config: Dict = {},
concurrency: int = 8,
max_seq_len: int = None,
meta_template: Optional[Dict] = None,
fastchat_template: Optional[str] = None,
stop_words: List[str] = [],
):
from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.turbomind import TurboMind
from lmdeploy.version import version_info
from transformers import AutoTokenizer
self.logger = get_logger()
self.path = path
self.tokenizer_only = tokenizer_only
self.template_parser = _get_meta_template(meta_template)
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
self.origin_tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
if not tokenizer_only:
DEFAULT_ENGING_CONFIG = {'session_len': self.max_seq_len}
_engine_config = DEFAULT_ENGING_CONFIG.copy()
_engine_config.update(engine_config)
engine_config = TurbomindEngineConfig(**_engine_config)
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
self.tokenizer = tm_model.tokenizer
self.generators = [tm_model.create_instance() for i in range(concurrency)]
self.generator_ids = [i + 1 for i in range(concurrency)]
self.concurrency = concurrency
self.gen_config = gen_config
self.version_info = version_info
self.fastchat_template = fastchat_template
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
self.logger.info(f'using stop words: {self.stop_words}')
def _get_potential_stop_words(self, path: Optional[str]):
from transformers import GenerationConfig
potential_stop_words = []
try:
generation_config = GenerationConfig.from_pretrained(path)
for token_id in generation_config.eos_token_id:
potential_stop_words.append(self.origin_tokenizer.decode(token_id))
except:
pass
potential_stop_words.append(self.origin_tokenizer.eos_token)
potential_stop_words = list(set(potential_stop_words))
return potential_stop_words
def generate(self,
inputs: List[str],
max_out_len: int = 512,
stopping_criteria: List[str] = [],
do_sample: Optional[bool] = None,
temperature: int = 1,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of prompts
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
assert isinstance(inputs, List), f'List(str) is expected, but got {type(inputs)}'
messages = _convert_chat_messages(inputs)
if self.fastchat_template:
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
else:
messages = [self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
# split messages into batches
batch_messages = [messages[i:i + self.concurrency] for i in range(0, len(messages), self.concurrency)]
stop_words = list(set(self.stop_words + stopping_criteria))
DEFAULT_GEN_CONFIG = {
'max_new_tokens': max_out_len,
'min_new_tokens': 1,
'top_k': 1,
'stop_words': stop_words,
}
gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG)
gen_config.update(self.gen_config)
if do_sample:
gen_config['top_k'] = 1000
gen_config['temperature'] = temperature
# if stopping_criteria:
# stop_words = gen_config.get('stop_words', [])
# for t in stopping_criteria:
# t = self.tokenizer.encode(t, add_bos=False)
# stop_words.append(t[0])
# gen_config['stop_words'] = list(set(stop_words))
from lmdeploy.messages import EngineGenerationConfig, GenerationConfig
gen_config = GenerationConfig(**gen_config)
gen_config = EngineGenerationConfig.From(gen_config, self.tokenizer)
results = []
for batch_message in batch_messages:
n = len(batch_message)
with ThreadPoolExecutor() as executor:
_results = list(
executor.map(
self._generate,
self.generators[:n],
self.generator_ids[:n],
batch_message,
[gen_config] * n,
))
results += _results
for s in stop_words:
results = [r.split(s)[0] for r in results]
return results
def _generate(self,
generator,
session_id,
prompt: PromptType,
gen_config=None) -> str:
"""Generate results given a list of inputs.
Args:
prompt (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
gen_config (EngineGenerationConfig, optional): Generation
config to set arguments like top_k, top_p, temperature.
Returns:
str: The generated string.
"""
assert type(prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt)
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[input_ids],
gen_config=gen_config,
sequence_start=True,
sequence_end=True,
step=0,
stream_output=False):
if self.version_info >= (0, 4, 0):
output_ids = outputs.token_ids
else:
_, output_ids, _ = outputs
response = self.tokenizer.decode(output_ids)
response = valid_str(response)
return response
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
m = _convert_chat_messages([prompt])[0]
t = self.origin_tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True)
return len(t['input_ids'])

View File

@ -130,21 +130,6 @@ class VLLM(BaseModel):
ce_loss.append(loss)
return np.array(ce_loss)
def prompts_preproccess(self, inputs: List[str]):
if self.use_fastchat_template:
try:
from fastchat.model import get_conversation_template
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Fastchat is not implemented. You can use '
"'pip install \"fschat[model_worker,webui]\"' "
'to implement fastchat.')
conv = get_conversation_template('vicuna')
conv.append_message(conv.roles[0], inputs[0])
conv.append_message(conv.roles[1], None)
inputs = [conv.get_prompt()]
return inputs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.

View File

@ -0,0 +1,127 @@
# flake8: noqa
# yapf: disable
from typing import Dict, List, Optional
import numpy as np
from opencompass.models.base import BaseModel
from opencompass.utils import get_logger
from .huggingface_above_v4_33 import (_convert_chat_messages,
_format_with_fast_chat_template,
_get_meta_template,
_get_possible_max_seq_len)
try:
from vllm import LLM, SamplingParams
except ImportError:
LLM, SamplingParams = None, None
class VLLMwithChatTemplate(BaseModel):
def __init__(
self,
path: str,
model_kwargs: dict = dict(),
tokenizer_only: bool = False,
generation_kwargs: dict = dict(),
max_seq_len: int = None,
meta_template: Optional[Dict] = None,
fastchat_template: Optional[str] = None,
stop_words: List[str] = [],
):
assert LLM, ('Please install VLLM with `pip install vllm`. note: torch==2.1.2 is required.')
self.logger = get_logger()
self.path = path
self.tokenizer_only = tokenizer_only
self.template_parser = _get_meta_template(meta_template)
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
if tokenizer_only:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
else:
self._load_model(path, model_kwargs)
self.tokenizer = self.model.get_tokenizer()
self.generation_kwargs = generation_kwargs
self.generation_kwargs.pop('do_sample', None)
self.fastchat_template = fastchat_template
self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
def _load_model(self, path: str, added_model_kwargs: dict = dict()):
import ray
if ray.is_initialized():
self.logger.info('shutdown ray instance to avoid "Calling ray.init() again" error.')
ray.shutdown()
DEFAULT_MODEL_KWARGS = dict(trust_remote_code=True)
model_kwargs = DEFAULT_MODEL_KWARGS.copy()
model_kwargs.update(added_model_kwargs)
self.model = LLM(path, **model_kwargs)
def _get_potential_stop_words(self, path: Optional[str]):
from transformers import GenerationConfig
potential_stop_words = []
try:
generation_config = GenerationConfig.from_pretrained(path)
for token_id in generation_config.eos_token_id:
potential_stop_words.append(self.tokenizer.decode(token_id))
except:
pass
potential_stop_words.append(self.tokenizer.eos_token)
potential_stop_words = list(set(potential_stop_words))
return potential_stop_words
def generate(self, inputs: List[str], max_out_len: int, stopping_criteria: List[str] = [], **kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
messages = _convert_chat_messages(inputs)
if self.fastchat_template:
messages = _format_with_fast_chat_template(messages, self.fastchat_template)
else:
messages = [self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) for m in messages]
DEFAULT_GENERATION_KWARGS = {
'temperature': 0,
'max_tokens': max_out_len,
'stop': list(set(self.stop_words + stopping_criteria))
}
sampling_kwargs = DEFAULT_GENERATION_KWARGS.copy()
sampling_kwargs.update(self.generation_kwargs)
sampling_kwargs.update(kwargs)
sampling_kwargs = SamplingParams(**sampling_kwargs)
outputs = self.model.generate(messages, sampling_kwargs)
prompt_list, output_strs = [], []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
prompt_list.append(prompt)
output_strs.append(generated_text)
return output_strs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
m = _convert_chat_messages([prompt])[0]
t = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, return_dict=True)
return len(t['input_ids'])

View File

@ -1,3 +1,5 @@
# flake8: noqa
# yapf: disable
import os
from typing import List, Tuple, Union
@ -7,7 +9,9 @@ from mmengine.config import Config
from opencompass.datasets.custom import make_custom_dataset_config
from opencompass.models import (VLLM, HuggingFace, HuggingFaceBaseModel,
HuggingFaceCausalLM, HuggingFaceChatGLM3,
HuggingFacewithChatTemplate, TurboMindModel)
HuggingFacewithChatTemplate, TurboMindModel,
TurboMindModelwithChatTemplate,
VLLMwithChatTemplate)
from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner
from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner
from opencompass.tasks import OpenICLEvalTask, OpenICLInferTask
@ -79,25 +83,16 @@ def get_config_from_arg(args) -> Config:
config = try_fill_in_custom_cfgs(config)
# set infer accelerator if needed
if args.accelerator in ['vllm', 'lmdeploy']:
config['models'] = change_accelerator(config['models'],
args.accelerator)
config['models'] = change_accelerator(config['models'], args.accelerator)
if 'eval' in config and 'partitioner' in config['eval']:
if 'models' in config['eval']['partitioner']:
config['eval']['partitioner'][
'models'] = change_accelerator(
config['eval']['partitioner']['models'],
args.accelerator)
if 'judge_models' in config['eval']['partitioner']:
config['eval']['partitioner'][
'judge_models'] = change_accelerator(
config['eval']['partitioner']['judge_models'],
args.accelerator)
config['eval']['partitioner']['models'] = change_accelerator(config['eval']['partitioner']['models'], args.accelerator)
if config.get('eval', {}).get('partitioner', {}).get('judge_models') is not None:
config['eval']['partitioner']['judge_models'] = change_accelerator(config['eval']['partitioner']['judge_models'], args.accelerator)
return config
# parse dataset args
if not args.datasets and not args.custom_dataset_path:
raise ValueError('You must specify "--datasets" or '
'"--custom-dataset-path" if you do not specify a '
'config file path.')
raise ValueError('You must specify "--datasets" or "--custom-dataset-path" if you do not specify a config file path.')
datasets = []
if args.datasets:
datasets_dir = os.path.join(args.config_dir, 'datasets')
@ -110,7 +105,7 @@ def get_config_from_arg(args) -> Config:
dataset_key_suffix = '_datasets'
for dataset in match_cfg_file(datasets_dir, [dataset_name]):
get_logger().info(f'Loading {dataset[0]}: {dataset[1]}')
logger.info(f'Loading {dataset[0]}: {dataset[1]}')
cfg = Config.fromfile(dataset[1])
for k in cfg.keys():
if k.endswith(dataset_key_suffix):
@ -128,19 +123,15 @@ def get_config_from_arg(args) -> Config:
# parse model args
if not args.models and not args.hf_path:
raise ValueError('You must specify a config file path, '
'or specify --models and --datasets, or '
'specify HuggingFace model parameters and '
'--datasets.')
raise ValueError('You must specify a config file path, or specify --models and --datasets, or specify HuggingFace model parameters and --datasets.')
models = []
if args.models:
model_dir = os.path.join(args.config_dir, 'models')
for model in match_cfg_file(model_dir, args.models):
get_logger().info(f'Loading {model[0]}: {model[1]}')
logger.info(f'Loading {model[0]}: {model[1]}')
cfg = Config.fromfile(model[1])
if 'models' not in cfg:
raise ValueError(
f'Config file {model[1]} does not contain "models" field')
raise ValueError(f'Config file {model[1]} does not contain "models" field')
models += cfg['models']
else:
if args.hf_type == 'chat':
@ -167,8 +158,7 @@ def get_config_from_arg(args) -> Config:
if args.accelerator in ['vllm', 'lmdeploy']:
models = change_accelerator(models, args.accelerator)
# parse summarizer args
summarizer_arg = args.summarizer if args.summarizer is not None \
else 'example'
summarizer_arg = args.summarizer if args.summarizer is not None else 'example'
summarizers_dir = os.path.join(args.config_dir, 'summarizers')
# Check if summarizer_arg contains '/'
@ -188,9 +178,7 @@ def get_config_from_arg(args) -> Config:
# from the configuration file
summarizer = cfg[summarizer_key]
return Config(dict(models=models, datasets=datasets,
summarizer=summarizer),
format_python_code=False)
return Config(dict(models=models, datasets=datasets, summarizer=summarizer), format_python_code=False)
def change_accelerator(models, accelerator):
@ -200,20 +188,15 @@ def change_accelerator(models, accelerator):
for model in models:
logger.info(f'Transforming {model["abbr"]} to {accelerator}')
# change HuggingFace model to VLLM or TurboMindModel
if model['type'] in [
HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3
]:
if model['type'] in [HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3]:
gen_args = dict()
if model.get('generation_kwargs') is not None:
generation_kwargs = model['generation_kwargs'].copy()
gen_args['temperature'] = generation_kwargs.get(
'temperature', 0.001)
gen_args['temperature'] = generation_kwargs.get('temperature', 0.001)
gen_args['top_k'] = generation_kwargs.get('top_k', 1)
gen_args['top_p'] = generation_kwargs.get('top_p', 0.9)
gen_args['stop_token_ids'] = generation_kwargs.get(
'eos_token_id', None)
generation_kwargs['stop_token_ids'] = generation_kwargs.get(
'eos_token_id', None)
gen_args['stop_token_ids'] = generation_kwargs.get('eos_token_id', None)
generation_kwargs['stop_token_ids'] = generation_kwargs.get('eos_token_id', None)
generation_kwargs.pop('eos_token_id')
else:
# if generation_kwargs is not provided, set default values
@ -228,8 +211,7 @@ def change_accelerator(models, accelerator):
mod = TurboMindModel
acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}',
abbr=model['abbr'].replace('hf', 'lmdeploy')
if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy',
abbr=model['abbr'].replace('hf', 'lmdeploy') if '-hf' in model['abbr'] else model['abbr'] + '-lmdeploy',
path=model['path'],
engine_config=dict(session_len=model['max_seq_len'],
max_batch_size=model['batch_size'],
@ -253,11 +235,9 @@ def change_accelerator(models, accelerator):
acc_model = dict(
type=f'{VLLM.__module__}.{VLLM.__name__}',
abbr=model['abbr'].replace('hf', 'vllm')
if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
abbr=model['abbr'].replace('hf', 'vllm') if '-hf' in model['abbr'] else model['abbr'] + '-vllm',
path=model['path'],
model_kwargs=dict(
tensor_parallel_size=model['run_cfg']['num_gpus']),
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']),
max_out_len=model['max_out_len'],
max_seq_len=model['max_seq_len'],
batch_size=model['batch_size'],
@ -268,7 +248,36 @@ def change_accelerator(models, accelerator):
if model.get(item) is not None:
acc_model[item] = model[item]
else:
raise ValueError(f'Unsupported accelerator {accelerator}')
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
elif model['type'] in [HuggingFacewithChatTemplate]:
if accelerator == 'vllm':
mod = VLLMwithChatTemplate
acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}',
abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-vllm',
path=model['path'],
model_kwargs=dict(tensor_parallel_size=model['run_cfg']['num_gpus']),
max_out_len=model['max_out_len'],
batch_size=32768,
run_cfg=model['run_cfg'],
stop_words=model.get('stop_words', []),
)
elif accelerator == 'lmdeploy':
mod = TurboMindModelwithChatTemplate
acc_model = dict(
type=f'{mod.__module__}.{mod.__name__}',
abbr='-hf'.join(model['abbr'].split('-hf')[:-1]) + '-turbomind',
path=model['path'],
engine_config=dict(max_batch_size=model.get('batch_size', 16), tp=model['run_cfg']['num_gpus']),
gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9),
max_seq_len=model.get('max_seq_len', 2048),
max_out_len=model['max_out_len'],
batch_size=32768,
run_cfg=model['run_cfg'],
stop_words=model.get('stop_words', []),
)
else:
raise ValueError(f'Unsupported accelerator {accelerator} for model type {model["type"]}')
else:
raise ValueError(f'Unsupported model type {model["type"]}')
model_accels.append(acc_model)