Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
0a6da985dd | |||
0655f98c70 |
100
data_generate/zcs/fenbu/cluster_kmeans_train_from_torch_bin.py
Normal file
100
data_generate/zcs/fenbu/cluster_kmeans_train_from_torch_bin.py
Normal file
@ -0,0 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from sklearnex import patch_sklearn, unpatch_sklearn
|
||||
from matplotlib import pyplot as plt
|
||||
import time
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def train_kmeans_model(train_data, n_clusters, data_index_list):
|
||||
# 注意需要先做 patch_sklearn() 的操作之后再正常导入 sklearn 的工具包。
|
||||
patch_sklearn()
|
||||
# unpatch_sklearn()
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 这里 训练数据 要求 输入 是 np.ndarray 格式
|
||||
data_set = train_data
|
||||
print("train_data_len:" + str(len(data_set)))
|
||||
print("n_clusters:" + str(n_clusters))
|
||||
|
||||
# init 默认就是 k-means++
|
||||
model = KMeans(n_clusters=n_clusters, init='k-means++', max_iter=300, n_init="auto", random_state=0)
|
||||
|
||||
model.fit(data_set)
|
||||
# kmeans 模型本身就是包含了聚类中心
|
||||
center = model.cluster_centers_
|
||||
print(center)
|
||||
|
||||
cluster_label = model.predict(data_set)
|
||||
print("cluster_label_size:" + str(len(cluster_label)))
|
||||
print("type:" + str(type(cluster_label)))
|
||||
|
||||
train_dfs = pd.DataFrame(data_set)
|
||||
|
||||
train_dfs["predict_label"] = cluster_label
|
||||
train_dfs["data_index"] = data_index_list
|
||||
|
||||
print(train_dfs.columns.values)
|
||||
|
||||
end_time = time.time()
|
||||
avg_time_cost = (end_time - start_time) * 1.0
|
||||
print("train_kmeans_time:" + str(avg_time_cost) + " s")
|
||||
return train_dfs
|
||||
|
||||
|
||||
# step1: 读取保存的 embeding 二进制文件
|
||||
embed_file = "/data/zhaochsh01/buquan/12345/embedding/123451wfilter_embedding.pth"
|
||||
embed_dict_list = torch.load(embed_file)
|
||||
print("len_embed_dict_list:", len(embed_dict_list))
|
||||
|
||||
# step2: 数据 parse
|
||||
raw_embeding_list = []
|
||||
raw_index_list = []
|
||||
for line in enumerate(embed_dict_list):
|
||||
# line[0] 是 index, line[1] 是保存的 dict, 这里 读取进来 直接 就是 dict 对象
|
||||
cur_dict = line[1]
|
||||
cur_embeding = cur_dict["embedding"]
|
||||
cur_data_idx = cur_dict["data_index"]
|
||||
raw_embeding_list.append(cur_embeding)
|
||||
raw_index_list.append(cur_data_idx)
|
||||
|
||||
train_array = np.array(raw_embeding_list)
|
||||
print("train_array_shape:", train_array.shape)
|
||||
|
||||
# 总共聚类 1000类,先聚类 50大类,每个大类里选20小类
|
||||
# 好吧,直接干到1000个大类,要是效果不好,在继续选择
|
||||
num_cluster = 50
|
||||
# 这里会自动进行 pandas row index 对齐
|
||||
train_dfs = train_kmeans_model(train_array, num_cluster, raw_index_list)
|
||||
predict_label = train_dfs['predict_label'].tolist()
|
||||
|
||||
print("len_predict_label:", len(predict_label))
|
||||
|
||||
|
||||
# 这里不保存csv,因为 pandas 的 csv 数据保存在遇到特殊字符的时候有意想不到的异常!!
|
||||
#data_to_save = {'embeding': raw_embeding_list, 'cluster_center': predict_label, "data_idx": raw_index_list}
|
||||
data_to_save = [
|
||||
{
|
||||
"embedding": raw_embeding_list[i],
|
||||
"cluster_center": predict_label[i],
|
||||
"data_idx": raw_index_list[i]
|
||||
}
|
||||
for i in range(len(raw_embeding_list))
|
||||
]
|
||||
|
||||
output_file = "./12345/kmeans/123451wfilter_cluster_kmeans_result.jsonl"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for record in data_to_save:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
data_to_save_df = pd.DataFrame(data_to_save)
|
||||
data_to_save_df.to_pickle("./12345/kmeans/123451wfilter_cluster_kmeans_result.pkl")
|
123
data_generate/zcs/fenbu/select_final_data.py
Normal file
123
data_generate/zcs/fenbu/select_final_data.py
Normal file
@ -0,0 +1,123 @@
|
||||
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
def process_jsonl_file(file_path, top_n_per_group, result, seen_uids):
|
||||
"""
|
||||
读取 jsonl 文件,对每一行 group 数据,尽力从中选出 top_n_per_group 条不重复数据
|
||||
如果无法选够指定数量的数据,则跳过该 group 并打印警告信息。
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_idx, line in enumerate(f):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
|
||||
# 获取 group key 和对应的数据列表
|
||||
key = next(k for k in entry if k != "count")
|
||||
records = entry[key]
|
||||
|
||||
selected = []
|
||||
added_count = 0
|
||||
|
||||
# 遍历 records,直到选够 top_n_per_group 个不重复的数据
|
||||
for item in records:
|
||||
uid = item.get('工单编号')
|
||||
|
||||
if added_count >= top_n_per_group:
|
||||
break # 已经选够了,退出
|
||||
|
||||
if uid and uid not in seen_uids:
|
||||
seen_uids.add(uid)
|
||||
selected.append(item)
|
||||
result.append(item)
|
||||
added_count += 1
|
||||
|
||||
# 如果最终无法选满 top_n_per_group 条,可以给出警告或处理
|
||||
if added_count < top_n_per_group:
|
||||
print(f"[Group {key}] Only found {added_count}/{top_n_per_group} unique items. Skipping this group.")
|
||||
# 如果需要的话,可以从结果中移除已添加的记录
|
||||
for item in selected:
|
||||
result.remove(item)
|
||||
seen_uids.discard(item.get('uid'))
|
||||
else:
|
||||
print(f"[Group {key}] Successfully added {added_count} items.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing line {line_idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
def load_data_to_memory(input_file):
|
||||
"""
|
||||
读取jsonl文件,将数据存入result列表,工单编号存入seen_uids集合
|
||||
|
||||
:param input_file: 输入jsonl文件路径
|
||||
:return: tuple(result列表, seen_uids集合)
|
||||
"""
|
||||
result = []
|
||||
seen_uids = set()
|
||||
|
||||
try:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
# 解析JSON行数据
|
||||
data = json.loads(line.strip())
|
||||
|
||||
# 检查是否包含工单编号字段
|
||||
if '工单编号' in data:
|
||||
uid = data['工单编号']
|
||||
|
||||
# 添加到结果列表和已见集合
|
||||
result.append(data)
|
||||
seen_uids.add(uid)
|
||||
else:
|
||||
print(f"警告:跳过缺少工单编号的记录: {data}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:文件 {input_file} 不存在")
|
||||
return [], set()
|
||||
except Exception as e:
|
||||
print(f"读取文件时发生错误: {str(e)}")
|
||||
return [], set()
|
||||
|
||||
print(f"成功加载 {len(result)} 条数据,发现 {len(seen_uids)} 个唯一工单编号")
|
||||
return result, seen_uids
|
||||
|
||||
|
||||
def main():
|
||||
result = []
|
||||
seen_uids = set()
|
||||
result, seen_uids = load_data_to_memory("/data/zhaochsh01/buquan/12345/tool/output_prompt.jsonl")
|
||||
|
||||
|
||||
# Step 3: 处理 dhbq_count_instag.jsonl,每组前130条
|
||||
print("Step 3: Processing dhbq_count_instag.jsonl...")
|
||||
process_jsonl_file("/data/zhaochsh01/buquan/12345/count_result/12345_count_instag.jsonl", 3, result, seen_uids)
|
||||
|
||||
# Step 4: 处理 dhbq_count_cluster.jsonl,每组前4条
|
||||
print("Step 4: Processing dhbq_count_cluster.jsonl...")
|
||||
process_jsonl_file("/data/zhaochsh01/buquan/12345/count_result/12345_count_cluster.jsonl", 3, result, seen_uids)
|
||||
|
||||
# 最终写入输出文件
|
||||
output_file = "/data/zhaochsh01/buquan/12345/merge_result/merged_result_for_more_test.jsonl"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for item in result:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"Total merged items: {len(result)}")
|
||||
print(f"Saved to: {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -0,0 +1,233 @@
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
import random
|
||||
import argparse
|
||||
import json
|
||||
import pandas
|
||||
from typing import List
|
||||
|
||||
|
||||
def set_seed(seed=128):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
class MyTorchDataset(torch.utils.data.Dataset):
|
||||
# def __init__(self, input_file, tokenizer, max_seq_length):
|
||||
# self.tokenizer = tokenizer
|
||||
# self.max_seq_length = max_seq_length
|
||||
# logger.info('Loading data: {}'.format(input_file))
|
||||
# with open(input_file, 'r', encoding='utf8') as f:
|
||||
# data_list = f.readlines()
|
||||
# logger.info("There are {} data in dataset".format(len(data_list)))
|
||||
# # 测试,这里仅仅取 top 100
|
||||
# self.raw_data_list = data_list[:]
|
||||
#
|
||||
# # 这个数据在初始化的时候全部 tokenizer 好,如果在训练的时候每次选择一条,会拖慢整体获取 embeding 的速度
|
||||
# self.query_list = []
|
||||
# self.index_list = []
|
||||
# for line in self.raw_data_list:
|
||||
# json_line = json.loads(line)
|
||||
# print(json_line["data"][0])
|
||||
# # 这个仅仅得到 query 的 embeding, content[0] 是 query
|
||||
# query = json_line["data"][0]['content'].strip()
|
||||
# # 这个是原始数据里,单轮对话对应的 data_index
|
||||
# data_index = json_line["uid"]
|
||||
# self.query_list.append(query)
|
||||
# self.index_list.append(data_index)
|
||||
# assert len(self.query_list) == len(self.index_list)
|
||||
# logger.info(f"final len_query_list:{len(self.query_list)}")
|
||||
# logger.info(f"final len_index_list:{len(self.index_list)}")
|
||||
#
|
||||
# # 这里批量一次性整好所有的数据 tokenizer,这里 250w 数据大概需要2个小时,需要200g左右内存,100g 内存程序会崩溃
|
||||
# logger.info(f" 开始批量 tokenizer 所有数据 ... ")
|
||||
# self.all_data_token = self.tokenizer(self.query_list,
|
||||
# padding='max_length',
|
||||
# truncation=True,
|
||||
# max_length=self.max_seq_length,
|
||||
# return_tensors='pt')
|
||||
# logger.info(f" 批量 tokenizer 所有数据 完成 !!!")
|
||||
def __init__(self, input_file, tokenizer, max_seq_length):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_length = max_seq_length
|
||||
logger.info('Loading data: {}'.format(input_file))
|
||||
with open(input_file, 'r', encoding='utf8') as f:
|
||||
data_list = f.readlines()
|
||||
logger.info("There are {} data in dataset".format(len(data_list)))
|
||||
|
||||
# 测试:这里仅取 top 100 条数据
|
||||
self.raw_data_list = data_list #[:100]
|
||||
|
||||
# 提取对话内容
|
||||
self.query_list = []
|
||||
self.index_list = []
|
||||
for line in self.raw_data_list:
|
||||
json_line = json.loads(line)
|
||||
# 这个仅仅得到 query 的 embeding, content[0] 是 query
|
||||
query = json_line["主要内容"].strip()
|
||||
# 这个是原始数据里,单轮对话对应的 data_index
|
||||
data_index = json_line["工单编号"]
|
||||
data_index = str(data_index)
|
||||
self.query_list.append(query)
|
||||
self.index_list.append(data_index)
|
||||
|
||||
assert len(self.query_list) == len(self.index_list)
|
||||
logger.info(f"final len_query_list:{len(self.query_list)}")
|
||||
logger.info(f"final len_index_list:{len(self.index_list)}")
|
||||
|
||||
# 批量 Tokenize 所有数据
|
||||
logger.info(f"开始批量 tokenize 所有数据 ...")
|
||||
self.all_data_token = self.tokenizer(
|
||||
self.query_list,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=self.max_seq_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
logger.info(f"批量 tokenize 所有数据 完成 !!!")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# item 包含了 input_ids,attention_mask 等参数,已经转化为 tensor
|
||||
item = {key: torch.as_tensor(value[idx]) for key, value in self.all_data_token.items()}
|
||||
item['data_index'] = self.index_list[idx]
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.query_list)
|
||||
|
||||
|
||||
class TextEmbedder:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.device = torch.device("cuda:0")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path,
|
||||
model_max_length=args.max_length,
|
||||
# 这里是推理过程,指定 左侧padding,不然取推理结果有问题
|
||||
padding_side="left",
|
||||
use_fast=False)
|
||||
self.model = AutoModel.from_pretrained(self.args.model_name_or_path).to(self.device)
|
||||
logger.info('tokenizer,model 加载完成')
|
||||
self.all_dataset = MyTorchDataset(self.args.raw_data_path, self.tokenizer, self.args.max_length)
|
||||
self.data_loader = DataLoader(dataset=self.all_dataset, batch_size=self.args.batch_size,
|
||||
# pin_memory=True,num_workers=1, prefetch_factor=8
|
||||
)
|
||||
|
||||
def save_data(self, results: List) -> None:
|
||||
logger.info(f'需要保存的数据长度: {len(results)}')
|
||||
|
||||
df = pandas.DataFrame(results)
|
||||
df.sort_values(by="data_index", inplace=True)
|
||||
df.reset_index(drop=True, inplace=True)
|
||||
|
||||
df.to_pickle(self.args.output_path)
|
||||
logger.info(f"Saved pickle to {self.args.output_path}")
|
||||
|
||||
'''
|
||||
# 这个地方 .csv 会保存失败
|
||||
csv_file_name = "./rl_demo_emb_240624.csv"
|
||||
df.to_csv(csv_file_name, index=False)
|
||||
logger.info(f"Saved csv to {csv_file_name}")
|
||||
'''
|
||||
|
||||
# 指定保存的文件名
|
||||
torch_py_bin_name = './12345/embedding/123451wfilter_embedding.pth'
|
||||
# 请注意,torch.save()和torch.load()函数不仅可以用于保存和加载张量,还可以用于保存和加载任何Python对象,
|
||||
# 包括但不限于字典、集合、自定义类实例等。
|
||||
torch.save(results, torch_py_bin_name)
|
||||
logger.info(f'List has been saved to {torch_py_bin_name}')
|
||||
|
||||
def encode_samples(self):
|
||||
total_samples = len(self.all_dataset)
|
||||
logger.info(f"total_samples: {total_samples}")
|
||||
total_batches = len(self.data_loader)
|
||||
logger.info(f"total_batches: {total_batches}")
|
||||
|
||||
all_embeddings_list = []
|
||||
|
||||
for b_idx, batch in enumerate(tqdm(self.data_loader, total=total_batches)):
|
||||
self.model.eval()
|
||||
batch_idx = batch["data_index"]
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
attention_mask = batch['attention_mask'].to(self.device)
|
||||
output = None
|
||||
with torch.no_grad():
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
cls_embedding = output.pooler_output.detach().cpu().numpy().tolist()
|
||||
|
||||
# 记录 embedding 和 index
|
||||
cur_sample_idx = batch_idx#.tolist() # 如果 batch_idx 是张量
|
||||
cur_sample_dict_list = [
|
||||
{"embedding": cls_emb, "data_index": s_id}
|
||||
for cls_emb, s_id in zip(cls_embedding, cur_sample_idx)
|
||||
]
|
||||
all_embeddings_list.extend(cur_sample_dict_list)
|
||||
|
||||
return all_embeddings_list
|
||||
|
||||
# def encode_samples(self):
|
||||
# total_samples = len(self.all_dataset)
|
||||
# logger.info(f"total_samples:{total_samples}")
|
||||
# total_batches = len(self.data_loader)
|
||||
# logger.info(f"total_batches:{total_batches}")
|
||||
#
|
||||
# all_embeddings_list = []
|
||||
#
|
||||
# for b_idx, batch in enumerate(tqdm(self.data_loader, total=total_samples // self.args.batch_size)):
|
||||
# self.model.eval()
|
||||
#
|
||||
# batch_idx = batch["data_index"]
|
||||
# input_ids = batch['input_ids'].to(self.device)
|
||||
# attention_mask = batch['attention_mask'].to(self.device)
|
||||
# output = None
|
||||
# with torch.no_grad():
|
||||
# output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
#
|
||||
# cls_embedding = output.pooler_output.detach().cpu().numpy().tolist()
|
||||
# print(batch_idx)
|
||||
# # 这里批量记录下 embeding 和 index
|
||||
# cur_sample_idx = batch_idx.tolist()
|
||||
# cur_sample_dict_list = [{"embedding": cls_emb, "data_index": s_id} for cls_emb, s_id in
|
||||
# zip(cls_embedding, cur_sample_idx)]
|
||||
# # list extend 多个 list元素 flatten 合并成一个 list
|
||||
# all_embeddings_list.extend(cur_sample_dict_list)
|
||||
#
|
||||
# return all_embeddings_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.cuda.empty_cache()
|
||||
# 参数设置
|
||||
set_seed()
|
||||
# 若干参数设置
|
||||
parser = argparse.ArgumentParser()
|
||||
# 不要改该参数,系统会自动分配
|
||||
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||
# 开启的进程数(注意不是线程),不用设置该参数,会根据nproc_per_node自动设置
|
||||
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
||||
args = parser.parse_args()
|
||||
|
||||
args.max_length = 512
|
||||
args.batch_size = 2048
|
||||
args.model_name_or_path = "/model-pvc/suojiayi/bge-base-zh-v1.5/"
|
||||
#args.raw_data_path = "/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_COIG_filtered_2504212014.jsonl"
|
||||
args.raw_data_path = '/data/zhaochsh01/buquan/12345/tool/output.jsonl'
|
||||
args.output_path = "./12345/embedding/123451wffilter_embedding.bin"
|
||||
|
||||
logger.info(f"\nargs:{args}\n")
|
||||
|
||||
text_embeder = TextEmbedder(args)
|
||||
all_embeddings_list = text_embeder.encode_samples()
|
||||
|
||||
logger.info(f"len_all_embeddings_list:{len(all_embeddings_list)}")
|
||||
logger.info("Finished embedding")
|
||||
|
||||
text_embeder.save_data(all_embeddings_list)
|
||||
logger.info(f"Pipeline run complete.")
|
68
data_generate/zcs/fenbu/tool/count_readjsonl.py
Normal file
68
data_generate/zcs/fenbu/tool/count_readjsonl.py
Normal file
@ -0,0 +1,68 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
import openpyxl
|
||||
|
||||
def count_categories_from_jsonl(input_file, output_xlsx):
|
||||
"""
|
||||
读取jsonl文件,统计类别信息并输出到xlsx文件
|
||||
|
||||
:param input_file: 输入的jsonl文件路径
|
||||
:param output_xlsx: 输出的xlsx文件路径
|
||||
"""
|
||||
# 用于统计每个类别的数据量
|
||||
category_counts = defaultdict(int)
|
||||
total_categories = 0
|
||||
total_samples = 0
|
||||
|
||||
# 1. 读取jsonl文件并统计
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
answer = data.get('answer')
|
||||
if answer:
|
||||
category_counts[answer] += 1
|
||||
total_samples += 1
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
total_categories = len(category_counts)
|
||||
|
||||
# 2. 创建Excel工作簿和工作表
|
||||
wb = openpyxl.Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "类别统计"
|
||||
|
||||
# 3. 写入表头
|
||||
headers = ["类别名称", "数据量", "占比(%)"]
|
||||
ws.append(headers)
|
||||
|
||||
# 4. 写入统计数据
|
||||
for category, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):
|
||||
percentage = (count / total_samples) * 100
|
||||
ws.append([category, count, f"{percentage:.2f}%"])
|
||||
|
||||
# 5. 写入汇总信息
|
||||
ws.append([]) # 空行分隔
|
||||
ws.append(["总类别数", total_categories])
|
||||
ws.append(["总数据量", total_samples])
|
||||
|
||||
# 6. 设置单元格样式
|
||||
# 设置列宽
|
||||
ws.column_dimensions['A'].width = 40 # 类别名称列
|
||||
ws.column_dimensions['B'].width = 15 # 数据量列
|
||||
ws.column_dimensions['C'].width = 15 # 占比列
|
||||
|
||||
# 设置标题行样式
|
||||
for cell in ws[1]:
|
||||
cell.font = openpyxl.styles.Font(bold=True)
|
||||
|
||||
# 7. 保存Excel文件
|
||||
wb.save(output_xlsx)
|
||||
print(f"统计完成!共{total_categories}个类别,{total_samples}条数据。结果已保存到{output_xlsx}")
|
||||
|
||||
# 使用示例
|
||||
input_jsonl = "/data/zhaochsh01/buquan/12345/tool/output_prompt.jsonl" # 替换为你的jsonl文件路径
|
||||
output_excel = "12345category_stats.xlsx" # 输出Excel文件路径
|
||||
count_categories_from_jsonl(input_jsonl, output_excel)
|
18
data_generate/zcs/fenbu/tool/count_tag.py
Normal file
18
data_generate/zcs/fenbu/tool/count_tag.py
Normal file
@ -0,0 +1,18 @@
|
||||
import pandas as pd
|
||||
|
||||
# 示例数据,假设有一个名为'data.xlsx'的文件,其中包含'answe'列
|
||||
# 这里使用模拟数据来演示过程
|
||||
|
||||
# 模拟数据
|
||||
df = pd.read_excel('/data/zhaochsh01/buquan/12345/count_result/12345_prompt.xlsx')
|
||||
|
||||
|
||||
# 按照answe列的值统计每个类的数量
|
||||
category_counts = df['answer'].value_counts().reset_index()
|
||||
category_counts.columns = ['Category', 'Count']
|
||||
|
||||
# 输出到xlsx文件
|
||||
output_file = '/data/zhaochsh01/buquan/12345/count_result/category_counts.xlsx'
|
||||
category_counts.to_excel(output_file, index=False)
|
||||
|
||||
output_file
|
52
data_generate/zcs/fenbu/tool/format.py
Normal file
52
data_generate/zcs/fenbu/tool/format.py
Normal file
@ -0,0 +1,52 @@
|
||||
import json
|
||||
|
||||
def process_jsonl_file(input_file, output_file):
|
||||
"""
|
||||
处理包含两种格式的jsonl文件:
|
||||
1. 有"主要内容"字段的数据:保留指定字段并重命名
|
||||
2. 其他格式数据:原样保留
|
||||
|
||||
:param input_file: 输入文件路径
|
||||
:param output_file: 输出文件路径
|
||||
"""
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as infile, \
|
||||
open(output_file, 'w', encoding='utf-8') as outfile:
|
||||
|
||||
for line in infile:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
|
||||
# 检查是否有"主要内容"字段
|
||||
if '主要内容' in data:
|
||||
# 创建新数据对象,只保留指定字段
|
||||
new_data = {
|
||||
'工单编号': data.get('工单编号'),
|
||||
'data': data.get('主要内容'), # 重命名字段
|
||||
'ins_tag_label': data.get('ins_tag_label'),
|
||||
'cluster_center': data.get('cluster_center')
|
||||
}
|
||||
|
||||
# 写入处理后的数据
|
||||
outfile.write(json.dumps(new_data, ensure_ascii=False) + '\n')
|
||||
processed_count += 1
|
||||
else:
|
||||
# 没有"主要内容"字段的数据原样写入
|
||||
outfile.write(line)
|
||||
skipped_count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line.strip()}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
print(f"处理完成!共处理 {processed_count} 条有'主要内容'的数据")
|
||||
print(f"保留原样 {skipped_count} 条其他格式数据")
|
||||
print(f"结果已保存到 {output_file}")
|
||||
|
||||
# 使用示例
|
||||
input_file = "./merged_result_for_more_test.jsonl" # 替换为你的输入文件路径
|
||||
output_file = "./12345_merged_resul.jsonl" # 输出文件路径
|
||||
process_jsonl_file(input_file, output_file)
|
40
data_generate/zcs/fenbu/tool/merge.py
Normal file
40
data_generate/zcs/fenbu/tool/merge.py
Normal file
@ -0,0 +1,40 @@
|
||||
import json
|
||||
|
||||
# 1. 读取 K-Means 聚类结果文件(JSON 格式),构建 {data_idx: cluster_center} 的字典
|
||||
kmeans_data = {}
|
||||
with open('/data/zhaochsh01/buquan/12345/kmeans/123451wfilter_cluster_kmeans_result.jsonl', 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
data_idx = entry.get("data_idx") # 获取 data_idx
|
||||
cluster_center = entry.get("cluster_center") # 获取 cluster_center(整数)
|
||||
if data_idx and cluster_center is not None:
|
||||
kmeans_data[data_idx] = cluster_center
|
||||
except Exception as e:
|
||||
print(f"Error parsing K-Means line: {line}, error: {e}")
|
||||
|
||||
# 2. 读取 JSONL 文件,匹配并合并 cluster_center
|
||||
output_lines = []
|
||||
with open('/data/zhaochsh01/buquan/12345/instag/123451wfilter_instag.jsonl', 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
key_id = str(data.get("工单编号")) # 转为字符串确保匹配
|
||||
if key_id in kmeans_data:
|
||||
data["cluster_center"] = kmeans_data[key_id] # 添加 cluster_center
|
||||
output_lines.append(json.dumps(data, ensure_ascii=False)) # 重新转为 JSON 字符串
|
||||
except Exception as e:
|
||||
print(f"Error processing JSONL line: {line}, error: {e}")
|
||||
|
||||
# 3. 将结果写入新文件
|
||||
with open('merged_result.jsonl', 'w', encoding='utf-8') as f:
|
||||
for line in output_lines:
|
||||
f.write(line + '\n')
|
||||
|
||||
print("数据处理完成,结果已保存到 merged_result.jsonl")
|
67
data_generate/zcs/fenbu/tool/prompt_change.py
Normal file
67
data_generate/zcs/fenbu/tool/prompt_change.py
Normal file
@ -0,0 +1,67 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
def select_and_transform_data(input_file, output_file, samples_per_category=4, total_samples=250):
|
||||
"""
|
||||
读取jsonl文件,按answer分类,每类选指定条数数据,转换字段并保存
|
||||
|
||||
:param input_file: 输入jsonl文件路径
|
||||
:param output_file: 输出jsonl文件路径
|
||||
:param samples_per_category: 每类选取的样本数
|
||||
:param total_samples: 总共需要选取的样本数
|
||||
"""
|
||||
# 按answer分类存储数据
|
||||
categorized_data = defaultdict(list)
|
||||
|
||||
# 1. 读取并分类数据
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
answer = data.get('answer')
|
||||
if answer:
|
||||
categorized_data[answer].append(data)
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
# 2. 从每类中随机选取指定数量的样本
|
||||
selected_data = []
|
||||
selected_count = 0
|
||||
|
||||
for category, items in categorized_data.items():
|
||||
# 计算当前类别最多能取多少样本(避免超过总数限制)
|
||||
remaining = total_samples - selected_count
|
||||
take = min(samples_per_category, len(items), remaining)
|
||||
|
||||
if take <= 0:
|
||||
break # 已达到总数要求
|
||||
|
||||
# 随机选取(这里简单取前take条,如需随机可改为random.sample)
|
||||
selected = items[:take]
|
||||
|
||||
# 3. 转换字段并添加到结果
|
||||
for item in selected:
|
||||
transformed = {
|
||||
"工单编号": item["uid"],
|
||||
"data": item["data"],
|
||||
"answer": item["answer"]
|
||||
}
|
||||
selected_data.append(transformed)
|
||||
selected_count += 1
|
||||
|
||||
if selected_count >= total_samples:
|
||||
break
|
||||
|
||||
# 4. 保存结果到jsonl文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for item in selected_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"处理完成!共选取{selected_count}条数据,保存到{output_file}")
|
||||
print(f"分类统计:{ {k: len([d for d in selected_data if d['answer'] == k]) for k in categorized_data} }")
|
||||
|
||||
# 使用示例
|
||||
input_file = "/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl" # 替换为你的输入文件路径
|
||||
output_file = "output_prompt.jsonl" # 输出文件路径
|
||||
select_and_transform_data(input_file, output_file, samples_per_category=4, total_samples=250)
|
85
data_generate/zcs/fenbu/tool/prompt_choose250.py
Normal file
85
data_generate/zcs/fenbu/tool/prompt_choose250.py
Normal file
@ -0,0 +1,85 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
def select_and_transform_data(input_file, output_file, small_category_max=2, normal_samples=4, total_samples=250):
|
||||
"""
|
||||
读取jsonl文件,按answer分类,小类别(<=2)全选,其他类别每类选指定条数数据,转换字段并保存
|
||||
|
||||
:param input_file: 输入jsonl文件路径
|
||||
:param output_file: 输出jsonl文件路径
|
||||
:param small_category_max: 小类别的最大数据量阈值(<=此值视为小类别)
|
||||
:param normal_samples: 普通类别每类选取的样本数
|
||||
:param total_samples: 总共需要选取的样本数
|
||||
"""
|
||||
# 按answer分类存储数据
|
||||
categorized_data = defaultdict(list)
|
||||
|
||||
# 1. 读取并分类数据
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
answer = data.get('answer')
|
||||
if answer:
|
||||
categorized_data[answer].append(data)
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
# 分离小类别和普通类别
|
||||
small_categories = {k: v for k, v in categorized_data.items() if len(v) <= small_category_max}
|
||||
normal_categories = {k: v for k, v in categorized_data.items() if len(v) > small_category_max}
|
||||
|
||||
# 2. 先选取小类别的所有数据
|
||||
selected_data = []
|
||||
selected_count = 0
|
||||
|
||||
for category, items in small_categories.items():
|
||||
# 全取小类别的数据
|
||||
for item in items:
|
||||
transformed = {
|
||||
"工单编号": item["uid"],
|
||||
"data": item["data"],
|
||||
"answer": item["answer"]
|
||||
}
|
||||
selected_data.append(transformed)
|
||||
selected_count += 1
|
||||
|
||||
# 3. 从普通类别中选取数据,直到达到总数
|
||||
for category, items in normal_categories.items():
|
||||
if selected_count >= total_samples:
|
||||
break
|
||||
|
||||
# 计算还能取多少数据
|
||||
remaining = total_samples - selected_count
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
# 取min(普通类别样本数, 当前类别数据量, 剩余需要的数据量)
|
||||
take = min(normal_samples, len(items), remaining)
|
||||
|
||||
# 选取数据(这里简单取前take条,如需随机可改为random.sample)
|
||||
for item in items[:take]:
|
||||
transformed = {
|
||||
"工单编号": item["uid"],
|
||||
"data": item["data"],
|
||||
"answer": item["answer"]
|
||||
}
|
||||
selected_data.append(transformed)
|
||||
selected_count += 1
|
||||
|
||||
# 4. 保存结果到jsonl文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for item in selected_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"处理完成!共选取{selected_count}条数据,保存到{output_file}")
|
||||
print("分类统计:")
|
||||
stat = {k: len([d for d in selected_data if d['answer'] == k]) for k in categorized_data}
|
||||
for k, v in stat.items():
|
||||
print(f"{k}: {v}条 (共{categorized_data[k]}条)")
|
||||
|
||||
# 使用示例
|
||||
input_file = "/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl" # 替换为你的输入文件路径
|
||||
output_file = "output_prompt.jsonl" # 输出文件路径
|
||||
select_and_transform_data(input_file, output_file, small_category_max=2, normal_samples=4, total_samples=250)
|
144
data_generate/zcs/fenbu/tool/prompt_label2.py
Normal file
144
data_generate/zcs/fenbu/tool/prompt_label2.py
Normal file
@ -0,0 +1,144 @@
|
||||
import json
|
||||
import requests
|
||||
import pandas as pd
|
||||
import re
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
def read_jsonl_lines_in_batches(file_path, batch_size=10000):
|
||||
"""按批次读取 JSONL 文件"""
|
||||
batch = []
|
||||
with open(file_path, mode="r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
batch.append(json.loads(line.strip()))
|
||||
if len(batch) == batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}")
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def process_data_concurrently(data_list, api_url, headers, max_workers=10):
|
||||
"""并发处理数据并调用 API"""
|
||||
result_data = []
|
||||
list_category = [
|
||||
"农业生产", "农村工作", "农村生活", "农民生活", "巩固脱贫攻坚成果", "精准扶贫", "人才交流",
|
||||
"企业工资福利", "企业离休退", "企业离休退休", "劳动保护", "劳动关系", "劳动纠纷",
|
||||
"就业创业", "招录辞退", "社会保险政策", "职务职称", "行政效能", "信息查询", "历史遗留",
|
||||
"表扬感谢", "人防工作", "住房与房地产", "园林绿化", "国有土地上房屋征收补偿", "城乡规划",
|
||||
"工程质量", "建筑市场管理", "道路修建", "交通运输", "城市管理", "城管执法", "居民生活",
|
||||
"环境保护", "旅游信息", "人口计生", "体育工作", "医疗保障", "卫生", "执法监督", "救灾救济",
|
||||
"教育", "文化工作", "文明建设", "残疾人服务管理", "民政", "民族宗教", "法律服务", "社会治安",
|
||||
"退役军人", "信息通信产业", "商业贸易", "国资监管", "安全生产", "市场秩序", "招商引资",
|
||||
"旅游管理", "综合审批", "财税金融", "质量监管", "自然资产:海洋渔业", "国土资源", "林业",
|
||||
"水利", "自然资源:海洋渔业"
|
||||
]
|
||||
def process_single_data(data):
|
||||
try:
|
||||
query = data.get('主要内容')
|
||||
if query:
|
||||
input_content = f'''
|
||||
您是一位文本分类专家,请依据专业的眼光判断下当前 query 属于以下65个领域类别中的哪一个并使用简体中文给出分析过程。
|
||||
领域类别:{list_category}
|
||||
需要判断的内容:{query}
|
||||
**注意:严格仅输出领域类别中的其中一个类别。
|
||||
|
||||
"输出格式如下:结果:<插入返回的结果>\n分析过程:"
|
||||
'''
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"model": "Qwen2.5-72B-Instruct",
|
||||
"stream": False,
|
||||
"temperature": 0.01,
|
||||
"messages": [{"role": "user", "content": input_content}]
|
||||
}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
answer = ""
|
||||
final_result = ""
|
||||
content = response.json()["choices"][0]["message"]["content"]
|
||||
print(f"返回结果为:{content}")
|
||||
# 提取 `返回结果为:结果:XXX` 的最后一个词
|
||||
match = re.search(r'返回结果为:结果[::]\s*(\S+)', content)
|
||||
if match:
|
||||
final_result = match.group(1).strip("。") # 去除可能结尾的标点
|
||||
print(f"截取的结果为:{final_result}") # 输出:交通运输
|
||||
else:
|
||||
print("未找到有效分类")
|
||||
# answer = json.loads(content)
|
||||
# answer = answer.get("返回结果")
|
||||
# print(f"解析的结果为:{answer}")
|
||||
except (KeyError, IndexError, json.JSONDecodeError):
|
||||
content = "无法解析返回内容"
|
||||
else:
|
||||
content = f"API请求失败,状态码:{response.status_code}"
|
||||
return {
|
||||
"uid": data.get('工单编号'),
|
||||
"data": query,
|
||||
"answer": final_result
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_single_data, data) for data in data_list]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
if result:
|
||||
result_data.append(result)
|
||||
logger.info(f"已完成 {data.get('工单编号')} ({i}/{len(data_list)})")
|
||||
|
||||
return result_data
|
||||
|
||||
|
||||
def save_to_excel_in_batches(data_list, output_file, batch_size=10000):
|
||||
"""按批次保存数据到 Excel 文件"""
|
||||
df = pd.DataFrame(data_list)
|
||||
writer = pd.ExcelWriter(output_file, engine='openpyxl')
|
||||
for i in range(0, len(df), batch_size):
|
||||
batch_df = df.iloc[i:i + batch_size]
|
||||
batch_df.to_excel(writer, index=False, startrow=i)
|
||||
writer.close()
|
||||
print(f"数据已成功保存到 {output_file}")
|
||||
|
||||
def save_to_jsonl_in_batches(data_list, output_file, batch_size=10000):
|
||||
"""按批次保存数据到 JSONL 文件"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
# 获取当前批次的数据
|
||||
batch_data = data_list[i:i + batch_size]
|
||||
# 将每个数据对象写入文件,每行一个 JSON 对象
|
||||
for item in batch_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
print(f"数据已成功保存到 {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
#output_excel_file = 'result-taoli-5.xlsx'
|
||||
# api_url = "http://100.105.149.39:8000/v1/chat/completions"
|
||||
api_url = "http://100.105.246.130:8000/v1/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c"
|
||||
}
|
||||
|
||||
#file_path = '/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_BELLE_Multiturn_Chat_filtered_2504232014.jsonl'
|
||||
file_path = '/data/zhaochsh01/buquan/12345/tool/output.jsonl'
|
||||
output_file = '/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl'
|
||||
output_excel_file = '/data/zhaochsh01/buquan/12345/count_result/12345_prompt.xlsx'
|
||||
|
||||
#file_path = '/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_COIG_filtered_2504212014.jsonl'
|
||||
all_results = []
|
||||
for batch in read_jsonl_lines_in_batches(file_path, batch_size=10000):
|
||||
processed_batch = process_data_concurrently(batch, api_url, headers, max_workers=20)
|
||||
all_results.extend(processed_batch)
|
||||
save_to_excel_in_batches(all_results, output_excel_file, batch_size=10000)
|
||||
save_to_jsonl_in_batches(all_results, output_file, batch_size=10000)
|
67
data_generate/zcs/fenbu/tool/xslx2jsonl.py
Normal file
67
data_generate/zcs/fenbu/tool/xslx2jsonl.py
Normal file
@ -0,0 +1,67 @@
|
||||
import pandas as pd
|
||||
import json
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
def convert_value(obj):
|
||||
"""处理各种不可JSON序列化的类型"""
|
||||
# 处理空值
|
||||
if pd.isna(obj) or obj is None:
|
||||
return None
|
||||
|
||||
# 处理时间类型
|
||||
if isinstance(obj, (pd.Timestamp, datetime)):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 处理NaT类型
|
||||
if isinstance(obj, pd._libs.tslibs.nattype.NaTType):
|
||||
return None
|
||||
|
||||
# 处理numpy数值类型
|
||||
if isinstance(obj, (np.integer, np.floating)):
|
||||
return int(obj) if isinstance(obj, np.integer) else float(obj)
|
||||
|
||||
# 处理numpy数组和pandas Series
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
if isinstance(obj, pd.Series):
|
||||
return obj.to_dict()
|
||||
|
||||
# 其他类型直接返回
|
||||
return obj
|
||||
|
||||
def xlsx_to_jsonl(input_file, output_file):
|
||||
"""
|
||||
将XLSX文件转换为JSONL格式
|
||||
|
||||
参数:
|
||||
input_file (str): 输入的XLSX文件路径
|
||||
output_file (str): 输出的JSONL文件路径
|
||||
"""
|
||||
try:
|
||||
# 读取Excel文件
|
||||
df = pd.read_excel(input_file)
|
||||
|
||||
# 将数据写入JSONL文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for _, row in df.iterrows():
|
||||
# 将行转换为字典并处理所有值
|
||||
record = {k: convert_value(v) for k, v in row.items()}
|
||||
|
||||
# 写入JSON行
|
||||
json.dump(record, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
print(f"转换成功,结果已保存到 {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"转换过程中发生错误: {str(e)}")
|
||||
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
input_xlsx = "/data/zhaochsh01/buquan/12345/1w_fillter.xlsx" # 替换为你的输入文件路径
|
||||
output_jsonl = "output.jsonl" # 替换为你想要的输出文件路径
|
||||
|
||||
xlsx_to_jsonl(input_xlsx, output_jsonl)
|
109
data_generate/zcs/fenbu/vllm_infer_qw_instag_format_json.py
Normal file
109
data_generate/zcs/fenbu/vllm_infer_qw_instag_format_json.py
Normal file
@ -0,0 +1,109 @@
|
||||
import json
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from tqdm import tqdm, trange
|
||||
from copy import deepcopy
|
||||
|
||||
# need vllm==0.4.2
|
||||
# firefly_copy 使用
|
||||
|
||||
# step1 读入待跑 tag 的 jsonl 文件
|
||||
input_jsonl_file = "/data/zhaochsh01/buquan/12345/tool/output.jsonl"
|
||||
with open(input_jsonl_file, 'r', encoding='utf8') as f:
|
||||
raw_data_list = f.readlines()
|
||||
print("len_raw_data_list:", len(raw_data_list))
|
||||
|
||||
# 这里测试,先跑前100个数据
|
||||
# raw_data_list = raw_data_list[:100]
|
||||
|
||||
# step2 读取 instag 模型 配置 tokenizer
|
||||
model_path = '/model-pvc/instagger_qwen1_8B/'
|
||||
tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
||||
tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{ '### ' + message['role'].capitalize() + ':\\n' }}"
|
||||
"{{ message['content'] | string + '\\n' }}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '### Assistant:\\n' }}"
|
||||
"{% endif %}"
|
||||
|
||||
)
|
||||
# step3 这里内存够用,一次性读取所有文件进行拼接 prompt 处理
|
||||
prompt_text_list = []
|
||||
for index, line in enumerate(raw_data_list):
|
||||
# print("process:", index)
|
||||
json_obj = json.loads(line)
|
||||
query = json_obj["主要内容"]
|
||||
prompt = query
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an intelligent assistant, please identify the category based on the user’s query and return the corresponding category."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
prompt_text_list.append(text)
|
||||
|
||||
print("len_prompt_text_list:", len(prompt_text_list))
|
||||
assert len(raw_data_list) == len(prompt_text_list)
|
||||
|
||||
|
||||
def extract_raw_qwen_res(response):
|
||||
if "<|im_end|>" in response:
|
||||
response = response.split("<|im_end|>")[0]
|
||||
return response
|
||||
|
||||
|
||||
# step4: vllm 推理参数设置,官方推荐这里要求 temperature=0
|
||||
sampling_params = SamplingParams(max_tokens=64,
|
||||
temperature=0,
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.05)
|
||||
|
||||
# step5: 读入模型 这里读入模型往后放,数据有问题在load 模型之前暴露出来节省时间
|
||||
# set gpu_memory_utilization=0.95 for OOM error
|
||||
vllm_model = LLM(model=model_path,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.85,
|
||||
max_num_seqs=64, # 减少最大并发序列数(默认256)
|
||||
max_model_len=2048, # 减少最大上下文长度
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True)
|
||||
vllm_model.set_tokenizer(tokenizer)
|
||||
|
||||
# step6: 分批量推理并且得到结果的过程
|
||||
tag_result_list = []
|
||||
batch_size = 1024
|
||||
for start in trange(0, len(prompt_text_list), batch_size):
|
||||
print("index:", start)
|
||||
batch_list = prompt_text_list[start: start + batch_size]
|
||||
vllm_outputs = vllm_model.generate(batch_list, sampling_params)
|
||||
for output in vllm_outputs:
|
||||
prompt = output.prompt
|
||||
cur_tag_result = extract_raw_qwen_res(output.outputs[0].text)
|
||||
tag_result_list.append(cur_tag_result)
|
||||
|
||||
print("len_tag_result_list:", len(tag_result_list))
|
||||
|
||||
assert len(tag_result_list) == len(prompt_text_list)
|
||||
|
||||
# step7: 写入结果
|
||||
print("开始写 instag 结果 数据")
|
||||
single_turn_instag_path = "./12345/instag/123451wfilter_instag.jsonl"
|
||||
print("single_turn_instag_path:", single_turn_instag_path)
|
||||
single_f_sample_save = open(single_turn_instag_path, "w")
|
||||
write_single_sample_count = 0
|
||||
for line, tag in zip(raw_data_list, tag_result_list):
|
||||
json_line = json.loads(line)
|
||||
json_line["ins_tag_label"] = tag
|
||||
cur_final_line = json.dumps(deepcopy(json_line), ensure_ascii=False)
|
||||
write_single_sample_count = write_single_sample_count + 1
|
||||
single_f_sample_save.write(cur_final_line + "\n")
|
||||
if write_single_sample_count % 100000 == 0:
|
||||
print("write_single_count:", write_single_sample_count)
|
||||
|
||||
print("all finish success !")
|
205
智数员工/zhaopin_zaoshu.py
Normal file
205
智数员工/zhaopin_zaoshu.py
Normal file
@ -0,0 +1,205 @@
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook
|
||||
import requests
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# 配置参数
|
||||
class Config:
|
||||
OUTPUT_FILE = "recruitment_data.xlsx"
|
||||
FIXED_QUESTIONS = [
|
||||
"上个月面试了多少人",
|
||||
"本周安排了几个面试",
|
||||
"招聘进度如何",
|
||||
"有多少候选人进入二面",
|
||||
"销售岗位的招聘情况",
|
||||
"技术岗位的简历筛选数量",
|
||||
"最近一周的offer发放数量",
|
||||
"哪个部门的招聘完成率最高",
|
||||
"招聘成本是否超出预算",
|
||||
"候选人平均面试周期是多长"
|
||||
]
|
||||
LOCATIONS = ["北京", "上海", "广州", "深圳", "杭州", "", "成都"]
|
||||
INTENTS = ["招聘数据", "招聘进度", "其他", "成本分析", "效率统计"]
|
||||
COMMISSIONER_TYPES = ["yxz", "hrbp", "recruiter", "manager"]
|
||||
USER_NAMES = ["张招聘", "李HR", "王人事", "赵经理", "刘专员"]
|
||||
|
||||
|
||||
|
||||
|
||||
async def chat(input_content):
|
||||
response = requests.post(
|
||||
api_url = "http://100.105.1.227:8000/v1/chat/completions",
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c"
|
||||
},
|
||||
json={
|
||||
"model": "Qwen3-72B",
|
||||
"stream": False,
|
||||
"temperature": 0.6,
|
||||
"TopP": 0.95,
|
||||
"TopK": 20,
|
||||
"MinP": 0,
|
||||
"messages": [{"role": "user", "content": input_content}]
|
||||
},
|
||||
timeout=180
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing API response: {e}")
|
||||
else:
|
||||
logging.error(f"API request failed with status code: {response.status_code}")
|
||||
await asyncio.sleep(0.1)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
# 模拟模型生成多样化问题
|
||||
async def generate_diverse_questions() -> List[str]:
|
||||
# 这里应该是实际调用模型生成多样化问题的代码
|
||||
# 模拟生成几个变体问题
|
||||
|
||||
input_content = """你是一个资深HR分析师。请生成一个招聘数据分析的查询请求,要求:
|
||||
- 聚焦在以下至少一个方面:面试、offer、入职、渠道效果、成本、周期时间
|
||||
- 包含具体的时间范围(如最近一周/上月/本季度)
|
||||
- 可选项包含部门/岗位/地域等维度
|
||||
- 直接返回问题,不要任何解释
|
||||
|
||||
例如:
|
||||
对比北京和上海地区过去两个月销售岗位的offer接受率"""
|
||||
gen_question = chat(input_content)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return gen_question
|
||||
|
||||
|
||||
|
||||
# 生成招聘相关的输入数据
|
||||
async def generate_input_data(use_fixed: bool = True) -> Dict[str, Any]:
|
||||
if random.random() > 0.3:
|
||||
base_question = random.choice(Config.FIXED_QUESTIONS)
|
||||
else:
|
||||
|
||||
base_question = await generate_diverse_questions()
|
||||
|
||||
|
||||
return {
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": base_question
|
||||
}],
|
||||
"location": random.choice(Config.LOCATIONS),
|
||||
"uuid": str(random.randint(1e18, 1e19-1)),
|
||||
"intent": random.choice(Config.INTENTS),
|
||||
"loginUserName": random.choice(Config.USER_NAMES),
|
||||
"loginUserId": "hr_" + str(random.randint(1000, 9999)),
|
||||
"commissioner_type": random.choice(Config.COMMISSIONER_TYPES)
|
||||
}
|
||||
|
||||
# 处理单个请求
|
||||
async def process_request(input_data: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
try:
|
||||
input_content = f"""
|
||||
你是一个专业招聘数据分析助手。请按以下规则处理问题:
|
||||
1. 如果问题已包含明确且可清晰回答,直接返回原问题
|
||||
2. 如果问题模糊或不完整,按标准改写:
|
||||
- 补充时间范围(最近/上月/本季度等)
|
||||
- 明确量化指标(数量/比率/趋势等)
|
||||
- 指定具体对象(岗位/部门/渠道等)
|
||||
3. 直接返回最终问题,不要任何解释
|
||||
|
||||
待处理问题:{input_data}
|
||||
"""
|
||||
user_content = input_data["messages"][0]["content"]
|
||||
rewritten_question = await chat(input_content)
|
||||
|
||||
output_data = {
|
||||
"code": "0",
|
||||
"message": "",
|
||||
"result": rewritten_question
|
||||
}
|
||||
return input_data, output_data
|
||||
except Exception as e:
|
||||
output_data = {
|
||||
"code": "1",
|
||||
"message": str(e),
|
||||
"result": ""
|
||||
}
|
||||
return input_data, output_data
|
||||
|
||||
# 保存数据到Excel
|
||||
def save_to_excel(data: List[Dict[str, Any]], filename: str):
|
||||
rows = []
|
||||
for item in data:
|
||||
input_data = item["input"]
|
||||
output_data = item["output"]
|
||||
|
||||
row = {
|
||||
"输入问题": input_data["messages"][0]["content"],
|
||||
"输出问题": output_data["result"],
|
||||
"地点": input_data["location"],
|
||||
"UUID": input_data["uuid"],
|
||||
"意图": input_data["intent"],
|
||||
"用户名": input_data["loginUserName"],
|
||||
"用户ID": input_data["loginUserId"],
|
||||
"专员类型": input_data["commissioner_type"],
|
||||
"状态码": output_data["code"],
|
||||
"消息": output_data["message"]
|
||||
}
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
df.to_excel(filename, index=False, engine='openpyxl')
|
||||
print(f"数据已保存到 {filename}")
|
||||
|
||||
# 并发生成数据
|
||||
async def generate_data(num_samples: int) -> List[Dict[str, Any]]:
|
||||
# 首先生成所有输入数据
|
||||
input_tasks = [generate_input_data() for _ in range(num_samples)]
|
||||
input_data_list = await asyncio.gather(*input_tasks)
|
||||
|
||||
# 然后并发处理所有请求
|
||||
process_tasks = [process_request(input_data) for input_data in input_data_list]
|
||||
results = await asyncio.gather(*process_tasks)
|
||||
|
||||
# 组合结果
|
||||
output = []
|
||||
for input_data, output_data in results:
|
||||
output.append({
|
||||
"input": input_data,
|
||||
"output": output_data
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
# 主函数
|
||||
async def main():
|
||||
try:
|
||||
|
||||
num_samples = 2000
|
||||
print(f"开始生成 {num_samples} 条招聘数据...")
|
||||
data_pairs = await generate_data(num_samples)
|
||||
|
||||
save_to_excel(data_pairs, Config.OUTPUT_FILE)
|
||||
|
||||
# 打印前3条样本
|
||||
print("\n样本示例:")
|
||||
for i, pair in enumerate(data_pairs[:3], 1):
|
||||
print(f"样本 {i}:")
|
||||
print("输入问题:", pair["input"]["messages"][0]["content"])
|
||||
print("输出问题:", pair["output"]["result"])
|
||||
print("-" * 50)
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Reference in New Issue
Block a user