From 3f36db3b0688fd1d510d94750514195a9a51cfff Mon Sep 17 00:00:00 2001 From: Songyang Zhang Date: Thu, 10 Aug 2023 16:25:11 +0800 Subject: [PATCH] [Feature] Support turbomind (#166) * support turbomind * update doc * Update docs/en/advanced_guides/evaluation_turbomind.md Co-authored-by: Tong Gao * Update docs/zh_cn/advanced_guides/evaluation_turbomind.md Co-authored-by: Tong Gao * Update docs/zh_cn/advanced_guides/evaluation_turbomind.md Co-authored-by: Tong Gao * Update docs/en/advanced_guides/evaluation_turbomind.md Co-authored-by: Tong Gao * update --------- Co-authored-by: Tong Gao --- configs/eval_internlm_chat_7b_turbomind.py | 32 ++++ configs/models/hf_internlm_chat_7b.py | 6 +- .../advanced_guides/evaluation_turbomind.md | 55 ++++++ docs/en/index.rst | 1 + .../advanced_guides/evaluation_turbomind.md | 55 ++++++ docs/zh_cn/index.rst | 1 + opencompass/models/turbomind.py | 161 ++++++++++++++++++ 7 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 configs/eval_internlm_chat_7b_turbomind.py create mode 100644 docs/en/advanced_guides/evaluation_turbomind.md create mode 100644 docs/zh_cn/advanced_guides/evaluation_turbomind.md create mode 100644 opencompass/models/turbomind.py diff --git a/configs/eval_internlm_chat_7b_turbomind.py b/configs/eval_internlm_chat_7b_turbomind.py new file mode 100644 index 00000000..aaddc4a5 --- /dev/null +++ b/configs/eval_internlm_chat_7b_turbomind.py @@ -0,0 +1,32 @@ +from mmengine.config import read_base +from opencompass.models.turbomind import TurboMindModel + +with read_base(): + # choose a list of datasets + from .datasets.SuperGLUE_CB.SuperGLUE_CB_gen import CB_datasets + # and output the results in a choosen format + from .summarizers.medium import summarizer + +datasets = [*CB_datasets] + + +_meta_template = dict( + round=[ + dict(role='HUMAN', begin='<|User|>:', end='\n'), + dict(role='BOT', begin='<|Bot|>:', end='\n', generate=True), + ], +) + +models = [ + dict( + type=TurboMindModel, + abbr='internlm-chat-7b-tb', + path="internlm-chat-7b", + model_path='./workspace', + max_out_len=100, + max_seq_len=2048, + batch_size=16, + meta_template=_meta_template, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/hf_internlm_chat_7b.py b/configs/models/hf_internlm_chat_7b.py index 2d526f60..0d0fc61f 100644 --- a/configs/models/hf_internlm_chat_7b.py +++ b/configs/models/hf_internlm_chat_7b.py @@ -19,12 +19,16 @@ models = [ truncation_side='left', use_fast=False, trust_remote_code=True, + revision="1a6328795c6e207904e1eb58177e03ad24ae06f3" ), max_out_len=100, max_seq_len=2048, batch_size=8, meta_template=_meta_template, - model_kwargs=dict(trust_remote_code=True, device_map='auto'), + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + revision="1a6328795c6e207904e1eb58177e03ad24ae06f3"), run_cfg=dict(num_gpus=1, num_procs=1), ) ] diff --git a/docs/en/advanced_guides/evaluation_turbomind.md b/docs/en/advanced_guides/evaluation_turbomind.md new file mode 100644 index 00000000..a2f63b8a --- /dev/null +++ b/docs/en/advanced_guides/evaluation_turbomind.md @@ -0,0 +1,55 @@ +# Evaluation with LMDeploy + +We now support evaluation of models accelerated by the [LMDeploy](https://github.com/InternLM/lmdeploy). LMDeploy is a toolkit designed for compressing, deploying, and serving LLM. **TurboMind** is an efficient inference engine proposed by LMDeploy. OpenCompass is compatible with TurboMind. We now illustrate how to evaluate a model with the support of TurboMind in OpenCompass. + +# Setup + +## Install OpenCompass + +Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started.html) to install the OpenCompass and prepare the evaluation datasets. + +## Install LMDeploy + +Install lmdeploy via pip (python 3.8+) + +```shell +pip install lmdeploy +``` + +# Evaluation + +We take the InternLM as example. + +## Step-1: Get InternLM model + +```shell +# 1. Download InternLM model(or use the cached model's checkpoint) + +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/internlm/internlm-chat-7b /path/to/internlm-chat-7b + +# if you want to clone without large files – just their pointers +# prepend your git clone with the following env var: +GIT_LFS_SKIP_SMUDGE=1 + +# 2. Convert InternLM model to turbomind's format, which will be in "./workspace" by default +python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-chat-7b + +``` + +## Step-2: Verify the Converted Model + +```shell +python -m lmdeploy.turbomind.chat ./workspace +``` + +## Step-3: Evaluate the Converted Model + +In the home folder of OpenCompass + +```shell +python run.py configs/eval_internlm_chat_7b_turbomind.py -w outputs/turbomind +``` + +You are expected to get the evaluation results after the inference and evaluation. diff --git a/docs/en/index.rst b/docs/en/index.rst index b49c836c..4293c59b 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -44,6 +44,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass. advanced_guides/new_dataset.md advanced_guides/new_model.md + advanced_guides/evaluation_turbomind.md .. _Prompt: .. toctree:: diff --git a/docs/zh_cn/advanced_guides/evaluation_turbomind.md b/docs/zh_cn/advanced_guides/evaluation_turbomind.md new file mode 100644 index 00000000..e3901b69 --- /dev/null +++ b/docs/zh_cn/advanced_guides/evaluation_turbomind.md @@ -0,0 +1,55 @@ +# 评测LMDeploy模型 + +我们支持评测使用[LMDeploy](https://github.com/InternLM/lmdeploy)加速过的大语言模型。LMDeploy 由 MMDeploy 和 MMRazor 团队联合开发,是涵盖了 LLM 任务的全套轻量化、部署和服务解决方案。 **TurboMind** 是 LMDeploy 推出的高效推理引擎。OpenCompass 对 TurboMind 进行了适配,本教程将介绍如何使用 OpenCompass 来对 TurboMind 加速后的模型进行评测。 + +# 环境配置 + +## 安装OpenCompass + +请根据OpenCompass[安装指南](https://opencompass.readthedocs.io/en/latest/get_started.html) 来安装算法库和准备数据集。 + +## 安装LMDeploy + +使用pip安装LMDeploy( python 3.8+) + +```shell +pip install lmdeploy +``` + +# 评测 + +我们使用InternLM作为例子来介绍如何评测 + +## 第一步: 获取InternLM模型 + +```shell +# 1. Download InternLM model(or use the cached model's checkpoint) + +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/internlm/internlm-chat-7b /path/to/internlm-chat-7b + +# if you want to clone without large files – just their pointers +# prepend your git clone with the following env var: +GIT_LFS_SKIP_SMUDGE=1 + +# 2. Convert InternLM model to turbomind's format, which will be in "./workspace" by default +python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-chat-7b + +``` + +## 第二步: 验证转换后的模型 + +```shell +python -m lmdeploy.turbomind.chat ./workspace +``` + +## 第三步: 评测转换后的模型 + +在OpenCompass项目文件执行: + +```shell +python run.py configs/eval_internlm_chat_7b_turbomind.py -w outputs/turbomind +``` + +当模型完成推理和指标计算后,我们便可获得模型的评测结果 diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 746fd360..ae46fca8 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -54,6 +54,7 @@ OpenCompass 上手路线 advanced_guides/new_dataset.md advanced_guides/new_model.md + advanced_guides/evaluation_turbomind.md .. _工具: .. toctree:: diff --git a/opencompass/models/turbomind.py b/opencompass/models/turbomind.py new file mode 100644 index 00000000..8d87906e --- /dev/null +++ b/opencompass/models/turbomind.py @@ -0,0 +1,161 @@ +import os.path as osp +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.models.base import BaseModel +from opencompass.models.base_api import TokenBucket +from opencompass.utils.logging import get_logger +from opencompass.utils.prompt import PromptList + +PromptType = Union[PromptList, str] + + +def valid_str(string, coding='utf-8'): + """decode text according to its encoding type.""" + invalid_chars = [b'\xef\xbf\xbd'] + bstr = bytes(string, coding) + for invalid_char in invalid_chars: + bstr = bstr.replace(invalid_char, b'') + ret = bstr.decode(encoding=coding, errors='ignore') + return ret + + +class TurboMindModel(BaseModel): + """Model wrapper for TurboMind API. + + Args: + path (str): The name of OpenAI's model. + model_path (str): folder of the turbomind model's path + max_seq_len (int): The maximum allowed sequence length of a model. + Note that the length of prompt + generated tokens shall not exceed + this value. Defaults to 2048. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + retry (int): Number of retires if the API call fails. Defaults to 2. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + is_api: bool = True + + def __init__( + self, + path: str, + model_path: str, + max_seq_len: int = 2048, + query_per_second: int = 1, + retry: int = 2, + meta_template: Optional[Dict] = None, + ): + + super().__init__(path=path, + max_seq_len=max_seq_len, + meta_template=meta_template) + self.logger = get_logger() + + from lmdeploy import turbomind as tm + from lmdeploy.model import MODELS as LMMODELS + from lmdeploy.turbomind.tokenizer import Tokenizer as LMTokenizer + + self.retry = retry + + tokenizer_model_path = osp.join(model_path, 'triton_models', + 'tokenizer') + self.tokenizer = LMTokenizer(tokenizer_model_path) + tm_model = tm.TurboMind(model_path, eos_id=self.tokenizer.eos_token_id) + self.model_name = tm_model.model_name + self.model = LMMODELS.get(self.model_name)() + self.generator = tm_model.create_instance() + self.token_bucket = TokenBucket(query_per_second) + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + temperature: float = 0.0, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + temperature (float): What sampling temperature to use, + between 0 and 2. Higher values like 0.8 will make the output + more random, while lower values like 0.2 will make it more + focused and deterministic. Defaults to 0.7. + + Returns: + List[str]: A list of generated strings. + """ + prompts = inputs + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, prompts, + [max_out_len] * len(inputs), + [temperature] * len(inputs))) + return results + + def wait(self): + """Wait till the next query can be sent. + + Applicable in both single-thread and multi-thread environments. + """ + return self.token_bucket.get_token() + + def _generate(self, input: str or PromptList, max_out_len: int, + temperature: float) -> str: + """Generate results given a list of inputs. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + temperature (float): What sampling temperature to use, + between 0 and 2. Higher values like 0.8 will make the output + more random, while lower values like 0.2 will make it more + focused and deterministic. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + assert type( + input + ) is str, 'We only support string for TurboMind Python API now' + + intput_token_ids = self.tokenizer.encode(input) + + for _ in range(self.retry): + self.wait() + session_id = random.randint(1, 100000) + nth_round = 0 + for outputs in self.generator.stream_infer( + session_id=session_id, + input_ids=[intput_token_ids], + stream_output=False, + request_output_len=max_out_len, + sequence_start=(nth_round == 0), + sequence_end=False, + step=0, + stop=False, + top_k=40, + top_p=0.8, + temperature=temperature, + repetition_penalty=1.0, + ignore_eos=False, + random_seed=random.getrandbits(64) + if nth_round == 0 else None): + pass + + output_token_ids, _ = outputs[0] + # decode output_token_ids + response = self.tokenizer.decode(output_token_ids) + response = valid_str(response) + + return response