OpenCompass/opencompass/tasks/openicl_infer.py

164 lines
6.5 KiB
Python
Raw Normal View History

2023-07-04 21:34:55 +08:00
import argparse
import os
2023-07-04 21:34:55 +08:00
import os.path as osp
import random
import sys
2023-07-04 21:34:55 +08:00
import time
from typing import Any
from mmengine.config import Config, ConfigDict
from mmengine.utils import mkdir_or_exist
from opencompass.registry import (ICL_INFERENCERS, ICL_PROMPT_TEMPLATES,
ICL_RETRIEVERS, TASKS)
from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg,
get_infer_output_path, get_logger,
model_abbr_from_cfg, task_abbr_from_cfg)
2023-07-04 21:34:55 +08:00
2024-07-11 19:59:35 +08:00
@TASKS.register_module()
2023-07-04 21:34:55 +08:00
class OpenICLInferTask(BaseTask):
"""OpenICL Inference Task.
This task is used to run the inference process.
"""
name_prefix = 'OpenICLInfer'
log_subdir = 'logs/infer'
output_subdir = 'predictions'
def __init__(self, cfg: ConfigDict):
super().__init__(cfg)
run_cfg = self.model_cfgs[0].get('run_cfg', {})
self.num_gpus = run_cfg.get('num_gpus', 0)
self.num_procs = run_cfg.get('num_procs', 1)
self.logger = get_logger()
2023-07-04 21:34:55 +08:00
def get_command(self, cfg_path, template):
"""Get the command template for the task.
Args:
cfg_path (str): The path to the config file of the task.
template (str): The template which have '{task_cmd}' to format
the command.
"""
sys.path.append(os.getcwd())
script_path = __file__
2024-04-09 17:50:23 +08:00
backend_keys = ['VLLM', 'Lmdeploy']
use_backend = any(
key in str(self.model_cfgs[0].get('type', ''))
or key in str(self.model_cfgs[0].get('llm', {}).get('type', ''))
for key in backend_keys)
2024-08-23 10:39:22 +08:00
if self.num_gpus > 1 and not use_backend:
port = random.randint(12000, 32000)
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
2023-07-04 21:34:55 +08:00
else:
python = sys.executable
command = f'{python} {script_path} {cfg_path}'
return template.format(task_cmd=command)
2023-07-04 21:34:55 +08:00
def run(self, cur_model=None, cur_model_abbr=None):
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}')
2023-07-04 21:34:55 +08:00
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
self.max_out_len = model_cfg.get('max_out_len', None)
self.batch_size = model_cfg.get('batch_size', None)
self.min_out_len = model_cfg.get('min_out_len', None)
if cur_model and cur_model_abbr == model_abbr_from_cfg(model_cfg):
self.model = cur_model
else:
self.model = build_model_from_cfg(model_cfg)
2023-07-04 21:34:55 +08:00
for dataset_cfg in dataset_cfgs:
self.model_cfg = model_cfg
self.dataset_cfg = dataset_cfg
self.infer_cfg = self.dataset_cfg['infer_cfg']
self.dataset = build_dataset_from_cfg(self.dataset_cfg)
self.sub_cfg = {
'models': [self.model_cfg],
'datasets': [[self.dataset_cfg]],
}
2023-07-04 21:34:55 +08:00
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'predictions'))
if osp.exists(out_path):
continue
self._inference()
def _inference(self):
self.logger.info(
f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}')
2023-07-04 21:34:55 +08:00
assert hasattr(self.infer_cfg, 'ice_template') or hasattr(self.infer_cfg, 'prompt_template'), \
'Both ice_template and prompt_template cannot be None simultaneously.' # noqa: E501
if hasattr(self.infer_cfg, 'ice_template'):
ice_template = ICL_PROMPT_TEMPLATES.build(
self.infer_cfg['ice_template'])
if hasattr(self.infer_cfg, 'prompt_template'):
prompt_template = ICL_PROMPT_TEMPLATES.build(
self.infer_cfg['prompt_template'])
retriever_cfg = self.infer_cfg['retriever'].copy()
retriever_cfg['dataset'] = self.dataset
retriever = ICL_RETRIEVERS.build(retriever_cfg)
# set inferencer's default value according to model's config'
inferencer_cfg = self.infer_cfg['inferencer']
inferencer_cfg['model'] = self.model
self._set_default_value(inferencer_cfg, 'max_out_len',
self.max_out_len)
self._set_default_value(inferencer_cfg, 'min_out_len',
self.min_out_len)
2023-07-04 21:34:55 +08:00
self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size)
inferencer_cfg['max_seq_len'] = self.model_cfg.get('max_seq_len')
2023-07-04 21:34:55 +08:00
inferencer = ICL_INFERENCERS.build(inferencer_cfg)
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'predictions'))
out_dir, out_file = osp.split(out_path)
mkdir_or_exist(out_dir)
if hasattr(self.infer_cfg, 'prompt_template') and \
hasattr(self.infer_cfg, 'ice_template'):
inferencer.inference(retriever,
ice_template=ice_template,
prompt_template=prompt_template,
output_json_filepath=out_dir,
output_json_filename=out_file)
elif hasattr(self.infer_cfg, 'prompt_template'):
inferencer.inference(retriever,
prompt_template=prompt_template,
output_json_filepath=out_dir,
output_json_filename=out_file)
else:
inferencer.inference(retriever,
ice_template=ice_template,
output_json_filepath=out_dir,
output_json_filename=out_file)
def _set_default_value(self, cfg: ConfigDict, key: str, value: Any):
if key not in cfg:
cfg[key] = value
def parse_args():
parser = argparse.ArgumentParser(description='Model Inferencer')
parser.add_argument('config', help='Config file path')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cfg = Config.fromfile(args.config)
start_time = time.time()
inferencer = OpenICLInferTask(cfg)
inferencer.run()
end_time = time.time()
get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')