OpenCompass/opencompass/runners/local_async.py
2024-12-15 18:51:04 +08:00

113 lines
3.2 KiB
Python

from math import prod
import os
import os.path as osp
import re
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from threading import Lock
from typing import Any, Dict, List, Tuple
import mmengine
import numpy as np
from mmengine.config import ConfigDict
from mmengine.device import is_npu_available
from tqdm import tqdm
from opencompass.registry import RUNNERS, TASKS
from opencompass.utils import get_logger, model_abbr_from_cfg
from .base import BaseRunner
from typing import TypedDict, Optional
from multiprocessing.managers import Namespace
import threading
import uuid
import enum
import signal
from enum import IntEnum
import asyncio
import traceback
class Status(IntEnum):
SUCCESS = 0
FAILED = -1
INTERRUPT = signal.SIGINT
@RUNNERS.register_module()
class AsyncRunner(BaseRunner):
"""Local runner. Start tasks by local python.
Args:
task (ConfigDict): Task type config.
max_num_workers (int): Max number of workers to run in parallel.
Defaults to 16.
max_workers_per_gpu (int): Max number of workers to run for one GPU.
Defaults to 1.
debug (bool): Whether to run in debug mode.
lark_bot_url (str): Lark bot url.
"""
# These is a fake typehint
def __init__(self,
task: ConfigDict,
debug: bool = False,
*,
max_num_workers: int = 16,
keep_tmp_file: bool = False,
**kwargs):
super().__init__(task=task, debug=debug)
self.max_num_workers = max_num_workers
self.keep_tmp_file = keep_tmp_file
logger = get_logger()
for k, v in kwargs.items():
logger.warning(f'Ignored argument in `AsyncRunner`: {k}={v}')
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, Status]]: # type: ignore
"""Launch multiple tasks.
Args:
tasks (list[dict]): A list of task configs, usually generated by
Partitioner.
Returns:
list[tuple[str, int]]: A list of (task name, exit code).
"""
from opencompass.tasks.openicl_async_task import OpenICLAsyncInferTask
if not tasks:
return [("", Status.SUCCESS)]
assert len(tasks) == 1, f"Task num must be 1 for `AsyncRunner`"
task_cfg = tasks[0]
task: OpenICLAsyncInferTask = TASKS.build(dict(cfg=task_cfg, type=self.task_cfg['type']))
task_name = task.name
# get cmd
mmengine.mkdir_or_exist('tmp/')
try:
asyncio.run(task.run())
except KeyboardInterrupt:
return [(task_name, Status.INTERRUPT)]
except:
print(traceback.print_exc())
return [(task_name, Status.FAILED)]
else:
return [(task_name, Status.SUCCESS)]
def __call__(self, tasks: List[Dict[str, Any]]):
"""Launch multiple tasks and summarize the results.
Args:
tasks (list[dict]): A list of task configs, usually generated by
Partitioner.
"""
status = self.launch(tasks)
status_list = list(status) # change into list format
self.summarize(status_list)