mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Support prompt template for LightllmApi. Update LightllmApi token bucket. (#945)
This commit is contained in:
parent
c54a5d3b0f
commit
107e022cf4
@ -5,18 +5,35 @@ from opencompass.runners import LocalRunner
|
||||
from opencompass.tasks import OpenICLInferTask
|
||||
|
||||
with read_base():
|
||||
from .summarizers.leaderboard import summarizer
|
||||
from .datasets.humaneval.humaneval_gen import humaneval_datasets
|
||||
|
||||
datasets = [*humaneval_datasets]
|
||||
|
||||
'''
|
||||
# Prompt template for InternLM2-Chat
|
||||
# https://github.com/InternLM/InternLM/blob/main/chat/chat_format.md
|
||||
|
||||
_meta_template = dict(
|
||||
begin='<|im_start|>system\nYou are InternLM2-Chat, a harmless AI assistant<|im_end|>\n',
|
||||
round=[
|
||||
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
|
||||
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
|
||||
]
|
||||
)
|
||||
'''
|
||||
|
||||
_meta_template = None
|
||||
|
||||
models = [
|
||||
dict(
|
||||
abbr='LightllmAPI',
|
||||
type=LightllmAPI,
|
||||
url='http://localhost:8080/generate',
|
||||
input_format='<input_text_to_replace>',
|
||||
max_seq_len=2048,
|
||||
url='http://localhost:1030/generate',
|
||||
meta_template=_meta_template,
|
||||
batch_size=32,
|
||||
rate_per_worker=32,
|
||||
retry=4,
|
||||
generation_kwargs=dict(
|
||||
do_sample=False,
|
||||
ignore_eos=False,
|
||||
|
@ -19,16 +19,23 @@ We use the evaluation of Humaneval with the llama2-7B model as an example.
|
||||
### Step-1: Deploy the model locally as a service using Lightllm.
|
||||
|
||||
```shell
|
||||
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
||||
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
--port 1030 \
|
||||
--nccl_port 2066 \
|
||||
--max_req_input_len 4096 \
|
||||
--max_req_total_len 6144 \
|
||||
--tp 1 \
|
||||
--trust_remote_code \
|
||||
--max_total_token_num 120000
|
||||
```
|
||||
|
||||
\*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.
|
||||
|
||||
\*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible.
|
||||
|
||||
\*\*Note: \*\* If you want to start multiple LightLLM services on the same machine, you need to reconfigure the above port and nccl_port.
|
||||
|
||||
You can use the following Python script to quickly test whether the current service has been successfully started.
|
||||
|
||||
```python
|
||||
|
@ -19,16 +19,23 @@
|
||||
### 第一步: 将模型通过 Lightllm 在本地以服务的形式起起来
|
||||
|
||||
```shell
|
||||
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
||||
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
--port 1030 \
|
||||
--nccl_port 2066 \
|
||||
--max_req_input_len 4096 \
|
||||
--max_req_total_len 6144 \
|
||||
--tp 1 \
|
||||
--trust_remote_code \
|
||||
--max_total_token_num 120000
|
||||
```
|
||||
|
||||
**注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。
|
||||
|
||||
**注:** 上述命令中的 max_total_token_num,会影响测试过程中的吞吐性能,可以根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 上的文档,进行设置。只要不爆显存,往往设置越大越好。
|
||||
|
||||
**注:** 如果要在同一个机器上起多个 Lightllm 服务,需要重新设定上面的 port 和 nccl_port。
|
||||
|
||||
可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功
|
||||
|
||||
```python
|
||||
|
@ -8,11 +8,12 @@ import requests
|
||||
from opencompass.registry import MODELS
|
||||
from opencompass.utils.logging import get_logger
|
||||
|
||||
from .base_api import BaseAPIModel
|
||||
from .base import BaseModel
|
||||
from .base_api import TokenBucket
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LightllmAPI(BaseAPIModel):
|
||||
class LightllmAPI(BaseModel):
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
@ -20,23 +21,21 @@ class LightllmAPI(BaseAPIModel):
|
||||
self,
|
||||
path: str = 'LightllmAPI',
|
||||
url: str = 'http://localhost:8080/generate',
|
||||
input_format: str = '<input_text_to_replace>',
|
||||
max_seq_len: int = 2048,
|
||||
meta_template: Optional[Dict] = None,
|
||||
rate_per_worker: int = 2,
|
||||
retry: int = 2,
|
||||
generation_kwargs: Optional[Dict] = dict(),
|
||||
):
|
||||
|
||||
super().__init__(path=path,
|
||||
max_seq_len=max_seq_len,
|
||||
meta_template=meta_template,
|
||||
retry=retry,
|
||||
generation_kwargs=generation_kwargs)
|
||||
self.logger = get_logger()
|
||||
self.url = url
|
||||
self.input_format = input_format
|
||||
self.retry = retry
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
||||
self.token_bucket = TokenBucket(rate_per_worker, False)
|
||||
|
||||
def generate(self, inputs: List[str], max_out_len: int,
|
||||
**kwargs) -> List[str]:
|
||||
@ -64,8 +63,6 @@ class LightllmAPI(BaseAPIModel):
|
||||
self.wait()
|
||||
header = {'content-type': 'application/json'}
|
||||
try:
|
||||
input = self.input_format.replace('<input_text_to_replace>',
|
||||
input)
|
||||
data = dict(inputs=input, parameters=self.generation_kwargs)
|
||||
raw_response = requests.post(self.url,
|
||||
headers=header,
|
||||
@ -118,8 +115,6 @@ class LightllmAPI(BaseAPIModel):
|
||||
self.wait()
|
||||
header = {'content-type': 'application/json'}
|
||||
try:
|
||||
input = self.input_format.replace('<input_text_to_replace>',
|
||||
input)
|
||||
data = dict(inputs=input, parameters=self.generation_kwargs)
|
||||
raw_response = requests.post(self.url,
|
||||
headers=header,
|
||||
@ -156,3 +151,10 @@ class LightllmAPI(BaseAPIModel):
|
||||
raise RuntimeError('Calling LightllmAPI failed after retrying for '
|
||||
f'{max_num_retries} times. Check the logs for '
|
||||
'details.')
|
||||
|
||||
def wait(self):
|
||||
"""Wait till the next query can be sent.
|
||||
|
||||
Applicable in both single-thread and multi-thread environments.
|
||||
"""
|
||||
return self.token_bucket.get_token()
|
||||
|
Loading…
Reference in New Issue
Block a user