import json from collections import defaultdict import pandas as pd def classify_data_by_labels(input_file, output_file, output_excel): """ 根据 prompt_label 中的 labels 提取类型(::后面的内容)。 每个类型对应一个列表,存储去掉了 embedding 的数据。 输出文件按类型的数据量从多到少排序,并记录每种类型的数量。 每个类型的 data_list 会根据 score 字段降序排序。 并且会过滤掉 data_list 长度小于 19 的类型和特定的 label 值。 """ # 初始化分类字典 classified_data = defaultdict(list) excluded_labels = {"null", "无", "无指代", "无指代消解", "无明显指代消解", "无上下文依赖", "无明显指代消解需求", "无明确指代", "无明显上下文依赖", "无依赖", "无上下文", "无明显指代", } # 读取输入文件 with open(input_file, 'r', encoding='utf-8') as f: for line in f: record = json.loads(line.strip()) # 提取 prompt_label 中的 labels 部分 prompt_label = record.get("prompt_label", "{}") try: prompt_label = json.loads(prompt_label) except json.JSONDecodeError: prompt_label = {} if isinstance(prompt_label, list): labels = prompt_label[0].get("labels", []) else: labels = prompt_label.get("labels", []) # 如果 labels 存在,则根据类型分类 if labels: for label in labels: if "::" in label: type_name = label.split("::")[-1] # 提取 :: 后面的内容 # 排除特定的 label 值 if any(excluded_label in label for excluded_label in excluded_labels): continue record_without_embedding = {k: v for k, v in record.items() if k != "embedding"} classified_data[type_name].append(record_without_embedding) # else: # # 如果没有 labels,则归类到 "null" # record_without_embedding = {k: v for k, v in record.items() if k != "embedding"} # classified_data["null"].append(record_without_embedding) # 对每个类型的 data_list 按照 score 字段降序排序 for type_name in classified_data.keys(): classified_data[type_name].sort(key=lambda x: x.get('score', 0), reverse=True) # 过滤掉 data_list 长度小于 19 的类型 filtered_classified_data = {k: v for k, v in classified_data.items() if len(v) >= 19} # 按类型的数据量从多到少排序 sorted_classified_data = sorted(filtered_classified_data.items(), key=lambda x: len(x[1]), reverse=True) # 写入输出文件 total_types = len(sorted_classified_data) with open(output_file, 'w', encoding='utf-8') as out_f: for type_name, data_list in sorted_classified_data: entry = { type_name: data_list, #"count": len(data_list) } out_f.write(json.dumps(entry, ensure_ascii=False) + '\n') print(f"Total types after filtering: {total_types}") # 准备 Excel 数据 excel_data = [] for type_name, data_list in sorted_classified_data: excel_data.append({"Type": type_name, "Count": len(data_list)}) # 导出到 Excel 文件 df = pd.DataFrame(excel_data) df.to_excel(output_excel, index=False) # 将类型为 null 的数据单独保存到一个 JSONL 文件中 # null_data = classified_data.get("null", []) # if len(null_data) >= 19: # 只有当 null 类型的数据长度大于等于19时才保存 # with open('./dhbq/prompt_null.jsonl', 'w', encoding='utf-8') as null_f: # for record in null_data: # null_f.write(json.dumps(record, ensure_ascii=False) + '\n') return total_types # 示例用法 input_file = './dhbq/dhbq_merged_with_score.jsonl' output_file = './dhbq/dhbq_count_prompt_label.jsonl' output_excel = './dhbq/dhbq_count_prompt_label.xlsx' total_types = classify_data_by_labels(input_file, output_file, output_excel) print(f"Total types found: {total_types}")