101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
import json
|
|||
|
import sys
|
|||
|
import os
|
|||
|
import torch
|
|||
|
import numpy as np
|
|||
|
from sklearnex import patch_sklearn, unpatch_sklearn
|
|||
|
from matplotlib import pyplot as plt
|
|||
|
import time
|
|||
|
import pandas as pd
|
|||
|
|
|||
|
|
|||
|
def train_kmeans_model(train_data, n_clusters, data_index_list):
|
|||
|
# 注意需要先做 patch_sklearn() 的操作之后再正常导入 sklearn 的工具包。
|
|||
|
patch_sklearn()
|
|||
|
# unpatch_sklearn()
|
|||
|
from sklearn.cluster import KMeans
|
|||
|
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
# 这里 训练数据 要求 输入 是 np.ndarray 格式
|
|||
|
data_set = train_data
|
|||
|
print("train_data_len:" + str(len(data_set)))
|
|||
|
print("n_clusters:" + str(n_clusters))
|
|||
|
|
|||
|
# init 默认就是 k-means++
|
|||
|
model = KMeans(n_clusters=n_clusters, init='k-means++', max_iter=300, n_init="auto", random_state=0)
|
|||
|
|
|||
|
model.fit(data_set)
|
|||
|
# kmeans 模型本身就是包含了聚类中心
|
|||
|
center = model.cluster_centers_
|
|||
|
print(center)
|
|||
|
|
|||
|
cluster_label = model.predict(data_set)
|
|||
|
print("cluster_label_size:" + str(len(cluster_label)))
|
|||
|
print("type:" + str(type(cluster_label)))
|
|||
|
|
|||
|
train_dfs = pd.DataFrame(data_set)
|
|||
|
|
|||
|
train_dfs["predict_label"] = cluster_label
|
|||
|
train_dfs["data_index"] = data_index_list
|
|||
|
|
|||
|
print(train_dfs.columns.values)
|
|||
|
|
|||
|
end_time = time.time()
|
|||
|
avg_time_cost = (end_time - start_time) * 1.0
|
|||
|
print("train_kmeans_time:" + str(avg_time_cost) + " s")
|
|||
|
return train_dfs
|
|||
|
|
|||
|
|
|||
|
# step1: 读取保存的 embeding 二进制文件
|
|||
|
embed_file = "./dhbq/dhbq_embedding.pth"
|
|||
|
embed_dict_list = torch.load(embed_file)
|
|||
|
print("len_embed_dict_list:", len(embed_dict_list))
|
|||
|
|
|||
|
# step2: 数据 parse
|
|||
|
raw_embeding_list = []
|
|||
|
raw_index_list = []
|
|||
|
for line in enumerate(embed_dict_list):
|
|||
|
# line[0] 是 index, line[1] 是保存的 dict, 这里 读取进来 直接 就是 dict 对象
|
|||
|
cur_dict = line[1]
|
|||
|
cur_embeding = cur_dict["embedding"]
|
|||
|
cur_data_idx = cur_dict["data_index"]
|
|||
|
raw_embeding_list.append(cur_embeding)
|
|||
|
raw_index_list.append(cur_data_idx)
|
|||
|
|
|||
|
train_array = np.array(raw_embeding_list)
|
|||
|
print("train_array_shape:", train_array.shape)
|
|||
|
|
|||
|
# 总共聚类 1000类,先聚类 50大类,每个大类里选20小类
|
|||
|
# 好吧,直接干到1000个大类,要是效果不好,在继续选择
|
|||
|
num_cluster = 300
|
|||
|
# 这里会自动进行 pandas row index 对齐
|
|||
|
train_dfs = train_kmeans_model(train_array, num_cluster, raw_index_list)
|
|||
|
predict_label = train_dfs['predict_label'].tolist()
|
|||
|
|
|||
|
print("len_predict_label:", len(predict_label))
|
|||
|
|
|||
|
|
|||
|
# 这里不保存csv,因为 pandas 的 csv 数据保存在遇到特殊字符的时候有意想不到的异常!!
|
|||
|
#data_to_save = {'embeding': raw_embeding_list, 'cluster_center': predict_label, "data_idx": raw_index_list}
|
|||
|
data_to_save = [
|
|||
|
{
|
|||
|
"embedding": raw_embeding_list[i],
|
|||
|
"cluster_center": predict_label[i],
|
|||
|
"data_idx": raw_index_list[i]
|
|||
|
}
|
|||
|
for i in range(len(raw_embeding_list))
|
|||
|
]
|
|||
|
|
|||
|
output_file = "./dhbq/dhbq_cluster_kmeans_result.jsonl"
|
|||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|||
|
for record in data_to_save:
|
|||
|
f.write(json.dumps(record, ensure_ascii=False) + '\n')
|
|||
|
|
|||
|
print(f"Results saved to {output_file}")
|
|||
|
|
|||
|
data_to_save_df = pd.DataFrame(data_to_save)
|
|||
|
data_to_save_df.to_pickle("./dhbq/dhbq_cluster_kmeans_result.pkl")
|