diff --git a/data_generate/query_completion/merge.py b/data_generate/query_completion/merge.py index 27290df..d7981f7 100644 --- a/data_generate/query_completion/merge.py +++ b/data_generate/query_completion/merge.py @@ -152,12 +152,12 @@ def merge_jsonl_files(file1, file2, output_file): record = json.loads(line.strip()) index = record.get('data_idx') cluster_center = record.get('cluster_center') - #embedding = record.get('embedding') + embedding = record.get('embedding') # 如果'index'存在于第一个文件的'uid'中,则合并数据 if index in data_dict: data_dict[index]['cluster_center'] = cluster_center - #data_dict[index]['embedding'] = embedding + data_dict[index]['embedding'] = embedding # 将合并后的数据写入输出文件 with open(output_file, 'w', encoding='utf-8') as out_f: @@ -195,12 +195,9 @@ def merge_jsonl_files(file1, file2, output_file): record = json.loads(line.strip()) index = record.get('uid') score = record.get('answer') - embedding = record.get('embedding') - # 如果'index'存在于第一个文件的'uid'中,则合并数据 if index in data_dict: data_dict[index]['score'] = score - data_dict[index]['embedding'] = embedding # 将合并后的数据写入输出文件 with open(output_file, 'w', encoding='utf-8') as out_f: