mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Enhancement] Update prompt hash computation (#2)
This commit is contained in:
parent
16e759b996
commit
719ba34d1b
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
@ -24,15 +24,23 @@ def safe_format(input_str: str, **kwargs) -> str:
|
||||
return input_str
|
||||
|
||||
|
||||
def get_prompt_hash(dataset_cfg: ConfigDict) -> str:
|
||||
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
|
||||
"""Get the hash of the prompt configuration.
|
||||
|
||||
Args:
|
||||
dataset_cfg (ConfigDict): The dataset configuration.
|
||||
dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
|
||||
configuration.
|
||||
|
||||
Returns:
|
||||
str: The hash of the prompt configuration.
|
||||
"""
|
||||
if isinstance(dataset_cfg, list):
|
||||
if len(dataset_cfg) == 1:
|
||||
dataset_cfg = dataset_cfg[0]
|
||||
else:
|
||||
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
|
||||
hash_object = hashlib.sha256(hashes.encode())
|
||||
return hash_object.hexdigest()
|
||||
if 'reader_cfg' in dataset_cfg.infer_cfg:
|
||||
# new config
|
||||
reader_cfg = dict(type='DatasetReader',
|
||||
@ -48,7 +56,7 @@ def get_prompt_hash(dataset_cfg: ConfigDict) -> str:
|
||||
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
|
||||
for k, v in dataset_cfg.infer_cfg.items():
|
||||
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
|
||||
d_json = json.dumps(dataset_cfg.infer_cfg, sort_keys=True)
|
||||
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
|
||||
hash_object = hashlib.sha256(d_json.encode())
|
||||
return hash_object.hexdigest()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user