[Enhancement] Update prompt hash computation (#2)

This commit is contained in:
Tong Gao 2023-07-05 18:29:07 +08:00 committed by GitHub
parent 16e759b996
commit 719ba34d1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import hashlib
import json
from copy import deepcopy
from typing import Dict, Union
from typing import Dict, List, Union
from mmengine.config import ConfigDict
@ -24,15 +24,23 @@ def safe_format(input_str: str, **kwargs) -> str:
return input_str
def get_prompt_hash(dataset_cfg: ConfigDict) -> str:
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
"""Get the hash of the prompt configuration.
Args:
dataset_cfg (ConfigDict): The dataset configuration.
dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
configuration.
Returns:
str: The hash of the prompt configuration.
"""
if isinstance(dataset_cfg, list):
if len(dataset_cfg) == 1:
dataset_cfg = dataset_cfg[0]
else:
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
hash_object = hashlib.sha256(hashes.encode())
return hash_object.hexdigest()
if 'reader_cfg' in dataset_cfg.infer_cfg:
# new config
reader_cfg = dict(type='DatasetReader',
@ -48,7 +56,7 @@ def get_prompt_hash(dataset_cfg: ConfigDict) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
d_json = json.dumps(dataset_cfg.infer_cfg, sort_keys=True)
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()