#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright @2024 INSPUR Inc. (inspur.com) # # @author: J.G. Chen # @date: 2024/02/17 # """ 对配置文件的一些操作 Notes - pyyaml: * 对科学计数法的支持有特殊要求,对于 a[eE][+-]b:a 必须有小数点,指数必须含正负号 `pyyaml issues#173 `_; * 或者使用 ruamel.yaml 代替; - conf(hocon): * 变量替换:不能添加引号,`{ b = "hello"\n a = ${b} world }`; * 环境变量:'{ a = ${HOME} }'; """ import json import os from pathlib import Path from platform import python_version from pprint import pformat from typing import Dict from typing import Union import yaml from omegaconf import DictConfig from omegaconf import ListConfig from omegaconf import OmegaConf from packaging import version from pyhocon import ConfigFactory from pyhocon import HOCONConverter from common import HAIRUO_ENV from common import HairuoEnv try: import torch except ImportError: torch = None try: from common.utils.logger import init_logger logger = init_logger(__name__).info except ImportError: init_logger = None logger = print class ConfigBase: config: Union[DictConfig, ListConfig] def __init__(self, config_file: str = ""): """load config from `config_file` Args: config_file(str): config_file name, could be `.yaml`, `.conf`, `.json`, `.bin` Notes: config_file is relative to current work dir(cwd), pass to abs path """ self._file_path = self._resolve(config_file) self.config = self.load(self._file_path) @classmethod def _resolve(cls, config_file: str = ""): cwd = Path.cwd() valid_extension = (".yaml", ".yml", ".conf", ".json", ".bin") # check config file in current work dir if not config_file: files = filter(lambda f: f.is_file(), cwd.iterdir()) logger(f"loading config file from dir: {cwd.resolve()}") for filename in sorted(files, key=lambda x: os.path.getmtime(x), reverse=True): if filename.suffix in valid_extension: config_file = Path(cwd, filename) break else: # parse config file specified. config_file = Path(config_file) # check config file if not (config_file and config_file.is_file() and config_file.suffix in valid_extension): raise FileNotFoundError(f"Config '{config_file}' not find!") logger(f"loading config from {config_file.resolve()}") return config_file @classmethod def show(cls, config: Union[Dict, DictConfig, ListConfig] = None): """show config given or parsed self.config.""" kwargs = {} if version.parse(python_version()) < version.parse("3.8") else {"sort_dicts": False} if not config: config = cls.config if isinstance(config, (DictConfig, ListConfig)): config = OmegaConf.to_object(config) logger(f"======= resolved config ======>\n{pformat(config, **kwargs).encode('utf-8')}") @classmethod def load(cls, path: Union[str, Path]) -> Union[DictConfig, ListConfig]: """load config file.""" path = Path(path) logger(f"loading config from: {path}") if path.suffix in [".yaml", ".yml"]: with path.open(encoding="utf-8") as f: config = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader), flags={"allow_objects": True}) elif path.suffix == ".conf": config = ConfigFactory.parse_file(path.as_posix()) config = OmegaConf.create(HOCONConverter.to_json(config)) elif path.suffix == ".json": with path.open() as f: config = OmegaConf.create(json.load(f)) elif path.suffix == ".bin": if torch is not None: config = OmegaConf.create(torch.load(path)) else: raise RuntimeError("`torch` required to load .bin config file.") else: raise RuntimeError("unsupported file format to load.") logger(f"parse configs for ENV: {HAIRUO_ENV}") # keep configs for given env only # TODO: update fields recursively config.update(config.get(HAIRUO_ENV, dict())) for env in HairuoEnv: config.pop(env, default=None) ConfigBase.show(config) return config @classmethod def save_config(cls, config: Union[dict, DictConfig, ListConfig], path: Union[str, Path]): """save `config` to given `path`""" path = Path(path) def convert_to_ct(_config): # convert dict to ConfigTree recursively. if isinstance(_config, (dict, DictConfig)): tmp = {} for k, v in _config.items(): tmp[k] = convert_to_ct(v) return ConfigFactory.from_dict(tmp) elif isinstance(_config, ListConfig): return [convert_to_ct(i) for i in _config] else: return _config if path.suffix in [".yaml", ".yml"]: OmegaConf.save(config, path) elif path.suffix == ".conf": with open(path, "w") as writer: writer.write(HOCONConverter.to_hocon(convert_to_ct(config))) writer.write("\n") elif path.suffix == ".json": with open(path, "w") as writer: json.dump(config, writer, ensure_ascii=False, indent=2) elif path.suffix == ".bin": if torch is not None: torch.save(config, path) else: raise RuntimeError("`torch` required to load .bin config file.") else: raise RuntimeError("unsupported file format.") logger(f"saving config to: {path}") def save(self, path: Union[str, Path] = "", key: str = ""): """save resolved config obj to `path` or path provided by `key` in config file. Args: path: save key: the `path` is resolved from `key` in resolved config. """ if path: pass elif key and self.config.get(key, ""): path = self.config.get(key) else: raise RuntimeError("target dir not found") if Path(path).is_dir(): path = Path(path) / self._file_path.name else: path = Path(path) self.save_config(self.config, path) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="tool to convert config file format", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) args = parser.add_argument args("path", metavar="config_file", help="path to config file.") args("-f", "--format", type=str, default="conf", choices=["conf", "yaml", "json"], help="save conf in format.") args("-o", "--output", type=str, help="output config to file.") args("-s", "--show", action="store_true", help="just show the resolved config.") params = parser.parse_args() conf = ConfigBase.load(params.path) output = params.output if not output: output = Path(params.path).with_suffix(f".{params.format.strip('.')}") if params.show: ConfigBase.show(conf) else: ConfigBase.save_config(conf, output)