mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] Support local runner for windows (#515)
This commit is contained in:
parent
df07391ed8
commit
6f07af3039
@ -146,9 +146,17 @@ class LocalRunner(BaseRunner):
|
|||||||
try:
|
try:
|
||||||
task.cfg.dump(param_file)
|
task.cfg.dump(param_file)
|
||||||
|
|
||||||
# Build up slurm command
|
# Build up local command
|
||||||
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(str(i) for i in gpu_ids)
|
import sys
|
||||||
tmpl += ' {task_cmd}'
|
if sys.platform == 'win32': # Always return win32 for Windows
|
||||||
|
# use command in Windows format
|
||||||
|
tmpl = 'set CUDA_VISIBLE_DEVICES=' + ','.join(
|
||||||
|
str(i) for i in gpu_ids)
|
||||||
|
tmpl += ' & {task_cmd}'
|
||||||
|
else:
|
||||||
|
tmpl = 'CUDA_VISIBLE_DEVICES=' + ','.join(
|
||||||
|
str(i) for i in gpu_ids)
|
||||||
|
tmpl += ' {task_cmd}'
|
||||||
get_cmd = partial(task.get_command,
|
get_cmd = partial(task.get_command,
|
||||||
cfg_path=param_file,
|
cfg_path=param_file,
|
||||||
template=tmpl)
|
template=tmpl)
|
||||||
|
@ -4,6 +4,7 @@ import os.path as osp
|
|||||||
import time
|
import time
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
from shutil import which
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import mmengine
|
import mmengine
|
||||||
@ -37,7 +38,8 @@ class OpenICLEvalTask(BaseTask):
|
|||||||
|
|
||||||
def get_command(self, cfg_path, template):
|
def get_command(self, cfg_path, template):
|
||||||
script_path = __file__
|
script_path = __file__
|
||||||
command = f'python3 {script_path} {cfg_path}'
|
python = 'python3' if which('python3') else 'python'
|
||||||
|
command = f'{python} {script_path} {cfg_path}'
|
||||||
return template.format(task_cmd=command)
|
return template.format(task_cmd=command)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
@ -2,6 +2,7 @@ import argparse
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from shutil import which
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from mmengine.config import Config, ConfigDict
|
from mmengine.config import Config, ConfigDict
|
||||||
@ -48,7 +49,8 @@ class OpenICLInferTask(BaseTask):
|
|||||||
f'--nproc_per_node {self.num_procs} '
|
f'--nproc_per_node {self.num_procs} '
|
||||||
f'{script_path} {cfg_path}')
|
f'{script_path} {cfg_path}')
|
||||||
else:
|
else:
|
||||||
command = f'python {script_path} {cfg_path}'
|
python = 'python3' if which('python3') else 'python'
|
||||||
|
command = f'{python} {script_path} {cfg_path}'
|
||||||
|
|
||||||
return template.format(task_cmd=command)
|
return template.format(task_cmd=command)
|
||||||
|
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
import curses
|
import sys
|
||||||
|
|
||||||
|
if sys.platform == 'win32': # Always return win32 for Windows
|
||||||
|
# curses is not supported on Windows
|
||||||
|
# If you want to use this function in Windows platform
|
||||||
|
# you can try `windows_curses` module by yourself
|
||||||
|
curses = None
|
||||||
|
else:
|
||||||
|
import curses
|
||||||
|
|
||||||
|
|
||||||
class Menu:
|
class Menu:
|
||||||
|
Loading…
Reference in New Issue
Block a user