offline_data_model_pipline/data_generate/query_completion/cluster_kmeans.py

101 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import json
import sys
import os
import torch
import numpy as np
from sklearnex import patch_sklearn, unpatch_sklearn
from matplotlib import pyplot as plt
import time
import pandas as pd
def train_kmeans_model(train_data, n_clusters, data_index_list):
# 注意需要先做 patch_sklearn() 的操作之后再正常导入 sklearn 的工具包。
patch_sklearn()
# unpatch_sklearn()
from sklearn.cluster import KMeans
start_time = time.time()
# 这里 训练数据 要求 输入 是 np.ndarray 格式
data_set = train_data
print("train_data_len:" + str(len(data_set)))
print("n_clusters:" + str(n_clusters))
# init 默认就是 k-means++
model = KMeans(n_clusters=n_clusters, init='k-means++', max_iter=300, n_init="auto", random_state=0)
model.fit(data_set)
# kmeans 模型本身就是包含了聚类中心
center = model.cluster_centers_
print(center)
cluster_label = model.predict(data_set)
print("cluster_label_size:" + str(len(cluster_label)))
print("type:" + str(type(cluster_label)))
train_dfs = pd.DataFrame(data_set)
train_dfs["predict_label"] = cluster_label
train_dfs["data_index"] = data_index_list
print(train_dfs.columns.values)
end_time = time.time()
avg_time_cost = (end_time - start_time) * 1.0
print("train_kmeans_time:" + str(avg_time_cost) + " s")
return train_dfs
# step1: 读取保存的 embeding 二进制文件
embed_file = "./dhbq/dhbq_embedding.pth"
embed_dict_list = torch.load(embed_file)
print("len_embed_dict_list:", len(embed_dict_list))
# step2: 数据 parse
raw_embeding_list = []
raw_index_list = []
for line in enumerate(embed_dict_list):
# line[0] 是 index, line[1] 是保存的 dict, 这里 读取进来 直接 就是 dict 对象
cur_dict = line[1]
cur_embeding = cur_dict["embedding"]
cur_data_idx = cur_dict["data_index"]
raw_embeding_list.append(cur_embeding)
raw_index_list.append(cur_data_idx)
train_array = np.array(raw_embeding_list)
print("train_array_shape:", train_array.shape)
# 总共聚类 1000类,先聚类 50大类,每个大类里选20小类
# 好吧直接干到1000个大类,要是效果不好,在继续选择
num_cluster = 300
# 这里会自动进行 pandas row index 对齐
train_dfs = train_kmeans_model(train_array, num_cluster, raw_index_list)
predict_label = train_dfs['predict_label'].tolist()
print("len_predict_label:", len(predict_label))
# 这里不保存csv,因为 pandas 的 csv 数据保存在遇到特殊字符的时候有意想不到的异常!!
#data_to_save = {'embeding': raw_embeding_list, 'cluster_center': predict_label, "data_idx": raw_index_list}
data_to_save = [
{
"embedding": raw_embeding_list[i],
"cluster_center": predict_label[i],
"data_idx": raw_index_list[i]
}
for i in range(len(raw_embeding_list))
]
output_file = "./dhbq/dhbq_cluster_kmeans_result.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for record in data_to_save:
f.write(json.dumps(record, ensure_ascii=False) + '\n')
print(f"Results saved to {output_file}")
data_to_save_df = pd.DataFrame(data_to_save)
data_to_save_df.to_pickle("./dhbq/dhbq_cluster_kmeans_result.pkl")