mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Update TurboMindModel by integrating lmdeploy pipeline API (#1556)
* integrate lmdeploy's pipeline api * fix linting * update user guide * rename * update * update * update * rollback class name * update * remove unused code * update * update * use pipeline * fix ci check * compatibility * compatibility * remove concurrency * update * fix table content * update
This commit is contained in:
parent
5faee929db
commit
4fde41036f
@ -61,7 +61,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
|
||||
advanced_guides/new_dataset.md
|
||||
advanced_guides/custom_dataset.md
|
||||
advanced_guides/new_model.md
|
||||
advanced_guides/evaluation_turbomind.md
|
||||
advanced_guides/evaluation_lmdeploy.md
|
||||
advanced_guides/evaluation_lightllm.md
|
||||
advanced_guides/accelerator_intro.md
|
||||
advanced_guides/code_eval.md
|
||||
|
@ -32,7 +32,7 @@
|
||||
- **\[2023.08.18\]** [Dataset card](https://opencompass.org.cn/dataset-detail/MMLU) is now online. Welcome new evaluation benchmark OpenCompass !
|
||||
- **\[2023.08.11\]** [Model comparison](https://opencompass.org.cn/model-compare/GPT-4,ChatGPT,LLaMA-2-70B,LLaMA-65B) is now online. We hope this feature offers deeper insights!
|
||||
- **\[2023.08.11\]** We have supported [LEval](https://github.com/OpenLMLab/LEval).
|
||||
- **\[2023.08.10\]** OpenCompass is compatible with [LMDeploy](https://github.com/InternLM/lmdeploy). Now you can follow this [instruction](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_turbomind.html#) to evaluate the accelerated models provide by the **Turbomind**.
|
||||
- **\[2023.08.10\]** OpenCompass is compatible with [LMDeploy](https://github.com/InternLM/lmdeploy). Now you can follow this [instruction](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lmdeploy.html#) to evaluate the accelerated models provide by the **Turbomind**.
|
||||
- **\[2023.08.10\]** We have supported [Qwen-7B](https://github.com/QwenLM/Qwen-7B) and [XVERSE-13B](https://github.com/xverse-ai/XVERSE-13B) ! Go to our [leaderboard](https://opencompass.org.cn/leaderboard-llm) for more results! More models are welcome to join OpenCompass.
|
||||
- **\[2023.08.09\]** Several new datasets(**CMMLU, TydiQA, SQuAD2.0, DROP**) are updated on our [leaderboard](https://opencompass.org.cn/leaderboard-llm)! More datasets are welcomed to join OpenCompass.
|
||||
- **\[2023.08.07\]** We have added a [script](tools/eval_mmbench.py) for users to evaluate the inference results of [MMBench](https://opencompass.org.cn/MMBench)-dev.
|
||||
|
@ -73,7 +73,7 @@ models = [
|
||||
max_out_len=1024,
|
||||
# the max number of prompts that LMDeploy receives
|
||||
# in `generate` function
|
||||
batch_size=32,
|
||||
batch_size=5000,
|
||||
run_cfg=dict(num_gpus=1),
|
||||
)
|
||||
]
|
||||
|
@ -61,7 +61,7 @@ OpenCompass 上手路线
|
||||
advanced_guides/new_dataset.md
|
||||
advanced_guides/custom_dataset.md
|
||||
advanced_guides/new_model.md
|
||||
advanced_guides/evaluation_turbomind.md
|
||||
advanced_guides/evaluation_lmdeploy.md
|
||||
advanced_guides/evaluation_lightllm.md
|
||||
advanced_guides/accelerator_intro.md
|
||||
advanced_guides/code_eval.md
|
||||
|
@ -32,7 +32,7 @@
|
||||
- **\[2023.08.18\]** [数据集页面](https://opencompass.org.cn/dataset-detail/MMLU) 现已在OpenCompass官网上线,欢迎更多社区评测数据集加入OpenCompass !
|
||||
- **\[2023.08.11\]** 官网榜单上新增了[模型对比](https://opencompass.org.cn/model-compare/GPT-4,ChatGPT,LLaMA-2-70B,LLaMA-65B)功能,希望该功能可以协助提供更多发现!
|
||||
- **\[2023.08.11\]** 新增了 [LEval](https://github.com/OpenLMLab/LEval) 评测支持.
|
||||
- **\[2023.08.10\]** OpenCompass 现已适配 [LMDeploy](https://github.com/InternLM/lmdeploy). 请参考 [评测指南](https://opencompass.readthedocs.io/zh_CN/latest/advanced_guides/evaluation_turbomind.html) 对 **Turbomind** 加速后的模型进行评估.
|
||||
- **\[2023.08.10\]** OpenCompass 现已适配 [LMDeploy](https://github.com/InternLM/lmdeploy). 请参考 [评测指南](https://opencompass.readthedocs.io/zh_CN/latest/advanced_guides/evaluation_lmdeploy.html) 对 **Turbomind** 加速后的模型进行评估.
|
||||
- **\[2023.08.10\]** [Qwen-7B](https://github.com/QwenLM/Qwen-7B) 和 [XVERSE-13B](https://github.com/xverse-ai/XVERSE-13B)的评测结果已更新在 OpenCompass [大语言模型评测榜单](https://opencompass.org.cn/leaderboard-llm)!
|
||||
- **\[2023.08.09\]** 更新更多评测数据集(**CMMLU, TydiQA, SQuAD2.0, DROP**) ,请登录 [大语言模型评测榜单](https://opencompass.org.cn/leaderboard-llm) 查看更多结果! 欢迎添加你的评测数据集到OpenCompass.
|
||||
- **\[2023.08.07\]** 新增了 [MMBench 评测脚本](tools/eval_mmbench.py) 以支持用户自行获取 [MMBench](https://opencompass.org.cn/MMBench)-dev 的测试结果.
|
||||
|
@ -1,6 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -9,6 +7,8 @@ from opencompass.models.base import BaseModel
|
||||
from opencompass.utils.logging import get_logger
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from .huggingface_above_v4_33 import _get_possible_max_seq_len
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
|
||||
|
||||
@ -27,7 +27,9 @@ class TurboMindModel(BaseModel):
|
||||
|
||||
Args:
|
||||
path (str): path of the turbomind model
|
||||
concurrency (int): the maximum allowed concurrency of turbomind.
|
||||
backend (str): The infernce backend, which can be either 'turbomind' or
|
||||
'pytorch'. It will fallback to 'pytorch' once the model is not
|
||||
supported by 'turbomind'
|
||||
max_seq_len (int): The maximum allowed sequence length of a model.
|
||||
Note that the length of prompt + generated tokens shall not exceed
|
||||
this value. Defaults to 2048.
|
||||
@ -45,32 +47,30 @@ class TurboMindModel(BaseModel):
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
concurrency: int = 8,
|
||||
backend: str = 'turbomind',
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
engine_config: Dict = {},
|
||||
gen_config: Dict = {},
|
||||
batch_padding: bool = False,
|
||||
end_str: Optional[str] = None):
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template)
|
||||
from lmdeploy.turbomind import TurboMind
|
||||
from lmdeploy.version import version_info
|
||||
|
||||
if engine_config is not None:
|
||||
from lmdeploy.messages import TurbomindEngineConfig
|
||||
engine_config = TurbomindEngineConfig(**engine_config)
|
||||
self.logger = get_logger()
|
||||
if path.startswith('/') or path.startswith('.'):
|
||||
assert os.path.exists(path), '{} is not existist'.format(path)
|
||||
tm_model = TurboMind.from_pretrained(path, engine_config=engine_config)
|
||||
self.tokenizer = tm_model.tokenizer
|
||||
self.generators = [
|
||||
tm_model.create_instance() for i in range(concurrency)
|
||||
]
|
||||
self.generator_ids = [i + 1 for i in range(concurrency)]
|
||||
self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path)
|
||||
from lmdeploy import version_info
|
||||
from transformers import AutoTokenizer
|
||||
self.version_info = version_info
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path,
|
||||
trust_remote_code=True)
|
||||
|
||||
DEFAULT_ENGING_CONFIG = {'session_len': self.max_seq_len}
|
||||
_engine_config = DEFAULT_ENGING_CONFIG.copy()
|
||||
_engine_config.update(engine_config)
|
||||
self.pipe = self._build_pipe(path, backend, _engine_config)
|
||||
self.gen_config = gen_config
|
||||
self.major_version, self.minor_version, _ = version_info
|
||||
self.batch_padding = batch_padding
|
||||
self.end_str = end_str
|
||||
|
||||
def generate(self,
|
||||
@ -92,47 +92,39 @@ class TurboMindModel(BaseModel):
|
||||
assert isinstance(
|
||||
inputs, List), f'List(str) is expected, but got {type(inputs)}'
|
||||
|
||||
# split inputs into batches
|
||||
batch_size = len(self.generators)
|
||||
batch_inputs = [
|
||||
inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)
|
||||
]
|
||||
stop_words = list(set(stopping_criteria))
|
||||
|
||||
gen_config = copy.deepcopy(self.gen_config)
|
||||
if do_sample is not None:
|
||||
if do_sample:
|
||||
gen_config['top_k'] = 1000
|
||||
gen_config['temperature'] = temperature
|
||||
DEFAULT_GEN_CONFIG = {
|
||||
'max_new_tokens': max_out_len,
|
||||
'min_new_tokens': 1,
|
||||
'stop_words': stop_words,
|
||||
}
|
||||
|
||||
gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG)
|
||||
gen_config.update(self.gen_config)
|
||||
if do_sample:
|
||||
gen_config['top_k'] = 40
|
||||
gen_config['temperature'] = temperature
|
||||
else:
|
||||
if self.version_info >= (0, 6, 0):
|
||||
gen_config['do_sample'] = False
|
||||
else:
|
||||
gen_config['top_k'] = 1
|
||||
if stopping_criteria:
|
||||
stop_words = gen_config.get('stop_words', [])
|
||||
for t in stopping_criteria:
|
||||
t = self.tokenizer.encode(t, add_bos=False)
|
||||
stop_words.append(t[0])
|
||||
gen_config['stop_words'] = list(set(stop_words))
|
||||
gen_config.setdefault('min_new_tokens', 1)
|
||||
|
||||
from lmdeploy.messages import GenerationConfig
|
||||
from lmdeploy import GenerationConfig
|
||||
gen_config = {
|
||||
k: v
|
||||
for k, v in gen_config.items() if hasattr(GenerationConfig, k)
|
||||
}
|
||||
gen_config = GenerationConfig(**gen_config)
|
||||
|
||||
results = []
|
||||
for batch_input in batch_inputs:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
_results = list(
|
||||
executor.map(
|
||||
self._generate,
|
||||
self.generators[:len(batch_input)],
|
||||
self.generator_ids[:len(batch_input)],
|
||||
batch_input,
|
||||
[max_out_len] * len(batch_input),
|
||||
[gen_config] * len(batch_input),
|
||||
[self.end_str] * len(batch_input),
|
||||
))
|
||||
results += _results
|
||||
if stopping_criteria:
|
||||
for s in stopping_criteria:
|
||||
results = [r.split(s)[0] for r in results]
|
||||
outputs = self.pipe(inputs, gen_config=gen_config, do_preprocess=False)
|
||||
for output in outputs:
|
||||
text = self.tokenizer.decode(output.token_ids)
|
||||
results.append(text)
|
||||
for s in stop_words:
|
||||
results = [r.split(s)[0] for r in results]
|
||||
return results
|
||||
|
||||
def get_token_len(self, prompt: str) -> int:
|
||||
@ -146,56 +138,9 @@ class TurboMindModel(BaseModel):
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
||||
def _generate(self,
|
||||
generator,
|
||||
session_id,
|
||||
prompt: PromptType,
|
||||
max_out_len: int,
|
||||
gen_config=None,
|
||||
end_str: Optional[str] = None) -> str:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
prompt (PromptType): A string or PromptDict.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
gen_config (GenerationConfig, optional): Generation
|
||||
config to set arguments like top_k, top_p, temperature.
|
||||
end_str (str, optional): Whether to trim generated strings
|
||||
with end_str if the model has special ending strings
|
||||
that are not handled well.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
str: The generated string.
|
||||
"""
|
||||
assert type(
|
||||
prompt) is str, 'We only support string for TurboMind Python API'
|
||||
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
for outputs in generator.stream_infer(session_id=session_id,
|
||||
input_ids=[input_ids],
|
||||
gen_config=gen_config,
|
||||
request_output_len=max_out_len,
|
||||
sequence_start=True,
|
||||
sequence_end=True,
|
||||
step=0,
|
||||
stream_output=False):
|
||||
if self.major_version >= 0 and self.minor_version >= 4:
|
||||
output_ids = outputs.token_ids
|
||||
else:
|
||||
_, output_ids, _ = outputs
|
||||
response = self.tokenizer.decode(output_ids)
|
||||
response = valid_str(response)
|
||||
# used to trim
|
||||
if end_str:
|
||||
response = response.split(end_str)[0]
|
||||
return response
|
||||
|
||||
def get_ppl(self,
|
||||
inputs: List[str],
|
||||
mask_length: Optional[List[int]] = None) -> List[float]:
|
||||
mask_length: Optional[List[int]] = None) -> np.ndarray:
|
||||
"""Get perplexity scores given a list of inputs.
|
||||
|
||||
Args:
|
||||
@ -212,11 +157,28 @@ class TurboMindModel(BaseModel):
|
||||
assert isinstance(
|
||||
inputs, List), f'List(str) is expected, but got {type(inputs)}'
|
||||
results = []
|
||||
for text in inputs:
|
||||
input_ids = self.tokenizer.encode(text)
|
||||
res = self.generators[0].get_ppl(input_ids)
|
||||
results.append(res)
|
||||
results = np.concatenate(results)
|
||||
if self.version_info <= (0, 6, 0):
|
||||
for text in inputs:
|
||||
input_ids = self.tokenizer.encode(text)
|
||||
res = self.pipe.get_ppl(input_ids)
|
||||
results.append(res)
|
||||
results = np.concatenate(results)
|
||||
else:
|
||||
if self.batch_padding and len(inputs) > 1:
|
||||
assert self.tokenizer.pad_token
|
||||
input_ids = self.tokenizer(
|
||||
inputs,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_seq_len)['input_ids']
|
||||
else:
|
||||
input_ids = [
|
||||
self.tokenizer(text)['input_ids'] for text in inputs
|
||||
]
|
||||
for i in range(0, len(input_ids), 128):
|
||||
results.append(self.pipe.get_ppl(input_ids[i:i + 128]))
|
||||
results = np.concatenate(results)
|
||||
|
||||
return results
|
||||
|
||||
def get_loglikelihood(
|
||||
@ -229,11 +191,36 @@ class TurboMindModel(BaseModel):
|
||||
results = []
|
||||
for text, cont in zip(inputs, conts):
|
||||
input_ids = self.tokenizer.encode(text)
|
||||
res = self.generators[0].get_ppl(input_ids)
|
||||
res = self.pipe.get_ppl(input_ids)
|
||||
logit_sum = res * len(input_ids)
|
||||
input_ids = self.tokenizer.encode(text.replace(cont, ''))
|
||||
res = self.generators[0].get_ppl(input_ids)
|
||||
res = self.pipe.get_ppl(input_ids)
|
||||
logit_part = res * len(input_ids)
|
||||
results.append(-(logit_sum - logit_part))
|
||||
results = np.concatenate(results)
|
||||
return results
|
||||
|
||||
def _build_pipe(self, model_path, backend, engine_config):
|
||||
assert backend in ['pytorch', 'turbomind'], \
|
||||
f'unsupported backend type: {backend}'
|
||||
|
||||
from lmdeploy import (PytorchEngineConfig, TurbomindEngineConfig,
|
||||
pipeline)
|
||||
if backend == 'turbomind':
|
||||
filtered = {
|
||||
k: v
|
||||
for k, v in engine_config.items()
|
||||
if hasattr(TurbomindEngineConfig, k)
|
||||
}
|
||||
backend_config = TurbomindEngineConfig(**filtered)
|
||||
else:
|
||||
filtered = {
|
||||
k: v
|
||||
for k, v in engine_config.items()
|
||||
if hasattr(PytorchEngineConfig, k)
|
||||
}
|
||||
backend_config = PytorchEngineConfig(**filtered)
|
||||
return pipeline(model_path,
|
||||
backend_config=backend_config,
|
||||
log_level='INFO',
|
||||
max_log_len=10)
|
||||
|
Loading…
Reference in New Issue
Block a user