From aa48a2843d2f0a95797585d0e0fd5899b55b9387 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sun, 15 Dec 2024 18:49:03 +0800 Subject: [PATCH] async --- opencompass/models/async_openai_api.py | 369 ++++++++++++++++ opencompass/models/base_api.py | 72 +++- .../icl_chat_async_inferencer.py | 397 ++++++++++++++++++ .../icl_gen_async_inferencer.py | 239 +++++++++++ opencompass/runners/local_async.py | 112 +++++ opencompass/tasks/__init__.py | 1 + opencompass/tasks/openicl_async_task.py | 168 ++++++++ 7 files changed, 1356 insertions(+), 2 deletions(-) create mode 100644 opencompass/models/async_openai_api.py create mode 100644 opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py create mode 100644 opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py create mode 100644 opencompass/runners/local_async.py create mode 100644 opencompass/tasks/openicl_async_task.py diff --git a/opencompass/models/async_openai_api.py b/opencompass/models/async_openai_api.py new file mode 100644 index 00000000..fc247bfc --- /dev/null +++ b/opencompass/models/async_openai_api.py @@ -0,0 +1,369 @@ +import contextlib +import multiprocessing +import os +import re +from typing import Dict, List, Optional, Union + +import jieba +import weakref +from typing import Literal, Tuple, Iterable + +from opencompass.utils.prompt import PromptList +from opencompass.models.base_api import AsyncTokenBucket, BaseAPIModel + +import threading +import asyncio +from typing import cast +from contextlib import contextmanager + + +PromptType = Union[PromptList, str] +OPENAI_API_BASE = os.path.join( + os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/'), + 'chat/completions') + + +class _APIModelState: + _instance: Dict[str, weakref.ReferenceType["_APIModelState"]] = {} + _count: int + _concurrency: int + _locks = [threading.Lock(), multiprocessing.Lock()] + + def __init__(self, *, name: str, concurrency: int, query_per_second=1) -> None: + self._name = name + self._count = 0 + self._concurrency = concurrency + self._token_bucket = AsyncTokenBucket(rate=query_per_second) + + self._count += 1 + self._concurrency = max(1, self._concurrency // self._count) + + @property + def concurrency(self) -> int: + # If update and concurrency are called simultaneously, the values + # returned here may be inaccurate, but the impact is likely minimal + return self._concurrency + + async def acquire(self): + return await self._token_bucket.acquire() + + @property + def rpm(self): + return self._token_bucket.rpm + + @property + def name(self) -> str: + return self._name + + @property + def count(self): + return self._count + + @classmethod + def _cleanup(cls, ref: weakref.ReferenceType["_APIModelState"]): + with cls._lock(): + self: _APIModelState = ref() # type: ignore + cls._instance.pop(self._name) + + def __new__(cls, name: str, *args, **kwargs) -> "_APIModelState": + with cls._lock(): + if name not in cls._instance: + self = super().__new__(cls) + cls._instance[name] = weakref.ref(self, cls._cleanup) + return cls._instance[name]() # type: ignore + + @classmethod + @contextmanager + def _lock(cls): + with contextlib.ExitStack() as stack: + [stack.enter_context(lock) for lock in cls._locks] + yield + + + +class AsyncOpenAISDK(BaseAPIModel): + states: Dict[str, _APIModelState] = {} + + def __init__( + self, + path: str = 'gpt-3.5-turbo', + max_seq_len: int | None = None, # type: ignore + query_per_second: int = 1, + retry: int = 2, + key: str = 'ENV', + org: str | List[str] | None = None, + meta_template: Dict | None = None, + openai_api_base: str = OPENAI_API_BASE, + openai_proxy_url: Optional[str] = None, + mode: Literal['none', 'front', 'mid', 'rear'] = 'none', + logprobs: bool | None = False, + top_logprobs: int | None = None, + temperature: float | None = None, + tokenizer_path: str | None = None, + extra_body: Dict | None = None, + max_completion_tokens: int = 16384, + verbose: bool = False, + concurrency: int = 64, + status_code_mappings: dict = {}, + ): + from openai import AsyncOpenAI + + assert mode in ['none', 'front', 'mid', 'rear'] + self.mode = mode + state_key = self._get_state_key(api_base=openai_api_base, model_name=path) + if state_key not in AsyncOpenAISDK.states: + AsyncOpenAISDK.states[path] = _APIModelState( + name=state_key, + concurrency=concurrency, + query_per_second=query_per_second, + ) + self.state = AsyncOpenAISDK.states[path] + self.openai_client = AsyncOpenAI(base_url=openai_api_base, api_key=key) + + if max_seq_len is None: + if '16k' in path: + max_seq_len = 16384 + elif 'gpt-4' in path: + max_seq_len = 8192 + elif 'gpt-3.5' in path: + max_seq_len = 4097 + else: + max_seq_len = 32768 + else: + max_seq_len = max_seq_len + + super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, retry=retry) + + self.logprobs = logprobs + self.top_logprobs = top_logprobs + self.tokenizer_path = tokenizer_path + self.hf_tokenizer = None + self.extra_body = extra_body + self.max_completion_tokens = max_completion_tokens + self.temperature = temperature + self.openai_api_base = openai_api_base + self.concurrency = concurrency + + self.status_code_mappings = status_code_mappings + + if openai_proxy_url == 'ENV': + if 'OPENAI_PROXY_URL' not in os.environ: + raise ValueError('OPENAI_PROXY_URL is not set.') + self.proxy_url = os.getenv('OPENAI_PROXY_URL') + else: + self.proxy_url = openai_proxy_url + + async def generate(self, # type: ignore + inputs: Iterable[PromptType], + max_out_len: int = 512, + temperature: float = 0.7, + **kwargs) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[PromptType]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + temperature (float): What sampling temperature to use, + between 0 and 2. Higher values like 0.8 will make the output + more random, while lower values like 0.2 will make it more + focused and deterministic. Defaults to 0.7. + + Returns: + List[str]: A list of generated strings. + """ + if self.temperature is not None: + temperature = self.temperature + + # TODO: This should be an AsyncGenerator if an real `AsyncInference` has been implemented + tasks_queue: List[asyncio.Future] = [] + results_queue: List[Tuple[int, str]] = [] + inputs_iter = enumerate(inputs) + + data_stop = False + while not (data_stop and not tasks_queue): + concurrency = self.state.concurrency + + if tasks_queue: + done, pending = await asyncio.wait(tasks_queue, return_when=asyncio.FIRST_COMPLETED) + tasks_queue = list(pending) + for queue in done: + result: Tuple[int, str] = queue.result() + results_queue.append(result) + + while not data_stop and len(tasks_queue) < concurrency: + try: + index, _input = next(inputs_iter) + except StopIteration: + data_stop = True + break + tasks_queue.append( + asyncio.create_task( + self._generate( + input=_input, + max_out_len=self.max_completion_tokens or max_out_len, + temperature=temperature, + index=index, + ) + ) + ) + results_queue.sort() + return [item[1] for item in results_queue] + + async def generate_from_template(self, templates: List[PromptType], # type: ignore + max_out_len: int, **kwargs): + """Generate completion from a list of templates. + + Args: + templates (List[PromptType]): A list of templates. + max_out_len (int): The maximum length of the output. + """ + inputs = self.parse_template(templates, mode='gen') # type: ignore + return await self.generate(inputs, max_out_len=max_out_len, **kwargs) + + async def _generate(self, input: PromptList | str, max_out_len: int, + temperature: float, index: int) -> Tuple[int, str]: + from openai import APIStatusError, BadRequestError + assert isinstance(input, (str, PromptList)) + + # max num token for gpt-3.5-turbo is 4097 + # Most models' token limits are above 32k + + # will leave 100 tokens as prompt buffer, triggered if input is str + if isinstance(input, str) and self.mode != 'none': + context_window = self.max_seq_len + input = self.bin_trim( + input, + context_window - 100 - max_out_len, + cast(Literal['front', 'mid', 'rear'], self.mode), + ) + + if isinstance(input, str): + messages = [{'role': 'user', 'content': input}] + else: + messages = [] + for item in input: + msg = {'content': item['prompt']} + if item['role'] == 'HUMAN': + msg['role'] = 'user' + elif item['role'] == 'BOT': + msg['role'] = 'assistant' + elif item['role'] == 'SYSTEM': + msg['role'] = 'system' + messages.append(msg) + + + # Hold out 100 tokens due to potential errors in tiktoken calculation + # try: + # max_out_len = min( + # max_out_len, + # context_window - self.get_token_len(str(input)) - 100) + # except KeyError: + # max_out_len = max_out_len + # if max_out_len <= 0: + # return '' + + num_retries = 0 + while num_retries < self.retry: + await self.state.acquire() + + query_data = dict( + model=self.path, + max_tokens=max_out_len, + n=1, + temperature=self.temperature, + messages=messages, + extra_body=self.extra_body, + timeout=600, + ) + + try: + if self.verbose: + self.logger.info('Start calling OpenAI API') + responses = await self.openai_client.chat.completions.create(**query_data) + + if self.verbose: + self.logger.info( + 'Successfully get response from OpenAI API') + try: + self.logger.info(responses) + except Exception as e: # noqa F841 + pass + if not responses.choices: + self.logger.error( + 'Response is empty, it is an internal server error \ + from the API provider.') + return index, responses.choices[0].message.content + + except (BadRequestError, APIStatusError) as e: + # Handle BadRequest status + # You can specify self.status_code_mappings to bypass \ + # API sensitivity blocks + # For example: status_code_mappings={400: 'Input data \ + # may contain inappropriate content.'} + status_code = e.status_code + if (status_code is not None + and status_code in self.status_code_mappings): + error_message = self.status_code_mappings[status_code] + self.logger.info(f'Status Code: {status_code},\n' + f'Original Error Message: {e},\n' + f'Return Message: {error_message} ') + return index, error_message + else: + self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}") + except Exception as e: + self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}") + num_retries += 1 + raise RuntimeError('Calling OpenAI API failed after retrying for ' + f'{self.retry} times. Check the logs for details.') + + def _get_state_key(self, api_base: str, model_name: str): + return api_base + model_name + + def bin_trim(self, prompt: str, num_token: int, mode: Literal['front', 'mid', 'rear']) -> str: + """Get a suffix of prompt which is no longer than num_token tokens. + + Args: + prompt (str): Input string. + num_token (int): The upper bound of token numbers. + + Returns: + str: The trimmed prompt. + """ + token_len = self.get_token_len(prompt) + if token_len <= num_token: + return prompt + pattern = re.compile(r'[\u4e00-\u9fa5]') + if pattern.search(prompt): + words = list(jieba.cut(prompt, cut_all=False)) + sep = '' + else: + words = prompt.split(' ') + sep = ' ' + + l, r = 1, len(words) + while l + 2 < r: + # mode: Literal['front', 'mid', 'rear'] = self.mode + mid = (l + r) // 2 + if mode == 'front': + cur_prompt = sep.join(words[-mid:]) + elif mode == 'mid': + cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:]) + elif mode == 'rear': + cur_prompt = sep.join(words[:mid]) + + if self.get_token_len(cur_prompt) <= num_token: + l = mid # noqa: E741 + else: + r = mid + + if self.mode == 'front': + prompt = sep.join(words[-l:]) + elif self.mode == 'mid': + prompt = sep.join(words[:l]) + sep.join(words[-l:]) + elif self.mode == 'rear': + prompt = sep.join(words[:l]) + return prompt + + diff --git a/opencompass/models/base_api.py b/opencompass/models/base_api.py index 13a8e956..ce823f87 100644 --- a/opencompass/models/base_api.py +++ b/opencompass/models/base_api.py @@ -4,6 +4,7 @@ import threading import time import warnings from abc import abstractmethod +from collections import deque from copy import deepcopy from queue import Queue from time import sleep @@ -12,6 +13,8 @@ from typing import Dict, List, Optional, Tuple, Union from opencompass.utils import get_logger from opencompass.utils.prompt import PromptList +import asyncio + from .base import BaseModel PromptType = Union[PromptList, str] @@ -51,7 +54,7 @@ class BaseAPIModel(BaseModel): self.retry = retry self.query_per_second = query_per_second self.token_bucket = TokenBucket(query_per_second, rpm_verbose) - self.template_parser = APITemplateParser(meta_template) + self.template_parser = APITemplateParser(meta_template) # type: ignore self.logger = get_logger() self.generation_kwargs = generation_kwargs self.verbose = verbose @@ -459,4 +462,69 @@ class TokenBucket: else: break self._request_queue.put(cur_time) - self.logger.info(f'Current RPM {self._request_queue.qsize()}.') + self.logger.info(f"Current RPM {self._request_queue.qsize()}.") + + + +class AsyncTokenBucket: + def __init__(self, rate: int = 1): + self._rate = rate + self._max_tokens = rate * 60 + self._tokens: float = float(self._max_tokens) + self._last_refill_time: float | None = None + + self._request_timestamps: deque[float] = deque() + self._max_window_size = 60 + + self._token_available = asyncio.Event() + + async def _release(self) -> None: + if self._last_refill_time is None: + self._last_refill_time: float = time.monotonic() + + now = time.monotonic() + elapsed = now - self._last_refill_time + tokens_to_add = elapsed * self._rate + + self._tokens = min(self._max_tokens, self._tokens + tokens_to_add) + self._last_refill_time = now + + async def acquire(self) -> bool: + while True: + await self._release() + + if self._tokens >= 1: + self._tokens -= 1 + + now = time.monotonic() + self._request_timestamps.append(now) + + while ( + self._request_timestamps + and now - self._request_timestamps[0] > self._max_window_size + ): + self._request_timestamps.popleft() + + self._token_available.set() + return True + + self._token_available.clear() + + await self._token_available.wait() + + @property + def rpm(self) -> int: + now = time.monotonic() + + while ( + self._request_timestamps + and now - self._request_timestamps[0] > self._max_window_size + ): + self._request_timestamps.popleft() + + return len(self._request_timestamps) + + @property + def available_tokens(self) -> float: + return self._tokens + diff --git a/opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py b/opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py new file mode 100644 index 00000000..16c5ac04 --- /dev/null +++ b/opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py @@ -0,0 +1,397 @@ +"""Chat Inferencer.""" +import os +import os.path as osp +from typing import List, Optional, Union + +import mmengine +from mmengine import is_list_of +from tqdm import tqdm + +from opencompass.models import APITemplateParser as _APITemplateParser +from opencompass.models import BaseModel +from opencompass.models import LMTemplateParser as _LMTemplateParser +from opencompass.registry import ICL_INFERENCERS +from opencompass.utils.prompt import PromptList + +from ..icl_prompt_template import PromptTemplate +from ..icl_retriever import BaseRetriever +from ..utils.logging import get_logger +from .icl_base_inferencer import BaseInferencer, dump_results_dict + +logger = get_logger(__name__) + + +def promptlist_to_openai(prompt: Union[str, PromptList]): + output = [] + if isinstance(prompt, str): + return [dict(role='user', content=prompt)] + + for item in prompt: + if 'section' in item: + continue + if isinstance(item, str) and item: + output.append(dict(role='user', content=item)) + elif item['role'] == 'SYSTEM': + output.append(dict(role='system', content=item['prompt'])) + elif item['role'] == 'HUMAN': + output.append(dict(role='user', content=item['prompt'])) + elif item['role'] == 'BOT': + output.append(dict(role='assistant', content=item['prompt'])) + return output + + +class LMTemplateParser: + """LMTemplateParser accepts OpenAI format dialog inputs.""" + + def __init__(self, meta_template: Optional[dict] = None): + self.meta_template = meta_template + self.roles = {} + role_mapping = { + 'SYSTEM': 'system', + 'HUMAN': 'user', + 'BOT': 'assistant', + } + if meta_template: + for item in meta_template.get('round', []): + role = role_mapping.get(item['role'], item['role']) + self.roles[role] = item.copy() + for item in meta_template.get('reserved_roles', []): + role = role_mapping.get(item['role'], item['role']) + self.roles[role] = item.copy() + + def parse_template(self, chat: List[dict], mode='gen') -> str: + if is_list_of(chat, list): + # Handle batch inputs + return [self.parse_template(item) for item in chat] + + assert is_list_of(chat, dict) + prompt = '' + if self.roles: + for dialog in chat: + role_cfg = self.roles.get(dialog['role'], {}) + prompt += (role_cfg.get('begin') or '') + prompt += (dialog.get('content') or '') + prompt += (role_cfg.get('end') or '') + prompt += (self.roles['assistant'].get('begin') or '') + else: + # in case the model does not have any meta template + last_sep = '' + for item in chat: + prompt += last_sep + (item.get('content') or '') + last_sep = '\n' + return prompt + + +class APITemplateParser: + """APITemplateParser accepts OpenAI format dialog inputs.""" + + def __init__(self, meta_template: Optional[dict] = None): + self.meta_template = meta_template + self.roles = {} + role_mapping = { + 'SYSTEM': 'system', + 'HUMAN': 'user', + 'BOT': 'assistant', + } + if meta_template: + for item in meta_template.get('round', []): + role = role_mapping.get(item['role'], item['role']) + self.roles[role] = item.copy() + for item in meta_template.get('reserved_roles', []): + role = role_mapping.get(item['role'], item['role']) + self.roles[role] = item.copy() + else: + self.roles = dict( + system=dict(api_role='SYSTEM'), + user=dict(api_role='HUMAN'), + assistant=dict(api_role='BOT', generate=True), + ) + + def parse_template(self, chat: List[dict], mode='gen') -> str: + if is_list_of(chat, list): + # Handle batch inputs + return [self.parse_template(item) for item in chat] + + assert is_list_of(chat, dict) + prompt = [] + for dialog in chat: + if dialog['role'] in self.roles: + role = self.roles[dialog['role']]['api_role'] + else: + role = dialog['role'] + prompt.append(dict(role=role, prompt=dialog.get('content') or '')) + return PromptList(prompt) + + +class ChatOutputHandler: + + def __init__(self) -> None: + self.results_dict = {} + + def write_to_json(self, save_dir: str, filename: str): + """Dump the result to a json file.""" + dump_results_dict(self.results_dict, osp.join(save_dir, filename)) + + def save_results(self, + origin_prompt: list, + prediction: str, + idx: int, + gold: str = None): + result_dict = {} + if gold: + result_dict['gold'] = gold + result_dict.update({ + 'prediction': prediction, + 'origin_prompt': origin_prompt, + }) + self.results_dict[str(idx)] = result_dict + + def save_multiround_results(self, + origin_prompt: list, + prediction: str, + idx: int, + gold: str = None): + result_dict = self.results_dict.get(str(idx), { + 'gold': [], + 'prediction': [], + 'origin_prompt': [], + }) + result_dict['gold'].append(gold) + result_dict['prediction'].append(prediction) + result_dict['origin_prompt'].append(origin_prompt) + self.results_dict[str(idx)] = result_dict + + +@ICL_INFERENCERS.register_module() +class AsyncChatInferencer(BaseInferencer): + HandlerType = ChatOutputHandler + + def __init__( + self, + model, + output_json_filepath: Optional[str] = './icl_inference_output', + output_json_filename: Optional[str] = 'predictions', + save_every: Optional[int] = 1, + infer_mode: str = 'last', + max_out_len: int = 512, + **kwargs) -> None: + super().__init__( + model=model, + output_json_filename=output_json_filename, + output_json_filepath=output_json_filepath, + **kwargs, + ) + assert infer_mode in ['last', 'every', 'every_with_gt'] + self.infer_mode = infer_mode + self.model: BaseModel + self._set_meta_template(self.model) + + if self.model.is_api and save_every is None: + save_every = 1 + self.save_every = save_every + self.dialogue_mode = False + self.max_out_len = max_out_len + + def _set_meta_template(self, model): + origin = model.template_parser + if isinstance(origin, _APITemplateParser): + model.template_parser = APITemplateParser(origin.meta_template) + if isinstance(origin, _LMTemplateParser): + model.template_parser = LMTemplateParser(origin.meta_template) + + async def inference(self, # type: ignore + retriever: BaseRetriever, + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None, + output_json_filepath: Optional[str] = None, + output_json_filename: Optional[str] = None) -> dict: + # 1. Preparation for output logs + output_handler = self.HandlerType() + + if output_json_filepath is None: + output_json_filepath = self.output_json_filepath + if output_json_filename is None: + output_json_filename = self.output_json_filename + + # 2. Get results of retrieval process + ice_idx_list = retriever.retrieve() + + # 3. Generate prompts for testing input + chat_list = self.get_chat_list( + ice_idx_list, + retriever, + prompt_template=prompt_template, + ) + + # Create tmp json file for saving intermediate results and future + # resuming + index = 0 + tmp_json_filepath = os.path.join(output_json_filepath, + 'tmp_' + output_json_filename) + if osp.exists(tmp_json_filepath): + # TODO: move resume to output handler + try: + tmp_result_dict = mmengine.load(tmp_json_filepath) + except Exception: + pass + else: + output_handler.results_dict = tmp_result_dict + index = len(tmp_result_dict) + + # 4. Wrap prompts with Dataloader + dataloader = self.get_dataloader(chat_list[index:], batch_size=1) + + # 5. Inference for prompts in each batch + logger.debug('Starting inference process...') + for datum in tqdm(dataloader, disable=not self.is_main_process): + chat = datum[0] + if self.infer_mode == 'last': + await self.infer_last(chat, index, output_handler) + elif self.infer_mode == 'every': + await self.infer_every(chat, index, output_handler) + elif self.infer_mode == 'every_with_gt': + await self.infer_every_with_gt(chat, index, output_handler) + index += 1 + + # Save intermediate results + if (self.save_every is not None and index % self.save_every == 0 + and self.is_main_process): + output_handler.write_to_json(output_json_filepath, + 'tmp_' + output_json_filename) + + # 4. Output + if self.is_main_process: + os.makedirs(output_json_filepath, exist_ok=True) + output_handler.write_to_json(output_json_filepath, + output_json_filename) + if osp.exists(tmp_json_filepath): + os.remove(tmp_json_filepath) + + return output_handler.results_dict + + def get_chat_list(self, + ice_idx_list: List[List[int]], + retriever: BaseRetriever, + prompt_template: Optional[PromptTemplate] = None): + prompt_list = [] + input_columns = retriever.dataset_reader.input_columns + output_column = retriever.dataset_reader.output_column + + def chat_from_entry(entry): + if prompt_template is None and len(input_columns) == 1: + # Directly use the input column as the user input + user = entry.get(input_columns[0]) + assistant = entry.get(output_column, '') + return [ + dict(role='user', content=user), + dict(role='assistant', content=assistant), + ] + elif prompt_template is not None: + # Use prompt template to generate chat history + chat = promptlist_to_openai( + prompt_template.generate_item(entry)) + gold = entry.get(output_column, '') + if chat[-1]['role'] != 'assistant': + chat.append(dict(role='assistant', content=gold)) + return chat + else: + raise ValueError() + + for idx, ice_idx in enumerate(ice_idx_list): + # NOTE: The in-context examples won't be used by now. + + item = { + k: v + for k, v in retriever.test_ds[idx].items() + if k in input_columns or k == output_column + } + if all(isinstance(value, str) for value in item.values()): + # Every column is a single string + chat = chat_from_entry(item) + elif all(is_list_of(value, str) for value in item.values()): + # Every column is a list of string for multi-round chat + entries = [dict(zip(item, v)) for v in zip(*item.values())] + chat = sum((chat_from_entry(entry) for entry in entries), []) + elif len(input_columns) == 1 and is_list_of( + item[input_columns[0]], dict): + # Single input column and it's already a chat. + chat = item[input_columns[0]] + elif 'dialogue' in input_columns: + chat = item['dialogue'] + self.dialogue_mode = True + else: + raise ValueError('Cannot construct chat from the dataset.') + + prompt_list.append(chat) + return prompt_list + + async def infer_last(self, chat: List[dict], index: int, output_handler): + assistant_indices = [ + i for i, item in enumerate(chat) if item['role'] == 'assistant' + ] + + history = chat[:assistant_indices[-1]] + output = await self.model.generate_from_template( + [history], max_out_len=self.max_out_len)[0] + output_handler.save_results( + origin_prompt=history, + prediction=output, + idx=index, + gold=chat[assistant_indices[-1]]['content'], + ) + + async def infer_every(self, chat: List[dict], index: int, output_handler): + assistant_indices = [ + i for i, item in enumerate(chat) if item['role'] == 'assistant' + ] + index_copy = index + + for i in assistant_indices: + history = chat[:i] + output = await self.model.generate_from_template( + [history], max_out_len=self.max_out_len)[0] + chat[i]['content'] = output + if not self.dialogue_mode: + output_handler.save_multiround_results( + origin_prompt=history[-1]['content'], + prediction=output, + idx=index, + gold=chat[i]['content'], + ) + # index += 1 + if self.dialogue_mode: + # dialogue mode for subjective evaluation + assert len(chat) % 2 == 0 + round_num = int(len(chat) / 2) + preds_list = [] + for i in range(round_num): + temp_dict = { + 'round': i + 1, + 'user': chat[i * 2]['content'], + 'assistant': chat[i * 2 + 1]['content'] + } + preds_list.append(temp_dict) + output_handler.save_results( + origin_prompt=None, + prediction=preds_list, + idx=index_copy, + gold=None, + ) + + async def infer_every_with_gt(self, chat: List[dict], index: int, + output_handler): + assistant_indices = [ + i for i, item in enumerate(chat) if item['role'] == 'assistant' + ] + + for i in assistant_indices: + history = chat[:i] + output = await self.model.generate_from_template( + [history], max_out_len=self.max_out_len)[0] + output_handler.save_multiround_results( + origin_prompt=history[-1]['content'], + prediction=output, + idx=index, + gold=chat[i]['content'], + ) + index += 1 diff --git a/opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py new file mode 100644 index 00000000..8dba352a --- /dev/null +++ b/opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py @@ -0,0 +1,239 @@ +"""Direct Generation Inferencer.""" + +import inspect +import json +import os +import os.path as osp +import time +from typing import List, Optional + +import mmengine +import torch +from tqdm import tqdm + +from opencompass.models.base import BaseModel +from opencompass.registry import ICL_INFERENCERS +from opencompass.utils import batched + +from ..icl_prompt_template import PromptTemplate +from ..icl_retriever import BaseRetriever +from ..utils.logging import get_logger +from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler + +logger = get_logger(__name__) + + +@ICL_INFERENCERS.register_module() +class AsyncGenInferencer(BaseInferencer): + """Generation Inferencer class to directly evaluate by generation. + + Attributes: + model (:obj:`BaseModelWrapper`, optional): The module to inference. + max_seq_len (:obj:`int`, optional): Maximum number of tokenized words + allowed by the LM. + min_out_len (:obj:`int`, optional): Minimum number of generated tokens + by the LM + batch_size (:obj:`int`, optional): Batch size for the + :obj:`DataLoader`. + output_json_filepath (:obj:`str`, optional): File path for output + `JSON` file. + output_json_filename (:obj:`str`, optional): File name for output + `JSON` file. + gen_field_replace_token (:obj:`str`, optional): Used to replace the + generation field token when generating prompts. + save_every (:obj:`int`, optional): Save intermediate results every + `save_every` iters. Defaults to 1. + generation_kwargs (:obj:`Dict`, optional): Parameters for the + :obj:`model.generate()` method. + """ + + def __init__( + self, + model: BaseModel, + max_out_len: int, + stopping_criteria: List[str] = [], + max_seq_len: Optional[int] = None, + min_out_len: Optional[int] = None, + batch_size: Optional[int] = 1, + gen_field_replace_token: Optional[str] = '', + output_json_filepath: Optional[str] = './icl_inference_output', + output_json_filename: Optional[str] = 'predictions', + save_every: Optional[int] = 1, + **kwargs) -> None: + super().__init__( + model=model, + max_seq_len=max_seq_len, + batch_size=batch_size, + output_json_filename=output_json_filename, + output_json_filepath=output_json_filepath, + **kwargs, + ) + + self.gen_field_replace_token = gen_field_replace_token + self.max_out_len = max_out_len + self.min_out_len = min_out_len + self.stopping_criteria = stopping_criteria + self.dump_timer = kwargs.get('dump_timer', False) + + if self.model.is_api and save_every is None: + save_every = 1 + self.save_every = save_every + + async def inference(self, # type: ignore + retriever: BaseRetriever, + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None, + output_json_filepath: Optional[str] = None, + output_json_filename: Optional[str] = None) -> List: + # 1. Preparation for output logs + output_handler = GenInferencerOutputHandler() + + if output_json_filepath is None: + output_json_filepath = self.output_json_filepath + if output_json_filename is None: + output_json_filename = self.output_json_filename + + # 2. Get results of retrieval process + ice_idx_list = retriever.retrieve() + + # 3. Generate prompts for testing input + prompt_list = self.get_generation_prompt_list_from_retriever_indices( + ice_idx_list, + retriever, + self.gen_field_replace_token, + max_seq_len=self.max_seq_len, + ice_template=ice_template, + prompt_template=prompt_template) + + # 3.1 Fetch and zip prompt & gold answer if output column exists + ds_reader = retriever.dataset_reader + if ds_reader.output_column: + gold_ans = ds_reader.dataset['test'][ds_reader.output_column] + prompt_list = list(zip(prompt_list, gold_ans)) + + # Create tmp json file for saving intermediate results and future + # resuming + index = 0 + tmp_json_filepath = os.path.join(output_json_filepath, + 'tmp_' + output_json_filename) + if osp.exists(tmp_json_filepath): + # TODO: move resume to output handler + try: + tmp_result_dict = mmengine.load(tmp_json_filepath) + except Exception: + pass + else: + output_handler.results_dict = tmp_result_dict + index = len(tmp_result_dict) + + # 4. Wrap prompts with Dataloader + logger.debug('Starting build dataloader') + dataloader = self.get_dataloader(prompt_list[index:], self.batch_size) + + # 5. Inference for prompts in each batch + logger.debug('Starting inference process...') + + start_time_stamp = time.time() + num_sample = 0 + # TODO: batched dataloader shoule be replaced with async fetching + for datum in dataloader: + if ds_reader.output_column: + entry, golds = list(zip(*datum)) + else: + entry = datum + golds = [None for _ in range(len(entry))] + # 5-1. Inference with local model + extra_gen_kwargs = {} + sig = inspect.signature(self.model.generate) + if 'stopping_criteria' in sig.parameters: + extra_gen_kwargs['stopping_criteria'] = self.stopping_criteria + if 'min_out_len' in sig.parameters: + extra_gen_kwargs['min_out_len'] = self.min_out_len + with torch.no_grad(): + parsed_entries = self.model.parse_template(entry, mode='gen') + results = await self.model.generate_from_template( + entry, max_out_len=self.max_out_len, **extra_gen_kwargs) + generated = results + + num_return_sequences = getattr(self.model, 'generation_kwargs', + {}).get('num_return_sequences', 1) + # 5-3. Save current output + for prompt, prediction, gold in zip( + parsed_entries, batched(generated, num_return_sequences), + golds): + if num_return_sequences == 1: + prediction = prediction[0] + output_handler.save_results(prompt, + prediction, + index, + gold=gold) + index = index + 1 + + # 5-4. Save intermediate results + if (self.save_every is not None and index % self.save_every == 0 + and self.is_main_process): + output_handler.write_to_json(output_json_filepath, + 'tmp_' + output_json_filename) + num_sample += len(datum) + + end_time_stamp = time.time() + + # 6. Output + if self.is_main_process: + os.makedirs(output_json_filepath, exist_ok=True) + output_handler.write_to_json(output_json_filepath, + output_json_filename) + if osp.exists(tmp_json_filepath): + os.remove(tmp_json_filepath) + + if self.dump_timer and self.is_main_process: + timer_filepath = os.path.join(output_json_filepath, 'timer', + 'time.jsonl') + os.makedirs(os.path.dirname(timer_filepath), exist_ok=True) + time_dict = { + 'dataset_name': output_json_filename.removesuffix('.json'), + 'time': end_time_stamp - start_time_stamp, + 'num_sample': num_sample + } + with open(timer_filepath, 'a') as f: + f.write(json.dumps(time_dict) + '\n') + + return [ + sample['prediction'] + for sample in output_handler.results_dict.values() + ] + + def get_generation_prompt_list_from_retriever_indices( + self, + ice_idx_list: List[List[int]], + retriever: BaseRetriever, + gen_field_replace_token: str, + max_seq_len: Optional[int] = None, + ice_template: Optional[PromptTemplate] = None, + prompt_template: Optional[PromptTemplate] = None): + prompt_list = [] + for idx, ice_idx in enumerate(ice_idx_list): + ice = retriever.generate_ice(ice_idx, ice_template=ice_template) + prompt = retriever.generate_prompt_for_generate_task( + idx, + ice, + gen_field_replace_token=gen_field_replace_token, + ice_template=ice_template, + prompt_template=prompt_template) + if max_seq_len is not None: + prompt_token_num = self.model.get_token_len_from_template( + prompt, mode='gen') + while len(ice_idx) > 0 and prompt_token_num > max_seq_len: + ice_idx = ice_idx[:-1] + ice = retriever.generate_ice(ice_idx, + ice_template=ice_template) + prompt = retriever.generate_prompt_for_generate_task( + idx, + ice, + gen_field_replace_token=gen_field_replace_token, + ice_template=ice_template, + prompt_template=prompt_template) + prompt_token_num = self.model.get_token_len_from_template( + prompt, mode='gen') + prompt_list.append(prompt) + return prompt_list diff --git a/opencompass/runners/local_async.py b/opencompass/runners/local_async.py new file mode 100644 index 00000000..2d95d087 --- /dev/null +++ b/opencompass/runners/local_async.py @@ -0,0 +1,112 @@ +from math import prod +import os +import os.path as osp +import re +import subprocess +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from threading import Lock +from typing import Any, Dict, List, Tuple + +import mmengine +import numpy as np +from mmengine.config import ConfigDict +from mmengine.device import is_npu_available +from tqdm import tqdm + +from opencompass.registry import RUNNERS, TASKS +from opencompass.utils import get_logger, model_abbr_from_cfg + +from .base import BaseRunner +from typing import TypedDict, Optional +from multiprocessing.managers import Namespace +import threading +import uuid +import enum +import signal +from enum import IntEnum +import asyncio +import traceback + + +class Status(IntEnum): + SUCCESS = 0 + FAILED = -1 + INTERRUPT = signal.SIGINT + + +@RUNNERS.register_module() +class AsyncRunner(BaseRunner): + """Local runner. Start tasks by local python. + + Args: + task (ConfigDict): Task type config. + max_num_workers (int): Max number of workers to run in parallel. + Defaults to 16. + max_workers_per_gpu (int): Max number of workers to run for one GPU. + Defaults to 1. + debug (bool): Whether to run in debug mode. + lark_bot_url (str): Lark bot url. + """ + + # These is a fake typehint + + def __init__(self, + task: ConfigDict, + debug: bool = False, + *, + max_num_workers: int = 16, + keep_tmp_file: bool = False, + **kwargs): + super().__init__(task=task, debug=debug) + self.max_num_workers = max_num_workers + self.keep_tmp_file = keep_tmp_file + logger = get_logger() + for k, v in kwargs.items(): + logger.warning(f'Ignored argument in `AsyncRunner`: {k}={v}') + + def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, Status]]: # type: ignore + """Launch multiple tasks. + + Args: + tasks (list[dict]): A list of task configs, usually generated by + Partitioner. + Returns: + + list[tuple[str, int]]: A list of (task name, exit code). + """ + from opencompass.tasks.openicl_async_task import OpenICLAsyncInferTask + + if not tasks: + return [("", Status.SUCCESS)] + + assert len(tasks) == 1, f"Task num must be 1 for `AsyncRunner`" + task_cfg = tasks[0] + + task: OpenICLAsyncInferTask = TASKS.build(dict(cfg=task_cfg, type=self.task_cfg['type'])) + task_name = task.name + # get cmd + mmengine.mkdir_or_exist('tmp/') + + try: + asyncio.run(task.run()) + except KeyboardInterrupt: + return [(task_name, Status.INTERRUPT)] + except: + print(traceback.print_exc()) + return [(task_name, Status.FAILED)] + else: + return [(task_name, Status.SUCCESS)] + + def __call__(self, tasks: List[Dict[str, Any]]): + """Launch multiple tasks and summarize the results. + + Args: + tasks (list[dict]): A list of task configs, usually generated by + Partitioner. + """ + status = self.launch(tasks) + status_list = list(status) # change into list format + self.summarize(status_list) diff --git a/opencompass/tasks/__init__.py b/opencompass/tasks/__init__.py index 035662f7..498e408e 100644 --- a/opencompass/tasks/__init__.py +++ b/opencompass/tasks/__init__.py @@ -1,3 +1,4 @@ from .openicl_attack import * # noqa: F401, F403 from .openicl_eval import * # noqa: F401, F403 from .openicl_infer import * # noqa: F401, F403 +from .openicl_async_task import * # noqa: F401, F403 diff --git a/opencompass/tasks/openicl_async_task.py b/opencompass/tasks/openicl_async_task.py new file mode 100644 index 00000000..7fde87a9 --- /dev/null +++ b/opencompass/tasks/openicl_async_task.py @@ -0,0 +1,168 @@ +import argparse +import os +import os.path as osp +import random +import sys +import time +from typing import Any +from tqdm.asyncio import tqdm + +from mmengine.config import Config, ConfigDict +import inspect +from mmengine.utils import mkdir_or_exist + +from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES, + ICL_RETRIEVERS, TASKS) +from opencompass.tasks.base import BaseTask +from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg, + get_infer_output_path, get_logger, + task_abbr_from_cfg) +from opencompass.openicl.icl_inferencer.icl_gen_async_inferencer import AsyncGenInferencer +from opencompass.openicl.icl_inferencer.icl_chat_async_inferencer import AsyncChatInferencer +from opencompass.openicl.icl_inferencer import GenInferencer, ChatInferencer +from concurrent.futures import ThreadPoolExecutor +import asyncio +import resource +from more_itertools import consume + + +soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard)) + + +@TASKS.register_module() +class OpenICLAsyncInferTask(BaseTask): + """OpenICL Inference Task. + + This task is used to run the inference process. + """ + + name_prefix = 'OpenICLInfer' + log_subdir = 'logs/infer' + output_subdir = 'predictions' + + def __init__(self, cfg: ConfigDict): + super().__init__(cfg) + run_cfg = self.model_cfgs[0].get('run_cfg', {}) + self.nproc = run_cfg.get('nproc_per_worker', 16) + + def get_command(self, cfg_path, template) -> str: + # TODO: + raise NotImplementedError() + return "" + + async def run(self): # type: ignore + _dataset_cfgs = [] + infer_cfgs = [] + sub_cfgs = [] + datasets = [] + model_cfgs = [] + for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): + self.max_out_len = model_cfg.get('max_out_len', None) + self.batch_size = model_cfg.get('batch_size', None) + self.min_out_len = model_cfg.get('min_out_len', None) + + for dataset_cfg in dataset_cfgs: + self.dataset_cfg = dataset_cfg + out_path = get_infer_output_path( + model_cfg, dataset_cfg, + osp.join(self.work_dir, 'predictions')) + + if osp.exists(out_path): + continue + _dataset_cfgs.append(dataset_cfg) + datasets.append(build_dataset_from_cfg(dataset_cfg)) + infer_cfgs.append(dataset_cfg['infer_cfg']) + model_cfgs.append(model_cfg) + sub_cfg = { + 'models': [model_cfg], + 'datasets': [[dataset_cfg]], + } + sub_cfgs.append(sub_cfg) + + tasks = [] + args = list(zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs)) + for arg in tqdm( + args, + total=len(args), + desc=f"Starting building tasks..." + ): + tasks.append(asyncio.create_task(self._inference(*arg))) + + bar = tqdm(desc="Inferencing...", total=len(tasks)) + bar.refresh() + + while tasks: + done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for _ in done: + bar.update() + bar.refresh() + + # TODO: Needs a debug mode + # for arg in zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs): + # await self._inference(*arg) + + async def _inference(self, dataset_cfg, infer_cfg, dataset, model_cfg, sub_cfg): + model = build_model_from_cfg(model_cfg) + assert hasattr(infer_cfg, 'ice_template') or hasattr(infer_cfg, 'prompt_template'), \ + 'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501 + + infer_kwargs: dict = {} + if hasattr(infer_cfg, 'ice_template'): + ice_template = ICL_PROMPT_TEMPLATES.build( + infer_cfg['ice_template']) + infer_kwargs['ice_template'] = ice_template + + if hasattr(infer_cfg, 'prompt_template'): + prompt_template = ICL_PROMPT_TEMPLATES.build( + infer_cfg['prompt_template']) + infer_kwargs['prompt_template'] = prompt_template + + retriever_cfg = infer_cfg['retriever'].copy() + retriever_cfg['dataset'] = dataset + retriever = ICL_RETRIEVERS.build(retriever_cfg) + + # set inferencer's default value according to model's config' + inferencer_cfg: dict = infer_cfg['inferencer'] + inferencer_cfg['model'] = model + inferencer_cfg['max_seq_len'] = model_cfg.get('max_seq_len') + + infer_type = inferencer_cfg["type"] + if inspect.isclass(infer_type): + infer_name = infer_type.__name__ + else: + infer_name = infer_type + + if infer_name.split(".")[-1] == "ChatInferencer": + inferencer_cfg["type"] = AsyncChatInferencer + + elif infer_name.split(".")[-1] == "GenInferencer": + inferencer_cfg["type"] = AsyncGenInferencer + + inferencer_cfg.setdefault('max_out_len', self.max_out_len) + inferencer_cfg.setdefault('min_out_len', self.min_out_len) + inferencer_cfg.setdefault('batch_size', self.batch_size) + inferencer = ICL_INFERENCERS.build(inferencer_cfg) + + out_path = get_infer_output_path( + model_cfg, dataset_cfg, + osp.join(self.work_dir, 'predictions')) + out_dir, out_file = osp.split(out_path) + mkdir_or_exist(out_dir) + + infer_kwargs['output_json_filepath'] = out_dir + infer_kwargs['output_json_filename'] = out_file + + await inferencer.inference(retriever, **infer_kwargs) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model Inferencer') + parser.add_argument('config', help='Config file path') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + # TODO: + raise NotImplementedError()