diff --git a/opencompass/utils/prompt.py b/opencompass/utils/prompt.py index a8ea5bf8..d07126f4 100644 --- a/opencompass/utils/prompt.py +++ b/opencompass/utils/prompt.py @@ -3,7 +3,7 @@ from __future__ import annotations import hashlib import json from copy import deepcopy -from typing import Dict, Union +from typing import Dict, List, Union from mmengine.config import ConfigDict @@ -24,15 +24,23 @@ def safe_format(input_str: str, **kwargs) -> str: return input_str -def get_prompt_hash(dataset_cfg: ConfigDict) -> str: +def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: """Get the hash of the prompt configuration. Args: - dataset_cfg (ConfigDict): The dataset configuration. + dataset_cfg (ConfigDict or list[ConfigDict]): The dataset + configuration. Returns: str: The hash of the prompt configuration. """ + if isinstance(dataset_cfg, list): + if len(dataset_cfg) == 1: + dataset_cfg = dataset_cfg[0] + else: + hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg]) + hash_object = hashlib.sha256(hashes.encode()) + return hash_object.hexdigest() if 'reader_cfg' in dataset_cfg.infer_cfg: # new config reader_cfg = dict(type='DatasetReader', @@ -48,7 +56,7 @@ def get_prompt_hash(dataset_cfg: ConfigDict) -> str: 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split for k, v in dataset_cfg.infer_cfg.items(): dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] - d_json = json.dumps(dataset_cfg.infer_cfg, sort_keys=True) + d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) hash_object = hashlib.sha256(d_json.encode()) return hash_object.hexdigest()