mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] update python action and slurm (#694)
This commit is contained in:
parent
6130394165
commit
a94598d921
@ -1,19 +1,13 @@
|
|||||||
import copy
|
import copy
|
||||||
import io
|
import io
|
||||||
import signal
|
import multiprocessing
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from lagent.actions.base_action import BaseAction
|
from lagent.actions.base_action import BaseAction
|
||||||
from lagent.schema import ActionReturn, ActionStatusCode
|
from lagent.schema import ActionReturn, ActionStatusCode
|
||||||
|
|
||||||
|
from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit
|
||||||
class TimeoutError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def handler(signum, frame):
|
|
||||||
raise TimeoutError()
|
|
||||||
|
|
||||||
|
|
||||||
class GenericRuntime:
|
class GenericRuntime:
|
||||||
@ -90,30 +84,60 @@ class PythonInterpreter(BaseAction):
|
|||||||
self.answer_from_stdout = answer_from_stdout
|
self.answer_from_stdout = answer_from_stdout
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
def __call__(self, command: str) -> ActionReturn:
|
@staticmethod
|
||||||
self.runtime = GenericRuntime()
|
def extract_code(command: str) -> str:
|
||||||
signal.signal(signal.SIGALRM, handler)
|
|
||||||
signal.alarm(self.timeout)
|
|
||||||
try:
|
|
||||||
tool_return = self._call(command)
|
|
||||||
except TimeoutError as e:
|
|
||||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
|
||||||
tool_return.errmsg = repr(e)
|
|
||||||
tool_return.state = ActionStatusCode.API_ERROR
|
|
||||||
finally:
|
|
||||||
signal.alarm(0)
|
|
||||||
return tool_return
|
|
||||||
|
|
||||||
def _call(self, command: str) -> ActionReturn:
|
|
||||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
|
||||||
try:
|
|
||||||
if '```python' in command:
|
if '```python' in command:
|
||||||
command = command.split('```python')[1].split('```')[0]
|
command = command.split('```python')[1].split('```')[0]
|
||||||
elif '```' in command:
|
elif '```' in command:
|
||||||
command = command.split('```')[1].split('```')[0]
|
command = command.split('```')[1].split('```')[0]
|
||||||
tool_return.args = dict(text='```python\n' + command + '\n```')
|
|
||||||
command = command.split('\n')
|
command = command.split('\n')
|
||||||
|
return command
|
||||||
|
|
||||||
|
def __call__(self, command: str) -> ActionReturn:
|
||||||
|
"""Execution function for running generation code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command(str): Python code to be executed.
|
||||||
|
"""
|
||||||
|
extracted_command = self.extract_code(command)
|
||||||
|
tool_return = ActionReturn(url=None,
|
||||||
|
args=dict(text=command,
|
||||||
|
extract_code=extracted_command),
|
||||||
|
type=self.name)
|
||||||
|
|
||||||
|
def _execution(q, command, tool_return):
|
||||||
|
try:
|
||||||
|
with swallow_io():
|
||||||
|
# leave 1s for multiprocess
|
||||||
|
with time_limit(self.timeout - 1):
|
||||||
|
res = self._call(command)
|
||||||
|
tool_return.result = dict(text=str(res))
|
||||||
|
tool_return.state = ActionStatusCode.SUCCESS
|
||||||
|
except TimeOutException:
|
||||||
|
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
|
||||||
|
tool_return.state = ActionStatusCode.API_ERROR
|
||||||
|
except BaseException as e:
|
||||||
|
tool_return.errmsg = f'Failed. {e}.'
|
||||||
|
tool_return.state = ActionStatusCode.API_ERROR
|
||||||
|
q.put(tool_return)
|
||||||
|
|
||||||
|
# `signal` cannot be used in child thread, therefore, we
|
||||||
|
# need to create a process.
|
||||||
|
q = multiprocessing.Queue()
|
||||||
|
p = multiprocessing.Process(target=_execution,
|
||||||
|
args=(q, extracted_command, tool_return))
|
||||||
|
p.start()
|
||||||
|
p.join(timeout=self.timeout)
|
||||||
|
if p.is_alive():
|
||||||
|
p.kill()
|
||||||
|
# return timeout due to some unknown error
|
||||||
|
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
|
||||||
|
tool_return.state = ActionStatusCode.API_ERROR
|
||||||
|
return tool_return
|
||||||
|
return q.get()
|
||||||
|
|
||||||
|
def _call(self, command: str) -> ActionReturn:
|
||||||
|
self.runtime = GenericRuntime()
|
||||||
if self.answer_from_stdout:
|
if self.answer_from_stdout:
|
||||||
program_io = io.StringIO()
|
program_io = io.StringIO()
|
||||||
with redirect_stdout(program_io):
|
with redirect_stdout(program_io):
|
||||||
@ -129,16 +153,4 @@ class PythonInterpreter(BaseAction):
|
|||||||
else:
|
else:
|
||||||
self.runtime.exec_code('\n'.join(command[:-1]))
|
self.runtime.exec_code('\n'.join(command[:-1]))
|
||||||
res = True
|
res = True
|
||||||
except Exception as e:
|
return res
|
||||||
tool_return.errmsg = repr(e)
|
|
||||||
tool_return.type = self.name
|
|
||||||
tool_return.state = ActionStatusCode.API_ERROR
|
|
||||||
return tool_return
|
|
||||||
try:
|
|
||||||
tool_return.result = dict(text=str(res))
|
|
||||||
tool_return.state = ActionStatusCode.SUCCESS
|
|
||||||
except Exception as e:
|
|
||||||
tool_return.errmsg = repr(e)
|
|
||||||
tool_return.type = self.name
|
|
||||||
tool_return.state = ActionStatusCode.API_ERROR
|
|
||||||
return tool_return
|
|
||||||
|
@ -4,7 +4,7 @@ import random
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mmengine
|
import mmengine
|
||||||
from mmengine.config import ConfigDict
|
from mmengine.config import ConfigDict
|
||||||
@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
|
|||||||
qos (str): Slurm quality of service. Defaults to None.
|
qos (str): Slurm quality of service. Defaults to None.
|
||||||
debug (bool): Whether to run in debug mode. Defaults to False.
|
debug (bool): Whether to run in debug mode. Defaults to False.
|
||||||
lark_bot_url (str): Lark bot url. Defaults to None.
|
lark_bot_url (str): Lark bot url. Defaults to None.
|
||||||
|
extra_command (List, optional): Extra slurm command.
|
||||||
|
For example ['-c 12', '-w node1']. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
|
|||||||
quotatype: str = None,
|
quotatype: str = None,
|
||||||
qos: str = None,
|
qos: str = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
lark_bot_url: str = None):
|
lark_bot_url: str = None,
|
||||||
|
extra_command: Optional[List[str]] = None):
|
||||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||||
self.max_num_workers = max_num_workers
|
self.max_num_workers = max_num_workers
|
||||||
self.retry = retry
|
self.retry = retry
|
||||||
self.partition = partition
|
self.partition = partition
|
||||||
self.quotatype = quotatype
|
self.quotatype = quotatype
|
||||||
self.qos = qos
|
self.qos = qos
|
||||||
|
if not extra_command:
|
||||||
|
extra_command = []
|
||||||
|
assert isinstance(extra_command, list)
|
||||||
|
self.extra_command = extra_command
|
||||||
|
|
||||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||||
"""Launch multiple tasks.
|
"""Launch multiple tasks.
|
||||||
@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
|
|||||||
tmpl += f' --qos={self.qos}'
|
tmpl += f' --qos={self.qos}'
|
||||||
if num_gpus > 0:
|
if num_gpus > 0:
|
||||||
tmpl += f' --gres=gpu:{num_gpus}'
|
tmpl += f' --gres=gpu:{num_gpus}'
|
||||||
|
for extra_cmd in self.extra_command:
|
||||||
|
tmpl += f' {extra_cmd}'
|
||||||
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
||||||
get_cmd = partial(task.get_command,
|
get_cmd = partial(task.get_command,
|
||||||
cfg_path=param_file,
|
cfg_path=param_file,
|
||||||
|
@ -6,7 +6,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Pipe, Pool
|
from multiprocessing import Pipe, Pool
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mmengine
|
import mmengine
|
||||||
from mmengine.config import ConfigDict
|
from mmengine.config import ConfigDict
|
||||||
@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
|
|||||||
qos (str): Slurm quality of service. Defaults to None.
|
qos (str): Slurm quality of service. Defaults to None.
|
||||||
debug (bool): Whether to run in debug mode. Defaults to False.
|
debug (bool): Whether to run in debug mode. Defaults to False.
|
||||||
lark_bot_url (str): Lark bot url. Defaults to None.
|
lark_bot_url (str): Lark bot url. Defaults to None.
|
||||||
|
extra_command (List, optional): Extra slurm command.
|
||||||
|
For example ['-c 12', '-w node1']. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
|
|||||||
quotatype: str = None,
|
quotatype: str = None,
|
||||||
qos: str = None,
|
qos: str = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
lark_bot_url: str = None):
|
lark_bot_url: str = None,
|
||||||
|
extra_command: Optional[List[str]] = None):
|
||||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||||
self.max_num_workers = max_num_workers
|
self.max_num_workers = max_num_workers
|
||||||
self.retry = retry
|
self.retry = retry
|
||||||
@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
|
|||||||
self.quotatype = quotatype
|
self.quotatype = quotatype
|
||||||
self.qos = qos
|
self.qos = qos
|
||||||
self.task_prefix = task_prefix
|
self.task_prefix = task_prefix
|
||||||
|
if not extra_command:
|
||||||
|
extra_command = []
|
||||||
|
assert isinstance(extra_command, list)
|
||||||
|
self.extra_command = extra_command
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
if self.quotatype in ['spot', 'auto']:
|
if self.quotatype in ['spot', 'auto']:
|
||||||
@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
|
|||||||
tmpl += f' --qos={self.qos}'
|
tmpl += f' --qos={self.qos}'
|
||||||
if num_gpus > 0:
|
if num_gpus > 0:
|
||||||
tmpl += f' --gres=gpu:{num_gpus}'
|
tmpl += f' --gres=gpu:{num_gpus}'
|
||||||
|
for extra_cmd in self.extra_command:
|
||||||
|
tmpl += f' {extra_cmd}'
|
||||||
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
||||||
get_cmd = partial(task.get_command,
|
get_cmd = partial(task.get_command,
|
||||||
cfg_path=param_file,
|
cfg_path=param_file,
|
||||||
|
Loading…
Reference in New Issue
Block a user