mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Support lmdeploy pytorch engine (#875)
* add lmdeploy pytorch model * fix * speed up encoding and decoding * fix * change tokenizer
This commit is contained in:
parent
6d04decab4
commit
32ba0b074e
69
configs/eval_internlm_chat_lmdeploy_pytorch.py
Normal file
69
configs/eval_internlm_chat_lmdeploy_pytorch.py
Normal file
@ -0,0 +1,69 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.models import LmdeployPytorchModel
|
||||
|
||||
|
||||
with read_base():
|
||||
# choose a list of datasets
|
||||
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
|
||||
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
|
||||
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
|
||||
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
|
||||
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
||||
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
|
||||
from .datasets.race.race_gen_69ee4f import race_datasets
|
||||
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
|
||||
# and output the results in a choosen format
|
||||
from .summarizers.medium import summarizer
|
||||
|
||||
|
||||
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
||||
|
||||
|
||||
meta_template = dict(
|
||||
round=[
|
||||
dict(role='HUMAN', begin='<|User|>:', end='<eoh>\n'),
|
||||
dict(role='BOT', begin='<|Bot|>:', end='<eoa>\n', generate=True),
|
||||
],
|
||||
eos_token_id=103028)
|
||||
|
||||
# config for internlm-chat-7b
|
||||
internlm_chat_7b = dict(
|
||||
type=LmdeployPytorchModel,
|
||||
abbr='internlm-chat-7b-pytorch',
|
||||
path='internlm/internlm-chat-7b',
|
||||
engine_config=dict(session_len=2048,
|
||||
max_batch_size=16),
|
||||
gen_config=dict(top_k=1,
|
||||
top_p=0.8,
|
||||
temperature=1.0,
|
||||
max_new_tokens=100),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=16,
|
||||
concurrency=16,
|
||||
meta_template=meta_template,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
end_str='<eoa>',
|
||||
)
|
||||
|
||||
# config for internlm-chat-20b
|
||||
internlm_chat_20b = dict(
|
||||
type=LmdeployPytorchModel,
|
||||
abbr='internlm-chat-20b-pytorch',
|
||||
path='internlm/internlm-chat-20b',
|
||||
engine_config=dict(session_len=2048,
|
||||
max_batch_size=8),
|
||||
gen_config=dict(top_k=1,
|
||||
top_p=0.8,
|
||||
temperature=1.0,
|
||||
max_new_tokens=100),
|
||||
max_out_len=100,
|
||||
max_seq_len=2048,
|
||||
batch_size=8,
|
||||
concurrency=8,
|
||||
meta_template=meta_template,
|
||||
run_cfg=dict(num_gpus=1, num_procs=1),
|
||||
end_str='<eoa>',
|
||||
)
|
||||
|
||||
models = [internlm_chat_20b]
|
@ -14,6 +14,7 @@ from .huggingface import HuggingFaceChatGLM3 # noqa: F401, F403
|
||||
from .intern_model import InternLM # noqa: F401, F403
|
||||
from .lightllm_api import LightllmAPI # noqa: F401
|
||||
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
||||
from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401
|
||||
from .minimax_api import MiniMax # noqa: F401
|
||||
from .mixtral import Mixtral # noqa: F401
|
||||
from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403
|
||||
|
157
opencompass/models/lmdeploy_pytorch.py
Normal file
157
opencompass/models/lmdeploy_pytorch.py
Normal file
@ -0,0 +1,157 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.utils.logging import get_logger
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
def valid_str(string, coding='utf-8'):
|
||||
"""decode text according to its encoding type."""
|
||||
invalid_chars = [b'\xef\xbf\xbd']
|
||||
bstr = bytes(string, coding)
|
||||
for invalid_char in invalid_chars:
|
||||
bstr = bstr.replace(invalid_char, b'')
|
||||
ret = bstr.decode(encoding=coding, errors='ignore')
|
||||
return ret
|
||||
|
||||
|
||||
class LmdeployPytorchModel(BaseModel):
|
||||
"""Model wrapper for lmdeploy pytorch engine through python API.
|
||||
|
||||
Args:
|
||||
path (str): path of the supported pytorch model.
|
||||
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||
Note that the length of prompt + generated tokens shall not exceed
|
||||
this value. Defaults to 2048.
|
||||
meta_template (Dict, optional): The model's meta prompt
|
||||
template if needed, in case the requirement of injecting or
|
||||
wrapping of any meta instructions.
|
||||
engine_config (Dict, optional): The engine config to set
|
||||
arguments like session_len, max_batch_size for TurboMind.
|
||||
gen_config (Dict, optional): Generation config to set
|
||||
arguments like top_k, top_p, temperature.
|
||||
end_str (str, optional): Whether to trim generated strings with end_str
|
||||
if the model has special ending strings that are not handled well.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
concurrency: int = 8,
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
engine_config: Optional[Dict] = None,
|
||||
gen_config: Optional[Dict] = None,
|
||||
end_str: Optional[str] = None):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template)
|
||||
from lmdeploy.pytorch import engine as tm
|
||||
|
||||
if engine_config is not None:
|
||||
from lmdeploy.messages import PytorchEngineConfig
|
||||
engine_config = PytorchEngineConfig(**engine_config)
|
||||
if gen_config is not None:
|
||||
from lmdeploy.messages import EngineGenerationConfig
|
||||
gen_config = EngineGenerationConfig(**gen_config)
|
||||
|
||||
self.logger = get_logger()
|
||||
tm_model = tm.Engine(path, engine_config)
|
||||
self.tokenizer = tm_model.tokenizer
|
||||
self.generators = [
|
||||
tm_model.create_instance() for i in range(concurrency)
|
||||
]
|
||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||
self.gen_config = gen_config
|
||||
self.end_str = end_str
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: List[str],
|
||||
max_out_len: int = 512,
|
||||
) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[str]): A list of prompts
|
||||
max_out_len (int): The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
assert isinstance(
|
||||
inputs, List), f'List(str) is expected, but got {type(inputs)}'
|
||||
|
||||
# split inputs into batches
|
||||
batch_size = len(self.generators)
|
||||
batch_inputs = [
|
||||
inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)
|
||||
]
|
||||
|
||||
results = []
|
||||
for batch_input in batch_inputs:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
_results = list(
|
||||
executor.map(
|
||||
self._generate,
|
||||
self.generators[:len(batch_input)],
|
||||
self.generator_ids[:len(batch_input)],
|
||||
batch_input,
|
||||
[self.gen_config] * len(batch_input),
|
||||
[self.end_str] * len(batch_input),
|
||||
))
|
||||
results += _results
|
||||
return results
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
return len(input_ids)
|
||||
|
||||
def wait(self):
|
||||
"""Wait till the next query can be sent.
|
||||
|
||||
Applicable in both single-thread and multi-thread environments.
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
||||
def _generate(self,
|
||||
generator,
|
||||
session_id,
|
||||
prompt: str or PromptList,
|
||||
gen_config=None,
|
||||
end_str: Optional[str] = None) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
prompt (str or PromptList): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
gen_config (EngineGenerationConfig, optional): Generation
|
||||
config to set arguments like top_k, top_p, temperature.
|
||||
end_str (str, optional): Whether to trim generated strings
|
||||
with end_str if the model has special ending strings
|
||||
that are not handled well.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert type(
|
||||
prompt) is str, 'We only support string for TurboMind Python API'
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
_, output_ids, _ = generator.infer(session_id,
|
||||
input_ids,
|
||||
gen_config=gen_config)
|
||||
# stop engine
|
||||
if hasattr(generator, 'end'):
|
||||
generator.end(session_id)
|
||||
# decode output
|
||||
response_all = self.tokenizer.decode(output_ids)
|
||||
# trim output
|
||||
if end_str:
|
||||
response_all = response_all.split(end_str)[0]
|
||||
# remove invalid characters
|
||||
response_all = valid_str(response_all)
|
||||
return response_all
|
Loading…
Reference in New Issue
Block a user