offline_data_model_pipline/data_generate/query_completion/count_label.py

97 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from collections import defaultdict
import pandas as pd
def classify_data_by_labels(input_file, output_file, output_excel):
"""
根据 prompt_label 中的 labels 提取类型(::后面的内容)。
每个类型对应一个列表,存储去掉了 embedding 的数据。
输出文件按类型的数据量从多到少排序,并记录每种类型的数量。
每个类型的 data_list 会根据 score 字段降序排序。
并且会过滤掉 data_list 长度小于 19 的类型和特定的 label 值。
"""
# 初始化分类字典
classified_data = defaultdict(list)
excluded_labels = {"null", "", "无指代", "无指代消解", "无明显指代消解",
"无上下文依赖", "无明显指代消解需求", "无明确指代",
"无明显上下文依赖", "无依赖", "无上下文", "无明显指代",
}
# 读取输入文件
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
record = json.loads(line.strip())
# 提取 prompt_label 中的 labels 部分
prompt_label = record.get("prompt_label", "{}")
try:
prompt_label = json.loads(prompt_label)
except json.JSONDecodeError:
prompt_label = {}
if isinstance(prompt_label, list):
labels = prompt_label[0].get("labels", [])
else:
labels = prompt_label.get("labels", [])
# 如果 labels 存在,则根据类型分类
if labels:
for label in labels:
if "::" in label:
type_name = label.split("::")[-1] # 提取 :: 后面的内容
# 排除特定的 label 值
if any(excluded_label in label for excluded_label in excluded_labels):
continue
record_without_embedding = {k: v for k, v in record.items() if k != "embedding"}
classified_data[type_name].append(record_without_embedding)
# else:
# # 如果没有 labels则归类到 "null"
# record_without_embedding = {k: v for k, v in record.items() if k != "embedding"}
# classified_data["null"].append(record_without_embedding)
# 对每个类型的 data_list 按照 score 字段降序排序
for type_name in classified_data.keys():
classified_data[type_name].sort(key=lambda x: x.get('score', 0), reverse=True)
# 过滤掉 data_list 长度小于 19 的类型
filtered_classified_data = {k: v for k, v in classified_data.items() if len(v) >= 19}
# 按类型的数据量从多到少排序
sorted_classified_data = sorted(filtered_classified_data.items(), key=lambda x: len(x[1]), reverse=True)
# 写入输出文件
total_types = len(sorted_classified_data)
with open(output_file, 'w', encoding='utf-8') as out_f:
for type_name, data_list in sorted_classified_data:
entry = {
type_name: data_list,
#"count": len(data_list)
}
out_f.write(json.dumps(entry, ensure_ascii=False) + '\n')
print(f"Total types after filtering: {total_types}")
# 准备 Excel 数据
excel_data = []
for type_name, data_list in sorted_classified_data:
excel_data.append({"Type": type_name, "Count": len(data_list)})
# 导出到 Excel 文件
df = pd.DataFrame(excel_data)
df.to_excel(output_excel, index=False)
# 将类型为 null 的数据单独保存到一个 JSONL 文件中
# null_data = classified_data.get("null", [])
# if len(null_data) >= 19: # 只有当 null 类型的数据长度大于等于19时才保存
# with open('./dhbq/prompt_null.jsonl', 'w', encoding='utf-8') as null_f:
# for record in null_data:
# null_f.write(json.dumps(record, ensure_ascii=False) + '\n')
return total_types
# 示例用法
input_file = './dhbq/dhbq_merged_with_score.jsonl'
output_file = './dhbq/dhbq_count_prompt_label.jsonl'
output_excel = './dhbq/dhbq_count_prompt_label.xlsx'
total_types = classify_data_by_labels(input_file, output_file, output_excel)
print(f"Total types found: {total_types}")