2023-12-13 19:59:30 +08:00
|
|
|
import argparse
|
|
|
|
import csv
|
|
|
|
import json
|
|
|
|
import os
|
2023-12-15 15:07:25 +08:00
|
|
|
from glob import glob
|
2023-12-13 19:59:30 +08:00
|
|
|
|
2023-12-15 15:07:25 +08:00
|
|
|
from tqdm import tqdm
|
2023-12-13 19:59:30 +08:00
|
|
|
|
|
|
|
|
2023-12-15 15:07:25 +08:00
|
|
|
def extract_predictions_from_json(input_folder):
|
2023-12-13 19:59:30 +08:00
|
|
|
|
2023-12-15 15:07:25 +08:00
|
|
|
sub_folder = os.path.join(input_folder, 'submission')
|
|
|
|
pred_folder = os.path.join(input_folder, 'predictions')
|
|
|
|
if not os.path.exists(sub_folder):
|
|
|
|
os.makedirs(sub_folder)
|
2023-12-13 19:59:30 +08:00
|
|
|
|
2023-12-15 15:07:25 +08:00
|
|
|
for model_name in os.listdir(pred_folder):
|
|
|
|
model_folder = os.path.join(pred_folder, model_name)
|
2023-12-25 16:45:22 +08:00
|
|
|
try:
|
|
|
|
# when use split
|
|
|
|
json_paths = glob(
|
|
|
|
os.path.join(model_folder, 'alignment_bench_*.json'))
|
|
|
|
# sorted by index
|
|
|
|
json_paths = sorted(
|
|
|
|
json_paths,
|
|
|
|
key=lambda x: int(x.split('.json')[0].split('_')[-1]))
|
|
|
|
except Exception as e:
|
|
|
|
# when only one complete file
|
|
|
|
print(e)
|
|
|
|
json_paths = [os.path.join(model_folder, 'alignment_bench.json')]
|
|
|
|
|
2023-12-15 15:07:25 +08:00
|
|
|
all_predictions = []
|
|
|
|
for json_ in json_paths:
|
|
|
|
json_data = json.load(open(json_))
|
|
|
|
for _, value in json_data.items():
|
|
|
|
prediction = value['prediction']
|
|
|
|
all_predictions.append(prediction)
|
2023-12-13 19:59:30 +08:00
|
|
|
|
2023-12-15 15:07:25 +08:00
|
|
|
# for prediction
|
|
|
|
output_path = os.path.join(sub_folder, model_name + '_submission.csv')
|
2023-12-21 15:58:20 +08:00
|
|
|
with open(output_path, 'w', encoding='utf-8-sig') as file:
|
2023-12-15 15:07:25 +08:00
|
|
|
writer = csv.writer(file)
|
|
|
|
for ans in tqdm(all_predictions):
|
|
|
|
writer.writerow([str(ans)])
|
|
|
|
print('Saved {} for submission'.format(output_path))
|
2023-12-13 19:59:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
def process_jsonl(file_path):
|
|
|
|
new_data = []
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
|
for line in file:
|
|
|
|
json_data = json.loads(line)
|
|
|
|
new_dict = {
|
|
|
|
'question': json_data['question'],
|
|
|
|
'capability': json_data['category'],
|
|
|
|
'others': {
|
|
|
|
'subcategory': json_data['subcategory'],
|
|
|
|
'reference': json_data['reference'],
|
|
|
|
'question_id': json_data['question_id']
|
|
|
|
}
|
|
|
|
}
|
|
|
|
new_data.append(new_dict)
|
|
|
|
return new_data
|
|
|
|
|
|
|
|
|
|
|
|
def save_as_json(data, output_file='./alignment_bench.json'):
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as file:
|
|
|
|
json.dump(data, file, indent=4, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser(description='File Converter')
|
|
|
|
parser.add_argument('--mode',
|
|
|
|
default='json',
|
|
|
|
help='The mode of convert to json or convert to csv')
|
|
|
|
parser.add_argument('--jsonl',
|
|
|
|
default='./data_release.jsonl',
|
|
|
|
help='The original jsonl path')
|
|
|
|
parser.add_argument('--json',
|
2024-06-28 14:16:34 +08:00
|
|
|
default='./alignment_bench.json',
|
2023-12-13 19:59:30 +08:00
|
|
|
help='The results json path')
|
2023-12-15 15:07:25 +08:00
|
|
|
parser.add_argument('--exp-folder', help='The results json name')
|
2023-12-13 19:59:30 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = parse_args()
|
|
|
|
mode = args.mode
|
|
|
|
if mode == 'json':
|
|
|
|
processed_data = process_jsonl(args.jsonl)
|
2024-06-28 14:16:34 +08:00
|
|
|
save_as_json(processed_data, args.json)
|
2023-12-13 19:59:30 +08:00
|
|
|
elif mode == 'csv':
|
2023-12-15 15:07:25 +08:00
|
|
|
extract_predictions_from_json(args.exp_folder)
|