[Feat] Support local runner for windows (#515)

This commit is contained in:
Hubert 2023-10-27 17:16:22 +08:00 committed by GitHub
parent df07391ed8
commit 6f07af3039
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 6 deletions

View File

@ -146,9 +146,17 @@ class LocalRunner(BaseRunner):
try: try:
task.cfg.dump(param_file) task.cfg.dump(param_file)
# Build up slurm command # Build up local command
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids) import sys
tmpl += ' {task_cmd}' if sys.platform == 'win32': # Always return win32 for Windows
# use command in Windows format
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(
str(i) for i in gpu_ids)
tmpl += ' & {task_cmd}'
else:
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(
str(i) for i in gpu_ids)
tmpl += ' {task_cmd}'
get_cmd = partial(task.get_command, get_cmd = partial(task.get_command,
cfg_path=param_file, cfg_path=param_file,
template=tmpl) template=tmpl)

View File

@ -4,6 +4,7 @@ import os.path as osp
import time import time
from collections import Counter from collections import Counter
from inspect import signature from inspect import signature
from shutil import which
from typing import Optional from typing import Optional
import mmengine import mmengine
@ -37,7 +38,8 @@ class OpenICLEvalTask(BaseTask):
def get_command(self, cfg_path, template): def get_command(self, cfg_path, template):
script_path = __file__ script_path = __file__
command = f'python3 {script_path} {cfg_path}' python = 'python3' if which('python3') else 'python'
command = f'{python} {script_path} {cfg_path}'
return template.format(task_cmd=command) return template.format(task_cmd=command)
def run(self): def run(self):

View File

@ -2,6 +2,7 @@ import argparse
import os.path as osp import os.path as osp
import random import random
import time import time
from shutil import which
from typing import Any from typing import Any
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
@ -48,7 +49,8 @@ class OpenICLInferTask(BaseTask):
f'--nproc_per_node {self.num_procs} ' f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}') f'{script_path} {cfg_path}')
else: else:
command = f'python {script_path} {cfg_path}' python = 'python3' if which('python3') else 'python'
command = f'{python} {script_path} {cfg_path}'
return template.format(task_cmd=command) return template.format(task_cmd=command)

View File

@ -1,4 +1,12 @@
import curses import sys
if sys.platform == 'win32': # Always return win32 for Windows
# curses is not supported on Windows
# If you want to use this function in Windows platform
# you can try `windows_curses` module by yourself
curses = None
else:
import curses
class Menu: class Menu: