mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* feat datasetrefine drop * fix datasets in fullbench_int3 * fix * fix * back * fix * fix and doc * feat * fix hook * fix * fix * fix * fix * fix * fix * fix * fix * fix * doc * fix * fix * Update dataset-index.yml
146 lines
5.2 KiB
Python
Executable File
146 lines
5.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import argparse
|
|
import glob
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import re
|
|
from multiprocessing import Pool
|
|
from typing import List, Union
|
|
|
|
from mmengine.config import Config, ConfigDict
|
|
|
|
|
|
# from opencompass.utils import get_prompt_hash
|
|
# copied from opencompass.utils.get_prompt_hash, for easy use in ci
|
|
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
|
|
"""Get the hash of the prompt configuration.
|
|
|
|
Args:
|
|
dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
|
|
configuration.
|
|
|
|
Returns:
|
|
str: The hash of the prompt configuration.
|
|
"""
|
|
if isinstance(dataset_cfg, list):
|
|
if len(dataset_cfg) == 1:
|
|
dataset_cfg = dataset_cfg[0]
|
|
else:
|
|
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
|
|
hash_object = hashlib.sha256(hashes.encode())
|
|
return hash_object.hexdigest()
|
|
# for custom datasets
|
|
if 'infer_cfg' not in dataset_cfg:
|
|
dataset_cfg.pop('abbr', '')
|
|
dataset_cfg.pop('path', '')
|
|
d_json = json.dumps(dataset_cfg.to_dict(), sort_keys=True)
|
|
hash_object = hashlib.sha256(d_json.encode())
|
|
return hash_object.hexdigest()
|
|
# for regular datasets
|
|
if 'reader_cfg' in dataset_cfg.infer_cfg:
|
|
# new config
|
|
reader_cfg = dict(type='DatasetReader',
|
|
input_columns=dataset_cfg.reader_cfg.input_columns,
|
|
output_column=dataset_cfg.reader_cfg.output_column)
|
|
dataset_cfg.infer_cfg.reader = reader_cfg
|
|
if 'train_split' in dataset_cfg.infer_cfg.reader_cfg:
|
|
dataset_cfg.infer_cfg.retriever[
|
|
'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][
|
|
'train_split']
|
|
if 'test_split' in dataset_cfg.infer_cfg.reader_cfg:
|
|
dataset_cfg.infer_cfg.retriever[
|
|
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
|
|
for k, v in dataset_cfg.infer_cfg.items():
|
|
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
|
|
# A compromise for the hash consistency
|
|
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
|
|
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
|
|
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
|
|
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
|
|
hash_object = hashlib.sha256(d_json.encode())
|
|
return hash_object.hexdigest()
|
|
|
|
|
|
# Assuming get_hash is a function that computes the hash of a file
|
|
# from get_hash import get_hash
|
|
def get_hash(path):
|
|
cfg = Config.fromfile(path)
|
|
for k in cfg.keys():
|
|
if k.endswith('_datasets'):
|
|
return get_prompt_hash(cfg[k])[:6]
|
|
print(f'Could not find *_datasets in {path}')
|
|
return None
|
|
|
|
|
|
def check_and_rename(filepath):
|
|
base_name = os.path.basename(filepath)
|
|
match = re.match(r'(.*)_(gen|ppl|ll|mixed)_(.*).py', base_name)
|
|
if match:
|
|
dataset, mode, old_hash = match.groups()
|
|
try:
|
|
new_hash = get_hash(filepath)
|
|
except Exception:
|
|
print(f'Failed to get hash for {filepath}')
|
|
raise ModuleNotFoundError
|
|
|
|
if not new_hash:
|
|
return None, None
|
|
if old_hash != new_hash:
|
|
new_name = f'{dataset}_{mode}_{new_hash}.py'
|
|
new_file = os.path.join(os.path.dirname(filepath), new_name)
|
|
print(f'Rename {filepath} to {new_file}')
|
|
return filepath, new_file
|
|
return None, None
|
|
|
|
|
|
# def update_imports(data):
|
|
# python_file, name_pairs = data
|
|
# for filepath, new_file in name_pairs:
|
|
# old_name = os.path.basename(filepath)[:-3]
|
|
# new_name = os.path.basename(new_file)[:-3]
|
|
# if not os.path.exists(python_file):
|
|
# return
|
|
# with open(python_file, 'r') as file:
|
|
# filedata = file.read()
|
|
# # Replace the old name with new name
|
|
# new_data = filedata.replace(old_name, new_name)
|
|
# if filedata != new_data:
|
|
# with open(python_file, 'w') as file:
|
|
# file.write(new_data)
|
|
# # print(f"Updated imports in {python_file}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('python_files', nargs='*')
|
|
# Could be opencompass/configs/datasets and configs/datasets
|
|
parser.add_argument('--root_folder', default='configs/datasets')
|
|
args = parser.parse_args()
|
|
|
|
root_folder = args.root_folder
|
|
if args.python_files:
|
|
python_files = [
|
|
i for i in args.python_files if i.startswith(root_folder)
|
|
]
|
|
else:
|
|
python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
|
|
|
|
# Use multiprocessing to speed up the check and rename process
|
|
with Pool(16) as p:
|
|
name_pairs = p.map(check_and_rename, python_files)
|
|
name_pairs = [pair for pair in name_pairs if pair[0] is not None]
|
|
if not name_pairs:
|
|
return
|
|
with Pool(16) as p:
|
|
p.starmap(os.rename, name_pairs)
|
|
# root_folder = 'configs'
|
|
# python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
|
|
# update_data = [(python_file, name_pairs) for python_file in python_files]
|
|
# with Pool(16) as p:
|
|
# p.map(update_imports, update_data)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|