scene-digit-human/common/config.py
2024-12-09 08:51:09 +08:00

219 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 INSPUR Inc. (inspur.com)
#
# @author: J.G. Chen <chenjianguo@inspur.com>
# @date: 2024/02/17
#
"""
对配置文件的一些操作
Notes
- pyyaml:
* 对科学计数法的支持有特殊要求,对于 a[eE][+-]ba 必须有小数点,指数必须含正负号
`pyyaml issues#173 <https://github.com/yaml/pyyaml/issues/173#issuecomment-507918276>`_;
* 或者使用 ruamel.yaml 代替;
- conf(hocon):
* 变量替换:不能添加引号,`{ b = "hello"\n a = ${b} world }`;
* 环境变量:'{ a = ${HOME} }';
"""
import json
import os
from pathlib import Path
from platform import python_version
from pprint import pformat
from typing import Dict
from typing import Union
import yaml
from omegaconf import DictConfig
from omegaconf import ListConfig
from omegaconf import OmegaConf
from packaging import version
from pyhocon import ConfigFactory
from pyhocon import HOCONConverter
from common import HAIRUO_ENV
from common import HairuoEnv
try:
import torch
except ImportError:
torch = None
try:
from common.utils.logger import init_logger
logger = init_logger(__name__).info
except ImportError:
init_logger = None
logger = print
class ConfigBase:
config: Union[DictConfig, ListConfig]
def __init__(self, config_file: str = ""):
"""load config from `config_file`
Args:
config_file(str): config_file name, could be `.yaml`, `.conf`, `.json`, `.bin`
Notes:
config_file is relative to current work dir(cwd), pass to abs path
"""
self._file_path = self._resolve(config_file)
self.config = self.load(self._file_path)
@classmethod
def _resolve(cls, config_file: str = ""):
cwd = Path.cwd()
valid_extension = (".yaml", ".yml", ".conf", ".json", ".bin")
# check config file in current work dir
if not config_file:
files = filter(lambda f: f.is_file(), cwd.iterdir())
logger(f"loading config file from dir: {cwd.resolve()}")
for filename in sorted(files, key=lambda x: os.path.getmtime(x), reverse=True):
if filename.suffix in valid_extension:
config_file = Path(cwd, filename)
break
else:
# parse config file specified.
config_file = Path(config_file)
# check config file
if not (config_file and config_file.is_file() and config_file.suffix in valid_extension):
raise FileNotFoundError(f"Config '{config_file}' not find!")
logger(f"loading config from {config_file.resolve()}")
return config_file
@classmethod
def show(cls, config: Union[Dict, DictConfig, ListConfig] = None):
"""show config given or parsed self.config."""
kwargs = {} if version.parse(python_version()) < version.parse("3.8") else {"sort_dicts": False}
if not config:
config = cls.config
if isinstance(config, (DictConfig, ListConfig)):
config = OmegaConf.to_object(config)
logger(f"======= resolved config ======>\n{pformat(config, **kwargs).encode('utf-8')}")
@classmethod
def load(cls, path: Union[str, Path]) -> Union[DictConfig, ListConfig]:
"""load config file."""
path = Path(path)
logger(f"loading config from: {path}")
if path.suffix in [".yaml", ".yml"]:
with path.open(encoding="utf-8") as f:
config = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader), flags={"allow_objects": True})
elif path.suffix == ".conf":
config = ConfigFactory.parse_file(path.as_posix())
config = OmegaConf.create(HOCONConverter.to_json(config))
elif path.suffix == ".json":
with path.open() as f:
config = OmegaConf.create(json.load(f))
elif path.suffix == ".bin":
if torch is not None:
config = OmegaConf.create(torch.load(path))
else:
raise RuntimeError("`torch` required to load .bin config file.")
else:
raise RuntimeError("unsupported file format to load.")
logger(f"parse configs for ENV: {HAIRUO_ENV}")
# keep configs for given env only
# TODO: update fields recursively
config.update(config.get(HAIRUO_ENV, dict()))
for env in HairuoEnv:
config.pop(env, default=None)
ConfigBase.show(config)
return config
@classmethod
def save_config(cls, config: Union[dict, DictConfig, ListConfig], path: Union[str, Path]):
"""save `config` to given `path`"""
path = Path(path)
def convert_to_ct(_config):
# convert dict to ConfigTree recursively.
if isinstance(_config, (dict, DictConfig)):
tmp = {}
for k, v in _config.items():
tmp[k] = convert_to_ct(v)
return ConfigFactory.from_dict(tmp)
elif isinstance(_config, ListConfig):
return [convert_to_ct(i) for i in _config]
else:
return _config
if path.suffix in [".yaml", ".yml"]:
OmegaConf.save(config, path)
elif path.suffix == ".conf":
with open(path, "w") as writer:
writer.write(HOCONConverter.to_hocon(convert_to_ct(config)))
writer.write("\n")
elif path.suffix == ".json":
with open(path, "w") as writer:
json.dump(config, writer, ensure_ascii=False, indent=2)
elif path.suffix == ".bin":
if torch is not None:
torch.save(config, path)
else:
raise RuntimeError("`torch` required to load .bin config file.")
else:
raise RuntimeError("unsupported file format.")
logger(f"saving config to: {path}")
def save(self, path: Union[str, Path] = "", key: str = ""):
"""save resolved config obj to `path` or path provided by `key` in config file.
Args:
path: save
key: the `path` is resolved from `key` in resolved config.
"""
if path:
pass
elif key and self.config.get(key, ""):
path = self.config.get(key)
else:
raise RuntimeError("target dir not found")
if Path(path).is_dir():
path = Path(path) / self._file_path.name
else:
path = Path(path)
self.save_config(self.config, path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="tool to convert config file format",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
args = parser.add_argument
args("path", metavar="config_file", help="path to config file.")
args("-f", "--format", type=str, default="conf", choices=["conf", "yaml", "json"], help="save conf in format.")
args("-o", "--output", type=str, help="output config to file.")
args("-s", "--show", action="store_true", help="just show the resolved config.")
params = parser.parse_args()
conf = ConfigBase.load(params.path)
output = params.output
if not output:
output = Path(params.path).with_suffix(f".{params.format.strip('.')}")
if params.show:
ConfigBase.show(conf)
else:
ConfigBase.save_config(conf, output)