[Feat] update python action and slurm (#694)

This commit is contained in:
Hubert 2023-12-13 10:41:10 +08:00 committed by GitHub
parent 6130394165
commit a94598d921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 59 deletions

View File

@ -1,19 +1,13 @@
import copy import copy
import io import io
import signal import multiprocessing
from contextlib import redirect_stdout from contextlib import redirect_stdout
from typing import Any, Optional from typing import Any, Optional
from lagent.actions.base_action import BaseAction from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode from lagent.schema import ActionReturn, ActionStatusCode
from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit
class TimeoutError(Exception):
pass
def handler(signum, frame):
raise TimeoutError()
class GenericRuntime: class GenericRuntime:
@ -90,30 +84,60 @@ class PythonInterpreter(BaseAction):
self.answer_from_stdout = answer_from_stdout self.answer_from_stdout = answer_from_stdout
self.timeout = timeout self.timeout = timeout
def __call__(self, command: str) -> ActionReturn: @staticmethod
self.runtime = GenericRuntime() def extract_code(command: str) -> str:
signal.signal(signal.SIGALRM, handler)
signal.alarm(self.timeout)
try:
tool_return = self._call(command)
except TimeoutError as e:
tool_return = ActionReturn(url=None, args=None, type=self.name)
tool_return.errmsg = repr(e)
tool_return.state = ActionStatusCode.API_ERROR
finally:
signal.alarm(0)
return tool_return
def _call(self, command: str) -> ActionReturn:
tool_return = ActionReturn(url=None, args=None, type=self.name)
try:
if '```python' in command: if '```python' in command:
command = command.split('```python')[1].split('```')[0] command = command.split('```python')[1].split('```')[0]
elif '```' in command: elif '```' in command:
command = command.split('```')[1].split('```')[0] command = command.split('```')[1].split('```')[0]
tool_return.args = dict(text='```python\n' + command + '\n```')
command = command.split('\n') command = command.split('\n')
return command
def __call__(self, command: str) -> ActionReturn:
"""Execution function for running generation code.
Args:
command(str): Python code to be executed.
"""
extracted_command = self.extract_code(command)
tool_return = ActionReturn(url=None,
args=dict(text=command,
extract_code=extracted_command),
type=self.name)
def _execution(q, command, tool_return):
try:
with swallow_io():
# leave 1s for multiprocess
with time_limit(self.timeout - 1):
res = self._call(command)
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except TimeOutException:
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
except BaseException as e:
tool_return.errmsg = f'Failed. {e}.'
tool_return.state = ActionStatusCode.API_ERROR
q.put(tool_return)
# `signal` cannot be used in child thread, therefore, we
# need to create a process.
q = multiprocessing.Queue()
p = multiprocessing.Process(target=_execution,
args=(q, extracted_command, tool_return))
p.start()
p.join(timeout=self.timeout)
if p.is_alive():
p.kill()
# return timeout due to some unknown error
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
return q.get()
def _call(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime()
if self.answer_from_stdout: if self.answer_from_stdout:
program_io = io.StringIO() program_io = io.StringIO()
with redirect_stdout(program_io): with redirect_stdout(program_io):
@ -129,16 +153,4 @@ class PythonInterpreter(BaseAction):
else: else:
self.runtime.exec_code('\n'.join(command[:-1])) self.runtime.exec_code('\n'.join(command[:-1]))
res = True res = True
except Exception as e: return res
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
try:
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return

View File

@ -4,7 +4,7 @@ import random
import subprocess import subprocess
import time import time
from functools import partial from functools import partial
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
import mmengine import mmengine
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None. qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False. debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None. lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
""" """
def __init__(self, def __init__(self,
@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
quotatype: str = None, quotatype: str = None,
qos: str = None, qos: str = None,
debug: bool = False, debug: bool = False,
lark_bot_url: str = None): lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers self.max_num_workers = max_num_workers
self.retry = retry self.retry = retry
self.partition = partition self.partition = partition
self.quotatype = quotatype self.quotatype = quotatype
self.qos = qos self.qos = qos
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks. """Launch multiple tasks.
@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
tmpl += f' --qos={self.qos}' tmpl += f' --qos={self.qos}'
if num_gpus > 0: if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}' tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command, get_cmd = partial(task.get_command,
cfg_path=param_file, cfg_path=param_file,

View File

@ -6,7 +6,7 @@ import time
import traceback import traceback
from functools import partial from functools import partial
from multiprocessing import Pipe, Pool from multiprocessing import Pipe, Pool
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
import mmengine import mmengine
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None. qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False. debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None. lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
""" """
def __init__(self, def __init__(self,
@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
quotatype: str = None, quotatype: str = None,
qos: str = None, qos: str = None,
debug: bool = False, debug: bool = False,
lark_bot_url: str = None): lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers self.max_num_workers = max_num_workers
self.retry = retry self.retry = retry
@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
self.quotatype = quotatype self.quotatype = quotatype
self.qos = qos self.qos = qos
self.task_prefix = task_prefix self.task_prefix = task_prefix
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
logger = get_logger() logger = get_logger()
if self.quotatype in ['spot', 'auto']: if self.quotatype in ['spot', 'auto']:
@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
tmpl += f' --qos={self.qos}' tmpl += f' --qos={self.qos}'
if num_gpus > 0: if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}' tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command, get_cmd = partial(task.get_command,
cfg_path=param_file, cfg_path=param_file,