From ac3a2c4501eae2c9ef98fe88244944b70a322eaa Mon Sep 17 00:00:00 2001 From: Hubert <42952108+yingfhu@users.noreply.github.com> Date: Wed, 25 Oct 2023 21:12:20 +0800 Subject: [PATCH] [Feat] local api speed up with fixed concurrent users (#497) * [Feat] local api speed up * fix lint * fix lint * minor fix * add example api --- opencompass/models/__init__.py | 1 + opencompass/models/zhipuai.py | 159 ++++++++++++++++++++ opencompass/runners/local_api.py | 242 +++++++++++++++++++++++++++++++ 3 files changed, 402 insertions(+) create mode 100644 opencompass/models/zhipuai.py create mode 100644 opencompass/runners/local_api.py diff --git a/opencompass/models/__init__.py b/opencompass/models/__init__.py index 2a9455b4..3d66b324 100644 --- a/opencompass/models/__init__.py +++ b/opencompass/models/__init__.py @@ -7,3 +7,4 @@ from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 from .intern_model import InternLM # noqa: F401, F403 from .llama2 import Llama2, Llama2Chat # noqa: F401, F403 from .openai_api import OpenAI # noqa: F401 +from .zhipuai import ZhiPuAI # noqa: F401 diff --git a/opencompass/models/zhipuai.py b/opencompass/models/zhipuai.py new file mode 100644 index 00000000..2e35d0ce --- /dev/null +++ b/opencompass/models/zhipuai.py @@ -0,0 +1,159 @@ +import sys +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.registry import MODELS +from opencompass.utils.prompt import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +@MODELS.register_module() +class ZhiPuAI(BaseAPIModel): + """Model wrapper around ZhiPuAI. + + Args: + path (str): The name of OpenAI's model. + key (str): Authorization key. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + path: str, + key: str, + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + import zhipuai + self.zhipuai = zhipuai + self.zhipuai.api_key = key + self.model = path + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + self.flush() + return results + + def flush(self): + """Flush stdout and stderr when concurrent resources exists. + + When use multiproessing with standard io rediected to files, need to + flush internal information for examination or log loss when system + breaks. + """ + if hasattr(self, 'tokens'): + sys.stdout.flush() + sys.stderr.flush() + + def acquire(self): + """Acquire concurrent resources if exists. + + This behavior will fall back to wait with query_per_second if there are + no concurrent resources. + """ + if hasattr(self, 'tokens'): + self.tokens.acquire() + else: + self.wait() + + def release(self): + """Release concurrent resources if acquired. + + This behavior will fall back to do nothing if there are no concurrent + resources. + """ + if hasattr(self, 'tokens'): + self.tokens.release() + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + messages.append(msg) + + data = {'model': self.model, 'prompt': messages} + + max_num_retries = 0 + while max_num_retries < self.retry: + self.acquire() + response = self.zhipuai.model_api.invoke(**data) + self.release() + + if response is None: + print('Connection error, reconnect.') + # if connect error, frequent requests will casuse + # continuous unstable network, therefore wait here + # to slow down the request + self.wait() + continue + if response['code'] == 200 and response['success']: + msg = response['data']['choices'][0]['content'] + return msg + # sensitive content, prompt overlength, network error + # or illegal prompt + if (response['code'] == 1301 or response['code'] == 1261 + or response['code'] == 1234 or response['code'] == 1214): + print(response['msg']) + return '' + print(response) + max_num_retries += 1 + + raise RuntimeError(response['msg']) diff --git a/opencompass/runners/local_api.py b/opencompass/runners/local_api.py new file mode 100644 index 00000000..c5346c24 --- /dev/null +++ b/opencompass/runners/local_api.py @@ -0,0 +1,242 @@ +import logging +import os +import os.path as osp +import subprocess +import sys +import time +from multiprocessing import Manager, Pool +from multiprocessing.managers import SyncManager +from typing import Any, Dict, List, Tuple + +import mmengine +from mmengine.config import ConfigDict +from tqdm import tqdm + +from opencompass.registry import RUNNERS, TASKS +from opencompass.tasks import OpenICLInferTask +from opencompass.tasks.base import BaseTask +from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg, + get_infer_output_path, get_logger, + task_abbr_from_cfg) + +from .base import BaseRunner + + +def monkey_run(self, tokens: SyncManager.Semaphore): + """Hack for infer task run, add tokens for multiprocess.""" + self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}') + for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): + self.max_out_len = model_cfg.get('max_out_len', None) + self.batch_size = model_cfg.get('batch_size', None) + self.model = build_model_from_cfg(model_cfg) + # add global tokens for concurrents + assert self.model.is_api, 'Only API model is supported.' + self.model.tokens = tokens + + for dataset_cfg in dataset_cfgs: + self.model_cfg = model_cfg + self.dataset_cfg = dataset_cfg + self.infer_cfg = self.dataset_cfg['infer_cfg'] + self.dataset = build_dataset_from_cfg(self.dataset_cfg) + self.sub_cfg = { + 'models': [self.model_cfg], + 'datasets': [[self.dataset_cfg]], + } + out_path = get_infer_output_path( + self.model_cfg, self.dataset_cfg, + osp.join(self.work_dir, 'predictions')) + if osp.exists(out_path): + continue + self._inference() + + +old_stdout = sys.stdout +old_stderr = sys.stderr + + +def redirect_std_to_file(filename: str): + """Redirect stdout and stderr, also change logger stream handler.""" + f = open(filename, 'w', encoding='utf-8') + sys.stdout = f + sys.stderr = f + # change logger stream handler as well + logger = get_logger() + for h in logger.handlers: + if isinstance(h, logging.StreamHandler): + h.stream = sys.stdout + # special treat for icl_gen_inferencer logger + gen_logger = logging.getLogger( + 'opencompass.openicl.icl_inferencer.icl_gen_inferencer') + for h in gen_logger.handlers: + if isinstance(h, logging.StreamHandler): + h.stream = sys.stdout + + +def reset_std(): + """Reset stdout and stderr, also change logger stream handler.""" + sys.stdout.close() + sys.stdout = old_stdout + sys.stderr = old_stderr + # change logger stream handler as well + logger = get_logger() + for h in logger.handlers: + if isinstance(h, logging.StreamHandler): + h.stream = sys.stdout + # special treat for icl_gen_inferencer logger + gen_logger = logging.getLogger( + 'opencompass.openicl.icl_inferencer.icl_gen_inferencer') + for h in gen_logger.handlers: + if isinstance(h, logging.StreamHandler): + h.stream = sys.stdout + + +def launch(task: BaseTask, tokens: SyncManager.Semaphore): + """Launch a single task. + + Args: + task (BaseTask): Task to launch. + tokens (SyncManager.Semaphore): Multiprocessing semaphore + for every subprocess to follow. + + Returns: + tuple[str, int]: Task name and exit code. + """ + + task_name = task.name + returncode = 0 + logger = get_logger() + + try: + # get log file and redirect stdout and stderr + out_path = task.get_log_path(file_extension='out') + mmengine.mkdir_or_exist(osp.split(out_path)[0]) + redirect_std_to_file(out_path) + + # start infer with monkey_run + start_time = time.time() + inferencer = OpenICLInferTask(task.cfg) + origin_run = inferencer.run + inferencer.run = monkey_run + inferencer.run(inferencer, tokens) + inferencer.run = origin_run + end_time = time.time() + logger.info(f'time elapsed: {end_time - start_time:.2f}s') + except Exception: + logger.warning(f'task {task_name} fail, see\n{out_path}') + returncode = 1 + finally: + # reset stdout and stderr + reset_std() + return task_name, returncode + + +def submit(task, type, tokens): + """Helper for launch the task.""" + task = TASKS.build(dict(cfg=task, type=type)) + tqdm.write(f'Launch {task.name} on CPU ') + + res = launch(task, tokens) + return res + + +@RUNNERS.register_module() +class LocalAPIRunner(BaseRunner): + """Local API Runner. Start tasks by local python. + + The query per second cannot guarantee the number of concurrents, therefore + Supported concurrent users with multiple tasks. Applied for those apis + which has a restriction on concurrent numbers. + + Args: + task (ConfigDict): Task type config. + concurrent_users (int): Max number of concurrent workers to request + the resources. + max_num_workers (int): Max number of workers to run in parallel. + Defaults to 16. + debug (bool): Whether to run in debug mode. + lark_bot_url (str): Lark bot url. + """ + + def __init__(self, + task: ConfigDict, + concurrent_users: int, + max_num_workers: int = 16, + debug: bool = False, + lark_bot_url: str = None): + super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) + self.max_num_workers = max_num_workers + self.concurrent_users = concurrent_users + assert task['type'] in [ + 'OpenICLInferTask', 'opencompass.tasks.OpenICLInferTask' + ], 'Only supported for api infer task.' + + def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: + """Launch multiple tasks. + + Args: + tasks (list[dict]): A list of task configs, usually generated by + Partitioner. + + Returns: + list[tuple[str, int]]: A list of (task name, exit code). + """ + status = [] + if self.debug: + # fall back to LocalRunner debug mode + for task in tasks: + task = TASKS.build(dict(cfg=task, type=self.task_cfg['type'])) + task_name = task.name + # get cmd + mmengine.mkdir_or_exist('tmp/') + param_file = f'tmp/{os.getpid()}_params.py' + try: + task.cfg.dump(param_file) + cmd = task.get_command(cfg_path=param_file, + template='{task_cmd}') + # run in subprocess if starts with torchrun etc. + if cmd.startswith('python'): + task.run() + else: + subprocess.run(cmd, shell=True, text=True) + finally: + os.remove(param_file) + status.append((task_name, 0)) + else: + + pbar = tqdm(total=len(tasks)) + + get_logger().info('All the logs and processes for each task' + ' should be checked in each infer/.out file.') + with Manager() as manager: + tokens = manager.Semaphore(self.concurrent_users) + # pbar update has visualization issue when direct + # update pbar in callback, need an extra counter + pbar_counter = manager.Value('i', 0) + status = [] + + def update(args): + """Update pbar counter when callback.""" + pbar_counter.value += 1 + status.append(args) + + with Pool(processes=self.max_num_workers) as pool: + for task in tasks: + pool.apply_async(submit, + (task, self.task_cfg['type'], tokens), + callback=update) + pool.close() + + # update progress bar + while True: + cur_count = pbar_counter.value + if cur_count > pbar.n: + pbar.update(cur_count - pbar.n) + # break when all the task finished + if cur_count >= pbar.total: + pbar.close() + break + # sleep to lower the usage + time.sleep(1) + + pool.join() + return status