diff --git a/opencompass/lagent/actions/python_interpreter.py b/opencompass/lagent/actions/python_interpreter.py index cabed33b..a8c67bc8 100644 --- a/opencompass/lagent/actions/python_interpreter.py +++ b/opencompass/lagent/actions/python_interpreter.py @@ -1,19 +1,13 @@ import copy import io -import signal +import multiprocessing from contextlib import redirect_stdout from typing import Any, Optional from lagent.actions.base_action import BaseAction from lagent.schema import ActionReturn, ActionStatusCode - -class TimeoutError(Exception): - pass - - -def handler(signum, frame): - raise TimeoutError() +from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit class GenericRuntime: @@ -90,55 +84,73 @@ class PythonInterpreter(BaseAction): self.answer_from_stdout = answer_from_stdout self.timeout = timeout + @staticmethod + def extract_code(command: str) -> str: + if '```python' in command: + command = command.split('```python')[1].split('```')[0] + elif '```' in command: + command = command.split('```')[1].split('```')[0] + command = command.split('\n') + return command + def __call__(self, command: str) -> ActionReturn: - self.runtime = GenericRuntime() - signal.signal(signal.SIGALRM, handler) - signal.alarm(self.timeout) - try: - tool_return = self._call(command) - except TimeoutError as e: - tool_return = ActionReturn(url=None, args=None, type=self.name) - tool_return.errmsg = repr(e) - tool_return.state = ActionStatusCode.API_ERROR - finally: - signal.alarm(0) - return tool_return + """Execution function for running generation code. - def _call(self, command: str) -> ActionReturn: - tool_return = ActionReturn(url=None, args=None, type=self.name) - try: - if '```python' in command: - command = command.split('```python')[1].split('```')[0] - elif '```' in command: - command = command.split('```')[1].split('```')[0] - tool_return.args = dict(text='```python\n' + command + '\n```') - command = command.split('\n') + Args: + command(str): Python code to be executed. + """ + extracted_command = self.extract_code(command) + tool_return = ActionReturn(url=None, + args=dict(text=command, + extract_code=extracted_command), + type=self.name) - if self.answer_from_stdout: - program_io = io.StringIO() - with redirect_stdout(program_io): - self.runtime.exec_code('\n'.join(command)) - program_io.seek(0) - res = program_io.readlines()[-1] - elif self.answer_symbol: - self.runtime.exec_code('\n'.join(command)) - res = self.runtime._global_vars[self.answer_symbol] - elif self.answer_expr: - self.runtime.exec_code('\n'.join(command)) - res = self.runtime.eval_code(self.answer_expr) - else: - self.runtime.exec_code('\n'.join(command[:-1])) - res = True - except Exception as e: - tool_return.errmsg = repr(e) - tool_return.type = self.name + def _execution(q, command, tool_return): + try: + with swallow_io(): + # leave 1s for multiprocess + with time_limit(self.timeout - 1): + res = self._call(command) + tool_return.result = dict(text=str(res)) + tool_return.state = ActionStatusCode.SUCCESS + except TimeOutException: + tool_return.errmsg = f'Time out after {self.timeout} seconds.' + tool_return.state = ActionStatusCode.API_ERROR + except BaseException as e: + tool_return.errmsg = f'Failed. {e}.' + tool_return.state = ActionStatusCode.API_ERROR + q.put(tool_return) + + # `signal` cannot be used in child thread, therefore, we + # need to create a process. + q = multiprocessing.Queue() + p = multiprocessing.Process(target=_execution, + args=(q, extracted_command, tool_return)) + p.start() + p.join(timeout=self.timeout) + if p.is_alive(): + p.kill() + # return timeout due to some unknown error + tool_return.errmsg = f'Time out after {self.timeout} seconds.' tool_return.state = ActionStatusCode.API_ERROR return tool_return - try: - tool_return.result = dict(text=str(res)) - tool_return.state = ActionStatusCode.SUCCESS - except Exception as e: - tool_return.errmsg = repr(e) - tool_return.type = self.name - tool_return.state = ActionStatusCode.API_ERROR - return tool_return + return q.get() + + def _call(self, command: str) -> ActionReturn: + self.runtime = GenericRuntime() + if self.answer_from_stdout: + program_io = io.StringIO() + with redirect_stdout(program_io): + self.runtime.exec_code('\n'.join(command)) + program_io.seek(0) + res = program_io.readlines()[-1] + elif self.answer_symbol: + self.runtime.exec_code('\n'.join(command)) + res = self.runtime._global_vars[self.answer_symbol] + elif self.answer_expr: + self.runtime.exec_code('\n'.join(command)) + res = self.runtime.eval_code(self.answer_expr) + else: + self.runtime.exec_code('\n'.join(command[:-1])) + res = True + return res diff --git a/opencompass/runners/slurm.py b/opencompass/runners/slurm.py index 363a21a9..1873e04e 100644 --- a/opencompass/runners/slurm.py +++ b/opencompass/runners/slurm.py @@ -4,7 +4,7 @@ import random import subprocess import time from functools import partial -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import mmengine from mmengine.config import ConfigDict @@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner): qos (str): Slurm quality of service. Defaults to None. debug (bool): Whether to run in debug mode. Defaults to False. lark_bot_url (str): Lark bot url. Defaults to None. + extra_command (List, optional): Extra slurm command. + For example ['-c 12', '-w node1']. Defaults to None. """ def __init__(self, @@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner): quotatype: str = None, qos: str = None, debug: bool = False, - lark_bot_url: str = None): + lark_bot_url: str = None, + extra_command: Optional[List[str]] = None): super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) self.max_num_workers = max_num_workers self.retry = retry self.partition = partition self.quotatype = quotatype self.qos = qos + if not extra_command: + extra_command = [] + assert isinstance(extra_command, list) + self.extra_command = extra_command def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: """Launch multiple tasks. @@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner): tmpl += f' --qos={self.qos}' if num_gpus > 0: tmpl += f' --gres=gpu:{num_gpus}' + for extra_cmd in self.extra_command: + tmpl += f' {extra_cmd}' tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' get_cmd = partial(task.get_command, cfg_path=param_file, diff --git a/opencompass/runners/slurm_sequential.py b/opencompass/runners/slurm_sequential.py index 6bc11df2..ad36a973 100644 --- a/opencompass/runners/slurm_sequential.py +++ b/opencompass/runners/slurm_sequential.py @@ -6,7 +6,7 @@ import time import traceback from functools import partial from multiprocessing import Pipe, Pool -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import mmengine from mmengine.config import ConfigDict @@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner): qos (str): Slurm quality of service. Defaults to None. debug (bool): Whether to run in debug mode. Defaults to False. lark_bot_url (str): Lark bot url. Defaults to None. + extra_command (List, optional): Extra slurm command. + For example ['-c 12', '-w node1']. Defaults to None. """ def __init__(self, @@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner): quotatype: str = None, qos: str = None, debug: bool = False, - lark_bot_url: str = None): + lark_bot_url: str = None, + extra_command: Optional[List[str]] = None): super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) self.max_num_workers = max_num_workers self.retry = retry @@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner): self.quotatype = quotatype self.qos = qos self.task_prefix = task_prefix + if not extra_command: + extra_command = [] + assert isinstance(extra_command, list) + self.extra_command = extra_command logger = get_logger() if self.quotatype in ['spot', 'auto']: @@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner): tmpl += f' --qos={self.qos}' if num_gpus > 0: tmpl += f' --gres=gpu:{num_gpus}' + for extra_cmd in self.extra_command: + tmpl += f' {extra_cmd}' tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' get_cmd = partial(task.get_command, cfg_path=param_file,