mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
import argparse
|
|
import os
|
|
import os.path as osp
|
|
import random
|
|
import sys
|
|
import time
|
|
from typing import Any
|
|
from tqdm.asyncio import tqdm
|
|
|
|
from mmengine.config import Config, ConfigDict
|
|
import inspect
|
|
from mmengine.utils import mkdir_or_exist
|
|
|
|
from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
|
|
ICL_RETRIEVERS, TASKS)
|
|
from opencompass.tasks.base import BaseTask
|
|
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
|
|
get_infer_output_path, get_logger,
|
|
task_abbr_from_cfg)
|
|
from opencompass.openicl.icl_inferencer.icl_gen_async_inferencer import AsyncGenInferencer
|
|
from opencompass.openicl.icl_inferencer.icl_chat_async_inferencer import AsyncChatInferencer
|
|
from opencompass.openicl.icl_inferencer import GenInferencer, ChatInferencer
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import asyncio
|
|
import resource
|
|
from more_itertools import consume
|
|
|
|
|
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
|
|
|
|
|
@TASKS.register_module()
|
|
class OpenICLAsyncInferTask(BaseTask):
|
|
"""OpenICL Inference Task.
|
|
|
|
This task is used to run the inference process.
|
|
"""
|
|
|
|
name_prefix = 'OpenICLInfer'
|
|
log_subdir = 'logs/infer'
|
|
output_subdir = 'predictions'
|
|
|
|
def __init__(self, cfg: ConfigDict):
|
|
super().__init__(cfg)
|
|
run_cfg = self.model_cfgs[0].get('run_cfg', {})
|
|
self.nproc = run_cfg.get('nproc_per_worker', 16)
|
|
|
|
def get_command(self, cfg_path, template) -> str:
|
|
# TODO:
|
|
raise NotImplementedError()
|
|
return ""
|
|
|
|
async def run(self): # type: ignore
|
|
_dataset_cfgs = []
|
|
infer_cfgs = []
|
|
sub_cfgs = []
|
|
datasets = []
|
|
model_cfgs = []
|
|
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
|
|
self.max_out_len = model_cfg.get('max_out_len', None)
|
|
self.batch_size = model_cfg.get('batch_size', None)
|
|
self.min_out_len = model_cfg.get('min_out_len', None)
|
|
|
|
for dataset_cfg in dataset_cfgs:
|
|
self.dataset_cfg = dataset_cfg
|
|
out_path = get_infer_output_path(
|
|
model_cfg, dataset_cfg,
|
|
osp.join(self.work_dir, 'predictions'))
|
|
|
|
if osp.exists(out_path):
|
|
continue
|
|
_dataset_cfgs.append(dataset_cfg)
|
|
datasets.append(build_dataset_from_cfg(dataset_cfg))
|
|
infer_cfgs.append(dataset_cfg['infer_cfg'])
|
|
model_cfgs.append(model_cfg)
|
|
sub_cfg = {
|
|
'models': [model_cfg],
|
|
'datasets': [[dataset_cfg]],
|
|
}
|
|
sub_cfgs.append(sub_cfg)
|
|
|
|
tasks = []
|
|
args = list(zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs))
|
|
for arg in tqdm(
|
|
args,
|
|
total=len(args),
|
|
desc=f"Starting building tasks..."
|
|
):
|
|
tasks.append(asyncio.create_task(self._inference(*arg)))
|
|
|
|
bar = tqdm(desc="Inferencing...", total=len(tasks))
|
|
bar.refresh()
|
|
|
|
while tasks:
|
|
done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
for _ in done:
|
|
bar.update()
|
|
bar.refresh()
|
|
|
|
# TODO: Needs a debug mode
|
|
# for arg in zip(_dataset_cfgs, infer_cfgs, datasets, model_cfgs, sub_cfgs):
|
|
# await self._inference(*arg)
|
|
|
|
async def _inference(self, dataset_cfg, infer_cfg, dataset, model_cfg, sub_cfg):
|
|
model = build_model_from_cfg(model_cfg)
|
|
assert hasattr(infer_cfg, 'ice_template') or hasattr(infer_cfg, 'prompt_template'), \
|
|
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
|
|
|
|
infer_kwargs: dict = {}
|
|
if hasattr(infer_cfg, 'ice_template'):
|
|
ice_template = ICL_PROMPT_TEMPLATES.build(
|
|
infer_cfg['ice_template'])
|
|
infer_kwargs['ice_template'] = ice_template
|
|
|
|
if hasattr(infer_cfg, 'prompt_template'):
|
|
prompt_template = ICL_PROMPT_TEMPLATES.build(
|
|
infer_cfg['prompt_template'])
|
|
infer_kwargs['prompt_template'] = prompt_template
|
|
|
|
retriever_cfg = infer_cfg['retriever'].copy()
|
|
retriever_cfg['dataset'] = dataset
|
|
retriever = ICL_RETRIEVERS.build(retriever_cfg)
|
|
|
|
# set inferencer's default value according to model's config'
|
|
inferencer_cfg: dict = infer_cfg['inferencer']
|
|
inferencer_cfg['model'] = model
|
|
inferencer_cfg['max_seq_len'] = model_cfg.get('max_seq_len')
|
|
|
|
infer_type = inferencer_cfg["type"]
|
|
if inspect.isclass(infer_type):
|
|
infer_name = infer_type.__name__
|
|
else:
|
|
infer_name = infer_type
|
|
|
|
if infer_name.split(".")[-1] == "ChatInferencer":
|
|
inferencer_cfg["type"] = AsyncChatInferencer
|
|
|
|
elif infer_name.split(".")[-1] == "GenInferencer":
|
|
inferencer_cfg["type"] = AsyncGenInferencer
|
|
|
|
inferencer_cfg.setdefault('max_out_len', self.max_out_len)
|
|
inferencer_cfg.setdefault('min_out_len', self.min_out_len)
|
|
inferencer_cfg.setdefault('batch_size', self.batch_size)
|
|
inferencer = ICL_INFERENCERS.build(inferencer_cfg)
|
|
|
|
out_path = get_infer_output_path(
|
|
model_cfg, dataset_cfg,
|
|
osp.join(self.work_dir, 'predictions'))
|
|
out_dir, out_file = osp.split(out_path)
|
|
mkdir_or_exist(out_dir)
|
|
|
|
infer_kwargs['output_json_filepath'] = out_dir
|
|
infer_kwargs['output_json_filename'] = out_file
|
|
|
|
await inferencer.inference(retriever, **infer_kwargs)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Model Inferencer')
|
|
parser.add_argument('config', help='Config file path')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# TODO:
|
|
raise NotImplementedError()
|