mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] update python action and slurm (#694)
This commit is contained in:
parent
6130394165
commit
a94598d921
@ -1,19 +1,13 @@
|
||||
import copy
|
||||
import io
|
||||
import signal
|
||||
import multiprocessing
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Optional
|
||||
|
||||
from lagent.actions.base_action import BaseAction
|
||||
from lagent.schema import ActionReturn, ActionStatusCode
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit
|
||||
|
||||
|
||||
class GenericRuntime:
|
||||
@ -90,55 +84,73 @@ class PythonInterpreter(BaseAction):
|
||||
self.answer_from_stdout = answer_from_stdout
|
||||
self.timeout = timeout
|
||||
|
||||
@staticmethod
|
||||
def extract_code(command: str) -> str:
|
||||
if '```python' in command:
|
||||
command = command.split('```python')[1].split('```')[0]
|
||||
elif '```' in command:
|
||||
command = command.split('```')[1].split('```')[0]
|
||||
command = command.split('\n')
|
||||
return command
|
||||
|
||||
def __call__(self, command: str) -> ActionReturn:
|
||||
self.runtime = GenericRuntime()
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(self.timeout)
|
||||
try:
|
||||
tool_return = self._call(command)
|
||||
except TimeoutError as e:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
return tool_return
|
||||
"""Execution function for running generation code.
|
||||
|
||||
def _call(self, command: str) -> ActionReturn:
|
||||
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
||||
try:
|
||||
if '```python' in command:
|
||||
command = command.split('```python')[1].split('```')[0]
|
||||
elif '```' in command:
|
||||
command = command.split('```')[1].split('```')[0]
|
||||
tool_return.args = dict(text='```python\n' + command + '\n```')
|
||||
command = command.split('\n')
|
||||
Args:
|
||||
command(str): Python code to be executed.
|
||||
"""
|
||||
extracted_command = self.extract_code(command)
|
||||
tool_return = ActionReturn(url=None,
|
||||
args=dict(text=command,
|
||||
extract_code=extracted_command),
|
||||
type=self.name)
|
||||
|
||||
if self.answer_from_stdout:
|
||||
program_io = io.StringIO()
|
||||
with redirect_stdout(program_io):
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
program_io.seek(0)
|
||||
res = program_io.readlines()[-1]
|
||||
elif self.answer_symbol:
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
res = self.runtime._global_vars[self.answer_symbol]
|
||||
elif self.answer_expr:
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
res = self.runtime.eval_code(self.answer_expr)
|
||||
else:
|
||||
self.runtime.exec_code('\n'.join(command[:-1]))
|
||||
res = True
|
||||
except Exception as e:
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.type = self.name
|
||||
def _execution(q, command, tool_return):
|
||||
try:
|
||||
with swallow_io():
|
||||
# leave 1s for multiprocess
|
||||
with time_limit(self.timeout - 1):
|
||||
res = self._call(command)
|
||||
tool_return.result = dict(text=str(res))
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
except TimeOutException:
|
||||
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
except BaseException as e:
|
||||
tool_return.errmsg = f'Failed. {e}.'
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
q.put(tool_return)
|
||||
|
||||
# `signal` cannot be used in child thread, therefore, we
|
||||
# need to create a process.
|
||||
q = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=_execution,
|
||||
args=(q, extracted_command, tool_return))
|
||||
p.start()
|
||||
p.join(timeout=self.timeout)
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
# return timeout due to some unknown error
|
||||
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
try:
|
||||
tool_return.result = dict(text=str(res))
|
||||
tool_return.state = ActionStatusCode.SUCCESS
|
||||
except Exception as e:
|
||||
tool_return.errmsg = repr(e)
|
||||
tool_return.type = self.name
|
||||
tool_return.state = ActionStatusCode.API_ERROR
|
||||
return tool_return
|
||||
return q.get()
|
||||
|
||||
def _call(self, command: str) -> ActionReturn:
|
||||
self.runtime = GenericRuntime()
|
||||
if self.answer_from_stdout:
|
||||
program_io = io.StringIO()
|
||||
with redirect_stdout(program_io):
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
program_io.seek(0)
|
||||
res = program_io.readlines()[-1]
|
||||
elif self.answer_symbol:
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
res = self.runtime._global_vars[self.answer_symbol]
|
||||
elif self.answer_expr:
|
||||
self.runtime.exec_code('\n'.join(command))
|
||||
res = self.runtime.eval_code(self.answer_expr)
|
||||
else:
|
||||
self.runtime.exec_code('\n'.join(command[:-1]))
|
||||
res = True
|
||||
return res
|
||||
|
@ -4,7 +4,7 @@ import random
|
||||
import subprocess
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import ConfigDict
|
||||
@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
|
||||
qos (str): Slurm quality of service. Defaults to None.
|
||||
debug (bool): Whether to run in debug mode. Defaults to False.
|
||||
lark_bot_url (str): Lark bot url. Defaults to None.
|
||||
extra_command (List, optional): Extra slurm command.
|
||||
For example ['-c 12', '-w node1']. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
|
||||
quotatype: str = None,
|
||||
qos: str = None,
|
||||
debug: bool = False,
|
||||
lark_bot_url: str = None):
|
||||
lark_bot_url: str = None,
|
||||
extra_command: Optional[List[str]] = None):
|
||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||
self.max_num_workers = max_num_workers
|
||||
self.retry = retry
|
||||
self.partition = partition
|
||||
self.quotatype = quotatype
|
||||
self.qos = qos
|
||||
if not extra_command:
|
||||
extra_command = []
|
||||
assert isinstance(extra_command, list)
|
||||
self.extra_command = extra_command
|
||||
|
||||
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
|
||||
"""Launch multiple tasks.
|
||||
@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
|
||||
tmpl += f' --qos={self.qos}'
|
||||
if num_gpus > 0:
|
||||
tmpl += f' --gres=gpu:{num_gpus}'
|
||||
for extra_cmd in self.extra_command:
|
||||
tmpl += f' {extra_cmd}'
|
||||
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
||||
get_cmd = partial(task.get_command,
|
||||
cfg_path=param_file,
|
||||
|
@ -6,7 +6,7 @@ import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from multiprocessing import Pipe, Pool
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import ConfigDict
|
||||
@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
|
||||
qos (str): Slurm quality of service. Defaults to None.
|
||||
debug (bool): Whether to run in debug mode. Defaults to False.
|
||||
lark_bot_url (str): Lark bot url. Defaults to None.
|
||||
extra_command (List, optional): Extra slurm command.
|
||||
For example ['-c 12', '-w node1']. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
|
||||
quotatype: str = None,
|
||||
qos: str = None,
|
||||
debug: bool = False,
|
||||
lark_bot_url: str = None):
|
||||
lark_bot_url: str = None,
|
||||
extra_command: Optional[List[str]] = None):
|
||||
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
|
||||
self.max_num_workers = max_num_workers
|
||||
self.retry = retry
|
||||
@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
|
||||
self.quotatype = quotatype
|
||||
self.qos = qos
|
||||
self.task_prefix = task_prefix
|
||||
if not extra_command:
|
||||
extra_command = []
|
||||
assert isinstance(extra_command, list)
|
||||
self.extra_command = extra_command
|
||||
|
||||
logger = get_logger()
|
||||
if self.quotatype in ['spot', 'auto']:
|
||||
@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
|
||||
tmpl += f' --qos={self.qos}'
|
||||
if num_gpus > 0:
|
||||
tmpl += f' --gres=gpu:{num_gpus}'
|
||||
for extra_cmd in self.extra_command:
|
||||
tmpl += f' {extra_cmd}'
|
||||
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
|
||||
get_cmd = partial(task.get_command,
|
||||
cfg_path=param_file,
|
||||
|
Loading…
Reference in New Issue
Block a user