[Feat] update python action and slurm (#694)

This commit is contained in:
Hubert 2023-12-13 10:41:10 +08:00 committed by GitHub
parent 6130394165
commit a94598d921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 59 deletions

View File

@ -1,19 +1,13 @@
import copy
import io
import signal
import multiprocessing
from contextlib import redirect_stdout
from typing import Any, Optional
from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode
class TimeoutError(Exception):
pass
def handler(signum, frame):
raise TimeoutError()
from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit
class GenericRuntime:
@ -90,55 +84,73 @@ class PythonInterpreter(BaseAction):
self.answer_from_stdout = answer_from_stdout
self.timeout = timeout
@staticmethod
def extract_code(command: str) -> str:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
command = command.split('\n')
return command
def __call__(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime()
signal.signal(signal.SIGALRM, handler)
signal.alarm(self.timeout)
try:
tool_return = self._call(command)
except TimeoutError as e:
tool_return = ActionReturn(url=None, args=None, type=self.name)
tool_return.errmsg = repr(e)
tool_return.state = ActionStatusCode.API_ERROR
finally:
signal.alarm(0)
return tool_return
"""Execution function for running generation code.
def _call(self, command: str) -> ActionReturn:
tool_return = ActionReturn(url=None, args=None, type=self.name)
try:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
tool_return.args = dict(text='```python\n' + command + '\n```')
command = command.split('\n')
Args:
command(str): Python code to be executed.
"""
extracted_command = self.extract_code(command)
tool_return = ActionReturn(url=None,
args=dict(text=command,
extract_code=extracted_command),
type=self.name)
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr)
else:
self.runtime.exec_code('\n'.join(command[:-1]))
res = True
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
def _execution(q, command, tool_return):
try:
with swallow_io():
# leave 1s for multiprocess
with time_limit(self.timeout - 1):
res = self._call(command)
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except TimeOutException:
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
except BaseException as e:
tool_return.errmsg = f'Failed. {e}.'
tool_return.state = ActionStatusCode.API_ERROR
q.put(tool_return)
# `signal` cannot be used in child thread, therefore, we
# need to create a process.
q = multiprocessing.Queue()
p = multiprocessing.Process(target=_execution,
args=(q, extracted_command, tool_return))
p.start()
p.join(timeout=self.timeout)
if p.is_alive():
p.kill()
# return timeout due to some unknown error
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
try:
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
return q.get()
def _call(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime()
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr)
else:
self.runtime.exec_code('\n'.join(command[:-1]))
res = True
return res

View File

@ -4,7 +4,7 @@ import random
import subprocess
import time
from functools import partial
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import mmengine
from mmengine.config import ConfigDict
@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
def __init__(self,
@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
quotatype: str = None,
qos: str = None,
debug: bool = False,
lark_bot_url: str = None):
lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.retry = retry
self.partition = partition
self.quotatype = quotatype
self.qos = qos
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks.
@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
tmpl += f' --qos={self.qos}'
if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command,
cfg_path=param_file,

View File

@ -6,7 +6,7 @@ import time
import traceback
from functools import partial
from multiprocessing import Pipe, Pool
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import mmengine
from mmengine.config import ConfigDict
@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
def __init__(self,
@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
quotatype: str = None,
qos: str = None,
debug: bool = False,
lark_bot_url: str = None):
lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.retry = retry
@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
self.quotatype = quotatype
self.qos = qos
self.task_prefix = task_prefix
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
logger = get_logger()
if self.quotatype in ['spot', 'auto']:
@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
tmpl += f' --qos={self.qos}'
if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command,
cfg_path=param_file,