offline_data_model_pipline/data_generate/query_completion/cluster_kmeans.py

101 lines
3.3 KiB
Python
Raw Normal View History

2025-05-13 13:00:51 +08:00
# -*- coding: utf-8 -*-
import json
import sys
import os
import torch
import numpy as np
from sklearnex import patch_sklearn, unpatch_sklearn
from matplotlib import pyplot as plt
import time
import pandas as pd
def train_kmeans_model(train_data, n_clusters, data_index_list):
# 注意需要先做 patch_sklearn() 的操作之后再正常导入 sklearn 的工具包。
patch_sklearn()
# unpatch_sklearn()
from sklearn.cluster import KMeans
start_time = time.time()
# 这里 训练数据 要求 输入 是 np.ndarray 格式
data_set = train_data
print("train_data_len:" + str(len(data_set)))
print("n_clusters:" + str(n_clusters))
# init 默认就是 k-means++
model = KMeans(n_clusters=n_clusters, init='k-means++', max_iter=300, n_init="auto", random_state=0)
model.fit(data_set)
# kmeans 模型本身就是包含了聚类中心
center = model.cluster_centers_
print(center)
cluster_label = model.predict(data_set)
print("cluster_label_size:" + str(len(cluster_label)))
print("type:" + str(type(cluster_label)))
train_dfs = pd.DataFrame(data_set)
train_dfs["predict_label"] = cluster_label
train_dfs["data_index"] = data_index_list
print(train_dfs.columns.values)
end_time = time.time()
avg_time_cost = (end_time - start_time) * 1.0
print("train_kmeans_time:" + str(avg_time_cost) + " s")
return train_dfs
# step1: 读取保存的 embeding 二进制文件
embed_file = "./dhbq/dhbq_embedding.pth"
embed_dict_list = torch.load(embed_file)
print("len_embed_dict_list:", len(embed_dict_list))
# step2: 数据 parse
raw_embeding_list = []
raw_index_list = []
for line in enumerate(embed_dict_list):
# line[0] 是 index, line[1] 是保存的 dict, 这里 读取进来 直接 就是 dict 对象
cur_dict = line[1]
cur_embeding = cur_dict["embedding"]
cur_data_idx = cur_dict["data_index"]
raw_embeding_list.append(cur_embeding)
raw_index_list.append(cur_data_idx)
train_array = np.array(raw_embeding_list)
print("train_array_shape:", train_array.shape)
# 总共聚类 1000类,先聚类 50大类,每个大类里选20小类
# 好吧直接干到1000个大类,要是效果不好,在继续选择
num_cluster = 300
# 这里会自动进行 pandas row index 对齐
train_dfs = train_kmeans_model(train_array, num_cluster, raw_index_list)
predict_label = train_dfs['predict_label'].tolist()
print("len_predict_label:", len(predict_label))
# 这里不保存csv,因为 pandas 的 csv 数据保存在遇到特殊字符的时候有意想不到的异常!!
#data_to_save = {'embeding': raw_embeding_list, 'cluster_center': predict_label, "data_idx": raw_index_list}
data_to_save = [
{
"embedding": raw_embeding_list[i],
"cluster_center": predict_label[i],
"data_idx": raw_index_list[i]
}
for i in range(len(raw_embeding_list))
]
output_file = "./dhbq/dhbq_cluster_kmeans_result.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for record in data_to_save:
f.write(json.dumps(record, ensure_ascii=False) + '\n')
print(f"Results saved to {output_file}")
data_to_save_df = pd.DataFrame(data_to_save)
data_to_save_df.to_pickle("./dhbq/dhbq_cluster_kmeans_result.pkl")