[Enhancement] Update prompt hash computation (#2)

This commit is contained in:
Tong Gao 2023-07-05 18:29:07 +08:00 committed by GitHub
parent 16e759b996
commit 719ba34d1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
from copy import deepcopy from copy import deepcopy
from typing import Dict, Union from typing import Dict, List, Union
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
@ -24,15 +24,23 @@ def safe_format(input_str: str, **kwargs) -> str:
return input_str return input_str
def get_prompt_hash(dataset_cfg: ConfigDict) -> str: def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
"""Get the hash of the prompt configuration. """Get the hash of the prompt configuration.
Args: Args:
dataset_cfg (ConfigDict): The dataset configuration. dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
configuration.
Returns: Returns:
str: The hash of the prompt configuration. str: The hash of the prompt configuration.
""" """
if isinstance(dataset_cfg, list):
if len(dataset_cfg) == 1:
dataset_cfg = dataset_cfg[0]
else:
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
hash_object = hashlib.sha256(hashes.encode())
return hash_object.hexdigest()
if 'reader_cfg' in dataset_cfg.infer_cfg: if 'reader_cfg' in dataset_cfg.infer_cfg:
# new config # new config
reader_cfg = dict(type='DatasetReader', reader_cfg = dict(type='DatasetReader',
@ -48,7 +56,7 @@ def get_prompt_hash(dataset_cfg: ConfigDict) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items(): for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
d_json = json.dumps(dataset_cfg.infer_cfg, sort_keys=True) d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode()) hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest() return hash_object.hexdigest()