From 107e022cf4729a8c2a9b8d80a3e78959fc5fa75d Mon Sep 17 00:00:00 2001 From: Yang Yong Date: Wed, 6 Mar 2024 15:33:53 +0800 Subject: [PATCH] Support prompt template for LightllmApi. Update LightllmApi token bucket. (#945) --- configs/eval_lightllm.py | 23 +++++++++++++++--- .../en/advanced_guides/evaluation_lightllm.md | 11 +++++++-- .../advanced_guides/evaluation_lightllm.md | 11 +++++++-- opencompass/models/lightllm_api.py | 24 ++++++++++--------- 4 files changed, 51 insertions(+), 18 deletions(-) diff --git a/configs/eval_lightllm.py b/configs/eval_lightllm.py index 92abf36d..d84cd2c2 100644 --- a/configs/eval_lightllm.py +++ b/configs/eval_lightllm.py @@ -5,18 +5,35 @@ from opencompass.runners import LocalRunner from opencompass.tasks import OpenICLInferTask with read_base(): + from .summarizers.leaderboard import summarizer from .datasets.humaneval.humaneval_gen import humaneval_datasets datasets = [*humaneval_datasets] +''' +# Prompt template for InternLM2-Chat +# https://github.com/InternLM/InternLM/blob/main/chat/chat_format.md + +_meta_template = dict( + begin='<|im_start|>system\nYou are InternLM2-Chat, a harmless AI assistant<|im_end|>\n', + round=[ + dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'), + dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True), + ] +) +''' + +_meta_template = None + models = [ dict( abbr='LightllmAPI', type=LightllmAPI, - url='http://localhost:8080/generate', - input_format='', - max_seq_len=2048, + url='http://localhost:1030/generate', + meta_template=_meta_template, batch_size=32, + rate_per_worker=32, + retry=4, generation_kwargs=dict( do_sample=False, ignore_eos=False, diff --git a/docs/en/advanced_guides/evaluation_lightllm.md b/docs/en/advanced_guides/evaluation_lightllm.md index 8584bf75..7064c7f0 100644 --- a/docs/en/advanced_guides/evaluation_lightllm.md +++ b/docs/en/advanced_guides/evaluation_lightllm.md @@ -19,16 +19,23 @@ We use the evaluation of Humaneval with the llama2-7B model as an example. ### Step-1: Deploy the model locally as a service using Lightllm. ```shell -python -m lightllm.server.api_server --model_dir /path/llama2-7B \ +python -m lightllm.server.api_server --model_dir /path/llama2-7B \ --host 0.0.0.0 \ - --port 8080 \ + --port 1030 \ + --nccl_port 2066 \ + --max_req_input_len 4096 \ + --max_req_total_len 6144 \ --tp 1 \ + --trust_remote_code \ --max_total_token_num 120000 ``` \*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models. + \*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible. +\*\*Note: \*\* If you want to start multiple LightLLM services on the same machine, you need to reconfigure the above port and nccl_port. + You can use the following Python script to quickly test whether the current service has been successfully started. ```python diff --git a/docs/zh_cn/advanced_guides/evaluation_lightllm.md b/docs/zh_cn/advanced_guides/evaluation_lightllm.md index b5a2489e..d1ba264b 100644 --- a/docs/zh_cn/advanced_guides/evaluation_lightllm.md +++ b/docs/zh_cn/advanced_guides/evaluation_lightllm.md @@ -19,16 +19,23 @@ ### 第一步: 将模型通过 Lightllm 在本地以服务的形式起起来 ```shell -python -m lightllm.server.api_server --model_dir /path/llama2-7B \ +python -m lightllm.server.api_server --model_dir /path/llama2-7B \ --host 0.0.0.0 \ - --port 8080 \ + --port 1030 \ + --nccl_port 2066 \ + --max_req_input_len 4096 \ + --max_req_total_len 6144 \ --tp 1 \ + --trust_remote_code \ --max_total_token_num 120000 ``` **注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。 + **注:** 上述命令中的 max_total_token_num,会影响测试过程中的吞吐性能,可以根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 上的文档,进行设置。只要不爆显存,往往设置越大越好。 +**注:** 如果要在同一个机器上起多个 Lightllm 服务,需要重新设定上面的 port 和 nccl_port。 + 可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功 ```python diff --git a/opencompass/models/lightllm_api.py b/opencompass/models/lightllm_api.py index 51705f43..3bb02229 100644 --- a/opencompass/models/lightllm_api.py +++ b/opencompass/models/lightllm_api.py @@ -8,11 +8,12 @@ import requests from opencompass.registry import MODELS from opencompass.utils.logging import get_logger -from .base_api import BaseAPIModel +from .base import BaseModel +from .base_api import TokenBucket @MODELS.register_module() -class LightllmAPI(BaseAPIModel): +class LightllmAPI(BaseModel): is_api: bool = True @@ -20,23 +21,21 @@ class LightllmAPI(BaseAPIModel): self, path: str = 'LightllmAPI', url: str = 'http://localhost:8080/generate', - input_format: str = '', - max_seq_len: int = 2048, meta_template: Optional[Dict] = None, + rate_per_worker: int = 2, retry: int = 2, generation_kwargs: Optional[Dict] = dict(), ): super().__init__(path=path, - max_seq_len=max_seq_len, meta_template=meta_template, - retry=retry, generation_kwargs=generation_kwargs) self.logger = get_logger() self.url = url - self.input_format = input_format + self.retry = retry self.generation_kwargs = generation_kwargs self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024) + self.token_bucket = TokenBucket(rate_per_worker, False) def generate(self, inputs: List[str], max_out_len: int, **kwargs) -> List[str]: @@ -64,8 +63,6 @@ class LightllmAPI(BaseAPIModel): self.wait() header = {'content-type': 'application/json'} try: - input = self.input_format.replace('', - input) data = dict(inputs=input, parameters=self.generation_kwargs) raw_response = requests.post(self.url, headers=header, @@ -118,8 +115,6 @@ class LightllmAPI(BaseAPIModel): self.wait() header = {'content-type': 'application/json'} try: - input = self.input_format.replace('', - input) data = dict(inputs=input, parameters=self.generation_kwargs) raw_response = requests.post(self.url, headers=header, @@ -156,3 +151,10 @@ class LightllmAPI(BaseAPIModel): raise RuntimeError('Calling LightllmAPI failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.') + + def wait(self): + """Wait till the next query can be sent. + + Applicable in both single-thread and multi-thread environments. + """ + return self.token_bucket.get_token()