219 lines
7.3 KiB
Python
219 lines
7.3 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
#
|
||
# Copyright @2024 INSPUR Inc. (inspur.com)
|
||
#
|
||
# @author: J.G. Chen <chenjianguo@inspur.com>
|
||
# @date: 2024/02/17
|
||
#
|
||
"""
|
||
对配置文件的一些操作
|
||
|
||
Notes
|
||
- pyyaml:
|
||
* 对科学计数法的支持有特殊要求,对于 a[eE][+-]b:a 必须有小数点,指数必须含正负号
|
||
`pyyaml issues#173 <https://github.com/yaml/pyyaml/issues/173#issuecomment-507918276>`_;
|
||
* 或者使用 ruamel.yaml 代替;
|
||
- conf(hocon):
|
||
* 变量替换:不能添加引号,`{ b = "hello"\n a = ${b} world }`;
|
||
* 环境变量:'{ a = ${HOME} }';
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
from platform import python_version
|
||
from pprint import pformat
|
||
from typing import Dict
|
||
from typing import Union
|
||
|
||
import yaml
|
||
from omegaconf import DictConfig
|
||
from omegaconf import ListConfig
|
||
from omegaconf import OmegaConf
|
||
from packaging import version
|
||
from pyhocon import ConfigFactory
|
||
from pyhocon import HOCONConverter
|
||
|
||
from common import HAIRUO_ENV
|
||
from common import HairuoEnv
|
||
|
||
try:
|
||
import torch
|
||
except ImportError:
|
||
torch = None
|
||
|
||
try:
|
||
from common.utils.logger import init_logger
|
||
|
||
logger = init_logger(__name__).info
|
||
except ImportError:
|
||
init_logger = None
|
||
logger = print
|
||
|
||
|
||
class ConfigBase:
|
||
config: Union[DictConfig, ListConfig]
|
||
|
||
def __init__(self, config_file: str = ""):
|
||
"""load config from `config_file`
|
||
|
||
Args:
|
||
config_file(str): config_file name, could be `.yaml`, `.conf`, `.json`, `.bin`
|
||
Notes:
|
||
config_file is relative to current work dir(cwd), pass to abs path
|
||
"""
|
||
self._file_path = self._resolve(config_file)
|
||
self.config = self.load(self._file_path)
|
||
|
||
@classmethod
|
||
def _resolve(cls, config_file: str = ""):
|
||
cwd = Path.cwd()
|
||
valid_extension = (".yaml", ".yml", ".conf", ".json", ".bin")
|
||
|
||
# check config file in current work dir
|
||
if not config_file:
|
||
files = filter(lambda f: f.is_file(), cwd.iterdir())
|
||
logger(f"loading config file from dir: {cwd.resolve()}")
|
||
for filename in sorted(files, key=lambda x: os.path.getmtime(x), reverse=True):
|
||
if filename.suffix in valid_extension:
|
||
config_file = Path(cwd, filename)
|
||
break
|
||
else:
|
||
# parse config file specified.
|
||
config_file = Path(config_file)
|
||
|
||
# check config file
|
||
if not (config_file and config_file.is_file() and config_file.suffix in valid_extension):
|
||
raise FileNotFoundError(f"Config '{config_file}' not find!")
|
||
|
||
logger(f"loading config from {config_file.resolve()}")
|
||
return config_file
|
||
|
||
@classmethod
|
||
def show(cls, config: Union[Dict, DictConfig, ListConfig] = None):
|
||
"""show config given or parsed self.config."""
|
||
kwargs = {} if version.parse(python_version()) < version.parse("3.8") else {"sort_dicts": False}
|
||
if not config:
|
||
config = cls.config
|
||
|
||
if isinstance(config, (DictConfig, ListConfig)):
|
||
config = OmegaConf.to_object(config)
|
||
|
||
logger(f"======= resolved config ======>\n{pformat(config, **kwargs).encode('utf-8')}")
|
||
|
||
@classmethod
|
||
def load(cls, path: Union[str, Path]) -> Union[DictConfig, ListConfig]:
|
||
"""load config file."""
|
||
path = Path(path)
|
||
logger(f"loading config from: {path}")
|
||
|
||
if path.suffix in [".yaml", ".yml"]:
|
||
with path.open(encoding="utf-8") as f:
|
||
config = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader), flags={"allow_objects": True})
|
||
elif path.suffix == ".conf":
|
||
config = ConfigFactory.parse_file(path.as_posix())
|
||
config = OmegaConf.create(HOCONConverter.to_json(config))
|
||
elif path.suffix == ".json":
|
||
with path.open() as f:
|
||
config = OmegaConf.create(json.load(f))
|
||
elif path.suffix == ".bin":
|
||
if torch is not None:
|
||
config = OmegaConf.create(torch.load(path))
|
||
else:
|
||
raise RuntimeError("`torch` required to load .bin config file.")
|
||
else:
|
||
raise RuntimeError("unsupported file format to load.")
|
||
|
||
logger(f"parse configs for ENV: {HAIRUO_ENV}")
|
||
# keep configs for given env only
|
||
# TODO: update fields recursively
|
||
config.update(config.get(HAIRUO_ENV, dict()))
|
||
for env in HairuoEnv:
|
||
config.pop(env, default=None)
|
||
|
||
ConfigBase.show(config)
|
||
return config
|
||
|
||
@classmethod
|
||
def save_config(cls, config: Union[dict, DictConfig, ListConfig], path: Union[str, Path]):
|
||
"""save `config` to given `path`"""
|
||
path = Path(path)
|
||
|
||
def convert_to_ct(_config):
|
||
# convert dict to ConfigTree recursively.
|
||
if isinstance(_config, (dict, DictConfig)):
|
||
tmp = {}
|
||
for k, v in _config.items():
|
||
tmp[k] = convert_to_ct(v)
|
||
return ConfigFactory.from_dict(tmp)
|
||
elif isinstance(_config, ListConfig):
|
||
return [convert_to_ct(i) for i in _config]
|
||
else:
|
||
return _config
|
||
|
||
if path.suffix in [".yaml", ".yml"]:
|
||
OmegaConf.save(config, path)
|
||
elif path.suffix == ".conf":
|
||
with open(path, "w") as writer:
|
||
writer.write(HOCONConverter.to_hocon(convert_to_ct(config)))
|
||
writer.write("\n")
|
||
elif path.suffix == ".json":
|
||
with open(path, "w") as writer:
|
||
json.dump(config, writer, ensure_ascii=False, indent=2)
|
||
elif path.suffix == ".bin":
|
||
if torch is not None:
|
||
torch.save(config, path)
|
||
else:
|
||
raise RuntimeError("`torch` required to load .bin config file.")
|
||
else:
|
||
raise RuntimeError("unsupported file format.")
|
||
logger(f"saving config to: {path}")
|
||
|
||
def save(self, path: Union[str, Path] = "", key: str = ""):
|
||
"""save resolved config obj to `path` or path provided by `key` in config file.
|
||
|
||
Args:
|
||
path: save
|
||
key: the `path` is resolved from `key` in resolved config.
|
||
"""
|
||
|
||
if path:
|
||
pass
|
||
elif key and self.config.get(key, ""):
|
||
path = self.config.get(key)
|
||
else:
|
||
raise RuntimeError("target dir not found")
|
||
|
||
if Path(path).is_dir():
|
||
path = Path(path) / self._file_path.name
|
||
else:
|
||
path = Path(path)
|
||
|
||
self.save_config(self.config, path)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(
|
||
description="tool to convert config file format",
|
||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||
)
|
||
args = parser.add_argument
|
||
args("path", metavar="config_file", help="path to config file.")
|
||
args("-f", "--format", type=str, default="conf", choices=["conf", "yaml", "json"], help="save conf in format.")
|
||
args("-o", "--output", type=str, help="output config to file.")
|
||
args("-s", "--show", action="store_true", help="just show the resolved config.")
|
||
params = parser.parse_args()
|
||
|
||
conf = ConfigBase.load(params.path)
|
||
output = params.output
|
||
if not output:
|
||
output = Path(params.path).with_suffix(f".{params.format.strip('.')}")
|
||
|
||
if params.show:
|
||
ConfigBase.show(conf)
|
||
else:
|
||
ConfigBase.save_config(conf, output)
|