diff --git a/opencompass/runners/local.py b/opencompass/runners/local.py index 81eede41..7959f6df 100644 --- a/opencompass/runners/local.py +++ b/opencompass/runners/local.py @@ -146,9 +146,17 @@ class LocalRunner(BaseRunner): try: task.cfg.dump(param_file) - # Build up slurm command - tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids) - tmpl += ' {task_cmd}' + # Build up local command + import sys + if sys.platform == 'win32': # Always return win32 for Windows + # use command in Windows format + tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join( + str(i) for i in gpu_ids) + tmpl += ' & {task_cmd}' + else: + tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join( + str(i) for i in gpu_ids) + tmpl += ' {task_cmd}' get_cmd = partial(task.get_command, cfg_path=param_file, template=tmpl) diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 3fa3f67e..5538a517 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -4,6 +4,7 @@ import os.path as osp import time from collections import Counter from inspect import signature +from shutil import which from typing import Optional import mmengine @@ -37,7 +38,8 @@ class OpenICLEvalTask(BaseTask): def get_command(self, cfg_path, template): script_path = __file__ - command = f'python3 {script_path} {cfg_path}' + python = 'python3' if which('python3') else 'python' + command = f'{python} {script_path} {cfg_path}' return template.format(task_cmd=command) def run(self): diff --git a/opencompass/tasks/openicl_infer.py b/opencompass/tasks/openicl_infer.py index 195a0bd9..c2398433 100644 --- a/opencompass/tasks/openicl_infer.py +++ b/opencompass/tasks/openicl_infer.py @@ -2,6 +2,7 @@ import argparse import os.path as osp import random import time +from shutil import which from typing import Any from mmengine.config import Config, ConfigDict @@ -48,7 +49,8 @@ class OpenICLInferTask(BaseTask): f'--nproc_per_node {self.num_procs} ' f'{script_path} {cfg_path}') else: - command = f'python {script_path} {cfg_path}' + python = 'python3' if which('python3') else 'python' + command = f'{python} {script_path} {cfg_path}' return template.format(task_cmd=command) diff --git a/opencompass/utils/menu.py b/opencompass/utils/menu.py index 1d5007d3..3ff1fe25 100644 --- a/opencompass/utils/menu.py +++ b/opencompass/utils/menu.py @@ -1,4 +1,12 @@ -import curses +import sys + +if sys.platform == 'win32': # Always return win32 for Windows + # curses is not supported on Windows + # If you want to use this function in Windows platform + # you can try `windows_curses` module by yourself + curses = None +else: + import curses class Menu: