From 719ba34d1b28f1166c648a7205c055b69bee08b1 Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Wed, 5 Jul 2023 18:29:07 +0800 Subject: [PATCH] [Enhancement] Update prompt hash computation (#2) --- opencompass/utils/prompt.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/opencompass/utils/prompt.py b/opencompass/utils/prompt.py index a8ea5bf8..d07126f4 100644 --- a/opencompass/utils/prompt.py +++ b/opencompass/utils/prompt.py @@ -3,7 +3,7 @@ from __future__ import annotations import hashlib import json from copy import deepcopy -from typing import Dict, Union +from typing import Dict, List, Union from mmengine.config import ConfigDict @@ -24,15 +24,23 @@ def safe_format(input_str: str, **kwargs) -> str: return input_str -def get_prompt_hash(dataset_cfg: ConfigDict) -> str: +def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: """Get the hash of the prompt configuration. Args: - dataset_cfg (ConfigDict): The dataset configuration. + dataset_cfg (ConfigDict or list[ConfigDict]): The dataset + configuration. Returns: str: The hash of the prompt configuration. """ + if isinstance(dataset_cfg, list): + if len(dataset_cfg) == 1: + dataset_cfg = dataset_cfg[0] + else: + hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg]) + hash_object = hashlib.sha256(hashes.encode()) + return hash_object.hexdigest() if 'reader_cfg' in dataset_cfg.infer_cfg: # new config reader_cfg = dict(type='DatasetReader', @@ -48,7 +56,7 @@ def get_prompt_hash(dataset_cfg: ConfigDict) -> str: 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split for k, v in dataset_cfg.infer_cfg.items(): dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] - d_json = json.dumps(dataset_cfg.infer_cfg, sort_keys=True) + d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) hash_object = hashlib.sha256(d_json.encode()) return hash_object.hexdigest()