offline_data_model_pipline/data_generate/query_completion/count_label.py

97 lines
4.3 KiB
Python
Raw Normal View History

2025-05-13 13:00:51 +08:00
import json
from collections import defaultdict
import pandas as pd
def classify_data_by_labels(input_file, output_file, output_excel):
"""
根据 prompt_label 中的 labels 提取类型::后面的内容
每个类型对应一个列表存储去掉了 embedding 的数据
输出文件按类型的数据量从多到少排序并记录每种类型的数量
每个类型的 data_list 会根据 score 字段降序排序
并且会过滤掉 data_list 长度小于 19 的类型和特定的 label
"""
# 初始化分类字典
classified_data = defaultdict(list)
excluded_labels = {"null", "", "无指代", "无指代消解", "无明显指代消解",
"无上下文依赖", "无明显指代消解需求", "无明确指代",
"无明显上下文依赖", "无依赖", "无上下文", "无明显指代",
}
# 读取输入文件
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
record = json.loads(line.strip())
# 提取 prompt_label 中的 labels 部分
prompt_label = record.get("prompt_label", "{}")
try:
prompt_label = json.loads(prompt_label)
except json.JSONDecodeError:
prompt_label = {}
if isinstance(prompt_label, list):
labels = prompt_label[0].get("labels", [])
else:
labels = prompt_label.get("labels", [])
# 如果 labels 存在,则根据类型分类
if labels:
for label in labels:
if "::" in label:
type_name = label.split("::")[-1] # 提取 :: 后面的内容
# 排除特定的 label 值
if any(excluded_label in label for excluded_label in excluded_labels):
continue
record_without_embedding = {k: v for k, v in record.items() if k != "embedding"}
classified_data[type_name].append(record_without_embedding)
# else:
# # 如果没有 labels则归类到 "null"
# record_without_embedding = {k: v for k, v in record.items() if k != "embedding"}
# classified_data["null"].append(record_without_embedding)
# 对每个类型的 data_list 按照 score 字段降序排序
for type_name in classified_data.keys():
classified_data[type_name].sort(key=lambda x: x.get('score', 0), reverse=True)
# 过滤掉 data_list 长度小于 19 的类型
filtered_classified_data = {k: v for k, v in classified_data.items() if len(v) >= 19}
# 按类型的数据量从多到少排序
sorted_classified_data = sorted(filtered_classified_data.items(), key=lambda x: len(x[1]), reverse=True)
# 写入输出文件
total_types = len(sorted_classified_data)
with open(output_file, 'w', encoding='utf-8') as out_f:
for type_name, data_list in sorted_classified_data:
entry = {
type_name: data_list,
#"count": len(data_list)
}
out_f.write(json.dumps(entry, ensure_ascii=False) + '\n')
print(f"Total types after filtering: {total_types}")
# 准备 Excel 数据
excel_data = []
for type_name, data_list in sorted_classified_data:
excel_data.append({"Type": type_name, "Count": len(data_list)})
# 导出到 Excel 文件
df = pd.DataFrame(excel_data)
df.to_excel(output_excel, index=False)
# 将类型为 null 的数据单独保存到一个 JSONL 文件中
# null_data = classified_data.get("null", [])
# if len(null_data) >= 19: # 只有当 null 类型的数据长度大于等于19时才保存
# with open('./dhbq/prompt_null.jsonl', 'w', encoding='utf-8') as null_f:
# for record in null_data:
# null_f.write(json.dumps(record, ensure_ascii=False) + '\n')
return total_types
# 示例用法
input_file = './dhbq/dhbq_merged_with_score.jsonl'
output_file = './dhbq/dhbq_count_prompt_label.jsonl'
output_excel = './dhbq/dhbq_count_prompt_label.xlsx'
total_types = classify_data_by_labels(input_file, output_file, output_excel)
print(f"Total types found: {total_types}")