mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
Merge aa48a2843d
into 408f5caff4
This commit is contained in:
commit
eb9eb625ba
369
opencompass/models/async_openai_api.py
Normal file
369
opencompass/models/async_openai_api.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
import contextlib
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
import weakref
|
||||||
|
from typing import Literal, Tuple, Iterable
|
||||||
|
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
from opencompass.models.base_api import AsyncTokenBucket, BaseAPIModel
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
|
from typing import cast
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
PromptType = Union[PromptList, str]
|
||||||
|
OPENAI_API_BASE = os.path.join(
|
||||||
|
os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/'),
|
||||||
|
'chat/completions')
|
||||||
|
|
||||||
|
|
||||||
|
class _APIModelState:
|
||||||
|
_instance: Dict[str, weakref.ReferenceType["_APIModelState"]] = {}
|
||||||
|
_count: int
|
||||||
|
_concurrency: int
|
||||||
|
_locks = [threading.Lock(), multiprocessing.Lock()]
|
||||||
|
|
||||||
|
def __init__(self, *, name: str, concurrency: int, query_per_second=1) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._count = 0
|
||||||
|
self._concurrency = concurrency
|
||||||
|
self._token_bucket = AsyncTokenBucket(rate=query_per_second)
|
||||||
|
|
||||||
|
self._count += 1
|
||||||
|
self._concurrency = max(1, self._concurrency // self._count)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def concurrency(self) -> int:
|
||||||
|
# If update and concurrency are called simultaneously, the values
|
||||||
|
# returned here may be inaccurate, but the impact is likely minimal
|
||||||
|
return self._concurrency
|
||||||
|
|
||||||
|
async def acquire(self):
|
||||||
|
return await self._token_bucket.acquire()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rpm(self):
|
||||||
|
return self._token_bucket.rpm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self):
|
||||||
|
return self._count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _cleanup(cls, ref: weakref.ReferenceType["_APIModelState"]):
|
||||||
|
with cls._lock():
|
||||||
|
self: _APIModelState = ref() # type: ignore
|
||||||
|
cls._instance.pop(self._name)
|
||||||
|
|
||||||
|
def __new__(cls, name: str, *args, **kwargs) -> "_APIModelState":
|
||||||
|
with cls._lock():
|
||||||
|
if name not in cls._instance:
|
||||||
|
self = super().__new__(cls)
|
||||||
|
cls._instance[name] = weakref.ref(self, cls._cleanup)
|
||||||
|
return cls._instance[name]() # type: ignore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def _lock(cls):
|
||||||
|
with contextlib.ExitStack() as stack:
|
||||||
|
[stack.enter_context(lock) for lock in cls._locks]
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncOpenAISDK(BaseAPIModel):
|
||||||
|
states: Dict[str, _APIModelState] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str = 'gpt-3.5-turbo',
|
||||||
|
max_seq_len: int | None = None, # type: ignore
|
||||||
|
query_per_second: int = 1,
|
||||||
|
retry: int = 2,
|
||||||
|
key: str = 'ENV',
|
||||||
|
org: str | List[str] | None = None,
|
||||||
|
meta_template: Dict | None = None,
|
||||||
|
openai_api_base: str = OPENAI_API_BASE,
|
||||||
|
openai_proxy_url: Optional[str] = None,
|
||||||
|
mode: Literal['none', 'front', 'mid', 'rear'] = 'none',
|
||||||
|
logprobs: bool | None = False,
|
||||||
|
top_logprobs: int | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
tokenizer_path: str | None = None,
|
||||||
|
extra_body: Dict | None = None,
|
||||||
|
max_completion_tokens: int = 16384,
|
||||||
|
verbose: bool = False,
|
||||||
|
concurrency: int = 64,
|
||||||
|
status_code_mappings: dict = {},
|
||||||
|
):
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
assert mode in ['none', 'front', 'mid', 'rear']
|
||||||
|
self.mode = mode
|
||||||
|
state_key = self._get_state_key(api_base=openai_api_base, model_name=path)
|
||||||
|
if state_key not in AsyncOpenAISDK.states:
|
||||||
|
AsyncOpenAISDK.states[path] = _APIModelState(
|
||||||
|
name=state_key,
|
||||||
|
concurrency=concurrency,
|
||||||
|
query_per_second=query_per_second,
|
||||||
|
)
|
||||||
|
self.state = AsyncOpenAISDK.states[path]
|
||||||
|
self.openai_client = AsyncOpenAI(base_url=openai_api_base, api_key=key)
|
||||||
|
|
||||||
|
if max_seq_len is None:
|
||||||
|
if '16k' in path:
|
||||||
|
max_seq_len = 16384
|
||||||
|
elif 'gpt-4' in path:
|
||||||
|
max_seq_len = 8192
|
||||||
|
elif 'gpt-3.5' in path:
|
||||||
|
max_seq_len = 4097
|
||||||
|
else:
|
||||||
|
max_seq_len = 32768
|
||||||
|
else:
|
||||||
|
max_seq_len = max_seq_len
|
||||||
|
|
||||||
|
super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, retry=retry)
|
||||||
|
|
||||||
|
self.logprobs = logprobs
|
||||||
|
self.top_logprobs = top_logprobs
|
||||||
|
self.tokenizer_path = tokenizer_path
|
||||||
|
self.hf_tokenizer = None
|
||||||
|
self.extra_body = extra_body
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.openai_api_base = openai_api_base
|
||||||
|
self.concurrency = concurrency
|
||||||
|
|
||||||
|
self.status_code_mappings = status_code_mappings
|
||||||
|
|
||||||
|
if openai_proxy_url == 'ENV':
|
||||||
|
if 'OPENAI_PROXY_URL' not in os.environ:
|
||||||
|
raise ValueError('OPENAI_PROXY_URL is not set.')
|
||||||
|
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
|
||||||
|
else:
|
||||||
|
self.proxy_url = openai_proxy_url
|
||||||
|
|
||||||
|
async def generate(self, # type: ignore
|
||||||
|
inputs: Iterable[PromptType],
|
||||||
|
max_out_len: int = 512,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[PromptType]): A list of strings or PromptDicts.
|
||||||
|
The PromptDict should be organized in OpenCompass'
|
||||||
|
API format.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
temperature (float): What sampling temperature to use,
|
||||||
|
between 0 and 2. Higher values like 0.8 will make the output
|
||||||
|
more random, while lower values like 0.2 will make it more
|
||||||
|
focused and deterministic. Defaults to 0.7.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of generated strings.
|
||||||
|
"""
|
||||||
|
if self.temperature is not None:
|
||||||
|
temperature = self.temperature
|
||||||
|
|
||||||
|
# TODO: This should be an AsyncGenerator if an real `AsyncInference` has been implemented
|
||||||
|
tasks_queue: List[asyncio.Future] = []
|
||||||
|
results_queue: List[Tuple[int, str]] = []
|
||||||
|
inputs_iter = enumerate(inputs)
|
||||||
|
|
||||||
|
data_stop = False
|
||||||
|
while not (data_stop and not tasks_queue):
|
||||||
|
concurrency = self.state.concurrency
|
||||||
|
|
||||||
|
if tasks_queue:
|
||||||
|
done, pending = await asyncio.wait(tasks_queue, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
tasks_queue = list(pending)
|
||||||
|
for queue in done:
|
||||||
|
result: Tuple[int, str] = queue.result()
|
||||||
|
results_queue.append(result)
|
||||||
|
|
||||||
|
while not data_stop and len(tasks_queue) < concurrency:
|
||||||
|
try:
|
||||||
|
index, _input = next(inputs_iter)
|
||||||
|
except StopIteration:
|
||||||
|
data_stop = True
|
||||||
|
break
|
||||||
|
tasks_queue.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
self._generate(
|
||||||
|
input=_input,
|
||||||
|
max_out_len=self.max_completion_tokens or max_out_len,
|
||||||
|
temperature=temperature,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
results_queue.sort()
|
||||||
|
return [item[1] for item in results_queue]
|
||||||
|
|
||||||
|
async def generate_from_template(self, templates: List[PromptType], # type: ignore
|
||||||
|
max_out_len: int, **kwargs):
|
||||||
|
"""Generate completion from a list of templates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
templates (List[PromptType]): A list of templates.
|
||||||
|
max_out_len (int): The maximum length of the output.
|
||||||
|
"""
|
||||||
|
inputs = self.parse_template(templates, mode='gen') # type: ignore
|
||||||
|
return await self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
||||||
|
|
||||||
|
async def _generate(self, input: PromptList | str, max_out_len: int,
|
||||||
|
temperature: float, index: int) -> Tuple[int, str]:
|
||||||
|
from openai import APIStatusError, BadRequestError
|
||||||
|
assert isinstance(input, (str, PromptList))
|
||||||
|
|
||||||
|
# max num token for gpt-3.5-turbo is 4097
|
||||||
|
# Most models' token limits are above 32k
|
||||||
|
|
||||||
|
# will leave 100 tokens as prompt buffer, triggered if input is str
|
||||||
|
if isinstance(input, str) and self.mode != 'none':
|
||||||
|
context_window = self.max_seq_len
|
||||||
|
input = self.bin_trim(
|
||||||
|
input,
|
||||||
|
context_window - 100 - max_out_len,
|
||||||
|
cast(Literal['front', 'mid', 'rear'], self.mode),
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
messages = [{'role': 'user', 'content': input}]
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
for item in input:
|
||||||
|
msg = {'content': item['prompt']}
|
||||||
|
if item['role'] == 'HUMAN':
|
||||||
|
msg['role'] = 'user'
|
||||||
|
elif item['role'] == 'BOT':
|
||||||
|
msg['role'] = 'assistant'
|
||||||
|
elif item['role'] == 'SYSTEM':
|
||||||
|
msg['role'] = 'system'
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
|
||||||
|
# Hold out 100 tokens due to potential errors in tiktoken calculation
|
||||||
|
# try:
|
||||||
|
# max_out_len = min(
|
||||||
|
# max_out_len,
|
||||||
|
# context_window - self.get_token_len(str(input)) - 100)
|
||||||
|
# except KeyError:
|
||||||
|
# max_out_len = max_out_len
|
||||||
|
# if max_out_len <= 0:
|
||||||
|
# return ''
|
||||||
|
|
||||||
|
num_retries = 0
|
||||||
|
while num_retries < self.retry:
|
||||||
|
await self.state.acquire()
|
||||||
|
|
||||||
|
query_data = dict(
|
||||||
|
model=self.path,
|
||||||
|
max_tokens=max_out_len,
|
||||||
|
n=1,
|
||||||
|
temperature=self.temperature,
|
||||||
|
messages=messages,
|
||||||
|
extra_body=self.extra_body,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info('Start calling OpenAI API')
|
||||||
|
responses = await self.openai_client.chat.completions.create(**query_data)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(
|
||||||
|
'Successfully get response from OpenAI API')
|
||||||
|
try:
|
||||||
|
self.logger.info(responses)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
if not responses.choices:
|
||||||
|
self.logger.error(
|
||||||
|
'Response is empty, it is an internal server error \
|
||||||
|
from the API provider.')
|
||||||
|
return index, responses.choices[0].message.content
|
||||||
|
|
||||||
|
except (BadRequestError, APIStatusError) as e:
|
||||||
|
# Handle BadRequest status
|
||||||
|
# You can specify self.status_code_mappings to bypass \
|
||||||
|
# API sensitivity blocks
|
||||||
|
# For example: status_code_mappings={400: 'Input data \
|
||||||
|
# may contain inappropriate content.'}
|
||||||
|
status_code = e.status_code
|
||||||
|
if (status_code is not None
|
||||||
|
and status_code in self.status_code_mappings):
|
||||||
|
error_message = self.status_code_mappings[status_code]
|
||||||
|
self.logger.info(f'Status Code: {status_code},\n'
|
||||||
|
f'Original Error Message: {e},\n'
|
||||||
|
f'Return Message: {error_message} ')
|
||||||
|
return index, error_message
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
|
||||||
|
num_retries += 1
|
||||||
|
raise RuntimeError('Calling OpenAI API failed after retrying for '
|
||||||
|
f'{self.retry} times. Check the logs for details.')
|
||||||
|
|
||||||
|
def _get_state_key(self, api_base: str, model_name: str):
|
||||||
|
return api_base + model_name
|
||||||
|
|
||||||
|
def bin_trim(self, prompt: str, num_token: int, mode: Literal['front', 'mid', 'rear']) -> str:
|
||||||
|
"""Get a suffix of prompt which is no longer than num_token tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Input string.
|
||||||
|
num_token (int): The upper bound of token numbers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The trimmed prompt.
|
||||||
|
"""
|
||||||
|
token_len = self.get_token_len(prompt)
|
||||||
|
if token_len <= num_token:
|
||||||
|
return prompt
|
||||||
|
pattern = re.compile(r'[\u4e00-\u9fa5]')
|
||||||
|
if pattern.search(prompt):
|
||||||
|
words = list(jieba.cut(prompt, cut_all=False))
|
||||||
|
sep = ''
|
||||||
|
else:
|
||||||
|
words = prompt.split(' ')
|
||||||
|
sep = ' '
|
||||||
|
|
||||||
|
l, r = 1, len(words)
|
||||||
|
while l + 2 < r:
|
||||||
|
# mode: Literal['front', 'mid', 'rear'] = self.mode
|
||||||
|
mid = (l + r) // 2
|
||||||
|
if mode == 'front':
|
||||||
|
cur_prompt = sep.join(words[-mid:])
|
||||||
|
elif mode == 'mid':
|
||||||
|
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
|
||||||
|
elif mode == 'rear':
|
||||||
|
cur_prompt = sep.join(words[:mid])
|
||||||
|
|
||||||
|
if self.get_token_len(cur_prompt) <= num_token:
|
||||||
|
l = mid # noqa: E741
|
||||||
|
else:
|
||||||
|
r = mid
|
||||||
|
|
||||||
|
if self.mode == 'front':
|
||||||
|
prompt = sep.join(words[-l:])
|
||||||
|
elif self.mode == 'mid':
|
||||||
|
prompt = sep.join(words[:l]) + sep.join(words[-l:])
|
||||||
|
elif self.mode == 'rear':
|
||||||
|
prompt = sep.join(words[:l])
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
@ -4,6 +4,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from collections import deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from time import sleep
|
from time import sleep
|
||||||
@ -12,6 +13,8 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
from opencompass.utils import get_logger
|
from opencompass.utils import get_logger
|
||||||
from opencompass.utils.prompt import PromptList
|
from opencompass.utils.prompt import PromptList
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
|
|
||||||
PromptType = Union[PromptList, str]
|
PromptType = Union[PromptList, str]
|
||||||
@ -51,7 +54,7 @@ class BaseAPIModel(BaseModel):
|
|||||||
self.retry = retry
|
self.retry = retry
|
||||||
self.query_per_second = query_per_second
|
self.query_per_second = query_per_second
|
||||||
self.token_bucket = TokenBucket(query_per_second, rpm_verbose)
|
self.token_bucket = TokenBucket(query_per_second, rpm_verbose)
|
||||||
self.template_parser = APITemplateParser(meta_template)
|
self.template_parser = APITemplateParser(meta_template) # type: ignore
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
self.generation_kwargs = generation_kwargs
|
self.generation_kwargs = generation_kwargs
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -459,4 +462,69 @@ class TokenBucket:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
self._request_queue.put(cur_time)
|
self._request_queue.put(cur_time)
|
||||||
self.logger.info(f'Current RPM {self._request_queue.qsize()}.')
|
self.logger.info(f"Current RPM {self._request_queue.qsize()}.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTokenBucket:
|
||||||
|
def __init__(self, rate: int = 1):
|
||||||
|
self._rate = rate
|
||||||
|
self._max_tokens = rate * 60
|
||||||
|
self._tokens: float = float(self._max_tokens)
|
||||||
|
self._last_refill_time: float | None = None
|
||||||
|
|
||||||
|
self._request_timestamps: deque[float] = deque()
|
||||||
|
self._max_window_size = 60
|
||||||
|
|
||||||
|
self._token_available = asyncio.Event()
|
||||||
|
|
||||||
|
async def _release(self) -> None:
|
||||||
|
if self._last_refill_time is None:
|
||||||
|
self._last_refill_time: float = time.monotonic()
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
elapsed = now - self._last_refill_time
|
||||||
|
tokens_to_add = elapsed * self._rate
|
||||||
|
|
||||||
|
self._tokens = min(self._max_tokens, self._tokens + tokens_to_add)
|
||||||
|
self._last_refill_time = now
|
||||||
|
|
||||||
|
async def acquire(self) -> bool:
|
||||||
|
while True:
|
||||||
|
await self._release()
|
||||||
|
|
||||||
|
if self._tokens >= 1:
|
||||||
|
self._tokens -= 1
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
self._request_timestamps.append(now)
|
||||||
|
|
||||||
|
while (
|
||||||
|
self._request_timestamps
|
||||||
|
and now - self._request_timestamps[0] > self._max_window_size
|
||||||
|
):
|
||||||
|
self._request_timestamps.popleft()
|
||||||
|
|
||||||
|
self._token_available.set()
|
||||||
|
return True
|
||||||
|
|
||||||
|
self._token_available.clear()
|
||||||
|
|
||||||
|
await self._token_available.wait()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rpm(self) -> int:
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
while (
|
||||||
|
self._request_timestamps
|
||||||
|
and now - self._request_timestamps[0] > self._max_window_size
|
||||||
|
):
|
||||||
|
self._request_timestamps.popleft()
|
||||||
|
|
||||||
|
return len(self._request_timestamps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_tokens(self) -> float:
|
||||||
|
return self._tokens
|
||||||
|
|
||||||
|
397
opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py
Normal file
397
opencompass/openicl/icl_inferencer/icl_chat_async_inferencer.py
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
"""Chat Inferencer."""
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
from mmengine import is_list_of
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from opencompass.models import APITemplateParser as _APITemplateParser
|
||||||
|
from opencompass.models import BaseModel
|
||||||
|
from opencompass.models import LMTemplateParser as _LMTemplateParser
|
||||||
|
from opencompass.registry import ICL_INFERENCERS
|
||||||
|
from opencompass.utils.prompt import PromptList
|
||||||
|
|
||||||
|
from ..icl_prompt_template import PromptTemplate
|
||||||
|
from ..icl_retriever import BaseRetriever
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
from .icl_base_inferencer import BaseInferencer, dump_results_dict
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def promptlist_to_openai(prompt: Union[str, PromptList]):
|
||||||
|
output = []
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
return [dict(role='user', content=prompt)]
|
||||||
|
|
||||||
|
for item in prompt:
|
||||||
|
if 'section' in item:
|
||||||
|
continue
|
||||||
|
if isinstance(item, str) and item:
|
||||||
|
output.append(dict(role='user', content=item))
|
||||||
|
elif item['role'] == 'SYSTEM':
|
||||||
|
output.append(dict(role='system', content=item['prompt']))
|
||||||
|
elif item['role'] == 'HUMAN':
|
||||||
|
output.append(dict(role='user', content=item['prompt']))
|
||||||
|
elif item['role'] == 'BOT':
|
||||||
|
output.append(dict(role='assistant', content=item['prompt']))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class LMTemplateParser:
|
||||||
|
"""LMTemplateParser accepts OpenAI format dialog inputs."""
|
||||||
|
|
||||||
|
def __init__(self, meta_template: Optional[dict] = None):
|
||||||
|
self.meta_template = meta_template
|
||||||
|
self.roles = {}
|
||||||
|
role_mapping = {
|
||||||
|
'SYSTEM': 'system',
|
||||||
|
'HUMAN': 'user',
|
||||||
|
'BOT': 'assistant',
|
||||||
|
}
|
||||||
|
if meta_template:
|
||||||
|
for item in meta_template.get('round', []):
|
||||||
|
role = role_mapping.get(item['role'], item['role'])
|
||||||
|
self.roles[role] = item.copy()
|
||||||
|
for item in meta_template.get('reserved_roles', []):
|
||||||
|
role = role_mapping.get(item['role'], item['role'])
|
||||||
|
self.roles[role] = item.copy()
|
||||||
|
|
||||||
|
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
||||||
|
if is_list_of(chat, list):
|
||||||
|
# Handle batch inputs
|
||||||
|
return [self.parse_template(item) for item in chat]
|
||||||
|
|
||||||
|
assert is_list_of(chat, dict)
|
||||||
|
prompt = ''
|
||||||
|
if self.roles:
|
||||||
|
for dialog in chat:
|
||||||
|
role_cfg = self.roles.get(dialog['role'], {})
|
||||||
|
prompt += (role_cfg.get('begin') or '')
|
||||||
|
prompt += (dialog.get('content') or '')
|
||||||
|
prompt += (role_cfg.get('end') or '')
|
||||||
|
prompt += (self.roles['assistant'].get('begin') or '')
|
||||||
|
else:
|
||||||
|
# in case the model does not have any meta template
|
||||||
|
last_sep = ''
|
||||||
|
for item in chat:
|
||||||
|
prompt += last_sep + (item.get('content') or '')
|
||||||
|
last_sep = '\n'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class APITemplateParser:
|
||||||
|
"""APITemplateParser accepts OpenAI format dialog inputs."""
|
||||||
|
|
||||||
|
def __init__(self, meta_template: Optional[dict] = None):
|
||||||
|
self.meta_template = meta_template
|
||||||
|
self.roles = {}
|
||||||
|
role_mapping = {
|
||||||
|
'SYSTEM': 'system',
|
||||||
|
'HUMAN': 'user',
|
||||||
|
'BOT': 'assistant',
|
||||||
|
}
|
||||||
|
if meta_template:
|
||||||
|
for item in meta_template.get('round', []):
|
||||||
|
role = role_mapping.get(item['role'], item['role'])
|
||||||
|
self.roles[role] = item.copy()
|
||||||
|
for item in meta_template.get('reserved_roles', []):
|
||||||
|
role = role_mapping.get(item['role'], item['role'])
|
||||||
|
self.roles[role] = item.copy()
|
||||||
|
else:
|
||||||
|
self.roles = dict(
|
||||||
|
system=dict(api_role='SYSTEM'),
|
||||||
|
user=dict(api_role='HUMAN'),
|
||||||
|
assistant=dict(api_role='BOT', generate=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_template(self, chat: List[dict], mode='gen') -> str:
|
||||||
|
if is_list_of(chat, list):
|
||||||
|
# Handle batch inputs
|
||||||
|
return [self.parse_template(item) for item in chat]
|
||||||
|
|
||||||
|
assert is_list_of(chat, dict)
|
||||||
|
prompt = []
|
||||||
|
for dialog in chat:
|
||||||
|
if dialog['role'] in self.roles:
|
||||||
|
role = self.roles[dialog['role']]['api_role']
|
||||||
|
else:
|
||||||
|
role = dialog['role']
|
||||||
|
prompt.append(dict(role=role, prompt=dialog.get('content') or ''))
|
||||||
|
return PromptList(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatOutputHandler:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.results_dict = {}
|
||||||
|
|
||||||
|
def write_to_json(self, save_dir: str, filename: str):
|
||||||
|
"""Dump the result to a json file."""
|
||||||
|
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
|
||||||
|
|
||||||
|
def save_results(self,
|
||||||
|
origin_prompt: list,
|
||||||
|
prediction: str,
|
||||||
|
idx: int,
|
||||||
|
gold: str = None):
|
||||||
|
result_dict = {}
|
||||||
|
if gold:
|
||||||
|
result_dict['gold'] = gold
|
||||||
|
result_dict.update({
|
||||||
|
'prediction': prediction,
|
||||||
|
'origin_prompt': origin_prompt,
|
||||||
|
})
|
||||||
|
self.results_dict[str(idx)] = result_dict
|
||||||
|
|
||||||
|
def save_multiround_results(self,
|
||||||
|
origin_prompt: list,
|
||||||
|
prediction: str,
|
||||||
|
idx: int,
|
||||||
|
gold: str = None):
|
||||||
|
result_dict = self.results_dict.get(str(idx), {
|
||||||
|
'gold': [],
|
||||||
|
'prediction': [],
|
||||||
|
'origin_prompt': [],
|
||||||
|
})
|
||||||
|
result_dict['gold'].append(gold)
|
||||||
|
result_dict['prediction'].append(prediction)
|
||||||
|
result_dict['origin_prompt'].append(origin_prompt)
|
||||||
|
self.results_dict[str(idx)] = result_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_INFERENCERS.register_module()
|
||||||
|
class AsyncChatInferencer(BaseInferencer):
|
||||||
|
HandlerType = ChatOutputHandler
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||||
|
output_json_filename: Optional[str] = 'predictions',
|
||||||
|
save_every: Optional[int] = 1,
|
||||||
|
infer_mode: str = 'last',
|
||||||
|
max_out_len: int = 512,
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
output_json_filename=output_json_filename,
|
||||||
|
output_json_filepath=output_json_filepath,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
assert infer_mode in ['last', 'every', 'every_with_gt']
|
||||||
|
self.infer_mode = infer_mode
|
||||||
|
self.model: BaseModel
|
||||||
|
self._set_meta_template(self.model)
|
||||||
|
|
||||||
|
if self.model.is_api and save_every is None:
|
||||||
|
save_every = 1
|
||||||
|
self.save_every = save_every
|
||||||
|
self.dialogue_mode = False
|
||||||
|
self.max_out_len = max_out_len
|
||||||
|
|
||||||
|
def _set_meta_template(self, model):
|
||||||
|
origin = model.template_parser
|
||||||
|
if isinstance(origin, _APITemplateParser):
|
||||||
|
model.template_parser = APITemplateParser(origin.meta_template)
|
||||||
|
if isinstance(origin, _LMTemplateParser):
|
||||||
|
model.template_parser = LMTemplateParser(origin.meta_template)
|
||||||
|
|
||||||
|
async def inference(self, # type: ignore
|
||||||
|
retriever: BaseRetriever,
|
||||||
|
ice_template: Optional[PromptTemplate] = None,
|
||||||
|
prompt_template: Optional[PromptTemplate] = None,
|
||||||
|
output_json_filepath: Optional[str] = None,
|
||||||
|
output_json_filename: Optional[str] = None) -> dict:
|
||||||
|
# 1. Preparation for output logs
|
||||||
|
output_handler = self.HandlerType()
|
||||||
|
|
||||||
|
if output_json_filepath is None:
|
||||||
|
output_json_filepath = self.output_json_filepath
|
||||||
|
if output_json_filename is None:
|
||||||
|
output_json_filename = self.output_json_filename
|
||||||
|
|
||||||
|
# 2. Get results of retrieval process
|
||||||
|
ice_idx_list = retriever.retrieve()
|
||||||
|
|
||||||
|
# 3. Generate prompts for testing input
|
||||||
|
chat_list = self.get_chat_list(
|
||||||
|
ice_idx_list,
|
||||||
|
retriever,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tmp json file for saving intermediate results and future
|
||||||
|
# resuming
|
||||||
|
index = 0
|
||||||
|
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||||
|
'tmp_' + output_json_filename)
|
||||||
|
if osp.exists(tmp_json_filepath):
|
||||||
|
# TODO: move resume to output handler
|
||||||
|
try:
|
||||||
|
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
output_handler.results_dict = tmp_result_dict
|
||||||
|
index = len(tmp_result_dict)
|
||||||
|
|
||||||
|
# 4. Wrap prompts with Dataloader
|
||||||
|
dataloader = self.get_dataloader(chat_list[index:], batch_size=1)
|
||||||
|
|
||||||
|
# 5. Inference for prompts in each batch
|
||||||
|
logger.debug('Starting inference process...')
|
||||||
|
for datum in tqdm(dataloader, disable=not self.is_main_process):
|
||||||
|
chat = datum[0]
|
||||||
|
if self.infer_mode == 'last':
|
||||||
|
await self.infer_last(chat, index, output_handler)
|
||||||
|
elif self.infer_mode == 'every':
|
||||||
|
await self.infer_every(chat, index, output_handler)
|
||||||
|
elif self.infer_mode == 'every_with_gt':
|
||||||
|
await self.infer_every_with_gt(chat, index, output_handler)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
# Save intermediate results
|
||||||
|
if (self.save_every is not None and index % self.save_every == 0
|
||||||
|
and self.is_main_process):
|
||||||
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
'tmp_' + output_json_filename)
|
||||||
|
|
||||||
|
# 4. Output
|
||||||
|
if self.is_main_process:
|
||||||
|
os.makedirs(output_json_filepath, exist_ok=True)
|
||||||
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
output_json_filename)
|
||||||
|
if osp.exists(tmp_json_filepath):
|
||||||
|
os.remove(tmp_json_filepath)
|
||||||
|
|
||||||
|
return output_handler.results_dict
|
||||||
|
|
||||||
|
def get_chat_list(self,
|
||||||
|
ice_idx_list: List[List[int]],
|
||||||
|
retriever: BaseRetriever,
|
||||||
|
prompt_template: Optional[PromptTemplate] = None):
|
||||||
|
prompt_list = []
|
||||||
|
input_columns = retriever.dataset_reader.input_columns
|
||||||
|
output_column = retriever.dataset_reader.output_column
|
||||||
|
|
||||||
|
def chat_from_entry(entry):
|
||||||
|
if prompt_template is None and len(input_columns) == 1:
|
||||||
|
# Directly use the input column as the user input
|
||||||
|
user = entry.get(input_columns[0])
|
||||||
|
assistant = entry.get(output_column, '')
|
||||||
|
return [
|
||||||
|
dict(role='user', content=user),
|
||||||
|
dict(role='assistant', content=assistant),
|
||||||
|
]
|
||||||
|
elif prompt_template is not None:
|
||||||
|
# Use prompt template to generate chat history
|
||||||
|
chat = promptlist_to_openai(
|
||||||
|
prompt_template.generate_item(entry))
|
||||||
|
gold = entry.get(output_column, '')
|
||||||
|
if chat[-1]['role'] != 'assistant':
|
||||||
|
chat.append(dict(role='assistant', content=gold))
|
||||||
|
return chat
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
for idx, ice_idx in enumerate(ice_idx_list):
|
||||||
|
# NOTE: The in-context examples won't be used by now.
|
||||||
|
|
||||||
|
item = {
|
||||||
|
k: v
|
||||||
|
for k, v in retriever.test_ds[idx].items()
|
||||||
|
if k in input_columns or k == output_column
|
||||||
|
}
|
||||||
|
if all(isinstance(value, str) for value in item.values()):
|
||||||
|
# Every column is a single string
|
||||||
|
chat = chat_from_entry(item)
|
||||||
|
elif all(is_list_of(value, str) for value in item.values()):
|
||||||
|
# Every column is a list of string for multi-round chat
|
||||||
|
entries = [dict(zip(item, v)) for v in zip(*item.values())]
|
||||||
|
chat = sum((chat_from_entry(entry) for entry in entries), [])
|
||||||
|
elif len(input_columns) == 1 and is_list_of(
|
||||||
|
item[input_columns[0]], dict):
|
||||||
|
# Single input column and it's already a chat.
|
||||||
|
chat = item[input_columns[0]]
|
||||||
|
elif 'dialogue' in input_columns:
|
||||||
|
chat = item['dialogue']
|
||||||
|
self.dialogue_mode = True
|
||||||
|
else:
|
||||||
|
raise ValueError('Cannot construct chat from the dataset.')
|
||||||
|
|
||||||
|
prompt_list.append(chat)
|
||||||
|
return prompt_list
|
||||||
|
|
||||||
|
async def infer_last(self, chat: List[dict], index: int, output_handler):
|
||||||
|
assistant_indices = [
|
||||||
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||||
|
]
|
||||||
|
|
||||||
|
history = chat[:assistant_indices[-1]]
|
||||||
|
output = await self.model.generate_from_template(
|
||||||
|
[history], max_out_len=self.max_out_len)[0]
|
||||||
|
output_handler.save_results(
|
||||||
|
origin_prompt=history,
|
||||||
|
prediction=output,
|
||||||
|
idx=index,
|
||||||
|
gold=chat[assistant_indices[-1]]['content'],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def infer_every(self, chat: List[dict], index: int, output_handler):
|
||||||
|
assistant_indices = [
|
||||||
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||||
|
]
|
||||||
|
index_copy = index
|
||||||
|
|
||||||
|
for i in assistant_indices:
|
||||||
|
history = chat[:i]
|
||||||
|
output = await self.model.generate_from_template(
|
||||||
|
[history], max_out_len=self.max_out_len)[0]
|
||||||
|
chat[i]['content'] = output
|
||||||
|
if not self.dialogue_mode:
|
||||||
|
output_handler.save_multiround_results(
|
||||||
|
origin_prompt=history[-1]['content'],
|
||||||
|
prediction=output,
|
||||||
|
idx=index,
|
||||||
|
gold=chat[i]['content'],
|
||||||
|
)
|
||||||
|
# index += 1
|
||||||
|
if self.dialogue_mode:
|
||||||
|
# dialogue mode for subjective evaluation
|
||||||
|
assert len(chat) % 2 == 0
|
||||||
|
round_num = int(len(chat) / 2)
|
||||||
|
preds_list = []
|
||||||
|
for i in range(round_num):
|
||||||
|
temp_dict = {
|
||||||
|
'round': i + 1,
|
||||||
|
'user': chat[i * 2]['content'],
|
||||||
|
'assistant': chat[i * 2 + 1]['content']
|
||||||
|
}
|
||||||
|
preds_list.append(temp_dict)
|
||||||
|
output_handler.save_results(
|
||||||
|
origin_prompt=None,
|
||||||
|
prediction=preds_list,
|
||||||
|
idx=index_copy,
|
||||||
|
gold=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def infer_every_with_gt(self, chat: List[dict], index: int,
|
||||||
|
output_handler):
|
||||||
|
assistant_indices = [
|
||||||
|
i for i, item in enumerate(chat) if item['role'] == 'assistant'
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in assistant_indices:
|
||||||
|
history = chat[:i]
|
||||||
|
output = await self.model.generate_from_template(
|
||||||
|
[history], max_out_len=self.max_out_len)[0]
|
||||||
|
output_handler.save_multiround_results(
|
||||||
|
origin_prompt=history[-1]['content'],
|
||||||
|
prediction=output,
|
||||||
|
idx=index,
|
||||||
|
gold=chat[i]['content'],
|
||||||
|
)
|
||||||
|
index += 1
|
239
opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py
Normal file
239
opencompass/openicl/icl_inferencer/icl_gen_async_inferencer.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
"""Direct Generation Inferencer."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from opencompass.models.base import BaseModel
|
||||||
|
from opencompass.registry import ICL_INFERENCERS
|
||||||
|
from opencompass.utils import batched
|
||||||
|
|
||||||
|
from ..icl_prompt_template import PromptTemplate
|
||||||
|
from ..icl_retriever import BaseRetriever
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_INFERENCERS.register_module()
|
||||||
|
class AsyncGenInferencer(BaseInferencer):
|
||||||
|
"""Generation Inferencer class to directly evaluate by generation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (:obj:`BaseModelWrapper`, optional): The module to inference.
|
||||||
|
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
|
||||||
|
allowed by the LM.
|
||||||
|
min_out_len (:obj:`int`, optional): Minimum number of generated tokens
|
||||||
|
by the LM
|
||||||
|
batch_size (:obj:`int`, optional): Batch size for the
|
||||||
|
:obj:`DataLoader`.
|
||||||
|
output_json_filepath (:obj:`str`, optional): File path for output
|
||||||
|
`JSON` file.
|
||||||
|
output_json_filename (:obj:`str`, optional): File name for output
|
||||||
|
`JSON` file.
|
||||||
|
gen_field_replace_token (:obj:`str`, optional): Used to replace the
|
||||||
|
generation field token when generating prompts.
|
||||||
|
save_every (:obj:`int`, optional): Save intermediate results every
|
||||||
|
`save_every` iters. Defaults to 1.
|
||||||
|
generation_kwargs (:obj:`Dict`, optional): Parameters for the
|
||||||
|
:obj:`model.generate()` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: BaseModel,
|
||||||
|
max_out_len: int,
|
||||||
|
stopping_criteria: List[str] = [],
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
min_out_len: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = 1,
|
||||||
|
gen_field_replace_token: Optional[str] = '',
|
||||||
|
output_json_filepath: Optional[str] = './icl_inference_output',
|
||||||
|
output_json_filename: Optional[str] = 'predictions',
|
||||||
|
save_every: Optional[int] = 1,
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
output_json_filename=output_json_filename,
|
||||||
|
output_json_filepath=output_json_filepath,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gen_field_replace_token = gen_field_replace_token
|
||||||
|
self.max_out_len = max_out_len
|
||||||
|
self.min_out_len = min_out_len
|
||||||
|
self.stopping_criteria = stopping_criteria
|
||||||
|
self.dump_timer = kwargs.get('dump_timer', False)
|
||||||
|
|
||||||
|
if self.model.is_api and save_every is None:
|
||||||
|
save_every = 1
|
||||||
|
self.save_every = save_every
|
||||||
|
|
||||||
|
async def inference(self, # type: ignore
|
||||||
|
retriever: BaseRetriever,
|
||||||
|
ice_template: Optional[PromptTemplate] = None,
|
||||||
|
prompt_template: Optional[PromptTemplate] = None,
|
||||||
|
output_json_filepath: Optional[str] = None,
|
||||||
|
output_json_filename: Optional[str] = None) -> List:
|
||||||
|
# 1. Preparation for output logs
|
||||||
|
output_handler = GenInferencerOutputHandler()
|
||||||
|
|
||||||
|
if output_json_filepath is None:
|
||||||
|
output_json_filepath = self.output_json_filepath
|
||||||
|
if output_json_filename is None:
|
||||||
|
output_json_filename = self.output_json_filename
|
||||||
|
|
||||||
|
# 2. Get results of retrieval process
|
||||||
|
ice_idx_list = retriever.retrieve()
|
||||||
|
|
||||||
|
# 3. Generate prompts for testing input
|
||||||
|
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
|
||||||
|
ice_idx_list,
|
||||||
|
retriever,
|
||||||
|
self.gen_field_replace_token,
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
ice_template=ice_template,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
|
||||||
|
# 3.1 Fetch and zip prompt & gold answer if output column exists
|
||||||
|
ds_reader = retriever.dataset_reader
|
||||||
|
if ds_reader.output_column:
|
||||||
|
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
|
||||||
|
prompt_list = list(zip(prompt_list, gold_ans))
|
||||||
|
|
||||||
|
# Create tmp json file for saving intermediate results and future
|
||||||
|
# resuming
|
||||||
|
index = 0
|
||||||
|
tmp_json_filepath = os.path.join(output_json_filepath,
|
||||||
|
'tmp_' + output_json_filename)
|
||||||
|
if osp.exists(tmp_json_filepath):
|
||||||
|
# TODO: move resume to output handler
|
||||||
|
try:
|
||||||
|
tmp_result_dict = mmengine.load(tmp_json_filepath)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
output_handler.results_dict = tmp_result_dict
|
||||||
|
index = len(tmp_result_dict)
|
||||||
|
|
||||||
|
# 4. Wrap prompts with Dataloader
|
||||||
|
logger.debug('Starting build dataloader')
|
||||||
|
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
|
||||||
|
|
||||||
|
# 5. Inference for prompts in each batch
|
||||||
|
logger.debug('Starting inference process...')
|
||||||
|
|
||||||
|
start_time_stamp = time.time()
|
||||||
|
num_sample = 0
|
||||||
|
# TODO: batched dataloader shoule be replaced with async fetching
|
||||||
|
for datum in dataloader:
|
||||||
|
if ds_reader.output_column:
|
||||||
|
entry, golds = list(zip(*datum))
|
||||||
|
else:
|
||||||
|
entry = datum
|
||||||
|
golds = [None for _ in range(len(entry))]
|
||||||
|
# 5-1. Inference with local model
|
||||||
|
extra_gen_kwargs = {}
|
||||||
|
sig = inspect.signature(self.model.generate)
|
||||||
|
if 'stopping_criteria' in sig.parameters:
|
||||||
|
extra_gen_kwargs['stopping_criteria'] = self.stopping_criteria
|
||||||
|
if 'min_out_len' in sig.parameters:
|
||||||
|
extra_gen_kwargs['min_out_len'] = self.min_out_len
|
||||||
|
with torch.no_grad():
|
||||||
|
parsed_entries = self.model.parse_template(entry, mode='gen')
|
||||||
|
results = await self.model.generate_from_template(
|
||||||
|
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
|
||||||
|
generated = results
|
||||||
|
|
||||||
|
num_return_sequences = getattr(self.model, 'generation_kwargs',
|
||||||
|
{}).get('num_return_sequences', 1)
|
||||||
|
# 5-3. Save current output
|
||||||
|
for prompt, prediction, gold in zip(
|
||||||
|
parsed_entries, batched(generated, num_return_sequences),
|
||||||
|
golds):
|
||||||
|
if num_return_sequences == 1:
|
||||||
|
prediction = prediction[0]
|
||||||
|
output_handler.save_results(prompt,
|
||||||
|
prediction,
|
||||||
|
index,
|
||||||
|
gold=gold)
|
||||||
|
index = index + 1
|
||||||
|
|
||||||
|
# 5-4. Save intermediate results
|
||||||
|
if (self.save_every is not None and index % self.save_every == 0
|
||||||
|
and self.is_main_process):
|
||||||
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
'tmp_' + output_json_filename)
|
||||||
|
num_sample += len(datum)
|
||||||
|
|
||||||
|
end_time_stamp = time.time()
|
||||||
|
|
||||||
|
# 6. Output
|
||||||
|
if self.is_main_process:
|
||||||
|
os.makedirs(output_json_filepath, exist_ok=True)
|
||||||
|
output_handler.write_to_json(output_json_filepath,
|
||||||
|
output_json_filename)
|
||||||
|
if osp.exists(tmp_json_filepath):
|
||||||
|
os.remove(tmp_json_filepath)
|
||||||
|
|
||||||
|
if self.dump_timer and self.is_main_process:
|
||||||
|
timer_filepath = os.path.join(output_json_filepath, 'timer',
|
||||||
|
'time.jsonl')
|
||||||
|
os.makedirs(os.path.dirname(timer_filepath), exist_ok=True)
|
||||||
|
time_dict = {
|
||||||
|
'dataset_name': output_json_filename.removesuffix('.json'),
|
||||||
|
'time': end_time_stamp - start_time_stamp,
|
||||||
|
'num_sample': num_sample
|
||||||
|
}
|
||||||
|
with open(timer_filepath, 'a') as f:
|
||||||
|
f.write(json.dumps(time_dict) + '\n')
|
||||||
|
|
||||||
|
return [
|
||||||
|
sample['prediction']
|
||||||
|
for sample in output_handler.results_dict.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_generation_prompt_list_from_retriever_indices(
|
||||||
|
self,
|
||||||
|
ice_idx_list: List[List[int]],
|
||||||
|
retriever: BaseRetriever,
|
||||||
|
gen_field_replace_token: str,
|
||||||
|
max_seq_len: Optional[int] = None,
|
||||||
|
ice_template: Optional[PromptTemplate] = None,
|
||||||
|
prompt_template: Optional[PromptTemplate] = None):
|
||||||
|
prompt_list = []
|
||||||
|
for idx, ice_idx in enumerate(ice_idx_list):
|
||||||
|
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
|
||||||
|
prompt = retriever.generate_prompt_for_generate_task(
|
||||||
|
idx,
|
||||||
|
ice,
|
||||||
|
gen_field_replace_token=gen_field_replace_token,
|
||||||
|
ice_template=ice_template,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
if max_seq_len is not None:
|
||||||
|
prompt_token_num = self.model.get_token_len_from_template(
|
||||||
|
prompt, mode='gen')
|
||||||
|
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
|
||||||
|
ice_idx = ice_idx[:-1]
|
||||||
|
ice = retriever.generate_ice(ice_idx,
|
||||||
|
ice_template=ice_template)
|
||||||
|
prompt = retriever.generate_prompt_for_generate_task(
|
||||||
|
idx,
|
||||||
|
ice,
|
||||||
|
gen_field_replace_token=gen_field_replace_token,
|
||||||
|
ice_template=ice_template,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
prompt_token_num = self.model.get_token_len_from_template(
|
||||||
|
prompt, mode='gen')
|
||||||
|
prompt_list.append(prompt)
|
||||||
|
return prompt_list
|
112
opencompass/runners/local_async.py
Normal file
112
opencompass/runners/local_async.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from math import prod
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import numpy as np
|
||||||
|
from mmengine.config import ConfigDict
|
||||||
|
from mmengine.device import is_npu_available
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from opencompass.registry import RUNNERS, TASKS
|
||||||
|
from opencompass.utils import get_logger, model_abbr_from_cfg
|
||||||
|
|
||||||
|
from .base import BaseRunner
|
||||||
|
from typing import TypedDict, Optional
|
||||||
|
from multiprocessing.managers import Namespace
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
import enum
|
||||||
|
import signal
|
||||||
|
from enum import IntEnum
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
class Status(IntEnum):
|
||||||
|
SUCCESS = 0
|
||||||
|
FAILED = -1
|
||||||
|
INTERRUPT = signal.SIGINT
|
||||||
|
|
||||||
|
|
||||||
|
@RUNNERS.register_module()
|
||||||
|
class AsyncRunner(BaseRunner):
|
||||||
|
"""Local runner. Start tasks by local python.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (ConfigDict): Task type config.
|
||||||
|
max_num_workers (int): Max number of workers to run in parallel.
|
||||||
|
Defaults to 16.
|
||||||
|
max_workers_per_gpu (int): Max number of workers to run for one GPU.
|
||||||
|
Defaults to 1.
|
||||||
|
debug (bool): Whether to run in debug mode.
|
||||||
|
lark_bot_url (str): Lark bot url.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# These is a fake typehint
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
task: ConfigDict,
|
||||||
|
debug: bool = False,
|
||||||
|
*,
|
||||||
|
max_num_workers: int = 16,
|
||||||
|
keep_tmp_file: bool = False,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(task=task, debug=debug)
|
||||||
|
self.max_num_workers = max_num_workers
|
||||||
|
self.keep_tmp_file = keep_tmp_file
|
||||||
|
logger = get_logger()
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
logger.warning(f'Ignored argument in `AsyncRunner`: {k}={v}')
|
||||||
|
|
||||||
|
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, Status]]: # type: ignore
|
||||||
|
"""Launch multiple tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list[dict]): A list of task configs, usually generated by
|
||||||
|
Partitioner.
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
list[tuple[str, int]]: A list of (task name, exit code).
|
||||||
|
"""
|
||||||
|
from opencompass.tasks.openicl_async_task import OpenICLAsyncInferTask
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
return [("", Status.SUCCESS)]
|
||||||
|
|
||||||
|
assert len(tasks) == 1, f"Task num must be 1 for `AsyncRunner`"
|
||||||
|
task_cfg = tasks[0]
|
||||||
|
|
||||||
|
task: OpenICLAsyncInferTask = TASKS.build(dict(cfg=task_cfg, type=self.task_cfg['type']))
|
||||||
|
task_name = task.name
|
||||||
|
# get cmd
|
||||||
|
mmengine.mkdir_or_exist('tmp/')
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(task.run())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
return [(task_name, Status.INTERRUPT)]
|
||||||
|
except:
|
||||||
|
print(traceback.print_exc())
|
||||||
|
return [(task_name, Status.FAILED)]
|
||||||
|
else:
|
||||||
|
return [(task_name, Status.SUCCESS)]
|
||||||
|
|
||||||
|
def __call__(self, tasks: List[Dict[str, Any]]):
|
||||||
|
"""Launch multiple tasks and summarize the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list[dict]): A list of task configs, usually generated by
|
||||||
|
Partitioner.
|
||||||
|
"""
|
||||||
|
status = self.launch(tasks)
|
||||||
|
status_list = list(status) # change into list format
|
||||||
|
self.summarize(status_list)
|
@ -1,3 +1,4 @@
|
|||||||
from .openicl_attack import * # noqa: F401, F403
|
from .openicl_attack import * # noqa: F401, F403
|
||||||
from .openicl_eval import * # noqa: F401, F403
|
from .openicl_eval import * # noqa: F401, F403
|
||||||
from .openicl_infer import * # noqa: F401, F403
|
from .openicl_infer import * # noqa: F401, F403
|
||||||
|
from .openicl_async_task import * # noqa: F401, F403
|
||||||
|
168
opencompass/tasks/openicl_async_task.py
Normal file
168
opencompass/tasks/openicl_async_task.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
|
from mmengine.config import Config, ConfigDict
|
||||||
|
import inspect
|
||||||
|
from mmengine.utils import mkdir_or_exist
|
||||||
|
|
||||||
|
from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
|
||||||
|
ICL_RETRIEVERS, TASKS)
|
||||||
|
from opencompass.tasks.base import BaseTask
|
||||||
|
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
|
||||||
|
get_infer_output_path, get_logger,
|
||||||
|
task_abbr_from_cfg)
|
||||||
|
from opencompass.openicl.icl_inferencer.icl_gen_async_inferencer import AsyncGenInferencer
|
||||||
|
from opencompass.openicl.icl_inferencer.icl_chat_async_inferencer import AsyncChatInferencer
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer, ChatInferencer
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import asyncio
|
||||||
|
import resource
|
||||||
|
from more_itertools import consume
|
||||||
|
|
||||||
|
|
||||||
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
|
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
||||||
|
|
||||||
|
|
||||||
|
@TASKS.register_module()
|
||||||
|
class OpenICLAsyncInferTask(BaseTask):
|
||||||
|
"""OpenICL Inference Task.
|
||||||
|
|
||||||
|
This task is used to run the inference process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name_prefix = 'OpenICLInfer'
|
||||||
|
log_subdir = 'logs/infer'
|
||||||
|
output_subdir = 'predictions'
|
||||||
|
|
||||||
|
def __init__(self, cfg: ConfigDict):
|
||||||
|
super().__init__(cfg)
|
||||||
|
run_cfg = self.model_cfgs[0].get('run_cfg', {})
|
||||||
|
self.nproc = run_cfg.get('nproc_per_worker', 16)
|
||||||
|
|
||||||
|
def get_command(self, cfg_path, template) -> str:
|
||||||
|
# TODO:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def run(self): # type: ignore
|
||||||
|
_dataset_cfgs = []
|
||||||
|
infer_cfgs = []
|
||||||
|
sub_cfgs = []
|
||||||
|
datasets = []
|
||||||
|
model_cfgs = []
|
||||||
|
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
|
||||||
|
self.max_out_len = model_cfg.get('max_out_len', None)
|
||||||
|
self.batch_size = model_cfg.get('batch_size', None)
|
||||||
|
self.min_out_len = model_cfg.get('min_out_len', None)
|
||||||
|
|
||||||
|
for dataset_cfg in dataset_cfgs:
|
||||||
|
self.dataset_cfg = dataset_cfg
|
||||||
|
out_path = get_infer_output_path(
|
||||||
|
model_cfg, dataset_cfg,
|
||||||
|
osp.join(self.work_dir, 'predictions'))
|
||||||
|
|
||||||
|
if osp.exists(out_path):
|
||||||
|
continue
|
||||||
|
_dataset_cfgs.append(dataset_cfg)
|
||||||
|
datasets.append(build_dataset_from_cfg(dataset_cfg))
|
||||||
|
infer_cfgs.append(dataset_cfg['infer_cfg'])
|
||||||
|
model_cfgs.append(model_cfg)
|
||||||
|
sub_cfg = {
|
||||||
|
'models': [model_cfg],
|
||||||
|
'datasets': [[dataset_cfg]],
|
||||||
|
}
|
||||||
|
sub_cfgs.append(sub_cfg)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
args = list(zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs))
|
||||||
|
for arg in tqdm(
|
||||||
|
args,
|
||||||
|
total=len(args),
|
||||||
|
desc=f"Starting building tasks..."
|
||||||
|
):
|
||||||
|
tasks.append(asyncio.create_task(self._inference(*arg)))
|
||||||
|
|
||||||
|
bar = tqdm(desc="Inferencing...", total=len(tasks))
|
||||||
|
bar.refresh()
|
||||||
|
|
||||||
|
while tasks:
|
||||||
|
done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
for _ in done:
|
||||||
|
bar.update()
|
||||||
|
bar.refresh()
|
||||||
|
|
||||||
|
# TODO: Needs a debug mode
|
||||||
|
# for arg in zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs):
|
||||||
|
# await self._inference(*arg)
|
||||||
|
|
||||||
|
async def _inference(self, dataset_cfg, infer_cfg, dataset, model_cfg, sub_cfg):
|
||||||
|
model = build_model_from_cfg(model_cfg)
|
||||||
|
assert hasattr(infer_cfg, 'ice_template') or hasattr(infer_cfg, 'prompt_template'), \
|
||||||
|
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
||||||
|
|
||||||
|
infer_kwargs: dict = {}
|
||||||
|
if hasattr(infer_cfg, 'ice_template'):
|
||||||
|
ice_template = ICL_PROMPT_TEMPLATES.build(
|
||||||
|
infer_cfg['ice_template'])
|
||||||
|
infer_kwargs['ice_template'] = ice_template
|
||||||
|
|
||||||
|
if hasattr(infer_cfg, 'prompt_template'):
|
||||||
|
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
||||||
|
infer_cfg['prompt_template'])
|
||||||
|
infer_kwargs['prompt_template'] = prompt_template
|
||||||
|
|
||||||
|
retriever_cfg = infer_cfg['retriever'].copy()
|
||||||
|
retriever_cfg['dataset'] = dataset
|
||||||
|
retriever = ICL_RETRIEVERS.build(retriever_cfg)
|
||||||
|
|
||||||
|
# set inferencer's default value according to model's config'
|
||||||
|
inferencer_cfg: dict = infer_cfg['inferencer']
|
||||||
|
inferencer_cfg['model'] = model
|
||||||
|
inferencer_cfg['max_seq_len'] = model_cfg.get('max_seq_len')
|
||||||
|
|
||||||
|
infer_type = inferencer_cfg["type"]
|
||||||
|
if inspect.isclass(infer_type):
|
||||||
|
infer_name = infer_type.__name__
|
||||||
|
else:
|
||||||
|
infer_name = infer_type
|
||||||
|
|
||||||
|
if infer_name.split(".")[-1] == "ChatInferencer":
|
||||||
|
inferencer_cfg["type"] = AsyncChatInferencer
|
||||||
|
|
||||||
|
elif infer_name.split(".")[-1] == "GenInferencer":
|
||||||
|
inferencer_cfg["type"] = AsyncGenInferencer
|
||||||
|
|
||||||
|
inferencer_cfg.setdefault('max_out_len', self.max_out_len)
|
||||||
|
inferencer_cfg.setdefault('min_out_len', self.min_out_len)
|
||||||
|
inferencer_cfg.setdefault('batch_size', self.batch_size)
|
||||||
|
inferencer = ICL_INFERENCERS.build(inferencer_cfg)
|
||||||
|
|
||||||
|
out_path = get_infer_output_path(
|
||||||
|
model_cfg, dataset_cfg,
|
||||||
|
osp.join(self.work_dir, 'predictions'))
|
||||||
|
out_dir, out_file = osp.split(out_path)
|
||||||
|
mkdir_or_exist(out_dir)
|
||||||
|
|
||||||
|
infer_kwargs['output_json_filepath'] = out_dir
|
||||||
|
infer_kwargs['output_json_filename'] = out_file
|
||||||
|
|
||||||
|
await inferencer.inference(retriever, **infer_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Model Inferencer')
|
||||||
|
parser.add_argument('config', help='Config file path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# TODO:
|
||||||
|
raise NotImplementedError()
|
Loading…
Reference in New Issue
Block a user