mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Support prompt template for LightllmApi. Update LightllmApi token bucket. (#945)
This commit is contained in:
parent
c54a5d3b0f
commit
107e022cf4
@ -5,18 +5,35 @@ from opencompass.runners import LocalRunner
|
|||||||
from opencompass.tasks import OpenICLInferTask
|
from opencompass.tasks import OpenICLInferTask
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
|
from .summarizers.leaderboard import summarizer
|
||||||
from .datasets.humaneval.humaneval_gen import humaneval_datasets
|
from .datasets.humaneval.humaneval_gen import humaneval_datasets
|
||||||
|
|
||||||
datasets = [*humaneval_datasets]
|
datasets = [*humaneval_datasets]
|
||||||
|
|
||||||
|
'''
|
||||||
|
# Prompt template for InternLM2-Chat
|
||||||
|
# https://github.com/InternLM/InternLM/blob/main/chat/chat_format.md
|
||||||
|
|
||||||
|
_meta_template = dict(
|
||||||
|
begin='<|im_start|>system\nYou are InternLM2-Chat, a harmless AI assistant<|im_end|>\n',
|
||||||
|
round=[
|
||||||
|
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
|
||||||
|
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
|
||||||
|
_meta_template = None
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
dict(
|
dict(
|
||||||
abbr='LightllmAPI',
|
abbr='LightllmAPI',
|
||||||
type=LightllmAPI,
|
type=LightllmAPI,
|
||||||
url='http://localhost:8080/generate',
|
url='http://localhost:1030/generate',
|
||||||
input_format='<input_text_to_replace>',
|
meta_template=_meta_template,
|
||||||
max_seq_len=2048,
|
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
rate_per_worker=32,
|
||||||
|
retry=4,
|
||||||
generation_kwargs=dict(
|
generation_kwargs=dict(
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
ignore_eos=False,
|
ignore_eos=False,
|
||||||
|
@ -19,16 +19,23 @@ We use the evaluation of Humaneval with the llama2-7B model as an example.
|
|||||||
### Step-1: Deploy the model locally as a service using Lightllm.
|
### Step-1: Deploy the model locally as a service using Lightllm.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
--port 8080 \
|
--port 1030 \
|
||||||
|
--nccl_port 2066 \
|
||||||
|
--max_req_input_len 4096 \
|
||||||
|
--max_req_total_len 6144 \
|
||||||
--tp 1 \
|
--tp 1 \
|
||||||
|
--trust_remote_code \
|
||||||
--max_total_token_num 120000
|
--max_total_token_num 120000
|
||||||
```
|
```
|
||||||
|
|
||||||
\*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.
|
\*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.
|
||||||
|
|
||||||
\*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible.
|
\*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible.
|
||||||
|
|
||||||
|
\*\*Note: \*\* If you want to start multiple LightLLM services on the same machine, you need to reconfigure the above port and nccl_port.
|
||||||
|
|
||||||
You can use the following Python script to quickly test whether the current service has been successfully started.
|
You can use the following Python script to quickly test whether the current service has been successfully started.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -19,16 +19,23 @@
|
|||||||
### 第一步: 将模型通过 Lightllm 在本地以服务的形式起起来
|
### 第一步: 将模型通过 Lightllm 在本地以服务的形式起起来
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
--port 8080 \
|
--port 1030 \
|
||||||
|
--nccl_port 2066 \
|
||||||
|
--max_req_input_len 4096 \
|
||||||
|
--max_req_total_len 6144 \
|
||||||
--tp 1 \
|
--tp 1 \
|
||||||
|
--trust_remote_code \
|
||||||
--max_total_token_num 120000
|
--max_total_token_num 120000
|
||||||
```
|
```
|
||||||
|
|
||||||
**注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。
|
**注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。
|
||||||
|
|
||||||
**注:** 上述命令中的 max_total_token_num,会影响测试过程中的吞吐性能,可以根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 上的文档,进行设置。只要不爆显存,往往设置越大越好。
|
**注:** 上述命令中的 max_total_token_num,会影响测试过程中的吞吐性能,可以根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 上的文档,进行设置。只要不爆显存,往往设置越大越好。
|
||||||
|
|
||||||
|
**注:** 如果要在同一个机器上起多个 Lightllm 服务,需要重新设定上面的 port 和 nccl_port。
|
||||||
|
|
||||||
可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功
|
可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -8,11 +8,12 @@ import requests
|
|||||||
from opencompass.registry import MODELS
|
from opencompass.registry import MODELS
|
||||||
from opencompass.utils.logging import get_logger
|
from opencompass.utils.logging import get_logger
|
||||||
|
|
||||||
from .base_api import BaseAPIModel
|
from .base import BaseModel
|
||||||
|
from .base_api import TokenBucket
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class LightllmAPI(BaseAPIModel):
|
class LightllmAPI(BaseModel):
|
||||||
|
|
||||||
is_api: bool = True
|
is_api: bool = True
|
||||||
|
|
||||||
@ -20,23 +21,21 @@ class LightllmAPI(BaseAPIModel):
|
|||||||
self,
|
self,
|
||||||
path: str = 'LightllmAPI',
|
path: str = 'LightllmAPI',
|
||||||
url: str = 'http://localhost:8080/generate',
|
url: str = 'http://localhost:8080/generate',
|
||||||
input_format: str = '<input_text_to_replace>',
|
|
||||||
max_seq_len: int = 2048,
|
|
||||||
meta_template: Optional[Dict] = None,
|
meta_template: Optional[Dict] = None,
|
||||||
|
rate_per_worker: int = 2,
|
||||||
retry: int = 2,
|
retry: int = 2,
|
||||||
generation_kwargs: Optional[Dict] = dict(),
|
generation_kwargs: Optional[Dict] = dict(),
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(path=path,
|
super().__init__(path=path,
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
meta_template=meta_template,
|
meta_template=meta_template,
|
||||||
retry=retry,
|
|
||||||
generation_kwargs=generation_kwargs)
|
generation_kwargs=generation_kwargs)
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
self.url = url
|
self.url = url
|
||||||
self.input_format = input_format
|
self.retry = retry
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
|
||||||
|
self.token_bucket = TokenBucket(rate_per_worker, False)
|
||||||
|
|
||||||
def generate(self, inputs: List[str], max_out_len: int,
|
def generate(self, inputs: List[str], max_out_len: int,
|
||||||
**kwargs) -> List[str]:
|
**kwargs) -> List[str]:
|
||||||
@ -64,8 +63,6 @@ class LightllmAPI(BaseAPIModel):
|
|||||||
self.wait()
|
self.wait()
|
||||||
header = {'content-type': 'application/json'}
|
header = {'content-type': 'application/json'}
|
||||||
try:
|
try:
|
||||||
input = self.input_format.replace('<input_text_to_replace>',
|
|
||||||
input)
|
|
||||||
data = dict(inputs=input, parameters=self.generation_kwargs)
|
data = dict(inputs=input, parameters=self.generation_kwargs)
|
||||||
raw_response = requests.post(self.url,
|
raw_response = requests.post(self.url,
|
||||||
headers=header,
|
headers=header,
|
||||||
@ -118,8 +115,6 @@ class LightllmAPI(BaseAPIModel):
|
|||||||
self.wait()
|
self.wait()
|
||||||
header = {'content-type': 'application/json'}
|
header = {'content-type': 'application/json'}
|
||||||
try:
|
try:
|
||||||
input = self.input_format.replace('<input_text_to_replace>',
|
|
||||||
input)
|
|
||||||
data = dict(inputs=input, parameters=self.generation_kwargs)
|
data = dict(inputs=input, parameters=self.generation_kwargs)
|
||||||
raw_response = requests.post(self.url,
|
raw_response = requests.post(self.url,
|
||||||
headers=header,
|
headers=header,
|
||||||
@ -156,3 +151,10 @@ class LightllmAPI(BaseAPIModel):
|
|||||||
raise RuntimeError('Calling LightllmAPI failed after retrying for '
|
raise RuntimeError('Calling LightllmAPI failed after retrying for '
|
||||||
f'{max_num_retries} times. Check the logs for '
|
f'{max_num_retries} times. Check the logs for '
|
||||||
'details.')
|
'details.')
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
"""Wait till the next query can be sent.
|
||||||
|
|
||||||
|
Applicable in both single-thread and multi-thread environments.
|
||||||
|
"""
|
||||||
|
return self.token_bucket.get_token()
|
||||||
|
Loading…
Reference in New Issue
Block a user