Support prompt template for LightllmApi. Update LightllmApi token bucket. (#945)

This commit is contained in:
Yang Yong 2024-03-06 15:33:53 +08:00 committed by GitHub
parent c54a5d3b0f
commit 107e022cf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 18 deletions

View File

@ -5,18 +5,35 @@ from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask
with read_base():
from .summarizers.leaderboard import summarizer
from .datasets.humaneval.humaneval_gen import humaneval_datasets
datasets = [*humaneval_datasets]
'''
# Prompt template for InternLM2-Chat
# https://github.com/InternLM/InternLM/blob/main/chat/chat_format.md
_meta_template = dict(
begin='<|im_start|>system\nYou are InternLM2-Chat, a harmless AI assistant<|im_end|>\n',
round=[
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
]
)
'''
_meta_template = None
models = [
dict(
abbr='LightllmAPI',
type=LightllmAPI,
url='http://localhost:8080/generate',
input_format='<input_text_to_replace>',
max_seq_len=2048,
url='http://localhost:1030/generate',
meta_template=_meta_template,
batch_size=32,
rate_per_worker=32,
retry=4,
generation_kwargs=dict(
do_sample=False,
ignore_eos=False,

View File

@ -19,16 +19,23 @@ We use the evaluation of Humaneval with the llama2-7B model as an example.
### Step-1: Deploy the model locally as a service using Lightllm.
```shell
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
--host 0.0.0.0 \
--port 8080 \
--port 1030 \
--nccl_port 2066 \
--max_req_input_len 4096 \
--max_req_total_len 6144 \
--tp 1 \
--trust_remote_code \
--max_total_token_num 120000
```
\*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.
\*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible.
\*\*Note: \*\* If you want to start multiple LightLLM services on the same machine, you need to reconfigure the above port and nccl_port.
You can use the following Python script to quickly test whether the current service has been successfully started.
```python

View File

@ -19,16 +19,23 @@
### 第一步: 将模型通过 Lightllm 在本地以服务的形式起起来
```shell
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
--host 0.0.0.0 \
--port 8080 \
--port 1030 \
--nccl_port 2066 \
--max_req_input_len 4096 \
--max_req_total_len 6144 \
--tp 1 \
--trust_remote_code \
--max_total_token_num 120000
```
**注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。
**注:** 上述命令中的 max_total_token_num会影响测试过程中的吞吐性能可以根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 上的文档,进行设置。只要不爆显存,往往设置越大越好。
**注:** 如果要在同一个机器上起多个 Lightllm 服务,需要重新设定上面的 port 和 nccl_port。
可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功
```python

View File

@ -8,11 +8,12 @@ import requests
from opencompass.registry import MODELS
from opencompass.utils.logging import get_logger
from .base_api import BaseAPIModel
from .base import BaseModel
from .base_api import TokenBucket
@MODELS.register_module()
class LightllmAPI(BaseAPIModel):
class LightllmAPI(BaseModel):
is_api: bool = True
@ -20,23 +21,21 @@ class LightllmAPI(BaseAPIModel):
self,
path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate',
input_format: str = '<input_text_to_replace>',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
rate_per_worker: int = 2,
retry: int = 2,
generation_kwargs: Optional[Dict] = dict(),
):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
retry=retry,
generation_kwargs=generation_kwargs)
self.logger = get_logger()
self.url = url
self.input_format = input_format
self.retry = retry
self.generation_kwargs = generation_kwargs
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
self.token_bucket = TokenBucket(rate_per_worker, False)
def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
@ -64,8 +63,6 @@ class LightllmAPI(BaseAPIModel):
self.wait()
header = {'content-type': 'application/json'}
try:
input = self.input_format.replace('<input_text_to_replace>',
input)
data = dict(inputs=input, parameters=self.generation_kwargs)
raw_response = requests.post(self.url,
headers=header,
@ -118,8 +115,6 @@ class LightllmAPI(BaseAPIModel):
self.wait()
header = {'content-type': 'application/json'}
try:
input = self.input_format.replace('<input_text_to_replace>',
input)
data = dict(inputs=input, parameters=self.generation_kwargs)
raw_response = requests.post(self.url,
headers=header,
@ -156,3 +151,10 @@ class LightllmAPI(BaseAPIModel):
raise RuntimeError('Calling LightllmAPI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')
def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()