offline_data_model_pipline/data_generate/query_completion/count_cluster.py

71 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from collections import defaultdict
import pandas as pd
def classify_data_by_cluster_center(input_file, output_jsonl, output_excel):
"""
根据 cluster_center 对数据进行分类。
每个 cluster_center 对应一个列表,存储去掉了 embedding 的数据。
输出 JSONL 文件按类型的数据量从多到少排序,并记录每种类型的数量。
同时将类型及其数量导出到 Excel 表格中。
每个 cluster 对应的 data_list 会根据 'score' 字段降序排序。
"""
# 初始化分类字典
classified_data = defaultdict(list)
# 读取输入文件
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
record = json.loads(line.strip())
# 提取 cluster_center 部分
cluster_center = record.get("cluster_center")
# 如果 cluster_center 存在,则根据其值分类
if cluster_center is not None:
record_without_embedding = {k: v for k, v in record.items() if k != "embedding"}
classified_data[cluster_center].append(record_without_embedding)
else:
# 如果没有 cluster_center则归类到 "null"
record_without_embedding = {k: v for k, v in record.items() if k != "embedding"}
classified_data["null"].append(record_without_embedding)
# 对每个 cluster_center 下的 data_list 按照 score 排序默认为0
for center in classified_data:
classified_data[center].sort(key=lambda x: x.get('score', 0), reverse=True)
# 按类型的数据量从多到少排序
sorted_classified_data = sorted(classified_data.items(), key=lambda x: len(x[1]), reverse=True)
# 写入 JSONL 文件
total_types = len(sorted_classified_data)
with open(output_jsonl, 'w', encoding='utf-8') as out_f:
for cluster_center, data_list in sorted_classified_data:
entry = {
str(cluster_center): data_list,
#"count": len(data_list)
}
out_f.write(json.dumps(entry, ensure_ascii=False) + '\n')
# 准备 Excel 数据
excel_data = []
for cluster_center, data_list in sorted_classified_data:
excel_data.append({"Cluster Center": cluster_center, "Count": len(data_list)})
# 导出到 Excel 文件
df = pd.DataFrame(excel_data)
df.to_excel(output_excel, index=False)
print(f"Total types: {total_types}")
return total_types
# 示例用法
if __name__ == "__main__":
input_file = './dhbq/dhbq_merged_with_score_0513.jsonl'
output_jsonl = './dhbq/dhbq_count_cluster_0513.jsonl'
output_excel = './dhbq/dhbq_count_cluster_0513.xlsx'
total_types = classify_data_by_cluster_center(input_file, output_jsonl, output_excel)
print(f"Total types found: {total_types}")