mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
async
This commit is contained in:
parent
aeded4c4db
commit
aa48a2843d
369
opencompass/models/async_openai_api.py
Normal file
369
opencompass/models/async_openai_api.py
Normal file
@ -0,0 +1,369 @@
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import jieba
|
||||
import weakref
|
||||
from typing import Literal, Tuple, Iterable
|
||||
|
||||
from opencompass.utils.prompt import PromptList
|
||||
from opencompass.models.base_api import AsyncTokenBucket, BaseAPIModel
|
||||
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
OPENAI_API_BASE = os.path.join(
|
||||
os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/'),
|
||||
'chat/completions')
|
||||
|
||||
|
||||
class _APIModelState:
|
||||
_instance: Dict[str, weakref.ReferenceType["_APIModelState"]] = {}
|
||||
_count: int
|
||||
_concurrency: int
|
||||
_locks = [threading.Lock(), multiprocessing.Lock()]
|
||||
|
||||
def __init__(self, *, name: str, concurrency: int, query_per_second=1) -> None:
|
||||
self._name = name
|
||||
self._count = 0
|
||||
self._concurrency = concurrency
|
||||
self._token_bucket = AsyncTokenBucket(rate=query_per_second)
|
||||
|
||||
self._count += 1
|
||||
self._concurrency = max(1, self._concurrency // self._count)
|
||||
|
||||
@property
|
||||
def concurrency(self) -> int:
|
||||
# If update and concurrency are called simultaneously, the values
|
||||
# returned here may be inaccurate, but the impact is likely minimal
|
||||
return self._concurrency
|
||||
|
||||
async def acquire(self):
|
||||
return await self._token_bucket.acquire()
|
||||
|
||||
@property
|
||||
def rpm(self):
|
||||
return self._token_bucket.rpm
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return self._count
|
||||
|
||||
@classmethod
|
||||
def _cleanup(cls, ref: weakref.ReferenceType["_APIModelState"]):
|
||||
with cls._lock():
|
||||
self: _APIModelState = ref() # type: ignore
|
||||
cls._instance.pop(self._name)
|
||||
|
||||
def __new__(cls, name: str, *args, **kwargs) -> "_APIModelState":
|
||||
with cls._lock():
|
||||
if name not in cls._instance:
|
||||
self = super().__new__(cls)
|
||||
cls._instance[name] = weakref.ref(self, cls._cleanup)
|
||||
return cls._instance[name]() # type: ignore
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def _lock(cls):
|
||||
with contextlib.ExitStack() as stack:
|
||||
[stack.enter_context(lock) for lock in cls._locks]
|
||||
yield
|
||||
|
||||
|
||||
|
||||
class AsyncOpenAISDK(BaseAPIModel):
|
||||
states: Dict[str, _APIModelState] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str = 'gpt-3.5-turbo',
|
||||
max_seq_len: int | None = None, # type: ignore
|
||||
query_per_second: int = 1,
|
||||
retry: int = 2,
|
||||
key: str = 'ENV',
|
||||
org: str | List[str] | None = None,
|
||||
meta_template: Dict | None = None,
|
||||
openai_api_base: str = OPENAI_API_BASE,
|
||||
openai_proxy_url: Optional[str] = None,
|
||||
mode: Literal['none', 'front', 'mid', 'rear'] = 'none',
|
||||
logprobs: bool | None = False,
|
||||
top_logprobs: int | None = None,
|
||||
temperature: float | None = None,
|
||||
tokenizer_path: str | None = None,
|
||||
extra_body: Dict | None = None,
|
||||
max_completion_tokens: int = 16384,
|
||||
verbose: bool = False,
|
||||
concurrency: int = 64,
|
||||
status_code_mappings: dict = {},
|
||||
):
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
assert mode in ['none', 'front', 'mid', 'rear']
|
||||
self.mode = mode
|
||||
state_key = self._get_state_key(api_base=openai_api_base, model_name=path)
|
||||
if state_key not in AsyncOpenAISDK.states:
|
||||
AsyncOpenAISDK.states[path] = _APIModelState(
|
||||
name=state_key,
|
||||
concurrency=concurrency,
|
||||
query_per_second=query_per_second,
|
||||
)
|
||||
self.state = AsyncOpenAISDK.states[path]
|
||||
self.openai_client = AsyncOpenAI(base_url=openai_api_base, api_key=key)
|
||||
|
||||
if max_seq_len is None:
|
||||
if '16k' in path:
|
||||
max_seq_len = 16384
|
||||
elif 'gpt-4' in path:
|
||||
max_seq_len = 8192
|
||||
elif 'gpt-3.5' in path:
|
||||
max_seq_len = 4097
|
||||
else:
|
||||
max_seq_len = 32768
|
||||
else:
|
||||
max_seq_len = max_seq_len
|
||||
|
||||
super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, retry=retry)
|
||||
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.hf_tokenizer = None
|
||||
self.extra_body = extra_body
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.temperature = temperature
|
||||
self.openai_api_base = openai_api_base
|
||||
self.concurrency = concurrency
|
||||
|
||||
self.status_code_mappings = status_code_mappings
|
||||
|
||||
if openai_proxy_url == 'ENV':
|
||||
if 'OPENAI_PROXY_URL' not in os.environ:
|
||||
raise ValueError('OPENAI_PROXY_URL is not set.')
|
||||
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
|
||||
else:
|
||||
self.proxy_url = openai_proxy_url
|
||||
|
||||
async def generate(self, # type: ignore
|
||||
inputs: Iterable[PromptType],
|
||||
max_out_len: int = 512,
|
||||
temperature: float = 0.7,
|
||||
**kwargs) -> List[str]:
|
||||
"""Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs (List[PromptType]): A list of strings or PromptDicts.
|
||||
The PromptDict should be organized in OpenCompass'
|
||||
API format.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
temperature (float): What sampling temperature to use,
|
||||
between 0 and 2. Higher values like 0.8 will make the output
|
||||
more random, while lower values like 0.2 will make it more
|
||||
focused and deterministic. Defaults to 0.7.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of generated strings.
|
||||
"""
|
||||
if self.temperature is not None:
|
||||
temperature = self.temperature
|
||||
|
||||
# TODO: This should be an AsyncGenerator if an real `AsyncInference` has been implemented
|
||||
tasks_queue: List[asyncio.Future] = []
|
||||
results_queue: List[Tuple[int, str]] = []
|
||||
inputs_iter = enumerate(inputs)
|
||||
|
||||
data_stop = False
|
||||
while not (data_stop and not tasks_queue):
|
||||
concurrency = self.state.concurrency
|
||||
|
||||
if tasks_queue:
|
||||
done, pending = await asyncio.wait(tasks_queue, return_when=asyncio.FIRST_COMPLETED)
|
||||
tasks_queue = list(pending)
|
||||
for queue in done:
|
||||
result: Tuple[int, str] = queue.result()
|
||||
results_queue.append(result)
|
||||
|
||||
while not data_stop and len(tasks_queue) < concurrency:
|
||||
try:
|
||||
index, _input = next(inputs_iter)
|
||||
except StopIteration:
|
||||
data_stop = True
|
||||
break
|
||||
tasks_queue.append(
|
||||
asyncio.create_task(
|
||||
self._generate(
|
||||
input=_input,
|
||||
max_out_len=self.max_completion_tokens or max_out_len,
|
||||
temperature=temperature,
|
||||
index=index,
|
||||
)
|
||||
)
|
||||
)
|
||||
results_queue.sort()
|
||||
return [item[1] for item in results_queue]
|
||||
|
||||
async def generate_from_template(self, templates: List[PromptType], # type: ignore
|
||||
max_out_len: int, **kwargs):
|
||||
"""Generate completion from a list of templates.
|
||||
|
||||
Args:
|
||||
templates (List[PromptType]): A list of templates.
|
||||
max_out_len (int): The maximum length of the output.
|
||||
"""
|
||||
inputs = self.parse_template(templates, mode='gen') # type: ignore
|
||||
return await self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||
|
||||
async def _generate(self, input: PromptList | str, max_out_len: int,
|
||||
temperature: float, index: int) -> Tuple[int, str]:
|
||||
from openai import APIStatusError, BadRequestError
|
||||
assert isinstance(input, (str, PromptList))
|
||||
|
||||
# max num token for gpt-3.5-turbo is 4097
|
||||
# Most models' token limits are above 32k
|
||||
|
||||
# will leave 100 tokens as prompt buffer, triggered if input is str
|
||||
if isinstance(input, str) and self.mode != 'none':
|
||||
context_window = self.max_seq_len
|
||||
input = self.bin_trim(
|
||||
input,
|
||||
context_window - 100 - max_out_len,
|
||||
cast(Literal['front', 'mid', 'rear'], self.mode),
|
||||
)
|
||||
|
||||
if isinstance(input, str):
|
||||
messages = [{'role': 'user', 'content': input}]
|
||||
else:
|
||||
messages = []
|
||||
for item in input:
|
||||
msg = {'content': item['prompt']}
|
||||
if item['role'] == 'HUMAN':
|
||||
msg['role'] = 'user'
|
||||
elif item['role'] == 'BOT':
|
||||
msg['role'] = 'assistant'
|
||||
elif item['role'] == 'SYSTEM':
|
||||
msg['role'] = 'system'
|
||||
messages.append(msg)
|
||||
|
||||
|
||||
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
||||
# try:
|
||||
# max_out_len = min(
|
||||
# max_out_len,
|
||||
# context_window - self.get_token_len(str(input)) - 100)
|
||||
# except KeyError:
|
||||
# max_out_len = max_out_len
|
||||
# if max_out_len <= 0:
|
||||
# return ''
|
||||
|
||||
num_retries = 0
|
||||
while num_retries < self.retry:
|
||||
await self.state.acquire()
|
||||
|
||||
query_data = dict(
|
||||
model=self.path,
|
||||
max_tokens=max_out_len,
|
||||
n=1,
|
||||
temperature=self.temperature,
|
||||
messages=messages,
|
||||
extra_body=self.extra_body,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
try:
|
||||
if self.verbose:
|
||||
self.logger.info('Start calling OpenAI API')
|
||||
responses = await self.openai_client.chat.completions.create(**query_data)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
'Successfully get response from OpenAI API')
|
||||
try:
|
||||
self.logger.info(responses)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
if not responses.choices:
|
||||
self.logger.error(
|
||||
'Response is empty, it is an internal server error \
|
||||
from the API provider.')
|
||||
return index, responses.choices[0].message.content
|
||||
|
||||
except (BadRequestError, APIStatusError) as e:
|
||||
# Handle BadRequest status
|
||||
# You can specify self.status_code_mappings to bypass \
|
||||
# API sensitivity blocks
|
||||
# For example: status_code_mappings={400: 'Input data \
|
||||
# may contain inappropriate content.'}
|
||||
status_code = e.status_code
|
||||
if (status_code is not None
|
||||
and status_code in self.status_code_mappings):
|
||||
error_message = self.status_code_mappings[status_code]
|
||||
self.logger.info(f'Status Code: {status_code},\n'
|
||||
f'Original Error Message: {e},\n'
|
||||
f'Return Message: {error_message} ')
|
||||
return index, error_message
|
||||
else:
|
||||
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
|
||||
num_retries += 1
|
||||
raise RuntimeError('Calling OpenAI API failed after retrying for '
|
||||
f'{self.retry} times. Check the logs for details.')
|
||||
|
||||
def _get_state_key(self, api_base: str, model_name: str):
|
||||
return api_base + model_name
|
||||
|
||||
def bin_trim(self, prompt: str, num_token: int, mode: Literal['front', 'mid', 'rear']) -> str:
|
||||
"""Get a suffix of prompt which is no longer than num_token tokens.
|
||||
|
||||
Args:
|
||||
prompt (str): Input string.
|
||||
num_token (int): The upper bound of token numbers.
|
||||
|
||||
Returns:
|
||||
str: The trimmed prompt.
|
||||
"""
|
||||
token_len = self.get_token_len(prompt)
|
||||
if token_len <= num_token:
|
||||
return prompt
|
||||
pattern = re.compile(r'[\u4e00-\u9fa5]')
|
||||
if pattern.search(prompt):
|
||||
words = list(jieba.cut(prompt, cut_all=False))
|
||||
sep = ''
|
||||
else:
|
||||
words = prompt.split(' ')
|
||||
sep = ' '
|
||||
|
||||
l, r = 1, len(words)
|
||||
while l + 2 < r:
|
||||
# mode: Literal['front', 'mid', 'rear'] = self.mode
|
||||
mid = (l + r) // 2
|
||||
if mode == 'front':
|
||||
cur_prompt = sep.join(words[-mid:])
|
||||
elif mode == 'mid':
|
||||
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
|
||||
elif mode == 'rear':
|
||||
cur_prompt = sep.join(words[:mid])
|
||||
|
||||
if self.get_token_len(cur_prompt) <= num_token:
|
||||
l = mid # noqa: E741
|
||||
else:
|
||||
r = mid
|
||||
|
||||
if self.mode == 'front':
|
||||
prompt = sep.join(words[-l:])
|
||||
elif self.mode == 'mid':
|
||||
prompt = sep.join(words[:l]) + sep.join(words[-l:])
|
||||
elif self.mode == 'rear':
|
||||
prompt = sep.join(words[:l])
|
||||
return prompt
|
||||
|
||||
|
@ -4,6 +4,7 @@ import threading
|
||||
import time
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from queue import Queue
|
||||
from time import sleep
|
||||
@ -12,6 +13,8 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
from opencompass.utils import get_logger
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
import asyncio
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
PromptType = Union[PromptList, str]
|
||||
@ -51,7 +54,7 @@ class BaseAPIModel(BaseModel):
|
||||
self.retry = retry
|
||||
self.query_per_second = query_per_second
|
||||
self.token_bucket = TokenBucket(query_per_second, rpm_verbose)
|
||||
self.template_parser = APITemplateParser(meta_template)
|
||||
self.template_parser = APITemplateParser(meta_template) # type: ignore
|
||||
self.logger = get_logger()
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.verbose = verbose
|
||||
@ -459,4 +462,69 @@ class TokenBucket:
|
||||
else:
|
||||
break
|
||||
self._request_queue.put(cur_time)
|
||||
self.logger.info(f'Current RPM {self._request_queue.qsize()}.')
|
||||
self.logger.info(f"Current RPM {self._request_queue.qsize()}.")
|
||||
|
||||
|
||||
|
||||
class AsyncTokenBucket:
|
||||
def __init__(self, rate: int = 1):
|
||||
self._rate = rate
|
||||
self._max_tokens = rate * 60
|
||||
self._tokens: float = float(self._max_tokens)
|
||||
self._last_refill_time: float | None = None
|
||||
|
||||
self._request_timestamps: deque[float] = deque()
|
||||
self._max_window_size = 60
|
||||
|
||||
self._token_available = asyncio.Event()
|
||||
|
||||
async def _release(self) -> None:
|
||||
if self._last_refill_time is None:
|
||||
self._last_refill_time: float = time.monotonic()
|
||||
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_refill_time
|
||||
tokens_to_add = elapsed * self._rate
|
||||
|
||||
self._tokens = min(self._max_tokens, self._tokens + tokens_to_add)
|
||||
self._last_refill_time = now
|
||||
|
||||
async def acquire(self) -> bool:
|
||||
while True:
|
||||
await self._release()
|
||||
|
||||
if self._tokens >= 1:
|
||||
self._tokens -= 1
|
||||
|
||||
now = time.monotonic()
|
||||
self._request_timestamps.append(now)
|
||||
|
||||
while (
|
||||
self._request_timestamps
|
||||
and now - self._request_timestamps[0] > self._max_window_size
|
||||
):
|
||||
self._request_timestamps.popleft()
|
||||
|
||||
self._token_available.set()
|
||||
return True
|
||||
|
||||
self._token_available.clear()
|
||||
|
||||
await self._token_available.wait()
|
||||
|
||||
@property
|
||||
def rpm(self) -> int:
|
||||
now = time.monotonic()
|
||||
|
||||
while (
|
||||
self._request_timestamps
|
||||
and now - self._request_timestamps[0] > self._max_window_size
|
||||
):
|
||||
self._request_timestamps.popleft()
|
||||
|
||||
return len(self._request_timestamps)
|
||||
|
||||
@property
|
||||
def available_tokens(self) -> float:
|
||||
return self._tokens
|
||||
|
||||
|
397
opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py
Normal file
397
opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py
Normal file
@ -0,0 +1,397 @@
|
||||
"""Chat Inferencer."""
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mmengine
|
||||
from mmengine import is_list_of
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.models import APITemplateParser as _APITemplateParser
|
||||
from opencompass.models import BaseModel
|
||||
from opencompass.models import LMTemplateParser as _LMTemplateParser
|
||||
from opencompass.registry import ICL_INFERENCERS
|
||||
from opencompass.utils.prompt import PromptList
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils.logging import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, dump_results_dict
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def promptlist_to_openai(prompt: Union[str, PromptList]):
|
||||
output = []
|
||||
if isinstance(prompt, str):
|
||||
return [dict(role='user', content=prompt)]
|
||||
|
||||
for item in prompt:
|
||||
if 'section' in item:
|
||||
continue
|
||||
if isinstance(item, str) and item:
|
||||
output.append(dict(role='user', content=item))
|
||||
elif item['role'] == 'SYSTEM':
|
||||
output.append(dict(role='system', content=item['prompt']))
|
||||
elif item['role'] == 'HUMAN':
|
||||
output.append(dict(role='user', content=item['prompt']))
|
||||
elif item['role'] == 'BOT':
|
||||
output.append(dict(role='assistant', content=item['prompt']))
|
||||
return output
|
||||
|
||||
|
||||
class LMTemplateParser:
|
||||
"""LMTemplateParser accepts OpenAI format dialog inputs."""
|
||||
|
||||
def __init__(self, meta_template: Optional[dict] = None):
|
||||
self.meta_template = meta_template
|
||||
self.roles = {}
|
||||
role_mapping = {
|
||||
'SYSTEM': 'system',
|
||||
'HUMAN': 'user',
|
||||
'BOT': 'assistant',
|
||||
}
|
||||
if meta_template:
|
||||
for item in meta_template.get('round', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
for item in meta_template.get('reserved_roles', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
|
||||
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
||||
if is_list_of(chat, list):
|
||||
# Handle batch inputs
|
||||
return [self.parse_template(item) for item in chat]
|
||||
|
||||
assert is_list_of(chat, dict)
|
||||
prompt = ''
|
||||
if self.roles:
|
||||
for dialog in chat:
|
||||
role_cfg = self.roles.get(dialog['role'], {})
|
||||
prompt += (role_cfg.get('begin') or '')
|
||||
prompt += (dialog.get('content') or '')
|
||||
prompt += (role_cfg.get('end') or '')
|
||||
prompt += (self.roles['assistant'].get('begin') or '')
|
||||
else:
|
||||
# in case the model does not have any meta template
|
||||
last_sep = ''
|
||||
for item in chat:
|
||||
prompt += last_sep + (item.get('content') or '')
|
||||
last_sep = '\n'
|
||||
return prompt
|
||||
|
||||
|
||||
class APITemplateParser:
|
||||
"""APITemplateParser accepts OpenAI format dialog inputs."""
|
||||
|
||||
def __init__(self, meta_template: Optional[dict] = None):
|
||||
self.meta_template = meta_template
|
||||
self.roles = {}
|
||||
role_mapping = {
|
||||
'SYSTEM': 'system',
|
||||
'HUMAN': 'user',
|
||||
'BOT': 'assistant',
|
||||
}
|
||||
if meta_template:
|
||||
for item in meta_template.get('round', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
for item in meta_template.get('reserved_roles', []):
|
||||
role = role_mapping.get(item['role'], item['role'])
|
||||
self.roles[role] = item.copy()
|
||||
else:
|
||||
self.roles = dict(
|
||||
system=dict(api_role='SYSTEM'),
|
||||
user=dict(api_role='HUMAN'),
|
||||
assistant=dict(api_role='BOT', generate=True),
|
||||
)
|
||||
|
||||
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
||||
if is_list_of(chat, list):
|
||||
# Handle batch inputs
|
||||
return [self.parse_template(item) for item in chat]
|
||||
|
||||
assert is_list_of(chat, dict)
|
||||
prompt = []
|
||||
for dialog in chat:
|
||||
if dialog['role'] in self.roles:
|
||||
role = self.roles[dialog['role']]['api_role']
|
||||
else:
|
||||
role = dialog['role']
|
||||
prompt.append(dict(role=role, prompt=dialog.get('content') or ''))
|
||||
return PromptList(prompt)
|
||||
|
||||
|
||||
class ChatOutputHandler:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.results_dict = {}
|
||||
|
||||
def write_to_json(self, save_dir: str, filename: str):
|
||||
"""Dump the result to a json file."""
|
||||
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
|
||||
|
||||
def save_results(self,
|
||||
origin_prompt: list,
|
||||
prediction: str,
|
||||
idx: int,
|
||||
gold: str = None):
|
||||
result_dict = {}
|
||||
if gold:
|
||||
result_dict['gold'] = gold
|
||||
result_dict.update({
|
||||
'prediction': prediction,
|
||||
'origin_prompt': origin_prompt,
|
||||
})
|
||||
self.results_dict[str(idx)] = result_dict
|
||||
|
||||
def save_multiround_results(self,
|
||||
origin_prompt: list,
|
||||
prediction: str,
|
||||
idx: int,
|
||||
gold: str = None):
|
||||
result_dict = self.results_dict.get(str(idx), {
|
||||
'gold': [],
|
||||
'prediction': [],
|
||||
'origin_prompt': [],
|
||||
})
|
||||
result_dict['gold'].append(gold)
|
||||
result_dict['prediction'].append(prediction)
|
||||
result_dict['origin_prompt'].append(origin_prompt)
|
||||
self.results_dict[str(idx)] = result_dict
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class AsyncChatInferencer(BaseInferencer):
|
||||
HandlerType = ChatOutputHandler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = 1,
|
||||
infer_mode: str = 'last',
|
||||
max_out_len: int = 512,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
**kwargs,
|
||||
)
|
||||
assert infer_mode in ['last', 'every', 'every_with_gt']
|
||||
self.infer_mode = infer_mode
|
||||
self.model: BaseModel
|
||||
self._set_meta_template(self.model)
|
||||
|
||||
if self.model.is_api and save_every is None:
|
||||
save_every = 1
|
||||
self.save_every = save_every
|
||||
self.dialogue_mode = False
|
||||
self.max_out_len = max_out_len
|
||||
|
||||
def _set_meta_template(self, model):
|
||||
origin = model.template_parser
|
||||
if isinstance(origin, _APITemplateParser):
|
||||
model.template_parser = APITemplateParser(origin.meta_template)
|
||||
if isinstance(origin, _LMTemplateParser):
|
||||
model.template_parser = LMTemplateParser(origin.meta_template)
|
||||
|
||||
async def inference(self, # type: ignore
|
||||
retriever: BaseRetriever,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
output_json_filepath: Optional[str] = None,
|
||||
output_json_filename: Optional[str] = None) -> dict:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = self.HandlerType()
|
||||
|
||||
if output_json_filepath is None:
|
||||
output_json_filepath = self.output_json_filepath
|
||||
if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
chat_list = self.get_chat_list(
|
||||
ice_idx_list,
|
||||
retriever,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
# TODO: move resume to output handler
|
||||
try:
|
||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
output_handler.results_dict = tmp_result_dict
|
||||
index = len(tmp_result_dict)
|
||||
|
||||
# 4. Wrap prompts with Dataloader
|
||||
dataloader = self.get_dataloader(chat_list[index:], batch_size=1)
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.debug('Starting inference process...')
|
||||
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||
chat = datum[0]
|
||||
if self.infer_mode == 'last':
|
||||
await self.infer_last(chat, index, output_handler)
|
||||
elif self.infer_mode == 'every':
|
||||
await self.infer_every(chat, index, output_handler)
|
||||
elif self.infer_mode == 'every_with_gt':
|
||||
await self.infer_every_with_gt(chat, index, output_handler)
|
||||
index += 1
|
||||
|
||||
# Save intermediate results
|
||||
if (self.save_every is not None and index % self.save_every == 0
|
||||
and self.is_main_process):
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
|
||||
# 4. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
os.remove(tmp_json_filepath)
|
||||
|
||||
return output_handler.results_dict
|
||||
|
||||
def get_chat_list(self,
|
||||
ice_idx_list: List[List[int]],
|
||||
retriever: BaseRetriever,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
prompt_list = []
|
||||
input_columns = retriever.dataset_reader.input_columns
|
||||
output_column = retriever.dataset_reader.output_column
|
||||
|
||||
def chat_from_entry(entry):
|
||||
if prompt_template is None and len(input_columns) == 1:
|
||||
# Directly use the input column as the user input
|
||||
user = entry.get(input_columns[0])
|
||||
assistant = entry.get(output_column, '')
|
||||
return [
|
||||
dict(role='user', content=user),
|
||||
dict(role='assistant', content=assistant),
|
||||
]
|
||||
elif prompt_template is not None:
|
||||
# Use prompt template to generate chat history
|
||||
chat = promptlist_to_openai(
|
||||
prompt_template.generate_item(entry))
|
||||
gold = entry.get(output_column, '')
|
||||
if chat[-1]['role'] != 'assistant':
|
||||
chat.append(dict(role='assistant', content=gold))
|
||||
return chat
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
for idx, ice_idx in enumerate(ice_idx_list):
|
||||
# NOTE: The in-context examples won't be used by now.
|
||||
|
||||
item = {
|
||||
k: v
|
||||
for k, v in retriever.test_ds[idx].items()
|
||||
if k in input_columns or k == output_column
|
||||
}
|
||||
if all(isinstance(value, str) for value in item.values()):
|
||||
# Every column is a single string
|
||||
chat = chat_from_entry(item)
|
||||
elif all(is_list_of(value, str) for value in item.values()):
|
||||
# Every column is a list of string for multi-round chat
|
||||
entries = [dict(zip(item, v)) for v in zip(*item.values())]
|
||||
chat = sum((chat_from_entry(entry) for entry in entries), [])
|
||||
elif len(input_columns) == 1 and is_list_of(
|
||||
item[input_columns[0]], dict):
|
||||
# Single input column and it's already a chat.
|
||||
chat = item[input_columns[0]]
|
||||
elif 'dialogue' in input_columns:
|
||||
chat = item['dialogue']
|
||||
self.dialogue_mode = True
|
||||
else:
|
||||
raise ValueError('Cannot construct chat from the dataset.')
|
||||
|
||||
prompt_list.append(chat)
|
||||
return prompt_list
|
||||
|
||||
async def infer_last(self, chat: List[dict], index: int, output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
history = chat[:assistant_indices[-1]]
|
||||
output = await self.model.generate_from_template(
|
||||
[history], max_out_len=self.max_out_len)[0]
|
||||
output_handler.save_results(
|
||||
origin_prompt=history,
|
||||
prediction=output,
|
||||
idx=index,
|
||||
gold=chat[assistant_indices[-1]]['content'],
|
||||
)
|
||||
|
||||
async def infer_every(self, chat: List[dict], index: int, output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
index_copy = index
|
||||
|
||||
for i in assistant_indices:
|
||||
history = chat[:i]
|
||||
output = await self.model.generate_from_template(
|
||||
[history], max_out_len=self.max_out_len)[0]
|
||||
chat[i]['content'] = output
|
||||
if not self.dialogue_mode:
|
||||
output_handler.save_multiround_results(
|
||||
origin_prompt=history[-1]['content'],
|
||||
prediction=output,
|
||||
idx=index,
|
||||
gold=chat[i]['content'],
|
||||
)
|
||||
# index += 1
|
||||
if self.dialogue_mode:
|
||||
# dialogue mode for subjective evaluation
|
||||
assert len(chat) % 2 == 0
|
||||
round_num = int(len(chat) / 2)
|
||||
preds_list = []
|
||||
for i in range(round_num):
|
||||
temp_dict = {
|
||||
'round': i + 1,
|
||||
'user': chat[i * 2]['content'],
|
||||
'assistant': chat[i * 2 + 1]['content']
|
||||
}
|
||||
preds_list.append(temp_dict)
|
||||
output_handler.save_results(
|
||||
origin_prompt=None,
|
||||
prediction=preds_list,
|
||||
idx=index_copy,
|
||||
gold=None,
|
||||
)
|
||||
|
||||
async def infer_every_with_gt(self, chat: List[dict], index: int,
|
||||
output_handler):
|
||||
assistant_indices = [
|
||||
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||
]
|
||||
|
||||
for i in assistant_indices:
|
||||
history = chat[:i]
|
||||
output = await self.model.generate_from_template(
|
||||
[history], max_out_len=self.max_out_len)[0]
|
||||
output_handler.save_multiround_results(
|
||||
origin_prompt=history[-1]['content'],
|
||||
prediction=output,
|
||||
idx=index,
|
||||
gold=chat[i]['content'],
|
||||
)
|
||||
index += 1
|
239
opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py
Normal file
239
opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py
Normal file
@ -0,0 +1,239 @@
|
||||
"""Direct Generation Inferencer."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.models.base import BaseModel
|
||||
from opencompass.registry import ICL_INFERENCERS
|
||||
from opencompass.utils import batched
|
||||
|
||||
from ..icl_prompt_template import PromptTemplate
|
||||
from ..icl_retriever import BaseRetriever
|
||||
from ..utils.logging import get_logger
|
||||
from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ICL_INFERENCERS.register_module()
|
||||
class AsyncGenInferencer(BaseInferencer):
|
||||
"""Generation Inferencer class to directly evaluate by generation.
|
||||
|
||||
Attributes:
|
||||
model (:obj:`BaseModelWrapper`, optional): The module to inference.
|
||||
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
|
||||
allowed by the LM.
|
||||
min_out_len (:obj:`int`, optional): Minimum number of generated tokens
|
||||
by the LM
|
||||
batch_size (:obj:`int`, optional): Batch size for the
|
||||
:obj:`DataLoader`.
|
||||
output_json_filepath (:obj:`str`, optional): File path for output
|
||||
`JSON` file.
|
||||
output_json_filename (:obj:`str`, optional): File name for output
|
||||
`JSON` file.
|
||||
gen_field_replace_token (:obj:`str`, optional): Used to replace the
|
||||
generation field token when generating prompts.
|
||||
save_every (:obj:`int`, optional): Save intermediate results every
|
||||
`save_every` iters. Defaults to 1.
|
||||
generation_kwargs (:obj:`Dict`, optional): Parameters for the
|
||||
:obj:`model.generate()` method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
max_out_len: int,
|
||||
stopping_criteria: List[str] = [],
|
||||
max_seq_len: Optional[int] = None,
|
||||
min_out_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = 1,
|
||||
gen_field_replace_token: Optional[str] = '',
|
||||
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||
output_json_filename: Optional[str] = 'predictions',
|
||||
save_every: Optional[int] = 1,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
output_json_filename=output_json_filename,
|
||||
output_json_filepath=output_json_filepath,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.gen_field_replace_token = gen_field_replace_token
|
||||
self.max_out_len = max_out_len
|
||||
self.min_out_len = min_out_len
|
||||
self.stopping_criteria = stopping_criteria
|
||||
self.dump_timer = kwargs.get('dump_timer', False)
|
||||
|
||||
if self.model.is_api and save_every is None:
|
||||
save_every = 1
|
||||
self.save_every = save_every
|
||||
|
||||
async def inference(self, # type: ignore
|
||||
retriever: BaseRetriever,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
output_json_filepath: Optional[str] = None,
|
||||
output_json_filename: Optional[str] = None) -> List:
|
||||
# 1. Preparation for output logs
|
||||
output_handler = GenInferencerOutputHandler()
|
||||
|
||||
if output_json_filepath is None:
|
||||
output_json_filepath = self.output_json_filepath
|
||||
if output_json_filename is None:
|
||||
output_json_filename = self.output_json_filename
|
||||
|
||||
# 2. Get results of retrieval process
|
||||
ice_idx_list = retriever.retrieve()
|
||||
|
||||
# 3. Generate prompts for testing input
|
||||
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
|
||||
ice_idx_list,
|
||||
retriever,
|
||||
self.gen_field_replace_token,
|
||||
max_seq_len=self.max_seq_len,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||
ds_reader = retriever.dataset_reader
|
||||
if ds_reader.output_column:
|
||||
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||
prompt_list = list(zip(prompt_list, gold_ans))
|
||||
|
||||
# Create tmp json file for saving intermediate results and future
|
||||
# resuming
|
||||
index = 0
|
||||
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
# TODO: move resume to output handler
|
||||
try:
|
||||
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
output_handler.results_dict = tmp_result_dict
|
||||
index = len(tmp_result_dict)
|
||||
|
||||
# 4. Wrap prompts with Dataloader
|
||||
logger.debug('Starting build dataloader')
|
||||
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
||||
|
||||
# 5. Inference for prompts in each batch
|
||||
logger.debug('Starting inference process...')
|
||||
|
||||
start_time_stamp = time.time()
|
||||
num_sample = 0
|
||||
# TODO: batched dataloader shoule be replaced with async fetching
|
||||
for datum in dataloader:
|
||||
if ds_reader.output_column:
|
||||
entry, golds = list(zip(*datum))
|
||||
else:
|
||||
entry = datum
|
||||
golds = [None for _ in range(len(entry))]
|
||||
# 5-1. Inference with local model
|
||||
extra_gen_kwargs = {}
|
||||
sig = inspect.signature(self.model.generate)
|
||||
if 'stopping_criteria' in sig.parameters:
|
||||
extra_gen_kwargs['stopping_criteria'] = self.stopping_criteria
|
||||
if 'min_out_len' in sig.parameters:
|
||||
extra_gen_kwargs['min_out_len'] = self.min_out_len
|
||||
with torch.no_grad():
|
||||
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||
results = await self.model.generate_from_template(
|
||||
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
|
||||
generated = results
|
||||
|
||||
num_return_sequences = getattr(self.model, 'generation_kwargs',
|
||||
{}).get('num_return_sequences', 1)
|
||||
# 5-3. Save current output
|
||||
for prompt, prediction, gold in zip(
|
||||
parsed_entries, batched(generated, num_return_sequences),
|
||||
golds):
|
||||
if num_return_sequences == 1:
|
||||
prediction = prediction[0]
|
||||
output_handler.save_results(prompt,
|
||||
prediction,
|
||||
index,
|
||||
gold=gold)
|
||||
index = index + 1
|
||||
|
||||
# 5-4. Save intermediate results
|
||||
if (self.save_every is not None and index % self.save_every == 0
|
||||
and self.is_main_process):
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
'tmp_' + output_json_filename)
|
||||
num_sample += len(datum)
|
||||
|
||||
end_time_stamp = time.time()
|
||||
|
||||
# 6. Output
|
||||
if self.is_main_process:
|
||||
os.makedirs(output_json_filepath, exist_ok=True)
|
||||
output_handler.write_to_json(output_json_filepath,
|
||||
output_json_filename)
|
||||
if osp.exists(tmp_json_filepath):
|
||||
os.remove(tmp_json_filepath)
|
||||
|
||||
if self.dump_timer and self.is_main_process:
|
||||
timer_filepath = os.path.join(output_json_filepath, 'timer',
|
||||
'time.jsonl')
|
||||
os.makedirs(os.path.dirname(timer_filepath), exist_ok=True)
|
||||
time_dict = {
|
||||
'dataset_name': output_json_filename.removesuffix('.json'),
|
||||
'time': end_time_stamp - start_time_stamp,
|
||||
'num_sample': num_sample
|
||||
}
|
||||
with open(timer_filepath, 'a') as f:
|
||||
f.write(json.dumps(time_dict) + '\n')
|
||||
|
||||
return [
|
||||
sample['prediction']
|
||||
for sample in output_handler.results_dict.values()
|
||||
]
|
||||
|
||||
def get_generation_prompt_list_from_retriever_indices(
|
||||
self,
|
||||
ice_idx_list: List[List[int]],
|
||||
retriever: BaseRetriever,
|
||||
gen_field_replace_token: str,
|
||||
max_seq_len: Optional[int] = None,
|
||||
ice_template: Optional[PromptTemplate] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None):
|
||||
prompt_list = []
|
||||
for idx, ice_idx in enumerate(ice_idx_list):
|
||||
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
gen_field_replace_token=gen_field_replace_token,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
if max_seq_len is not None:
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||
ice_idx = ice_idx[:-1]
|
||||
ice = retriever.generate_ice(ice_idx,
|
||||
ice_template=ice_template)
|
||||
prompt = retriever.generate_prompt_for_generate_task(
|
||||
idx,
|
||||
ice,
|
||||
gen_field_replace_token=gen_field_replace_token,
|
||||
ice_template=ice_template,
|
||||
prompt_template=prompt_template)
|
||||
prompt_token_num = self.model.get_token_len_from_template(
|
||||
prompt, mode='gen')
|
||||
prompt_list.append(prompt)
|
||||
return prompt_list
|
112
opencompass/runners/local_async.py
Normal file
112
opencompass/runners/local_async.py
Normal file
@ -0,0 +1,112 @@
|
||||
from math import prod
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.device import is_npu_available
|
||||
from tqdm import tqdm
|
||||
|
||||
from opencompass.registry import RUNNERS, TASKS
|
||||
from opencompass.utils import get_logger, model_abbr_from_cfg
|
||||
|
||||
from .base import BaseRunner
|
||||
from typing import TypedDict, Optional
|
||||
from multiprocessing.managers import Namespace
|
||||
import threading
|
||||
import uuid
|
||||
import enum
|
||||
import signal
|
||||
from enum import IntEnum
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
|
||||
class Status(IntEnum):
|
||||
SUCCESS = 0
|
||||
FAILED = -1
|
||||
INTERRUPT = signal.SIGINT
|
||||
|
||||
|
||||
@RUNNERS.register_module()
|
||||
class AsyncRunner(BaseRunner):
|
||||
"""Local runner. Start tasks by local python.
|
||||
|
||||
Args:
|
||||
task (ConfigDict): Task type config.
|
||||
max_num_workers (int): Max number of workers to run in parallel.
|
||||
Defaults to 16.
|
||||
max_workers_per_gpu (int): Max number of workers to run for one GPU.
|
||||
Defaults to 1.
|
||||
debug (bool): Whether to run in debug mode.
|
||||
lark_bot_url (str): Lark bot url.
|
||||
"""
|
||||
|
||||
# These is a fake typehint
|
||||
|
||||
def __init__(self,
|
||||
task: ConfigDict,
|
||||
debug: bool = False,
|
||||
*,
|
||||
max_num_workers: int = 16,
|
||||
keep_tmp_file: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(task=task, debug=debug)
|
||||
self.max_num_workers = max_num_workers
|
||||
self.keep_tmp_file = keep_tmp_file
|
||||
logger = get_logger()
|
||||
for k, v in kwargs.items():
|
||||
logger.warning(f'Ignored argument in `AsyncRunner`: {k}={v}')
|
||||
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, Status]]: # type: ignore
|
||||
"""Launch multiple tasks.
|
||||
|
||||
Args:
|
||||
tasks (list[dict]): A list of task configs, usually generated by
|
||||
Partitioner.
|
||||
Returns:
|
||||
|
||||
list[tuple[str, int]]: A list of (task name, exit code).
|
||||
"""
|
||||
from opencompass.tasks.openicl_async_task import OpenICLAsyncInferTask
|
||||
|
||||
if not tasks:
|
||||
return [("", Status.SUCCESS)]
|
||||
|
||||
assert len(tasks) == 1, f"Task num must be 1 for `AsyncRunner`"
|
||||
task_cfg = tasks[0]
|
||||
|
||||
task: OpenICLAsyncInferTask = TASKS.build(dict(cfg=task_cfg, type=self.task_cfg['type']))
|
||||
task_name = task.name
|
||||
# get cmd
|
||||
mmengine.mkdir_or_exist('tmp/')
|
||||
|
||||
try:
|
||||
asyncio.run(task.run())
|
||||
except KeyboardInterrupt:
|
||||
return [(task_name, Status.INTERRUPT)]
|
||||
except:
|
||||
print(traceback.print_exc())
|
||||
return [(task_name, Status.FAILED)]
|
||||
else:
|
||||
return [(task_name, Status.SUCCESS)]
|
||||
|
||||
def __call__(self, tasks: List[Dict[str, Any]]):
|
||||
"""Launch multiple tasks and summarize the results.
|
||||
|
||||
Args:
|
||||
tasks (list[dict]): A list of task configs, usually generated by
|
||||
Partitioner.
|
||||
"""
|
||||
status = self.launch(tasks)
|
||||
status_list = list(status) # change into list format
|
||||
self.summarize(status_list)
|
@ -1,3 +1,4 @@
|
||||
from .openicl_attack import * # noqa: F401, F403
|
||||
from .openicl_eval import * # noqa: F401, F403
|
||||
from .openicl_infer import * # noqa: F401, F403
|
||||
from .openicl_async_task import * # noqa: F401, F403
|
||||
|
168
opencompass/tasks/openicl_async_task.py
Normal file
168
opencompass/tasks/openicl_async_task.py
Normal file
@ -0,0 +1,168 @@
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
import inspect
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
|
||||
ICL_RETRIEVERS, TASKS)
|
||||
from opencompass.tasks.base import BaseTask
|
||||
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
|
||||
get_infer_output_path, get_logger,
|
||||
task_abbr_from_cfg)
|
||||
from opencompass.openicl.icl_inferencer.icl_gen_async_inferencer import AsyncGenInferencer
|
||||
from opencompass.openicl.icl_inferencer.icl_chat_async_inferencer import AsyncChatInferencer
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer, ChatInferencer
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
import resource
|
||||
from more_itertools import consume
|
||||
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
||||
|
||||
|
||||
@TASKS.register_module()
|
||||
class OpenICLAsyncInferTask(BaseTask):
|
||||
"""OpenICL Inference Task.
|
||||
|
||||
This task is used to run the inference process.
|
||||
"""
|
||||
|
||||
name_prefix = 'OpenICLInfer'
|
||||
log_subdir = 'logs/infer'
|
||||
output_subdir = 'predictions'
|
||||
|
||||
def __init__(self, cfg: ConfigDict):
|
||||
super().__init__(cfg)
|
||||
run_cfg = self.model_cfgs[0].get('run_cfg', {})
|
||||
self.nproc = run_cfg.get('nproc_per_worker', 16)
|
||||
|
||||
def get_command(self, cfg_path, template) -> str:
|
||||
# TODO:
|
||||
raise NotImplementedError()
|
||||
return ""
|
||||
|
||||
async def run(self): # type: ignore
|
||||
_dataset_cfgs = []
|
||||
infer_cfgs = []
|
||||
sub_cfgs = []
|
||||
datasets = []
|
||||
model_cfgs = []
|
||||
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
|
||||
self.max_out_len = model_cfg.get('max_out_len', None)
|
||||
self.batch_size = model_cfg.get('batch_size', None)
|
||||
self.min_out_len = model_cfg.get('min_out_len', None)
|
||||
|
||||
for dataset_cfg in dataset_cfgs:
|
||||
self.dataset_cfg = dataset_cfg
|
||||
out_path = get_infer_output_path(
|
||||
model_cfg, dataset_cfg,
|
||||
osp.join(self.work_dir, 'predictions'))
|
||||
|
||||
if osp.exists(out_path):
|
||||
continue
|
||||
_dataset_cfgs.append(dataset_cfg)
|
||||
datasets.append(build_dataset_from_cfg(dataset_cfg))
|
||||
infer_cfgs.append(dataset_cfg['infer_cfg'])
|
||||
model_cfgs.append(model_cfg)
|
||||
sub_cfg = {
|
||||
'models': [model_cfg],
|
||||
'datasets': [[dataset_cfg]],
|
||||
}
|
||||
sub_cfgs.append(sub_cfg)
|
||||
|
||||
tasks = []
|
||||
args = list(zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs))
|
||||
for arg in tqdm(
|
||||
args,
|
||||
total=len(args),
|
||||
desc=f"Starting building tasks..."
|
||||
):
|
||||
tasks.append(asyncio.create_task(self._inference(*arg)))
|
||||
|
||||
bar = tqdm(desc="Inferencing...", total=len(tasks))
|
||||
bar.refresh()
|
||||
|
||||
while tasks:
|
||||
done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
for _ in done:
|
||||
bar.update()
|
||||
bar.refresh()
|
||||
|
||||
# TODO: Needs a debug mode
|
||||
# for arg in zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs):
|
||||
# await self._inference(*arg)
|
||||
|
||||
async def _inference(self, dataset_cfg, infer_cfg, dataset, model_cfg, sub_cfg):
|
||||
model = build_model_from_cfg(model_cfg)
|
||||
assert hasattr(infer_cfg, 'ice_template') or hasattr(infer_cfg, 'prompt_template'), \
|
||||
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
||||
|
||||
infer_kwargs: dict = {}
|
||||
if hasattr(infer_cfg, 'ice_template'):
|
||||
ice_template = ICL_PROMPT_TEMPLATES.build(
|
||||
infer_cfg['ice_template'])
|
||||
infer_kwargs['ice_template'] = ice_template
|
||||
|
||||
if hasattr(infer_cfg, 'prompt_template'):
|
||||
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
||||
infer_cfg['prompt_template'])
|
||||
infer_kwargs['prompt_template'] = prompt_template
|
||||
|
||||
retriever_cfg = infer_cfg['retriever'].copy()
|
||||
retriever_cfg['dataset'] = dataset
|
||||
retriever = ICL_RETRIEVERS.build(retriever_cfg)
|
||||
|
||||
# set inferencer's default value according to model's config'
|
||||
inferencer_cfg: dict = infer_cfg['inferencer']
|
||||
inferencer_cfg['model'] = model
|
||||
inferencer_cfg['max_seq_len'] = model_cfg.get('max_seq_len')
|
||||
|
||||
infer_type = inferencer_cfg["type"]
|
||||
if inspect.isclass(infer_type):
|
||||
infer_name = infer_type.__name__
|
||||
else:
|
||||
infer_name = infer_type
|
||||
|
||||
if infer_name.split(".")[-1] == "ChatInferencer":
|
||||
inferencer_cfg["type"] = AsyncChatInferencer
|
||||
|
||||
elif infer_name.split(".")[-1] == "GenInferencer":
|
||||
inferencer_cfg["type"] = AsyncGenInferencer
|
||||
|
||||
inferencer_cfg.setdefault('max_out_len', self.max_out_len)
|
||||
inferencer_cfg.setdefault('min_out_len', self.min_out_len)
|
||||
inferencer_cfg.setdefault('batch_size', self.batch_size)
|
||||
inferencer = ICL_INFERENCERS.build(inferencer_cfg)
|
||||
|
||||
out_path = get_infer_output_path(
|
||||
model_cfg, dataset_cfg,
|
||||
osp.join(self.work_dir, 'predictions'))
|
||||
out_dir, out_file = osp.split(out_path)
|
||||
mkdir_or_exist(out_dir)
|
||||
|
||||
infer_kwargs['output_json_filepath'] = out_dir
|
||||
infer_kwargs['output_json_filename'] = out_file
|
||||
|
||||
await inferencer.inference(retriever, **infer_kwargs)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Model Inferencer')
|
||||
parser.add_argument('config', help='Config file path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO:
|
||||
raise NotImplementedError()
|
Loading…
Reference in New Issue
Block a user