# -*- coding: utf-8 -*- import json import sys import os import torch import numpy as np from sklearnex import patch_sklearn, unpatch_sklearn from matplotlib import pyplot as plt import time import pandas as pd def train_kmeans_model(train_data, n_clusters, data_index_list): # 注意需要先做 patch_sklearn() 的操作之后再正常导入 sklearn 的工具包。 patch_sklearn() # unpatch_sklearn() from sklearn.cluster import KMeans start_time = time.time() # 这里 训练数据 要求 输入 是 np.ndarray 格式 data_set = train_data print("train_data_len:" + str(len(data_set))) print("n_clusters:" + str(n_clusters)) # init 默认就是 k-means++ model = KMeans(n_clusters=n_clusters, init='k-means++', max_iter=300, n_init="auto", random_state=0) model.fit(data_set) # kmeans 模型本身就是包含了聚类中心 center = model.cluster_centers_ print(center) cluster_label = model.predict(data_set) print("cluster_label_size:" + str(len(cluster_label))) print("type:" + str(type(cluster_label))) train_dfs = pd.DataFrame(data_set) train_dfs["predict_label"] = cluster_label train_dfs["data_index"] = data_index_list print(train_dfs.columns.values) end_time = time.time() avg_time_cost = (end_time - start_time) * 1.0 print("train_kmeans_time:" + str(avg_time_cost) + " s") return train_dfs # step1: 读取保存的 embeding 二进制文件 embed_file = "./dhbq/dhbq_embedding.pth" embed_dict_list = torch.load(embed_file) print("len_embed_dict_list:", len(embed_dict_list)) # step2: 数据 parse raw_embeding_list = [] raw_index_list = [] for line in enumerate(embed_dict_list): # line[0] 是 index, line[1] 是保存的 dict, 这里 读取进来 直接 就是 dict 对象 cur_dict = line[1] cur_embeding = cur_dict["embedding"] cur_data_idx = cur_dict["data_index"] raw_embeding_list.append(cur_embeding) raw_index_list.append(cur_data_idx) train_array = np.array(raw_embeding_list) print("train_array_shape:", train_array.shape) # 总共聚类 1000类,先聚类 50大类,每个大类里选20小类 # 好吧,直接干到1000个大类,要是效果不好,在继续选择 num_cluster = 300 # 这里会自动进行 pandas row index 对齐 train_dfs = train_kmeans_model(train_array, num_cluster, raw_index_list) predict_label = train_dfs['predict_label'].tolist() print("len_predict_label:", len(predict_label)) # 这里不保存csv,因为 pandas 的 csv 数据保存在遇到特殊字符的时候有意想不到的异常!! #data_to_save = {'embeding': raw_embeding_list, 'cluster_center': predict_label, "data_idx": raw_index_list} data_to_save = [ { "embedding": raw_embeding_list[i], "cluster_center": predict_label[i], "data_idx": raw_index_list[i] } for i in range(len(raw_embeding_list)) ] output_file = "./dhbq/dhbq_cluster_kmeans_result.jsonl" with open(output_file, 'w', encoding='utf-8') as f: for record in data_to_save: f.write(json.dumps(record, ensure_ascii=False) + '\n') print(f"Results saved to {output_file}") data_to_save_df = pd.DataFrame(data_to_save) data_to_save_df.to_pickle("./dhbq/dhbq_cluster_kmeans_result.pkl")