mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
113 lines
3.2 KiB
Python
113 lines
3.2 KiB
Python
from math import prod
|
|
import os
|
|
import os.path as osp
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from functools import partial
|
|
from threading import Lock
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import mmengine
|
|
import numpy as np
|
|
from mmengine.config import ConfigDict
|
|
from mmengine.device import is_npu_available
|
|
from tqdm import tqdm
|
|
|
|
from opencompass.registry import RUNNERS, TASKS
|
|
from opencompass.utils import get_logger, model_abbr_from_cfg
|
|
|
|
from .base import BaseRunner
|
|
from typing import TypedDict, Optional
|
|
from multiprocessing.managers import Namespace
|
|
import threading
|
|
import uuid
|
|
import enum
|
|
import signal
|
|
from enum import IntEnum
|
|
import asyncio
|
|
import traceback
|
|
|
|
|
|
class Status(IntEnum):
|
|
SUCCESS = 0
|
|
FAILED = -1
|
|
INTERRUPT = signal.SIGINT
|
|
|
|
|
|
@RUNNERS.register_module()
|
|
class AsyncRunner(BaseRunner):
|
|
"""Local runner. Start tasks by local python.
|
|
|
|
Args:
|
|
task (ConfigDict): Task type config.
|
|
max_num_workers (int): Max number of workers to run in parallel.
|
|
Defaults to 16.
|
|
max_workers_per_gpu (int): Max number of workers to run for one GPU.
|
|
Defaults to 1.
|
|
debug (bool): Whether to run in debug mode.
|
|
lark_bot_url (str): Lark bot url.
|
|
"""
|
|
|
|
# These is a fake typehint
|
|
|
|
def __init__(self,
|
|
task: ConfigDict,
|
|
debug: bool = False,
|
|
*,
|
|
max_num_workers: int = 16,
|
|
keep_tmp_file: bool = False,
|
|
**kwargs):
|
|
super().__init__(task=task, debug=debug)
|
|
self.max_num_workers = max_num_workers
|
|
self.keep_tmp_file = keep_tmp_file
|
|
logger = get_logger()
|
|
for k, v in kwargs.items():
|
|
logger.warning(f'Ignored argument in `AsyncRunner`: {k}={v}')
|
|
|
|
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, Status]]: # type: ignore
|
|
"""Launch multiple tasks.
|
|
|
|
Args:
|
|
tasks (list[dict]): A list of task configs, usually generated by
|
|
Partitioner.
|
|
Returns:
|
|
|
|
list[tuple[str, int]]: A list of (task name, exit code).
|
|
"""
|
|
from opencompass.tasks.openicl_async_task import OpenICLAsyncInferTask
|
|
|
|
if not tasks:
|
|
return [("", Status.SUCCESS)]
|
|
|
|
assert len(tasks) == 1, f"Task num must be 1 for `AsyncRunner`"
|
|
task_cfg = tasks[0]
|
|
|
|
task: OpenICLAsyncInferTask = TASKS.build(dict(cfg=task_cfg, type=self.task_cfg['type']))
|
|
task_name = task.name
|
|
# get cmd
|
|
mmengine.mkdir_or_exist('tmp/')
|
|
|
|
try:
|
|
asyncio.run(task.run())
|
|
except KeyboardInterrupt:
|
|
return [(task_name, Status.INTERRUPT)]
|
|
except:
|
|
print(traceback.print_exc())
|
|
return [(task_name, Status.FAILED)]
|
|
else:
|
|
return [(task_name, Status.SUCCESS)]
|
|
|
|
def __call__(self, tasks: List[Dict[str, Any]]):
|
|
"""Launch multiple tasks and summarize the results.
|
|
|
|
Args:
|
|
tasks (list[dict]): A list of task configs, usually generated by
|
|
Partitioner.
|
|
"""
|
|
status = self.launch(tasks)
|
|
status_list = list(status) # change into list format
|
|
self.summarize(status_list)
|