From 4fde41036f7c94a55d3a7457359ed28b35f41089 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Mon, 14 Oct 2024 15:33:40 +0800 Subject: [PATCH] [Feature] Update TurboMindModel by integrating lmdeploy pipeline API (#1556) * integrate lmdeploy's pipeline api * fix linting * update user guide * rename * update * update * update * rollback class name * update * remove unused code * update * update * use pipeline * fix ci check * compatibility * compatibility * remove concurrency * update * fix table content * update --- docs/en/index.rst | 2 +- docs/en/notes/news.md | 2 +- .../advanced_guides/evaluation_lmdeploy.md | 2 +- docs/zh_cn/index.rst | 2 +- docs/zh_cn/notes/news.md | 2 +- opencompass/models/turbomind.py | 203 ++++++++---------- 6 files changed, 100 insertions(+), 113 deletions(-) diff --git a/docs/en/index.rst b/docs/en/index.rst index 2f04aaee..fdad9c9e 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -61,7 +61,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass. advanced_guides/new_dataset.md advanced_guides/custom_dataset.md advanced_guides/new_model.md - advanced_guides/evaluation_turbomind.md + advanced_guides/evaluation_lmdeploy.md advanced_guides/evaluation_lightllm.md advanced_guides/accelerator_intro.md advanced_guides/code_eval.md diff --git a/docs/en/notes/news.md b/docs/en/notes/news.md index dd57bba6..b4a56fa5 100644 --- a/docs/en/notes/news.md +++ b/docs/en/notes/news.md @@ -32,7 +32,7 @@ - **\[2023.08.18\]** [Dataset card](https://opencompass.org.cn/dataset-detail/MMLU) is now online. Welcome new evaluation benchmark OpenCompass ! - **\[2023.08.11\]** [Model comparison](https://opencompass.org.cn/model-compare/GPT-4,ChatGPT,LLaMA-2-70B,LLaMA-65B) is now online. We hope this feature offers deeper insights! - **\[2023.08.11\]** We have supported [LEval](https://github.com/OpenLMLab/LEval). -- **\[2023.08.10\]** OpenCompass is compatible with [LMDeploy](https://github.com/InternLM/lmdeploy). Now you can follow this [instruction](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_turbomind.html#) to evaluate the accelerated models provide by the **Turbomind**. +- **\[2023.08.10\]** OpenCompass is compatible with [LMDeploy](https://github.com/InternLM/lmdeploy). Now you can follow this [instruction](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lmdeploy.html#) to evaluate the accelerated models provide by the **Turbomind**. - **\[2023.08.10\]** We have supported [Qwen-7B](https://github.com/QwenLM/Qwen-7B) and [XVERSE-13B](https://github.com/xverse-ai/XVERSE-13B) ! Go to our [leaderboard](https://opencompass.org.cn/leaderboard-llm) for more results! More models are welcome to join OpenCompass. - **\[2023.08.09\]** Several new datasets(**CMMLU, TydiQA, SQuAD2.0, DROP**) are updated on our [leaderboard](https://opencompass.org.cn/leaderboard-llm)! More datasets are welcomed to join OpenCompass. - **\[2023.08.07\]** We have added a [script](tools/eval_mmbench.py) for users to evaluate the inference results of [MMBench](https://opencompass.org.cn/MMBench)-dev. diff --git a/docs/zh_cn/advanced_guides/evaluation_lmdeploy.md b/docs/zh_cn/advanced_guides/evaluation_lmdeploy.md index 15839964..14bcbc6b 100644 --- a/docs/zh_cn/advanced_guides/evaluation_lmdeploy.md +++ b/docs/zh_cn/advanced_guides/evaluation_lmdeploy.md @@ -73,7 +73,7 @@ models = [ max_out_len=1024, # the max number of prompts that LMDeploy receives # in `generate` function - batch_size=32, + batch_size=5000, run_cfg=dict(num_gpus=1), ) ] diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 44f22c1a..37a3bc0c 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -61,7 +61,7 @@ OpenCompass 上手路线 advanced_guides/new_dataset.md advanced_guides/custom_dataset.md advanced_guides/new_model.md - advanced_guides/evaluation_turbomind.md + advanced_guides/evaluation_lmdeploy.md advanced_guides/evaluation_lightllm.md advanced_guides/accelerator_intro.md advanced_guides/code_eval.md diff --git a/docs/zh_cn/notes/news.md b/docs/zh_cn/notes/news.md index 29305359..9f1d15c6 100644 --- a/docs/zh_cn/notes/news.md +++ b/docs/zh_cn/notes/news.md @@ -32,7 +32,7 @@ - **\[2023.08.18\]** [数据集页面](https://opencompass.org.cn/dataset-detail/MMLU) 现已在OpenCompass官网上线,欢迎更多社区评测数据集加入OpenCompass ! - **\[2023.08.11\]** 官网榜单上新增了[模型对比](https://opencompass.org.cn/model-compare/GPT-4,ChatGPT,LLaMA-2-70B,LLaMA-65B)功能,希望该功能可以协助提供更多发现! - **\[2023.08.11\]** 新增了 [LEval](https://github.com/OpenLMLab/LEval) 评测支持. -- **\[2023.08.10\]** OpenCompass 现已适配 [LMDeploy](https://github.com/InternLM/lmdeploy). 请参考 [评测指南](https://opencompass.readthedocs.io/zh_CN/latest/advanced_guides/evaluation_turbomind.html) 对 **Turbomind** 加速后的模型进行评估. +- **\[2023.08.10\]** OpenCompass 现已适配 [LMDeploy](https://github.com/InternLM/lmdeploy). 请参考 [评测指南](https://opencompass.readthedocs.io/zh_CN/latest/advanced_guides/evaluation_lmdeploy.html) 对 **Turbomind** 加速后的模型进行评估. - **\[2023.08.10\]** [Qwen-7B](https://github.com/QwenLM/Qwen-7B) 和 [XVERSE-13B](https://github.com/xverse-ai/XVERSE-13B)的评测结果已更新在 OpenCompass [大语言模型评测榜单](https://opencompass.org.cn/leaderboard-llm)! - **\[2023.08.09\]** 更新更多评测数据集(**CMMLU, TydiQA, SQuAD2.0, DROP**) ,请登录 [大语言模型评测榜单](https://opencompass.org.cn/leaderboard-llm) 查看更多结果! 欢迎添加你的评测数据集到OpenCompass. - **\[2023.08.07\]** 新增了 [MMBench 评测脚本](tools/eval_mmbench.py) 以支持用户自行获取 [MMBench](https://opencompass.org.cn/MMBench)-dev 的测试结果. diff --git a/opencompass/models/turbomind.py b/opencompass/models/turbomind.py index e6cfebd2..687fef0d 100644 --- a/opencompass/models/turbomind.py +++ b/opencompass/models/turbomind.py @@ -1,6 +1,4 @@ import copy -import os -from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Union import numpy as np @@ -9,6 +7,8 @@ from opencompass.models.base import BaseModel from opencompass.utils.logging import get_logger from opencompass.utils.prompt import PromptList +from .huggingface_above_v4_33 import _get_possible_max_seq_len + PromptType = Union[PromptList, str] @@ -27,7 +27,9 @@ class TurboMindModel(BaseModel): Args: path (str): path of the turbomind model - concurrency (int): the maximum allowed concurrency of turbomind. + backend (str): The infernce backend, which can be either 'turbomind' or + 'pytorch'. It will fallback to 'pytorch' once the model is not + supported by 'turbomind' max_seq_len (int): The maximum allowed sequence length of a model. Note that the length of prompt + generated tokens shall not exceed this value. Defaults to 2048. @@ -45,32 +47,30 @@ class TurboMindModel(BaseModel): def __init__(self, path: str, - concurrency: int = 8, + backend: str = 'turbomind', max_seq_len: int = 2048, meta_template: Optional[Dict] = None, engine_config: Dict = {}, gen_config: Dict = {}, + batch_padding: bool = False, end_str: Optional[str] = None): super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template) - from lmdeploy.turbomind import TurboMind - from lmdeploy.version import version_info - - if engine_config is not None: - from lmdeploy.messages import TurbomindEngineConfig - engine_config = TurbomindEngineConfig(**engine_config) self.logger = get_logger() - if path.startswith('/') or path.startswith('.'): - assert os.path.exists(path), '{} is not existist'.format(path) - tm_model = TurboMind.from_pretrained(path, engine_config=engine_config) - self.tokenizer = tm_model.tokenizer - self.generators = [ - tm_model.create_instance() for i in range(concurrency) - ] - self.generator_ids = [i + 1 for i in range(concurrency)] + self.max_seq_len = _get_possible_max_seq_len(max_seq_len, path) + from lmdeploy import version_info + from transformers import AutoTokenizer + self.version_info = version_info + self.tokenizer = AutoTokenizer.from_pretrained(path, + trust_remote_code=True) + + DEFAULT_ENGING_CONFIG = {'session_len': self.max_seq_len} + _engine_config = DEFAULT_ENGING_CONFIG.copy() + _engine_config.update(engine_config) + self.pipe = self._build_pipe(path, backend, _engine_config) self.gen_config = gen_config - self.major_version, self.minor_version, _ = version_info + self.batch_padding = batch_padding self.end_str = end_str def generate(self, @@ -92,47 +92,39 @@ class TurboMindModel(BaseModel): assert isinstance( inputs, List), f'List(str) is expected, but got {type(inputs)}' - # split inputs into batches - batch_size = len(self.generators) - batch_inputs = [ - inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size) - ] + stop_words = list(set(stopping_criteria)) - gen_config = copy.deepcopy(self.gen_config) - if do_sample is not None: - if do_sample: - gen_config['top_k'] = 1000 - gen_config['temperature'] = temperature + DEFAULT_GEN_CONFIG = { + 'max_new_tokens': max_out_len, + 'min_new_tokens': 1, + 'stop_words': stop_words, + } + + gen_config = copy.deepcopy(DEFAULT_GEN_CONFIG) + gen_config.update(self.gen_config) + if do_sample: + gen_config['top_k'] = 40 + gen_config['temperature'] = temperature + else: + if self.version_info >= (0, 6, 0): + gen_config['do_sample'] = False else: gen_config['top_k'] = 1 - if stopping_criteria: - stop_words = gen_config.get('stop_words', []) - for t in stopping_criteria: - t = self.tokenizer.encode(t, add_bos=False) - stop_words.append(t[0]) - gen_config['stop_words'] = list(set(stop_words)) - gen_config.setdefault('min_new_tokens', 1) - from lmdeploy.messages import GenerationConfig + from lmdeploy import GenerationConfig + gen_config = { + k: v + for k, v in gen_config.items() if hasattr(GenerationConfig, k) + } gen_config = GenerationConfig(**gen_config) results = [] - for batch_input in batch_inputs: - with ThreadPoolExecutor() as executor: - _results = list( - executor.map( - self._generate, - self.generators[:len(batch_input)], - self.generator_ids[:len(batch_input)], - batch_input, - [max_out_len] * len(batch_input), - [gen_config] * len(batch_input), - [self.end_str] * len(batch_input), - )) - results += _results - if stopping_criteria: - for s in stopping_criteria: - results = [r.split(s)[0] for r in results] + outputs = self.pipe(inputs, gen_config=gen_config, do_preprocess=False) + for output in outputs: + text = self.tokenizer.decode(output.token_ids) + results.append(text) + for s in stop_words: + results = [r.split(s)[0] for r in results] return results def get_token_len(self, prompt: str) -> int: @@ -146,56 +138,9 @@ class TurboMindModel(BaseModel): """ return self.token_bucket.get_token() - def _generate(self, - generator, - session_id, - prompt: PromptType, - max_out_len: int, - gen_config=None, - end_str: Optional[str] = None) -> str: - """Generate results given a list of inputs. - - Args: - prompt (PromptType): A string or PromptDict. - The PromptDict should be organized in OpenCompass' - API format. - max_out_len (int): The maximum length of the output. - gen_config (GenerationConfig, optional): Generation - config to set arguments like top_k, top_p, temperature. - end_str (str, optional): Whether to trim generated strings - with end_str if the model has special ending strings - that are not handled well. - Defaults to None. - Returns: - str: The generated string. - """ - assert type( - prompt) is str, 'We only support string for TurboMind Python API' - - input_ids = self.tokenizer.encode(prompt) - - for outputs in generator.stream_infer(session_id=session_id, - input_ids=[input_ids], - gen_config=gen_config, - request_output_len=max_out_len, - sequence_start=True, - sequence_end=True, - step=0, - stream_output=False): - if self.major_version >= 0 and self.minor_version >= 4: - output_ids = outputs.token_ids - else: - _, output_ids, _ = outputs - response = self.tokenizer.decode(output_ids) - response = valid_str(response) - # used to trim - if end_str: - response = response.split(end_str)[0] - return response - def get_ppl(self, inputs: List[str], - mask_length: Optional[List[int]] = None) -> List[float]: + mask_length: Optional[List[int]] = None) -> np.ndarray: """Get perplexity scores given a list of inputs. Args: @@ -212,11 +157,28 @@ class TurboMindModel(BaseModel): assert isinstance( inputs, List), f'List(str) is expected, but got {type(inputs)}' results = [] - for text in inputs: - input_ids = self.tokenizer.encode(text) - res = self.generators[0].get_ppl(input_ids) - results.append(res) - results = np.concatenate(results) + if self.version_info <= (0, 6, 0): + for text in inputs: + input_ids = self.tokenizer.encode(text) + res = self.pipe.get_ppl(input_ids) + results.append(res) + results = np.concatenate(results) + else: + if self.batch_padding and len(inputs) > 1: + assert self.tokenizer.pad_token + input_ids = self.tokenizer( + inputs, + padding=True, + truncation=True, + max_length=self.max_seq_len)['input_ids'] + else: + input_ids = [ + self.tokenizer(text)['input_ids'] for text in inputs + ] + for i in range(0, len(input_ids), 128): + results.append(self.pipe.get_ppl(input_ids[i:i + 128])) + results = np.concatenate(results) + return results def get_loglikelihood( @@ -229,11 +191,36 @@ class TurboMindModel(BaseModel): results = [] for text, cont in zip(inputs, conts): input_ids = self.tokenizer.encode(text) - res = self.generators[0].get_ppl(input_ids) + res = self.pipe.get_ppl(input_ids) logit_sum = res * len(input_ids) input_ids = self.tokenizer.encode(text.replace(cont, '')) - res = self.generators[0].get_ppl(input_ids) + res = self.pipe.get_ppl(input_ids) logit_part = res * len(input_ids) results.append(-(logit_sum - logit_part)) results = np.concatenate(results) return results + + def _build_pipe(self, model_path, backend, engine_config): + assert backend in ['pytorch', 'turbomind'], \ + f'unsupported backend type: {backend}' + + from lmdeploy import (PytorchEngineConfig, TurbomindEngineConfig, + pipeline) + if backend == 'turbomind': + filtered = { + k: v + for k, v in engine_config.items() + if hasattr(TurbomindEngineConfig, k) + } + backend_config = TurbomindEngineConfig(**filtered) + else: + filtered = { + k: v + for k, v in engine_config.items() + if hasattr(PytorchEngineConfig, k) + } + backend_config = PytorchEngineConfig(**filtered) + return pipeline(model_path, + backend_config=backend_config, + log_level='INFO', + max_log_len=10)