mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] local api speed up with fixed concurrent users (#497)
* [Feat] local api speed up * fix lint * fix lint * minor fix * add example api
This commit is contained in:
parent
44c8d6cc60
commit
ac3a2c4501
@ -7,3 +7,4 @@ from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
|
|||||||
from .intern_model import InternLM # noqa: F401, F403
|
from .intern_model import InternLM # noqa: F401, F403
|
||||||
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
|
||||||
from .openai_api import OpenAI # noqa: F401
|
from .openai_api import OpenAI # noqa: F401
|
||||||
|
from .zhipuai import ZhiPuAI # noqa: F401
|
||||||
|
159
opencompass/models/zhipuai.py
Normal file
159
opencompass/models/zhipuai.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
import sys
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from opencompass.registry import MODELS
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
|
||||||
|
from .base_api import BaseAPIModel
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class ZhiPuAI(BaseAPIModel):
|
||||||
|
"""Model wrapper around ZhiPuAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The name of OpenAI's model.
|
||||||
|
key (str): Authorization key.
|
||||||
|
query_per_second (int): The maximum queries allowed per second
|
||||||
|
between two consecutive calls of the API. Defaults to 1.
|
||||||
|
max_seq_len (int): Unused here.
|
||||||
|
meta_template (Dict, optional): The model's meta prompt
|
||||||
|
template if needed, in case the requirement of injecting or
|
||||||
|
wrapping of any meta instructions.
|
||||||
|
retry (int): Number of retires if the API call fails. Defaults to 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
key: str,
|
||||||
|
query_per_second: int = 2,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
meta_template: Optional[Dict] = None,
|
||||||
|
retry: int = 2,
|
||||||
|
):
|
||||||
|
super().__init__(path=path,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
meta_template=meta_template,
|
||||||
|
retry=retry)
|
||||||
|
import zhipuai
|
||||||
|
self.zhipuai = zhipuai
|
||||||
|
self.zhipuai.api_key = key
|
||||||
|
self.model = path
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
inputs: List[str or PromptList],
|
||||||
|
max_out_len: int = 512,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = list(
|
||||||
|
executor.map(self._generate, inputs,
|
||||||
|
[max_out_len] * len(inputs)))
|
||||||
|
self.flush()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
"""Flush stdout and stderr when concurrent resources exists.
|
||||||
|
|
||||||
|
When use multiproessing with standard io rediected to files, need to
|
||||||
|
flush internal information for examination or log loss when system
|
||||||
|
breaks.
|
||||||
|
"""
|
||||||
|
if hasattr(self, 'tokens'):
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
def acquire(self):
|
||||||
|
"""Acquire concurrent resources if exists.
|
||||||
|
|
||||||
|
This behavior will fall back to wait with query_per_second if there are
|
||||||
|
no concurrent resources.
|
||||||
|
"""
|
||||||
|
if hasattr(self, 'tokens'):
|
||||||
|
self.tokens.acquire()
|
||||||
|
else:
|
||||||
|
self.wait()
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
"""Release concurrent resources if acquired.
|
||||||
|
|
||||||
|
This behavior will fall back to do nothing if there are no concurrent
|
||||||
|
resources.
|
||||||
|
"""
|
||||||
|
if hasattr(self, 'tokens'):
|
||||||
|
self.tokens.release()
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
input: str or PromptList,
|
||||||
|
max_out_len: int = 512,
|
||||||
|
) -> str:
|
||||||
|
"""Generate results given an input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (str or PromptList): A string or PromptDict.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated string.
|
||||||
|
"""
|
||||||
|
assert isinstance(input, (str, PromptList))
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [{'role': 'user', 'content': input}]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
for item in input:
|
||||||
|
msg = {'content': item['prompt']}
|
||||||
|
if item['role'] == 'HUMAN':
|
||||||
|
msg['role'] = 'user'
|
||||||
|
elif item['role'] == 'BOT':
|
||||||
|
msg['role'] = 'assistant'
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
data = {'model': self.model, 'prompt': messages}
|
||||||
|
|
||||||
|
max_num_retries = 0
|
||||||
|
while max_num_retries < self.retry:
|
||||||
|
self.acquire()
|
||||||
|
response = self.zhipuai.model_api.invoke(**data)
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
print('Connection error, reconnect.')
|
||||||
|
# if connect error, frequent requests will casuse
|
||||||
|
# continuous unstable network, therefore wait here
|
||||||
|
# to slow down the request
|
||||||
|
self.wait()
|
||||||
|
continue
|
||||||
|
if response['code'] == 200 and response['success']:
|
||||||
|
msg = response['data']['choices'][0]['content']
|
||||||
|
return msg
|
||||||
|
# sensitive content, prompt overlength, network error
|
||||||
|
# or illegal prompt
|
||||||
|
if (response['code'] == 1301 or response['code'] == 1261
|
||||||
|
or response['code'] == 1234 or response['code'] == 1214):
|
||||||
|
print(response['msg'])
|
||||||
|
return ''
|
||||||
|
print(response)
|
||||||
|
max_num_retries += 1
|
||||||
|
|
||||||
|
raise RuntimeError(response['msg'])
|
242
opencompass/runners/local_api.py
Normal file
242
opencompass/runners/local_api.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from multiprocessing import Manager, Pool
|
||||||
|
from multiprocessing.managers import SyncManager
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
from mmengine.config import ConfigDict
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from opencompass.registry import RUNNERS, TASKS
|
||||||
|
from opencompass.tasks import OpenICLInferTask
|
||||||
|
from opencompass.tasks.base import BaseTask
|
||||||
|
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
|
||||||
|
get_infer_output_path, get_logger,
|
||||||
|
task_abbr_from_cfg)
|
||||||
|
|
||||||
|
from .base import BaseRunner
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_run(self, tokens: SyncManager.Semaphore):
|
||||||
|
"""Hack for infer task run, add tokens for multiprocess."""
|
||||||
|
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}')
|
||||||
|
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
|
||||||
|
self.max_out_len = model_cfg.get('max_out_len', None)
|
||||||
|
self.batch_size = model_cfg.get('batch_size', None)
|
||||||
|
self.model = build_model_from_cfg(model_cfg)
|
||||||
|
# add global tokens for concurrents
|
||||||
|
assert self.model.is_api, 'Only API model is supported.'
|
||||||
|
self.model.tokens = tokens
|
||||||
|
|
||||||
|
for dataset_cfg in dataset_cfgs:
|
||||||
|
self.model_cfg = model_cfg
|
||||||
|
self.dataset_cfg = dataset_cfg
|
||||||
|
self.infer_cfg = self.dataset_cfg['infer_cfg']
|
||||||
|
self.dataset = build_dataset_from_cfg(self.dataset_cfg)
|
||||||
|
self.sub_cfg = {
|
||||||
|
'models': [self.model_cfg],
|
||||||
|
'datasets': [[self.dataset_cfg]],
|
||||||
|
}
|
||||||
|
out_path = get_infer_output_path(
|
||||||
|
self.model_cfg, self.dataset_cfg,
|
||||||
|
osp.join(self.work_dir, 'predictions'))
|
||||||
|
if osp.exists(out_path):
|
||||||
|
continue
|
||||||
|
self._inference()
|
||||||
|
|
||||||
|
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
old_stderr = sys.stderr
|
||||||
|
|
||||||
|
|
||||||
|
def redirect_std_to_file(filename: str):
|
||||||
|
"""Redirect stdout and stderr, also change logger stream handler."""
|
||||||
|
f = open(filename, 'w', encoding='utf-8')
|
||||||
|
sys.stdout = f
|
||||||
|
sys.stderr = f
|
||||||
|
# change logger stream handler as well
|
||||||
|
logger = get_logger()
|
||||||
|
for h in logger.handlers:
|
||||||
|
if isinstance(h, logging.StreamHandler):
|
||||||
|
h.stream = sys.stdout
|
||||||
|
# special treat for icl_gen_inferencer logger
|
||||||
|
gen_logger = logging.getLogger(
|
||||||
|
'opencompass.openicl.icl_inferencer.icl_gen_inferencer')
|
||||||
|
for h in gen_logger.handlers:
|
||||||
|
if isinstance(h, logging.StreamHandler):
|
||||||
|
h.stream = sys.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def reset_std():
|
||||||
|
"""Reset stdout and stderr, also change logger stream handler."""
|
||||||
|
sys.stdout.close()
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
sys.stderr = old_stderr
|
||||||
|
# change logger stream handler as well
|
||||||
|
logger = get_logger()
|
||||||
|
for h in logger.handlers:
|
||||||
|
if isinstance(h, logging.StreamHandler):
|
||||||
|
h.stream = sys.stdout
|
||||||
|
# special treat for icl_gen_inferencer logger
|
||||||
|
gen_logger = logging.getLogger(
|
||||||
|
'opencompass.openicl.icl_inferencer.icl_gen_inferencer')
|
||||||
|
for h in gen_logger.handlers:
|
||||||
|
if isinstance(h, logging.StreamHandler):
|
||||||
|
h.stream = sys.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def launch(task: BaseTask, tokens: SyncManager.Semaphore):
|
||||||
|
"""Launch a single task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (BaseTask): Task to launch.
|
||||||
|
tokens (SyncManager.Semaphore): Multiprocessing semaphore
|
||||||
|
for every subprocess to follow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, int]: Task name and exit code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_name = task.name
|
||||||
|
returncode = 0
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get log file and redirect stdout and stderr
|
||||||
|
out_path = task.get_log_path(file_extension='out')
|
||||||
|
mmengine.mkdir_or_exist(osp.split(out_path)[0])
|
||||||
|
redirect_std_to_file(out_path)
|
||||||
|
|
||||||
|
# start infer with monkey_run
|
||||||
|
start_time = time.time()
|
||||||
|
inferencer = OpenICLInferTask(task.cfg)
|
||||||
|
origin_run = inferencer.run
|
||||||
|
inferencer.run = monkey_run
|
||||||
|
inferencer.run(inferencer, tokens)
|
||||||
|
inferencer.run = origin_run
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f'time elapsed: {end_time - start_time:.2f}s')
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f'task {task_name} fail, see\n{out_path}')
|
||||||
|
returncode = 1
|
||||||
|
finally:
|
||||||
|
# reset stdout and stderr
|
||||||
|
reset_std()
|
||||||
|
return task_name, returncode
|
||||||
|
|
||||||
|
|
||||||
|
def submit(task, type, tokens):
|
||||||
|
"""Helper for launch the task."""
|
||||||
|
task = TASKS.build(dict(cfg=task, type=type))
|
||||||
|
tqdm.write(f'Launch {task.name} on CPU ')
|
||||||
|
|
||||||
|
res = launch(task, tokens)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@RUNNERS.register_module()
|
||||||
|
class LocalAPIRunner(BaseRunner):
|
||||||
|
"""Local API Runner. Start tasks by local python.
|
||||||
|
|
||||||
|
The query per second cannot guarantee the number of concurrents, therefore
|
||||||
|
Supported concurrent users with multiple tasks. Applied for those apis
|
||||||
|
which has a restriction on concurrent numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (ConfigDict): Task type config.
|
||||||
|
concurrent_users (int): Max number of concurrent workers to request
|
||||||
|
the resources.
|
||||||
|
max_num_workers (int): Max number of workers to run in parallel.
|
||||||
|
Defaults to 16.
|
||||||
|
debug (bool): Whether to run in debug mode.
|
||||||
|
lark_bot_url (str): Lark bot url.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
task: ConfigDict,
|
||||||
|
concurrent_users: int,
|
||||||
|
max_num_workers: int = 16,
|
||||||
|
debug: bool = False,
|
||||||
|
lark_bot_url: str = None):
|
||||||
|
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||||
|
self.max_num_workers = max_num_workers
|
||||||
|
self.concurrent_users = concurrent_users
|
||||||
|
assert task['type'] in [
|
||||||
|
'OpenICLInferTask', 'opencompass.tasks.OpenICLInferTask'
|
||||||
|
], 'Only supported for api infer task.'
|
||||||
|
|
||||||
|
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||||
|
"""Launch multiple tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list[dict]): A list of task configs, usually generated by
|
||||||
|
Partitioner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple[str, int]]: A list of (task name, exit code).
|
||||||
|
"""
|
||||||
|
status = []
|
||||||
|
if self.debug:
|
||||||
|
# fall back to LocalRunner debug mode
|
||||||
|
for task in tasks:
|
||||||
|
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type']))
|
||||||
|
task_name = task.name
|
||||||
|
# get cmd
|
||||||
|
mmengine.mkdir_or_exist('tmp/')
|
||||||
|
param_file = f'tmp/{os.getpid()}_params.py'
|
||||||
|
try:
|
||||||
|
task.cfg.dump(param_file)
|
||||||
|
cmd = task.get_command(cfg_path=param_file,
|
||||||
|
template='{task_cmd}')
|
||||||
|
# run in subprocess if starts with torchrun etc.
|
||||||
|
if cmd.startswith('python'):
|
||||||
|
task.run()
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, shell=True, text=True)
|
||||||
|
finally:
|
||||||
|
os.remove(param_file)
|
||||||
|
status.append((task_name, 0))
|
||||||
|
else:
|
||||||
|
|
||||||
|
pbar = tqdm(total=len(tasks))
|
||||||
|
|
||||||
|
get_logger().info('All the logs and processes for each task'
|
||||||
|
' should be checked in each infer/.out file.')
|
||||||
|
with Manager() as manager:
|
||||||
|
tokens = manager.Semaphore(self.concurrent_users)
|
||||||
|
# pbar update has visualization issue when direct
|
||||||
|
# update pbar in callback, need an extra counter
|
||||||
|
pbar_counter = manager.Value('i', 0)
|
||||||
|
status = []
|
||||||
|
|
||||||
|
def update(args):
|
||||||
|
"""Update pbar counter when callback."""
|
||||||
|
pbar_counter.value += 1
|
||||||
|
status.append(args)
|
||||||
|
|
||||||
|
with Pool(processes=self.max_num_workers) as pool:
|
||||||
|
for task in tasks:
|
||||||
|
pool.apply_async(submit,
|
||||||
|
(task, self.task_cfg['type'], tokens),
|
||||||
|
callback=update)
|
||||||
|
pool.close()
|
||||||
|
|
||||||
|
# update progress bar
|
||||||
|
while True:
|
||||||
|
cur_count = pbar_counter.value
|
||||||
|
if cur_count > pbar.n:
|
||||||
|
pbar.update(cur_count - pbar.n)
|
||||||
|
# break when all the task finished
|
||||||
|
if cur_count >= pbar.total:
|
||||||
|
pbar.close()
|
||||||
|
break
|
||||||
|
# sleep to lower the usage
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
pool.join()
|
||||||
|
return status
|
Loading…
Reference in New Issue
Block a user