12345数据聚类分布
This commit is contained in:
parent
a116fa2bbe
commit
0655f98c70
100
data_generate/zcs/fenbu/cluster_kmeans_train_from_torch_bin.py
Normal file
100
data_generate/zcs/fenbu/cluster_kmeans_train_from_torch_bin.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from sklearnex import patch_sklearn, unpatch_sklearn
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import time
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def train_kmeans_model(train_data, n_clusters, data_index_list):
|
||||||
|
# 注意需要先做 patch_sklearn() 的操作之后再正常导入 sklearn 的工具包。
|
||||||
|
patch_sklearn()
|
||||||
|
# unpatch_sklearn()
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 这里 训练数据 要求 输入 是 np.ndarray 格式
|
||||||
|
data_set = train_data
|
||||||
|
print("train_data_len:" + str(len(data_set)))
|
||||||
|
print("n_clusters:" + str(n_clusters))
|
||||||
|
|
||||||
|
# init 默认就是 k-means++
|
||||||
|
model = KMeans(n_clusters=n_clusters, init='k-means++', max_iter=300, n_init="auto", random_state=0)
|
||||||
|
|
||||||
|
model.fit(data_set)
|
||||||
|
# kmeans 模型本身就是包含了聚类中心
|
||||||
|
center = model.cluster_centers_
|
||||||
|
print(center)
|
||||||
|
|
||||||
|
cluster_label = model.predict(data_set)
|
||||||
|
print("cluster_label_size:" + str(len(cluster_label)))
|
||||||
|
print("type:" + str(type(cluster_label)))
|
||||||
|
|
||||||
|
train_dfs = pd.DataFrame(data_set)
|
||||||
|
|
||||||
|
train_dfs["predict_label"] = cluster_label
|
||||||
|
train_dfs["data_index"] = data_index_list
|
||||||
|
|
||||||
|
print(train_dfs.columns.values)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
avg_time_cost = (end_time - start_time) * 1.0
|
||||||
|
print("train_kmeans_time:" + str(avg_time_cost) + " s")
|
||||||
|
return train_dfs
|
||||||
|
|
||||||
|
|
||||||
|
# step1: 读取保存的 embeding 二进制文件
|
||||||
|
embed_file = "/data/zhaochsh01/buquan/12345/embedding/123451wfilter_embedding.pth"
|
||||||
|
embed_dict_list = torch.load(embed_file)
|
||||||
|
print("len_embed_dict_list:", len(embed_dict_list))
|
||||||
|
|
||||||
|
# step2: 数据 parse
|
||||||
|
raw_embeding_list = []
|
||||||
|
raw_index_list = []
|
||||||
|
for line in enumerate(embed_dict_list):
|
||||||
|
# line[0] 是 index, line[1] 是保存的 dict, 这里 读取进来 直接 就是 dict 对象
|
||||||
|
cur_dict = line[1]
|
||||||
|
cur_embeding = cur_dict["embedding"]
|
||||||
|
cur_data_idx = cur_dict["data_index"]
|
||||||
|
raw_embeding_list.append(cur_embeding)
|
||||||
|
raw_index_list.append(cur_data_idx)
|
||||||
|
|
||||||
|
train_array = np.array(raw_embeding_list)
|
||||||
|
print("train_array_shape:", train_array.shape)
|
||||||
|
|
||||||
|
# 总共聚类 1000类,先聚类 50大类,每个大类里选20小类
|
||||||
|
# 好吧,直接干到1000个大类,要是效果不好,在继续选择
|
||||||
|
num_cluster = 50
|
||||||
|
# 这里会自动进行 pandas row index 对齐
|
||||||
|
train_dfs = train_kmeans_model(train_array, num_cluster, raw_index_list)
|
||||||
|
predict_label = train_dfs['predict_label'].tolist()
|
||||||
|
|
||||||
|
print("len_predict_label:", len(predict_label))
|
||||||
|
|
||||||
|
|
||||||
|
# 这里不保存csv,因为 pandas 的 csv 数据保存在遇到特殊字符的时候有意想不到的异常!!
|
||||||
|
#data_to_save = {'embeding': raw_embeding_list, 'cluster_center': predict_label, "data_idx": raw_index_list}
|
||||||
|
data_to_save = [
|
||||||
|
{
|
||||||
|
"embedding": raw_embeding_list[i],
|
||||||
|
"cluster_center": predict_label[i],
|
||||||
|
"data_idx": raw_index_list[i]
|
||||||
|
}
|
||||||
|
for i in range(len(raw_embeding_list))
|
||||||
|
]
|
||||||
|
|
||||||
|
output_file = "./12345/kmeans/123451wfilter_cluster_kmeans_result.jsonl"
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for record in data_to_save:
|
||||||
|
f.write(json.dumps(record, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
print(f"Results saved to {output_file}")
|
||||||
|
|
||||||
|
data_to_save_df = pd.DataFrame(data_to_save)
|
||||||
|
data_to_save_df.to_pickle("./12345/kmeans/123451wfilter_cluster_kmeans_result.pkl")
|
123
data_generate/zcs/fenbu/select_final_data.py
Normal file
123
data_generate/zcs/fenbu/select_final_data.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_jsonl_file(file_path, top_n_per_group, result, seen_uids):
|
||||||
|
"""
|
||||||
|
读取 jsonl 文件,对每一行 group 数据,尽力从中选出 top_n_per_group 条不重复数据
|
||||||
|
如果无法选够指定数量的数据,则跳过该 group 并打印警告信息。
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line_idx, line in enumerate(f):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
|
||||||
|
# 获取 group key 和对应的数据列表
|
||||||
|
key = next(k for k in entry if k != "count")
|
||||||
|
records = entry[key]
|
||||||
|
|
||||||
|
selected = []
|
||||||
|
added_count = 0
|
||||||
|
|
||||||
|
# 遍历 records,直到选够 top_n_per_group 个不重复的数据
|
||||||
|
for item in records:
|
||||||
|
uid = item.get('工单编号')
|
||||||
|
|
||||||
|
if added_count >= top_n_per_group:
|
||||||
|
break # 已经选够了,退出
|
||||||
|
|
||||||
|
if uid and uid not in seen_uids:
|
||||||
|
seen_uids.add(uid)
|
||||||
|
selected.append(item)
|
||||||
|
result.append(item)
|
||||||
|
added_count += 1
|
||||||
|
|
||||||
|
# 如果最终无法选满 top_n_per_group 条,可以给出警告或处理
|
||||||
|
if added_count < top_n_per_group:
|
||||||
|
print(f"[Group {key}] Only found {added_count}/{top_n_per_group} unique items. Skipping this group.")
|
||||||
|
# 如果需要的话,可以从结果中移除已添加的记录
|
||||||
|
for item in selected:
|
||||||
|
result.remove(item)
|
||||||
|
seen_uids.discard(item.get('uid'))
|
||||||
|
else:
|
||||||
|
print(f"[Group {key}] Successfully added {added_count} items.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing line {line_idx + 1}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
def load_data_to_memory(input_file):
|
||||||
|
"""
|
||||||
|
读取jsonl文件,将数据存入result列表,工单编号存入seen_uids集合
|
||||||
|
|
||||||
|
:param input_file: 输入jsonl文件路径
|
||||||
|
:return: tuple(result列表, seen_uids集合)
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
seen_uids = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
# 解析JSON行数据
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
|
||||||
|
# 检查是否包含工单编号字段
|
||||||
|
if '工单编号' in data:
|
||||||
|
uid = data['工单编号']
|
||||||
|
|
||||||
|
# 添加到结果列表和已见集合
|
||||||
|
result.append(data)
|
||||||
|
seen_uids.add(uid)
|
||||||
|
else:
|
||||||
|
print(f"警告:跳过缺少工单编号的记录: {data}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"警告:跳过无法解析的行: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"错误:文件 {input_file} 不存在")
|
||||||
|
return [], set()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取文件时发生错误: {str(e)}")
|
||||||
|
return [], set()
|
||||||
|
|
||||||
|
print(f"成功加载 {len(result)} 条数据,发现 {len(seen_uids)} 个唯一工单编号")
|
||||||
|
return result, seen_uids
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
result = []
|
||||||
|
seen_uids = set()
|
||||||
|
result, seen_uids = load_data_to_memory("/data/zhaochsh01/buquan/12345/tool/output_prompt.jsonl")
|
||||||
|
|
||||||
|
|
||||||
|
# Step 3: 处理 dhbq_count_instag.jsonl,每组前130条
|
||||||
|
print("Step 3: Processing dhbq_count_instag.jsonl...")
|
||||||
|
process_jsonl_file("/data/zhaochsh01/buquan/12345/count_result/12345_count_instag.jsonl", 3, result, seen_uids)
|
||||||
|
|
||||||
|
# Step 4: 处理 dhbq_count_cluster.jsonl,每组前4条
|
||||||
|
print("Step 4: Processing dhbq_count_cluster.jsonl...")
|
||||||
|
process_jsonl_file("/data/zhaochsh01/buquan/12345/count_result/12345_count_cluster.jsonl", 3, result, seen_uids)
|
||||||
|
|
||||||
|
# 最终写入输出文件
|
||||||
|
output_file = "/data/zhaochsh01/buquan/12345/merge_result/merged_result_for_more_test.jsonl"
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for item in result:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
print(f"Total merged items: {len(result)}")
|
||||||
|
print(f"Saved to: {output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,233 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from loguru import logger
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pandas
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed=128):
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class MyTorchDataset(torch.utils.data.Dataset):
|
||||||
|
# def __init__(self, input_file, tokenizer, max_seq_length):
|
||||||
|
# self.tokenizer = tokenizer
|
||||||
|
# self.max_seq_length = max_seq_length
|
||||||
|
# logger.info('Loading data: {}'.format(input_file))
|
||||||
|
# with open(input_file, 'r', encoding='utf8') as f:
|
||||||
|
# data_list = f.readlines()
|
||||||
|
# logger.info("There are {} data in dataset".format(len(data_list)))
|
||||||
|
# # 测试,这里仅仅取 top 100
|
||||||
|
# self.raw_data_list = data_list[:]
|
||||||
|
#
|
||||||
|
# # 这个数据在初始化的时候全部 tokenizer 好,如果在训练的时候每次选择一条,会拖慢整体获取 embeding 的速度
|
||||||
|
# self.query_list = []
|
||||||
|
# self.index_list = []
|
||||||
|
# for line in self.raw_data_list:
|
||||||
|
# json_line = json.loads(line)
|
||||||
|
# print(json_line["data"][0])
|
||||||
|
# # 这个仅仅得到 query 的 embeding, content[0] 是 query
|
||||||
|
# query = json_line["data"][0]['content'].strip()
|
||||||
|
# # 这个是原始数据里,单轮对话对应的 data_index
|
||||||
|
# data_index = json_line["uid"]
|
||||||
|
# self.query_list.append(query)
|
||||||
|
# self.index_list.append(data_index)
|
||||||
|
# assert len(self.query_list) == len(self.index_list)
|
||||||
|
# logger.info(f"final len_query_list:{len(self.query_list)}")
|
||||||
|
# logger.info(f"final len_index_list:{len(self.index_list)}")
|
||||||
|
#
|
||||||
|
# # 这里批量一次性整好所有的数据 tokenizer,这里 250w 数据大概需要2个小时,需要200g左右内存,100g 内存程序会崩溃
|
||||||
|
# logger.info(f" 开始批量 tokenizer 所有数据 ... ")
|
||||||
|
# self.all_data_token = self.tokenizer(self.query_list,
|
||||||
|
# padding='max_length',
|
||||||
|
# truncation=True,
|
||||||
|
# max_length=self.max_seq_length,
|
||||||
|
# return_tensors='pt')
|
||||||
|
# logger.info(f" 批量 tokenizer 所有数据 完成 !!!")
|
||||||
|
def __init__(self, input_file, tokenizer, max_seq_length):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
logger.info('Loading data: {}'.format(input_file))
|
||||||
|
with open(input_file, 'r', encoding='utf8') as f:
|
||||||
|
data_list = f.readlines()
|
||||||
|
logger.info("There are {} data in dataset".format(len(data_list)))
|
||||||
|
|
||||||
|
# 测试:这里仅取 top 100 条数据
|
||||||
|
self.raw_data_list = data_list #[:100]
|
||||||
|
|
||||||
|
# 提取对话内容
|
||||||
|
self.query_list = []
|
||||||
|
self.index_list = []
|
||||||
|
for line in self.raw_data_list:
|
||||||
|
json_line = json.loads(line)
|
||||||
|
# 这个仅仅得到 query 的 embeding, content[0] 是 query
|
||||||
|
query = json_line["主要内容"].strip()
|
||||||
|
# 这个是原始数据里,单轮对话对应的 data_index
|
||||||
|
data_index = json_line["工单编号"]
|
||||||
|
data_index = str(data_index)
|
||||||
|
self.query_list.append(query)
|
||||||
|
self.index_list.append(data_index)
|
||||||
|
|
||||||
|
assert len(self.query_list) == len(self.index_list)
|
||||||
|
logger.info(f"final len_query_list:{len(self.query_list)}")
|
||||||
|
logger.info(f"final len_index_list:{len(self.index_list)}")
|
||||||
|
|
||||||
|
# 批量 Tokenize 所有数据
|
||||||
|
logger.info(f"开始批量 tokenize 所有数据 ...")
|
||||||
|
self.all_data_token = self.tokenizer(
|
||||||
|
self.query_list,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_seq_length,
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
logger.info(f"批量 tokenize 所有数据 完成 !!!")
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# item 包含了 input_ids,attention_mask 等参数,已经转化为 tensor
|
||||||
|
item = {key: torch.as_tensor(value[idx]) for key, value in self.all_data_token.items()}
|
||||||
|
item['data_index'] = self.index_list[idx]
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.query_list)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedder:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
self.device = torch.device("cuda:0")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path,
|
||||||
|
model_max_length=args.max_length,
|
||||||
|
# 这里是推理过程,指定 左侧padding,不然取推理结果有问题
|
||||||
|
padding_side="left",
|
||||||
|
use_fast=False)
|
||||||
|
self.model = AutoModel.from_pretrained(self.args.model_name_or_path).to(self.device)
|
||||||
|
logger.info('tokenizer,model 加载完成')
|
||||||
|
self.all_dataset = MyTorchDataset(self.args.raw_data_path, self.tokenizer, self.args.max_length)
|
||||||
|
self.data_loader = DataLoader(dataset=self.all_dataset, batch_size=self.args.batch_size,
|
||||||
|
# pin_memory=True,num_workers=1, prefetch_factor=8
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_data(self, results: List) -> None:
|
||||||
|
logger.info(f'需要保存的数据长度: {len(results)}')
|
||||||
|
|
||||||
|
df = pandas.DataFrame(results)
|
||||||
|
df.sort_values(by="data_index", inplace=True)
|
||||||
|
df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
df.to_pickle(self.args.output_path)
|
||||||
|
logger.info(f"Saved pickle to {self.args.output_path}")
|
||||||
|
|
||||||
|
'''
|
||||||
|
# 这个地方 .csv 会保存失败
|
||||||
|
csv_file_name = "./rl_demo_emb_240624.csv"
|
||||||
|
df.to_csv(csv_file_name, index=False)
|
||||||
|
logger.info(f"Saved csv to {csv_file_name}")
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 指定保存的文件名
|
||||||
|
torch_py_bin_name = './12345/embedding/123451wfilter_embedding.pth'
|
||||||
|
# 请注意,torch.save()和torch.load()函数不仅可以用于保存和加载张量,还可以用于保存和加载任何Python对象,
|
||||||
|
# 包括但不限于字典、集合、自定义类实例等。
|
||||||
|
torch.save(results, torch_py_bin_name)
|
||||||
|
logger.info(f'List has been saved to {torch_py_bin_name}')
|
||||||
|
|
||||||
|
def encode_samples(self):
|
||||||
|
total_samples = len(self.all_dataset)
|
||||||
|
logger.info(f"total_samples: {total_samples}")
|
||||||
|
total_batches = len(self.data_loader)
|
||||||
|
logger.info(f"total_batches: {total_batches}")
|
||||||
|
|
||||||
|
all_embeddings_list = []
|
||||||
|
|
||||||
|
for b_idx, batch in enumerate(tqdm(self.data_loader, total=total_batches)):
|
||||||
|
self.model.eval()
|
||||||
|
batch_idx = batch["data_index"]
|
||||||
|
input_ids = batch['input_ids'].to(self.device)
|
||||||
|
attention_mask = batch['attention_mask'].to(self.device)
|
||||||
|
output = None
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
cls_embedding = output.pooler_output.detach().cpu().numpy().tolist()
|
||||||
|
|
||||||
|
# 记录 embedding 和 index
|
||||||
|
cur_sample_idx = batch_idx#.tolist() # 如果 batch_idx 是张量
|
||||||
|
cur_sample_dict_list = [
|
||||||
|
{"embedding": cls_emb, "data_index": s_id}
|
||||||
|
for cls_emb, s_id in zip(cls_embedding, cur_sample_idx)
|
||||||
|
]
|
||||||
|
all_embeddings_list.extend(cur_sample_dict_list)
|
||||||
|
|
||||||
|
return all_embeddings_list
|
||||||
|
|
||||||
|
# def encode_samples(self):
|
||||||
|
# total_samples = len(self.all_dataset)
|
||||||
|
# logger.info(f"total_samples:{total_samples}")
|
||||||
|
# total_batches = len(self.data_loader)
|
||||||
|
# logger.info(f"total_batches:{total_batches}")
|
||||||
|
#
|
||||||
|
# all_embeddings_list = []
|
||||||
|
#
|
||||||
|
# for b_idx, batch in enumerate(tqdm(self.data_loader, total=total_samples // self.args.batch_size)):
|
||||||
|
# self.model.eval()
|
||||||
|
#
|
||||||
|
# batch_idx = batch["data_index"]
|
||||||
|
# input_ids = batch['input_ids'].to(self.device)
|
||||||
|
# attention_mask = batch['attention_mask'].to(self.device)
|
||||||
|
# output = None
|
||||||
|
# with torch.no_grad():
|
||||||
|
# output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
#
|
||||||
|
# cls_embedding = output.pooler_output.detach().cpu().numpy().tolist()
|
||||||
|
# print(batch_idx)
|
||||||
|
# # 这里批量记录下 embeding 和 index
|
||||||
|
# cur_sample_idx = batch_idx.tolist()
|
||||||
|
# cur_sample_dict_list = [{"embedding": cls_emb, "data_index": s_id} for cls_emb, s_id in
|
||||||
|
# zip(cls_embedding, cur_sample_idx)]
|
||||||
|
# # list extend 多个 list元素 flatten 合并成一个 list
|
||||||
|
# all_embeddings_list.extend(cur_sample_dict_list)
|
||||||
|
#
|
||||||
|
# return all_embeddings_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# 参数设置
|
||||||
|
set_seed()
|
||||||
|
# 若干参数设置
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# 不要改该参数,系统会自动分配
|
||||||
|
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||||
|
# 开启的进程数(注意不是线程),不用设置该参数,会根据nproc_per_node自动设置
|
||||||
|
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.max_length = 512
|
||||||
|
args.batch_size = 2048
|
||||||
|
args.model_name_or_path = "/model-pvc/suojiayi/bge-base-zh-v1.5/"
|
||||||
|
#args.raw_data_path = "/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_COIG_filtered_2504212014.jsonl"
|
||||||
|
args.raw_data_path = '/data/zhaochsh01/buquan/12345/tool/output.jsonl'
|
||||||
|
args.output_path = "./12345/embedding/123451wffilter_embedding.bin"
|
||||||
|
|
||||||
|
logger.info(f"\nargs:{args}\n")
|
||||||
|
|
||||||
|
text_embeder = TextEmbedder(args)
|
||||||
|
all_embeddings_list = text_embeder.encode_samples()
|
||||||
|
|
||||||
|
logger.info(f"len_all_embeddings_list:{len(all_embeddings_list)}")
|
||||||
|
logger.info("Finished embedding")
|
||||||
|
|
||||||
|
text_embeder.save_data(all_embeddings_list)
|
||||||
|
logger.info(f"Pipeline run complete.")
|
68
data_generate/zcs/fenbu/tool/count_readjsonl.py
Normal file
68
data_generate/zcs/fenbu/tool/count_readjsonl.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
import openpyxl
|
||||||
|
|
||||||
|
def count_categories_from_jsonl(input_file, output_xlsx):
|
||||||
|
"""
|
||||||
|
读取jsonl文件,统计类别信息并输出到xlsx文件
|
||||||
|
|
||||||
|
:param input_file: 输入的jsonl文件路径
|
||||||
|
:param output_xlsx: 输出的xlsx文件路径
|
||||||
|
"""
|
||||||
|
# 用于统计每个类别的数据量
|
||||||
|
category_counts = defaultdict(int)
|
||||||
|
total_categories = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
# 1. 读取jsonl文件并统计
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
answer = data.get('answer')
|
||||||
|
if answer:
|
||||||
|
category_counts[answer] += 1
|
||||||
|
total_samples += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"警告:跳过无法解析的行: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_categories = len(category_counts)
|
||||||
|
|
||||||
|
# 2. 创建Excel工作簿和工作表
|
||||||
|
wb = openpyxl.Workbook()
|
||||||
|
ws = wb.active
|
||||||
|
ws.title = "类别统计"
|
||||||
|
|
||||||
|
# 3. 写入表头
|
||||||
|
headers = ["类别名称", "数据量", "占比(%)"]
|
||||||
|
ws.append(headers)
|
||||||
|
|
||||||
|
# 4. 写入统计数据
|
||||||
|
for category, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):
|
||||||
|
percentage = (count / total_samples) * 100
|
||||||
|
ws.append([category, count, f"{percentage:.2f}%"])
|
||||||
|
|
||||||
|
# 5. 写入汇总信息
|
||||||
|
ws.append([]) # 空行分隔
|
||||||
|
ws.append(["总类别数", total_categories])
|
||||||
|
ws.append(["总数据量", total_samples])
|
||||||
|
|
||||||
|
# 6. 设置单元格样式
|
||||||
|
# 设置列宽
|
||||||
|
ws.column_dimensions['A'].width = 40 # 类别名称列
|
||||||
|
ws.column_dimensions['B'].width = 15 # 数据量列
|
||||||
|
ws.column_dimensions['C'].width = 15 # 占比列
|
||||||
|
|
||||||
|
# 设置标题行样式
|
||||||
|
for cell in ws[1]:
|
||||||
|
cell.font = openpyxl.styles.Font(bold=True)
|
||||||
|
|
||||||
|
# 7. 保存Excel文件
|
||||||
|
wb.save(output_xlsx)
|
||||||
|
print(f"统计完成!共{total_categories}个类别,{total_samples}条数据。结果已保存到{output_xlsx}")
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
input_jsonl = "/data/zhaochsh01/buquan/12345/tool/output_prompt.jsonl" # 替换为你的jsonl文件路径
|
||||||
|
output_excel = "12345category_stats.xlsx" # 输出Excel文件路径
|
||||||
|
count_categories_from_jsonl(input_jsonl, output_excel)
|
18
data_generate/zcs/fenbu/tool/count_tag.py
Normal file
18
data_generate/zcs/fenbu/tool/count_tag.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# 示例数据,假设有一个名为'data.xlsx'的文件,其中包含'answe'列
|
||||||
|
# 这里使用模拟数据来演示过程
|
||||||
|
|
||||||
|
# 模拟数据
|
||||||
|
df = pd.read_excel('/data/zhaochsh01/buquan/12345/count_result/12345_prompt.xlsx')
|
||||||
|
|
||||||
|
|
||||||
|
# 按照answe列的值统计每个类的数量
|
||||||
|
category_counts = df['answer'].value_counts().reset_index()
|
||||||
|
category_counts.columns = ['Category', 'Count']
|
||||||
|
|
||||||
|
# 输出到xlsx文件
|
||||||
|
output_file = '/data/zhaochsh01/buquan/12345/count_result/category_counts.xlsx'
|
||||||
|
category_counts.to_excel(output_file, index=False)
|
||||||
|
|
||||||
|
output_file
|
52
data_generate/zcs/fenbu/tool/format.py
Normal file
52
data_generate/zcs/fenbu/tool/format.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
def process_jsonl_file(input_file, output_file):
|
||||||
|
"""
|
||||||
|
处理包含两种格式的jsonl文件:
|
||||||
|
1. 有"主要内容"字段的数据:保留指定字段并重命名
|
||||||
|
2. 其他格式数据:原样保留
|
||||||
|
|
||||||
|
:param input_file: 输入文件路径
|
||||||
|
:param output_file: 输出文件路径
|
||||||
|
"""
|
||||||
|
processed_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as infile, \
|
||||||
|
open(output_file, 'w', encoding='utf-8') as outfile:
|
||||||
|
|
||||||
|
for line in infile:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
|
||||||
|
# 检查是否有"主要内容"字段
|
||||||
|
if '主要内容' in data:
|
||||||
|
# 创建新数据对象,只保留指定字段
|
||||||
|
new_data = {
|
||||||
|
'工单编号': data.get('工单编号'),
|
||||||
|
'data': data.get('主要内容'), # 重命名字段
|
||||||
|
'ins_tag_label': data.get('ins_tag_label'),
|
||||||
|
'cluster_center': data.get('cluster_center')
|
||||||
|
}
|
||||||
|
|
||||||
|
# 写入处理后的数据
|
||||||
|
outfile.write(json.dumps(new_data, ensure_ascii=False) + '\n')
|
||||||
|
processed_count += 1
|
||||||
|
else:
|
||||||
|
# 没有"主要内容"字段的数据原样写入
|
||||||
|
outfile.write(line)
|
||||||
|
skipped_count += 1
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"警告:跳过无法解析的行: {line.strip()}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"处理完成!共处理 {processed_count} 条有'主要内容'的数据")
|
||||||
|
print(f"保留原样 {skipped_count} 条其他格式数据")
|
||||||
|
print(f"结果已保存到 {output_file}")
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
input_file = "./merged_result_for_more_test.jsonl" # 替换为你的输入文件路径
|
||||||
|
output_file = "./12345_merged_resul.jsonl" # 输出文件路径
|
||||||
|
process_jsonl_file(input_file, output_file)
|
40
data_generate/zcs/fenbu/tool/merge.py
Normal file
40
data_generate/zcs/fenbu/tool/merge.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
# 1. 读取 K-Means 聚类结果文件(JSON 格式),构建 {data_idx: cluster_center} 的字典
|
||||||
|
kmeans_data = {}
|
||||||
|
with open('/data/zhaochsh01/buquan/12345/kmeans/123451wfilter_cluster_kmeans_result.jsonl', 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
data_idx = entry.get("data_idx") # 获取 data_idx
|
||||||
|
cluster_center = entry.get("cluster_center") # 获取 cluster_center(整数)
|
||||||
|
if data_idx and cluster_center is not None:
|
||||||
|
kmeans_data[data_idx] = cluster_center
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing K-Means line: {line}, error: {e}")
|
||||||
|
|
||||||
|
# 2. 读取 JSONL 文件,匹配并合并 cluster_center
|
||||||
|
output_lines = []
|
||||||
|
with open('/data/zhaochsh01/buquan/12345/instag/123451wfilter_instag.jsonl', 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
key_id = str(data.get("工单编号")) # 转为字符串确保匹配
|
||||||
|
if key_id in kmeans_data:
|
||||||
|
data["cluster_center"] = kmeans_data[key_id] # 添加 cluster_center
|
||||||
|
output_lines.append(json.dumps(data, ensure_ascii=False)) # 重新转为 JSON 字符串
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing JSONL line: {line}, error: {e}")
|
||||||
|
|
||||||
|
# 3. 将结果写入新文件
|
||||||
|
with open('merged_result.jsonl', 'w', encoding='utf-8') as f:
|
||||||
|
for line in output_lines:
|
||||||
|
f.write(line + '\n')
|
||||||
|
|
||||||
|
print("数据处理完成,结果已保存到 merged_result.jsonl")
|
67
data_generate/zcs/fenbu/tool/prompt_change.py
Normal file
67
data_generate/zcs/fenbu/tool/prompt_change.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def select_and_transform_data(input_file, output_file, samples_per_category=4, total_samples=250):
|
||||||
|
"""
|
||||||
|
读取jsonl文件,按answer分类,每类选指定条数数据,转换字段并保存
|
||||||
|
|
||||||
|
:param input_file: 输入jsonl文件路径
|
||||||
|
:param output_file: 输出jsonl文件路径
|
||||||
|
:param samples_per_category: 每类选取的样本数
|
||||||
|
:param total_samples: 总共需要选取的样本数
|
||||||
|
"""
|
||||||
|
# 按answer分类存储数据
|
||||||
|
categorized_data = defaultdict(list)
|
||||||
|
|
||||||
|
# 1. 读取并分类数据
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
answer = data.get('answer')
|
||||||
|
if answer:
|
||||||
|
categorized_data[answer].append(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"警告:跳过无法解析的行: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. 从每类中随机选取指定数量的样本
|
||||||
|
selected_data = []
|
||||||
|
selected_count = 0
|
||||||
|
|
||||||
|
for category, items in categorized_data.items():
|
||||||
|
# 计算当前类别最多能取多少样本(避免超过总数限制)
|
||||||
|
remaining = total_samples - selected_count
|
||||||
|
take = min(samples_per_category, len(items), remaining)
|
||||||
|
|
||||||
|
if take <= 0:
|
||||||
|
break # 已达到总数要求
|
||||||
|
|
||||||
|
# 随机选取(这里简单取前take条,如需随机可改为random.sample)
|
||||||
|
selected = items[:take]
|
||||||
|
|
||||||
|
# 3. 转换字段并添加到结果
|
||||||
|
for item in selected:
|
||||||
|
transformed = {
|
||||||
|
"工单编号": item["uid"],
|
||||||
|
"data": item["data"],
|
||||||
|
"answer": item["answer"]
|
||||||
|
}
|
||||||
|
selected_data.append(transformed)
|
||||||
|
selected_count += 1
|
||||||
|
|
||||||
|
if selected_count >= total_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 4. 保存结果到jsonl文件
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for item in selected_data:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
print(f"处理完成!共选取{selected_count}条数据,保存到{output_file}")
|
||||||
|
print(f"分类统计:{ {k: len([d for d in selected_data if d['answer'] == k]) for k in categorized_data} }")
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
input_file = "/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl" # 替换为你的输入文件路径
|
||||||
|
output_file = "output_prompt.jsonl" # 输出文件路径
|
||||||
|
select_and_transform_data(input_file, output_file, samples_per_category=4, total_samples=250)
|
85
data_generate/zcs/fenbu/tool/prompt_choose250.py
Normal file
85
data_generate/zcs/fenbu/tool/prompt_choose250.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def select_and_transform_data(input_file, output_file, small_category_max=2, normal_samples=4, total_samples=250):
|
||||||
|
"""
|
||||||
|
读取jsonl文件,按answer分类,小类别(<=2)全选,其他类别每类选指定条数数据,转换字段并保存
|
||||||
|
|
||||||
|
:param input_file: 输入jsonl文件路径
|
||||||
|
:param output_file: 输出jsonl文件路径
|
||||||
|
:param small_category_max: 小类别的最大数据量阈值(<=此值视为小类别)
|
||||||
|
:param normal_samples: 普通类别每类选取的样本数
|
||||||
|
:param total_samples: 总共需要选取的样本数
|
||||||
|
"""
|
||||||
|
# 按answer分类存储数据
|
||||||
|
categorized_data = defaultdict(list)
|
||||||
|
|
||||||
|
# 1. 读取并分类数据
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
answer = data.get('answer')
|
||||||
|
if answer:
|
||||||
|
categorized_data[answer].append(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"警告:跳过无法解析的行: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 分离小类别和普通类别
|
||||||
|
small_categories = {k: v for k, v in categorized_data.items() if len(v) <= small_category_max}
|
||||||
|
normal_categories = {k: v for k, v in categorized_data.items() if len(v) > small_category_max}
|
||||||
|
|
||||||
|
# 2. 先选取小类别的所有数据
|
||||||
|
selected_data = []
|
||||||
|
selected_count = 0
|
||||||
|
|
||||||
|
for category, items in small_categories.items():
|
||||||
|
# 全取小类别的数据
|
||||||
|
for item in items:
|
||||||
|
transformed = {
|
||||||
|
"工单编号": item["uid"],
|
||||||
|
"data": item["data"],
|
||||||
|
"answer": item["answer"]
|
||||||
|
}
|
||||||
|
selected_data.append(transformed)
|
||||||
|
selected_count += 1
|
||||||
|
|
||||||
|
# 3. 从普通类别中选取数据,直到达到总数
|
||||||
|
for category, items in normal_categories.items():
|
||||||
|
if selected_count >= total_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 计算还能取多少数据
|
||||||
|
remaining = total_samples - selected_count
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 取min(普通类别样本数, 当前类别数据量, 剩余需要的数据量)
|
||||||
|
take = min(normal_samples, len(items), remaining)
|
||||||
|
|
||||||
|
# 选取数据(这里简单取前take条,如需随机可改为random.sample)
|
||||||
|
for item in items[:take]:
|
||||||
|
transformed = {
|
||||||
|
"工单编号": item["uid"],
|
||||||
|
"data": item["data"],
|
||||||
|
"answer": item["answer"]
|
||||||
|
}
|
||||||
|
selected_data.append(transformed)
|
||||||
|
selected_count += 1
|
||||||
|
|
||||||
|
# 4. 保存结果到jsonl文件
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for item in selected_data:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
print(f"处理完成!共选取{selected_count}条数据,保存到{output_file}")
|
||||||
|
print("分类统计:")
|
||||||
|
stat = {k: len([d for d in selected_data if d['answer'] == k]) for k in categorized_data}
|
||||||
|
for k, v in stat.items():
|
||||||
|
print(f"{k}: {v}条 (共{categorized_data[k]}条)")
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
input_file = "/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl" # 替换为你的输入文件路径
|
||||||
|
output_file = "output_prompt.jsonl" # 输出文件路径
|
||||||
|
select_and_transform_data(input_file, output_file, small_category_max=2, normal_samples=4, total_samples=250)
|
144
data_generate/zcs/fenbu/tool/prompt_label2.py
Normal file
144
data_generate/zcs/fenbu/tool/prompt_label2.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import pandas as pd
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl_lines_in_batches(file_path, batch_size=10000):
|
||||||
|
"""按批次读取 JSONL 文件"""
|
||||||
|
batch = []
|
||||||
|
with open(file_path, mode="r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
batch.append(json.loads(line.strip()))
|
||||||
|
if len(batch) == batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error decoding JSON: {e}")
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def process_data_concurrently(data_list, api_url, headers, max_workers=10):
|
||||||
|
"""并发处理数据并调用 API"""
|
||||||
|
result_data = []
|
||||||
|
list_category = [
|
||||||
|
"农业生产", "农村工作", "农村生活", "农民生活", "巩固脱贫攻坚成果", "精准扶贫", "人才交流",
|
||||||
|
"企业工资福利", "企业离休退", "企业离休退休", "劳动保护", "劳动关系", "劳动纠纷",
|
||||||
|
"就业创业", "招录辞退", "社会保险政策", "职务职称", "行政效能", "信息查询", "历史遗留",
|
||||||
|
"表扬感谢", "人防工作", "住房与房地产", "园林绿化", "国有土地上房屋征收补偿", "城乡规划",
|
||||||
|
"工程质量", "建筑市场管理", "道路修建", "交通运输", "城市管理", "城管执法", "居民生活",
|
||||||
|
"环境保护", "旅游信息", "人口计生", "体育工作", "医疗保障", "卫生", "执法监督", "救灾救济",
|
||||||
|
"教育", "文化工作", "文明建设", "残疾人服务管理", "民政", "民族宗教", "法律服务", "社会治安",
|
||||||
|
"退役军人", "信息通信产业", "商业贸易", "国资监管", "安全生产", "市场秩序", "招商引资",
|
||||||
|
"旅游管理", "综合审批", "财税金融", "质量监管", "自然资产:海洋渔业", "国土资源", "林业",
|
||||||
|
"水利", "自然资源:海洋渔业"
|
||||||
|
]
|
||||||
|
def process_single_data(data):
|
||||||
|
try:
|
||||||
|
query = data.get('主要内容')
|
||||||
|
if query:
|
||||||
|
input_content = f'''
|
||||||
|
您是一位文本分类专家,请依据专业的眼光判断下当前 query 属于以下65个领域类别中的哪一个并使用简体中文给出分析过程。
|
||||||
|
领域类别:{list_category}
|
||||||
|
需要判断的内容:{query}
|
||||||
|
**注意:严格仅输出领域类别中的其中一个类别。
|
||||||
|
|
||||||
|
"输出格式如下:结果:<插入返回的结果>\n分析过程:"
|
||||||
|
'''
|
||||||
|
response = requests.post(
|
||||||
|
api_url,
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"model": "Qwen2.5-72B-Instruct",
|
||||||
|
"stream": False,
|
||||||
|
"temperature": 0.01,
|
||||||
|
"messages": [{"role": "user", "content": input_content}]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
try:
|
||||||
|
answer = ""
|
||||||
|
final_result = ""
|
||||||
|
content = response.json()["choices"][0]["message"]["content"]
|
||||||
|
print(f"返回结果为:{content}")
|
||||||
|
# 提取 `返回结果为:结果:XXX` 的最后一个词
|
||||||
|
match = re.search(r'返回结果为:结果[::]\s*(\S+)', content)
|
||||||
|
if match:
|
||||||
|
final_result = match.group(1).strip("。") # 去除可能结尾的标点
|
||||||
|
print(f"截取的结果为:{final_result}") # 输出:交通运输
|
||||||
|
else:
|
||||||
|
print("未找到有效分类")
|
||||||
|
# answer = json.loads(content)
|
||||||
|
# answer = answer.get("返回结果")
|
||||||
|
# print(f"解析的结果为:{answer}")
|
||||||
|
except (KeyError, IndexError, json.JSONDecodeError):
|
||||||
|
content = "无法解析返回内容"
|
||||||
|
else:
|
||||||
|
content = f"API请求失败,状态码:{response.status_code}"
|
||||||
|
return {
|
||||||
|
"uid": data.get('工单编号'),
|
||||||
|
"data": query,
|
||||||
|
"answer": final_result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = [executor.submit(process_single_data, data) for data in data_list]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
result = future.result()
|
||||||
|
if result:
|
||||||
|
result_data.append(result)
|
||||||
|
logger.info(f"已完成 {data.get('工单编号')} ({i}/{len(data_list)})")
|
||||||
|
|
||||||
|
return result_data
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_excel_in_batches(data_list, output_file, batch_size=10000):
|
||||||
|
"""按批次保存数据到 Excel 文件"""
|
||||||
|
df = pd.DataFrame(data_list)
|
||||||
|
writer = pd.ExcelWriter(output_file, engine='openpyxl')
|
||||||
|
for i in range(0, len(df), batch_size):
|
||||||
|
batch_df = df.iloc[i:i + batch_size]
|
||||||
|
batch_df.to_excel(writer, index=False, startrow=i)
|
||||||
|
writer.close()
|
||||||
|
print(f"数据已成功保存到 {output_file}")
|
||||||
|
|
||||||
|
def save_to_jsonl_in_batches(data_list, output_file, batch_size=10000):
|
||||||
|
"""按批次保存数据到 JSONL 文件"""
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for i in range(0, len(data_list), batch_size):
|
||||||
|
# 获取当前批次的数据
|
||||||
|
batch_data = data_list[i:i + batch_size]
|
||||||
|
# 将每个数据对象写入文件,每行一个 JSON 对象
|
||||||
|
for item in batch_data:
|
||||||
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
print(f"数据已成功保存到 {output_file}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#output_excel_file = 'result-taoli-5.xlsx'
|
||||||
|
# api_url = "http://100.105.149.39:8000/v1/chat/completions"
|
||||||
|
api_url = "http://100.105.246.130:8000/v1/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c"
|
||||||
|
}
|
||||||
|
|
||||||
|
#file_path = '/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_BELLE_Multiturn_Chat_filtered_2504232014.jsonl'
|
||||||
|
file_path = '/data/zhaochsh01/buquan/12345/tool/output.jsonl'
|
||||||
|
output_file = '/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl'
|
||||||
|
output_excel_file = '/data/zhaochsh01/buquan/12345/count_result/12345_prompt.xlsx'
|
||||||
|
|
||||||
|
#file_path = '/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_COIG_filtered_2504212014.jsonl'
|
||||||
|
all_results = []
|
||||||
|
for batch in read_jsonl_lines_in_batches(file_path, batch_size=10000):
|
||||||
|
processed_batch = process_data_concurrently(batch, api_url, headers, max_workers=20)
|
||||||
|
all_results.extend(processed_batch)
|
||||||
|
save_to_excel_in_batches(all_results, output_excel_file, batch_size=10000)
|
||||||
|
save_to_jsonl_in_batches(all_results, output_file, batch_size=10000)
|
67
data_generate/zcs/fenbu/tool/xslx2jsonl.py
Normal file
67
data_generate/zcs/fenbu/tool/xslx2jsonl.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def convert_value(obj):
|
||||||
|
"""处理各种不可JSON序列化的类型"""
|
||||||
|
# 处理空值
|
||||||
|
if pd.isna(obj) or obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理时间类型
|
||||||
|
if isinstance(obj, (pd.Timestamp, datetime)):
|
||||||
|
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
|
# 处理NaT类型
|
||||||
|
if isinstance(obj, pd._libs.tslibs.nattype.NaTType):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理numpy数值类型
|
||||||
|
if isinstance(obj, (np.integer, np.floating)):
|
||||||
|
return int(obj) if isinstance(obj, np.integer) else float(obj)
|
||||||
|
|
||||||
|
# 处理numpy数组和pandas Series
|
||||||
|
if isinstance(obj, np.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
if isinstance(obj, pd.Series):
|
||||||
|
return obj.to_dict()
|
||||||
|
|
||||||
|
# 其他类型直接返回
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def xlsx_to_jsonl(input_file, output_file):
|
||||||
|
"""
|
||||||
|
将XLSX文件转换为JSONL格式
|
||||||
|
|
||||||
|
参数:
|
||||||
|
input_file (str): 输入的XLSX文件路径
|
||||||
|
output_file (str): 输出的JSONL文件路径
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 读取Excel文件
|
||||||
|
df = pd.read_excel(input_file)
|
||||||
|
|
||||||
|
# 将数据写入JSONL文件
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
# 将行转换为字典并处理所有值
|
||||||
|
record = {k: convert_value(v) for k, v in row.items()}
|
||||||
|
|
||||||
|
# 写入JSON行
|
||||||
|
json.dump(record, f, ensure_ascii=False)
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
print(f"转换成功,结果已保存到 {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"转换过程中发生错误: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input_xlsx = "/data/zhaochsh01/buquan/12345/1w_fillter.xlsx" # 替换为你的输入文件路径
|
||||||
|
output_jsonl = "output.jsonl" # 替换为你想要的输出文件路径
|
||||||
|
|
||||||
|
xlsx_to_jsonl(input_xlsx, output_jsonl)
|
109
data_generate/zcs/fenbu/vllm_infer_qw_instag_format_json.py
Normal file
109
data_generate/zcs/fenbu/vllm_infer_qw_instag_format_json.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import json
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
# need vllm==0.4.2
|
||||||
|
# firefly_copy 使用
|
||||||
|
|
||||||
|
# step1 读入待跑 tag 的 jsonl 文件
|
||||||
|
input_jsonl_file = "/data/zhaochsh01/buquan/12345/tool/output.jsonl"
|
||||||
|
with open(input_jsonl_file, 'r', encoding='utf8') as f:
|
||||||
|
raw_data_list = f.readlines()
|
||||||
|
print("len_raw_data_list:", len(raw_data_list))
|
||||||
|
|
||||||
|
# 这里测试,先跑前100个数据
|
||||||
|
# raw_data_list = raw_data_list[:100]
|
||||||
|
|
||||||
|
# step2 读取 instag 模型 配置 tokenizer
|
||||||
|
model_path = '/model-pvc/instagger_qwen1_8B/'
|
||||||
|
tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
||||||
|
tokenizer.chat_template = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{{ '### ' + message['role'].capitalize() + ':\\n' }}"
|
||||||
|
"{{ message['content'] | string + '\\n' }}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"{{ '### Assistant:\\n' }}"
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
)
|
||||||
|
# step3 这里内存够用,一次性读取所有文件进行拼接 prompt 处理
|
||||||
|
prompt_text_list = []
|
||||||
|
for index, line in enumerate(raw_data_list):
|
||||||
|
# print("process:", index)
|
||||||
|
json_obj = json.loads(line)
|
||||||
|
query = json_obj["主要内容"]
|
||||||
|
prompt = query
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are an intelligent assistant, please identify the category based on the user’s query and return the corresponding category."},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
prompt_text_list.append(text)
|
||||||
|
|
||||||
|
print("len_prompt_text_list:", len(prompt_text_list))
|
||||||
|
assert len(raw_data_list) == len(prompt_text_list)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_raw_qwen_res(response):
|
||||||
|
if "<|im_end|>" in response:
|
||||||
|
response = response.split("<|im_end|>")[0]
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# step4: vllm 推理参数设置,官方推荐这里要求 temperature=0
|
||||||
|
sampling_params = SamplingParams(max_tokens=64,
|
||||||
|
temperature=0,
|
||||||
|
top_p=0.8,
|
||||||
|
repetition_penalty=1.05)
|
||||||
|
|
||||||
|
# step5: 读入模型 这里读入模型往后放,数据有问题在load 模型之前暴露出来节省时间
|
||||||
|
# set gpu_memory_utilization=0.95 for OOM error
|
||||||
|
vllm_model = LLM(model=model_path,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
gpu_memory_utilization=0.85,
|
||||||
|
max_num_seqs=64, # 减少最大并发序列数(默认256)
|
||||||
|
max_model_len=2048, # 减少最大上下文长度
|
||||||
|
enforce_eager=True,
|
||||||
|
trust_remote_code=True)
|
||||||
|
vllm_model.set_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
# step6: 分批量推理并且得到结果的过程
|
||||||
|
tag_result_list = []
|
||||||
|
batch_size = 1024
|
||||||
|
for start in trange(0, len(prompt_text_list), batch_size):
|
||||||
|
print("index:", start)
|
||||||
|
batch_list = prompt_text_list[start: start + batch_size]
|
||||||
|
vllm_outputs = vllm_model.generate(batch_list, sampling_params)
|
||||||
|
for output in vllm_outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
cur_tag_result = extract_raw_qwen_res(output.outputs[0].text)
|
||||||
|
tag_result_list.append(cur_tag_result)
|
||||||
|
|
||||||
|
print("len_tag_result_list:", len(tag_result_list))
|
||||||
|
|
||||||
|
assert len(tag_result_list) == len(prompt_text_list)
|
||||||
|
|
||||||
|
# step7: 写入结果
|
||||||
|
print("开始写 instag 结果 数据")
|
||||||
|
single_turn_instag_path = "./12345/instag/123451wfilter_instag.jsonl"
|
||||||
|
print("single_turn_instag_path:", single_turn_instag_path)
|
||||||
|
single_f_sample_save = open(single_turn_instag_path, "w")
|
||||||
|
write_single_sample_count = 0
|
||||||
|
for line, tag in zip(raw_data_list, tag_result_list):
|
||||||
|
json_line = json.loads(line)
|
||||||
|
json_line["ins_tag_label"] = tag
|
||||||
|
cur_final_line = json.dumps(deepcopy(json_line), ensure_ascii=False)
|
||||||
|
write_single_sample_count = write_single_sample_count + 1
|
||||||
|
single_f_sample_save.write(cur_final_line + "\n")
|
||||||
|
if write_single_sample_count % 100000 == 0:
|
||||||
|
print("write_single_count:", write_single_sample_count)
|
||||||
|
|
||||||
|
print("all finish success !")
|
Loading…
Reference in New Issue
Block a user