Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
0a6da985dd | |||
0655f98c70 |
100
data_generate/zcs/fenbu/cluster_kmeans_train_from_torch_bin.py
Normal file
100
data_generate/zcs/fenbu/cluster_kmeans_train_from_torch_bin.py
Normal file
@ -0,0 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from sklearnex import patch_sklearn, unpatch_sklearn
|
||||
from matplotlib import pyplot as plt
|
||||
import time
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def train_kmeans_model(train_data, n_clusters, data_index_list):
|
||||
# 注意需要先做 patch_sklearn() 的操作之后再正常导入 sklearn 的工具包。
|
||||
patch_sklearn()
|
||||
# unpatch_sklearn()
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 这里 训练数据 要求 输入 是 np.ndarray 格式
|
||||
data_set = train_data
|
||||
print("train_data_len:" + str(len(data_set)))
|
||||
print("n_clusters:" + str(n_clusters))
|
||||
|
||||
# init 默认就是 k-means++
|
||||
model = KMeans(n_clusters=n_clusters, init='k-means++', max_iter=300, n_init="auto", random_state=0)
|
||||
|
||||
model.fit(data_set)
|
||||
# kmeans 模型本身就是包含了聚类中心
|
||||
center = model.cluster_centers_
|
||||
print(center)
|
||||
|
||||
cluster_label = model.predict(data_set)
|
||||
print("cluster_label_size:" + str(len(cluster_label)))
|
||||
print("type:" + str(type(cluster_label)))
|
||||
|
||||
train_dfs = pd.DataFrame(data_set)
|
||||
|
||||
train_dfs["predict_label"] = cluster_label
|
||||
train_dfs["data_index"] = data_index_list
|
||||
|
||||
print(train_dfs.columns.values)
|
||||
|
||||
end_time = time.time()
|
||||
avg_time_cost = (end_time - start_time) * 1.0
|
||||
print("train_kmeans_time:" + str(avg_time_cost) + " s")
|
||||
return train_dfs
|
||||
|
||||
|
||||
# step1: 读取保存的 embeding 二进制文件
|
||||
embed_file = "/data/zhaochsh01/buquan/12345/embedding/123451wfilter_embedding.pth"
|
||||
embed_dict_list = torch.load(embed_file)
|
||||
print("len_embed_dict_list:", len(embed_dict_list))
|
||||
|
||||
# step2: 数据 parse
|
||||
raw_embeding_list = []
|
||||
raw_index_list = []
|
||||
for line in enumerate(embed_dict_list):
|
||||
# line[0] 是 index, line[1] 是保存的 dict, 这里 读取进来 直接 就是 dict 对象
|
||||
cur_dict = line[1]
|
||||
cur_embeding = cur_dict["embedding"]
|
||||
cur_data_idx = cur_dict["data_index"]
|
||||
raw_embeding_list.append(cur_embeding)
|
||||
raw_index_list.append(cur_data_idx)
|
||||
|
||||
train_array = np.array(raw_embeding_list)
|
||||
print("train_array_shape:", train_array.shape)
|
||||
|
||||
# 总共聚类 1000类,先聚类 50大类,每个大类里选20小类
|
||||
# 好吧,直接干到1000个大类,要是效果不好,在继续选择
|
||||
num_cluster = 50
|
||||
# 这里会自动进行 pandas row index 对齐
|
||||
train_dfs = train_kmeans_model(train_array, num_cluster, raw_index_list)
|
||||
predict_label = train_dfs['predict_label'].tolist()
|
||||
|
||||
print("len_predict_label:", len(predict_label))
|
||||
|
||||
|
||||
# 这里不保存csv,因为 pandas 的 csv 数据保存在遇到特殊字符的时候有意想不到的异常!!
|
||||
#data_to_save = {'embeding': raw_embeding_list, 'cluster_center': predict_label, "data_idx": raw_index_list}
|
||||
data_to_save = [
|
||||
{
|
||||
"embedding": raw_embeding_list[i],
|
||||
"cluster_center": predict_label[i],
|
||||
"data_idx": raw_index_list[i]
|
||||
}
|
||||
for i in range(len(raw_embeding_list))
|
||||
]
|
||||
|
||||
output_file = "./12345/kmeans/123451wfilter_cluster_kmeans_result.jsonl"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for record in data_to_save:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
data_to_save_df = pd.DataFrame(data_to_save)
|
||||
data_to_save_df.to_pickle("./12345/kmeans/123451wfilter_cluster_kmeans_result.pkl")
|
123
data_generate/zcs/fenbu/select_final_data.py
Normal file
123
data_generate/zcs/fenbu/select_final_data.py
Normal file
@ -0,0 +1,123 @@
|
||||
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
def process_jsonl_file(file_path, top_n_per_group, result, seen_uids):
|
||||
"""
|
||||
读取 jsonl 文件,对每一行 group 数据,尽力从中选出 top_n_per_group 条不重复数据
|
||||
如果无法选够指定数量的数据,则跳过该 group 并打印警告信息。
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_idx, line in enumerate(f):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
|
||||
# 获取 group key 和对应的数据列表
|
||||
key = next(k for k in entry if k != "count")
|
||||
records = entry[key]
|
||||
|
||||
selected = []
|
||||
added_count = 0
|
||||
|
||||
# 遍历 records,直到选够 top_n_per_group 个不重复的数据
|
||||
for item in records:
|
||||
uid = item.get('工单编号')
|
||||
|
||||
if added_count >= top_n_per_group:
|
||||
break # 已经选够了,退出
|
||||
|
||||
if uid and uid not in seen_uids:
|
||||
seen_uids.add(uid)
|
||||
selected.append(item)
|
||||
result.append(item)
|
||||
added_count += 1
|
||||
|
||||
# 如果最终无法选满 top_n_per_group 条,可以给出警告或处理
|
||||
if added_count < top_n_per_group:
|
||||
print(f"[Group {key}] Only found {added_count}/{top_n_per_group} unique items. Skipping this group.")
|
||||
# 如果需要的话,可以从结果中移除已添加的记录
|
||||
for item in selected:
|
||||
result.remove(item)
|
||||
seen_uids.discard(item.get('uid'))
|
||||
else:
|
||||
print(f"[Group {key}] Successfully added {added_count} items.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing line {line_idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
def load_data_to_memory(input_file):
|
||||
"""
|
||||
读取jsonl文件,将数据存入result列表,工单编号存入seen_uids集合
|
||||
|
||||
:param input_file: 输入jsonl文件路径
|
||||
:return: tuple(result列表, seen_uids集合)
|
||||
"""
|
||||
result = []
|
||||
seen_uids = set()
|
||||
|
||||
try:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
# 解析JSON行数据
|
||||
data = json.loads(line.strip())
|
||||
|
||||
# 检查是否包含工单编号字段
|
||||
if '工单编号' in data:
|
||||
uid = data['工单编号']
|
||||
|
||||
# 添加到结果列表和已见集合
|
||||
result.append(data)
|
||||
seen_uids.add(uid)
|
||||
else:
|
||||
print(f"警告:跳过缺少工单编号的记录: {data}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:文件 {input_file} 不存在")
|
||||
return [], set()
|
||||
except Exception as e:
|
||||
print(f"读取文件时发生错误: {str(e)}")
|
||||
return [], set()
|
||||
|
||||
print(f"成功加载 {len(result)} 条数据,发现 {len(seen_uids)} 个唯一工单编号")
|
||||
return result, seen_uids
|
||||
|
||||
|
||||
def main():
|
||||
result = []
|
||||
seen_uids = set()
|
||||
result, seen_uids = load_data_to_memory("/data/zhaochsh01/buquan/12345/tool/output_prompt.jsonl")
|
||||
|
||||
|
||||
# Step 3: 处理 dhbq_count_instag.jsonl,每组前130条
|
||||
print("Step 3: Processing dhbq_count_instag.jsonl...")
|
||||
process_jsonl_file("/data/zhaochsh01/buquan/12345/count_result/12345_count_instag.jsonl", 3, result, seen_uids)
|
||||
|
||||
# Step 4: 处理 dhbq_count_cluster.jsonl,每组前4条
|
||||
print("Step 4: Processing dhbq_count_cluster.jsonl...")
|
||||
process_jsonl_file("/data/zhaochsh01/buquan/12345/count_result/12345_count_cluster.jsonl", 3, result, seen_uids)
|
||||
|
||||
# 最终写入输出文件
|
||||
output_file = "/data/zhaochsh01/buquan/12345/merge_result/merged_result_for_more_test.jsonl"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for item in result:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"Total merged items: {len(result)}")
|
||||
print(f"Saved to: {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -0,0 +1,233 @@
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
import random
|
||||
import argparse
|
||||
import json
|
||||
import pandas
|
||||
from typing import List
|
||||
|
||||
|
||||
def set_seed(seed=128):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
class MyTorchDataset(torch.utils.data.Dataset):
|
||||
# def __init__(self, input_file, tokenizer, max_seq_length):
|
||||
# self.tokenizer = tokenizer
|
||||
# self.max_seq_length = max_seq_length
|
||||
# logger.info('Loading data: {}'.format(input_file))
|
||||
# with open(input_file, 'r', encoding='utf8') as f:
|
||||
# data_list = f.readlines()
|
||||
# logger.info("There are {} data in dataset".format(len(data_list)))
|
||||
# # 测试,这里仅仅取 top 100
|
||||
# self.raw_data_list = data_list[:]
|
||||
#
|
||||
# # 这个数据在初始化的时候全部 tokenizer 好,如果在训练的时候每次选择一条,会拖慢整体获取 embeding 的速度
|
||||
# self.query_list = []
|
||||
# self.index_list = []
|
||||
# for line in self.raw_data_list:
|
||||
# json_line = json.loads(line)
|
||||
# print(json_line["data"][0])
|
||||
# # 这个仅仅得到 query 的 embeding, content[0] 是 query
|
||||
# query = json_line["data"][0]['content'].strip()
|
||||
# # 这个是原始数据里,单轮对话对应的 data_index
|
||||
# data_index = json_line["uid"]
|
||||
# self.query_list.append(query)
|
||||
# self.index_list.append(data_index)
|
||||
# assert len(self.query_list) == len(self.index_list)
|
||||
# logger.info(f"final len_query_list:{len(self.query_list)}")
|
||||
# logger.info(f"final len_index_list:{len(self.index_list)}")
|
||||
#
|
||||
# # 这里批量一次性整好所有的数据 tokenizer,这里 250w 数据大概需要2个小时,需要200g左右内存,100g 内存程序会崩溃
|
||||
# logger.info(f" 开始批量 tokenizer 所有数据 ... ")
|
||||
# self.all_data_token = self.tokenizer(self.query_list,
|
||||
# padding='max_length',
|
||||
# truncation=True,
|
||||
# max_length=self.max_seq_length,
|
||||
# return_tensors='pt')
|
||||
# logger.info(f" 批量 tokenizer 所有数据 完成 !!!")
|
||||
def __init__(self, input_file, tokenizer, max_seq_length):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_length = max_seq_length
|
||||
logger.info('Loading data: {}'.format(input_file))
|
||||
with open(input_file, 'r', encoding='utf8') as f:
|
||||
data_list = f.readlines()
|
||||
logger.info("There are {} data in dataset".format(len(data_list)))
|
||||
|
||||
# 测试:这里仅取 top 100 条数据
|
||||
self.raw_data_list = data_list #[:100]
|
||||
|
||||
# 提取对话内容
|
||||
self.query_list = []
|
||||
self.index_list = []
|
||||
for line in self.raw_data_list:
|
||||
json_line = json.loads(line)
|
||||
# 这个仅仅得到 query 的 embeding, content[0] 是 query
|
||||
query = json_line["主要内容"].strip()
|
||||
# 这个是原始数据里,单轮对话对应的 data_index
|
||||
data_index = json_line["工单编号"]
|
||||
data_index = str(data_index)
|
||||
self.query_list.append(query)
|
||||
self.index_list.append(data_index)
|
||||
|
||||
assert len(self.query_list) == len(self.index_list)
|
||||
logger.info(f"final len_query_list:{len(self.query_list)}")
|
||||
logger.info(f"final len_index_list:{len(self.index_list)}")
|
||||
|
||||
# 批量 Tokenize 所有数据
|
||||
logger.info(f"开始批量 tokenize 所有数据 ...")
|
||||
self.all_data_token = self.tokenizer(
|
||||
self.query_list,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=self.max_seq_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
logger.info(f"批量 tokenize 所有数据 完成 !!!")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# item 包含了 input_ids,attention_mask 等参数,已经转化为 tensor
|
||||
item = {key: torch.as_tensor(value[idx]) for key, value in self.all_data_token.items()}
|
||||
item['data_index'] = self.index_list[idx]
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.query_list)
|
||||
|
||||
|
||||
class TextEmbedder:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.device = torch.device("cuda:0")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path,
|
||||
model_max_length=args.max_length,
|
||||
# 这里是推理过程,指定 左侧padding,不然取推理结果有问题
|
||||
padding_side="left",
|
||||
use_fast=False)
|
||||
self.model = AutoModel.from_pretrained(self.args.model_name_or_path).to(self.device)
|
||||
logger.info('tokenizer,model 加载完成')
|
||||
self.all_dataset = MyTorchDataset(self.args.raw_data_path, self.tokenizer, self.args.max_length)
|
||||
self.data_loader = DataLoader(dataset=self.all_dataset, batch_size=self.args.batch_size,
|
||||
# pin_memory=True,num_workers=1, prefetch_factor=8
|
||||
)
|
||||
|
||||
def save_data(self, results: List) -> None:
|
||||
logger.info(f'需要保存的数据长度: {len(results)}')
|
||||
|
||||
df = pandas.DataFrame(results)
|
||||
df.sort_values(by="data_index", inplace=True)
|
||||
df.reset_index(drop=True, inplace=True)
|
||||
|
||||
df.to_pickle(self.args.output_path)
|
||||
logger.info(f"Saved pickle to {self.args.output_path}")
|
||||
|
||||
'''
|
||||
# 这个地方 .csv 会保存失败
|
||||
csv_file_name = "./rl_demo_emb_240624.csv"
|
||||
df.to_csv(csv_file_name, index=False)
|
||||
logger.info(f"Saved csv to {csv_file_name}")
|
||||
'''
|
||||
|
||||
# 指定保存的文件名
|
||||
torch_py_bin_name = './12345/embedding/123451wfilter_embedding.pth'
|
||||
# 请注意,torch.save()和torch.load()函数不仅可以用于保存和加载张量,还可以用于保存和加载任何Python对象,
|
||||
# 包括但不限于字典、集合、自定义类实例等。
|
||||
torch.save(results, torch_py_bin_name)
|
||||
logger.info(f'List has been saved to {torch_py_bin_name}')
|
||||
|
||||
def encode_samples(self):
|
||||
total_samples = len(self.all_dataset)
|
||||
logger.info(f"total_samples: {total_samples}")
|
||||
total_batches = len(self.data_loader)
|
||||
logger.info(f"total_batches: {total_batches}")
|
||||
|
||||
all_embeddings_list = []
|
||||
|
||||
for b_idx, batch in enumerate(tqdm(self.data_loader, total=total_batches)):
|
||||
self.model.eval()
|
||||
batch_idx = batch["data_index"]
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
attention_mask = batch['attention_mask'].to(self.device)
|
||||
output = None
|
||||
with torch.no_grad():
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
cls_embedding = output.pooler_output.detach().cpu().numpy().tolist()
|
||||
|
||||
# 记录 embedding 和 index
|
||||
cur_sample_idx = batch_idx#.tolist() # 如果 batch_idx 是张量
|
||||
cur_sample_dict_list = [
|
||||
{"embedding": cls_emb, "data_index": s_id}
|
||||
for cls_emb, s_id in zip(cls_embedding, cur_sample_idx)
|
||||
]
|
||||
all_embeddings_list.extend(cur_sample_dict_list)
|
||||
|
||||
return all_embeddings_list
|
||||
|
||||
# def encode_samples(self):
|
||||
# total_samples = len(self.all_dataset)
|
||||
# logger.info(f"total_samples:{total_samples}")
|
||||
# total_batches = len(self.data_loader)
|
||||
# logger.info(f"total_batches:{total_batches}")
|
||||
#
|
||||
# all_embeddings_list = []
|
||||
#
|
||||
# for b_idx, batch in enumerate(tqdm(self.data_loader, total=total_samples // self.args.batch_size)):
|
||||
# self.model.eval()
|
||||
#
|
||||
# batch_idx = batch["data_index"]
|
||||
# input_ids = batch['input_ids'].to(self.device)
|
||||
# attention_mask = batch['attention_mask'].to(self.device)
|
||||
# output = None
|
||||
# with torch.no_grad():
|
||||
# output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
#
|
||||
# cls_embedding = output.pooler_output.detach().cpu().numpy().tolist()
|
||||
# print(batch_idx)
|
||||
# # 这里批量记录下 embeding 和 index
|
||||
# cur_sample_idx = batch_idx.tolist()
|
||||
# cur_sample_dict_list = [{"embedding": cls_emb, "data_index": s_id} for cls_emb, s_id in
|
||||
# zip(cls_embedding, cur_sample_idx)]
|
||||
# # list extend 多个 list元素 flatten 合并成一个 list
|
||||
# all_embeddings_list.extend(cur_sample_dict_list)
|
||||
#
|
||||
# return all_embeddings_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.cuda.empty_cache()
|
||||
# 参数设置
|
||||
set_seed()
|
||||
# 若干参数设置
|
||||
parser = argparse.ArgumentParser()
|
||||
# 不要改该参数,系统会自动分配
|
||||
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||
# 开启的进程数(注意不是线程),不用设置该参数,会根据nproc_per_node自动设置
|
||||
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
||||
args = parser.parse_args()
|
||||
|
||||
args.max_length = 512
|
||||
args.batch_size = 2048
|
||||
args.model_name_or_path = "/model-pvc/suojiayi/bge-base-zh-v1.5/"
|
||||
#args.raw_data_path = "/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_COIG_filtered_2504212014.jsonl"
|
||||
args.raw_data_path = '/data/zhaochsh01/buquan/12345/tool/output.jsonl'
|
||||
args.output_path = "./12345/embedding/123451wffilter_embedding.bin"
|
||||
|
||||
logger.info(f"\nargs:{args}\n")
|
||||
|
||||
text_embeder = TextEmbedder(args)
|
||||
all_embeddings_list = text_embeder.encode_samples()
|
||||
|
||||
logger.info(f"len_all_embeddings_list:{len(all_embeddings_list)}")
|
||||
logger.info("Finished embedding")
|
||||
|
||||
text_embeder.save_data(all_embeddings_list)
|
||||
logger.info(f"Pipeline run complete.")
|
68
data_generate/zcs/fenbu/tool/count_readjsonl.py
Normal file
68
data_generate/zcs/fenbu/tool/count_readjsonl.py
Normal file
@ -0,0 +1,68 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
import openpyxl
|
||||
|
||||
def count_categories_from_jsonl(input_file, output_xlsx):
|
||||
"""
|
||||
读取jsonl文件,统计类别信息并输出到xlsx文件
|
||||
|
||||
:param input_file: 输入的jsonl文件路径
|
||||
:param output_xlsx: 输出的xlsx文件路径
|
||||
"""
|
||||
# 用于统计每个类别的数据量
|
||||
category_counts = defaultdict(int)
|
||||
total_categories = 0
|
||||
total_samples = 0
|
||||
|
||||
# 1. 读取jsonl文件并统计
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
answer = data.get('answer')
|
||||
if answer:
|
||||
category_counts[answer] += 1
|
||||
total_samples += 1
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
total_categories = len(category_counts)
|
||||
|
||||
# 2. 创建Excel工作簿和工作表
|
||||
wb = openpyxl.Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "类别统计"
|
||||
|
||||
# 3. 写入表头
|
||||
headers = ["类别名称", "数据量", "占比(%)"]
|
||||
ws.append(headers)
|
||||
|
||||
# 4. 写入统计数据
|
||||
for category, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):
|
||||
percentage = (count / total_samples) * 100
|
||||
ws.append([category, count, f"{percentage:.2f}%"])
|
||||
|
||||
# 5. 写入汇总信息
|
||||
ws.append([]) # 空行分隔
|
||||
ws.append(["总类别数", total_categories])
|
||||
ws.append(["总数据量", total_samples])
|
||||
|
||||
# 6. 设置单元格样式
|
||||
# 设置列宽
|
||||
ws.column_dimensions['A'].width = 40 # 类别名称列
|
||||
ws.column_dimensions['B'].width = 15 # 数据量列
|
||||
ws.column_dimensions['C'].width = 15 # 占比列
|
||||
|
||||
# 设置标题行样式
|
||||
for cell in ws[1]:
|
||||
cell.font = openpyxl.styles.Font(bold=True)
|
||||
|
||||
# 7. 保存Excel文件
|
||||
wb.save(output_xlsx)
|
||||
print(f"统计完成!共{total_categories}个类别,{total_samples}条数据。结果已保存到{output_xlsx}")
|
||||
|
||||
# 使用示例
|
||||
input_jsonl = "/data/zhaochsh01/buquan/12345/tool/output_prompt.jsonl" # 替换为你的jsonl文件路径
|
||||
output_excel = "12345category_stats.xlsx" # 输出Excel文件路径
|
||||
count_categories_from_jsonl(input_jsonl, output_excel)
|
18
data_generate/zcs/fenbu/tool/count_tag.py
Normal file
18
data_generate/zcs/fenbu/tool/count_tag.py
Normal file
@ -0,0 +1,18 @@
|
||||
import pandas as pd
|
||||
|
||||
# 示例数据,假设有一个名为'data.xlsx'的文件,其中包含'answe'列
|
||||
# 这里使用模拟数据来演示过程
|
||||
|
||||
# 模拟数据
|
||||
df = pd.read_excel('/data/zhaochsh01/buquan/12345/count_result/12345_prompt.xlsx')
|
||||
|
||||
|
||||
# 按照answe列的值统计每个类的数量
|
||||
category_counts = df['answer'].value_counts().reset_index()
|
||||
category_counts.columns = ['Category', 'Count']
|
||||
|
||||
# 输出到xlsx文件
|
||||
output_file = '/data/zhaochsh01/buquan/12345/count_result/category_counts.xlsx'
|
||||
category_counts.to_excel(output_file, index=False)
|
||||
|
||||
output_file
|
52
data_generate/zcs/fenbu/tool/format.py
Normal file
52
data_generate/zcs/fenbu/tool/format.py
Normal file
@ -0,0 +1,52 @@
|
||||
import json
|
||||
|
||||
def process_jsonl_file(input_file, output_file):
|
||||
"""
|
||||
处理包含两种格式的jsonl文件:
|
||||
1. 有"主要内容"字段的数据:保留指定字段并重命名
|
||||
2. 其他格式数据:原样保留
|
||||
|
||||
:param input_file: 输入文件路径
|
||||
:param output_file: 输出文件路径
|
||||
"""
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as infile, \
|
||||
open(output_file, 'w', encoding='utf-8') as outfile:
|
||||
|
||||
for line in infile:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
|
||||
# 检查是否有"主要内容"字段
|
||||
if '主要内容' in data:
|
||||
# 创建新数据对象,只保留指定字段
|
||||
new_data = {
|
||||
'工单编号': data.get('工单编号'),
|
||||
'data': data.get('主要内容'), # 重命名字段
|
||||
'ins_tag_label': data.get('ins_tag_label'),
|
||||
'cluster_center': data.get('cluster_center')
|
||||
}
|
||||
|
||||
# 写入处理后的数据
|
||||
outfile.write(json.dumps(new_data, ensure_ascii=False) + '\n')
|
||||
processed_count += 1
|
||||
else:
|
||||
# 没有"主要内容"字段的数据原样写入
|
||||
outfile.write(line)
|
||||
skipped_count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line.strip()}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
print(f"处理完成!共处理 {processed_count} 条有'主要内容'的数据")
|
||||
print(f"保留原样 {skipped_count} 条其他格式数据")
|
||||
print(f"结果已保存到 {output_file}")
|
||||
|
||||
# 使用示例
|
||||
input_file = "./merged_result_for_more_test.jsonl" # 替换为你的输入文件路径
|
||||
output_file = "./12345_merged_resul.jsonl" # 输出文件路径
|
||||
process_jsonl_file(input_file, output_file)
|
40
data_generate/zcs/fenbu/tool/merge.py
Normal file
40
data_generate/zcs/fenbu/tool/merge.py
Normal file
@ -0,0 +1,40 @@
|
||||
import json
|
||||
|
||||
# 1. 读取 K-Means 聚类结果文件(JSON 格式),构建 {data_idx: cluster_center} 的字典
|
||||
kmeans_data = {}
|
||||
with open('/data/zhaochsh01/buquan/12345/kmeans/123451wfilter_cluster_kmeans_result.jsonl', 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
data_idx = entry.get("data_idx") # 获取 data_idx
|
||||
cluster_center = entry.get("cluster_center") # 获取 cluster_center(整数)
|
||||
if data_idx and cluster_center is not None:
|
||||
kmeans_data[data_idx] = cluster_center
|
||||
except Exception as e:
|
||||
print(f"Error parsing K-Means line: {line}, error: {e}")
|
||||
|
||||
# 2. 读取 JSONL 文件,匹配并合并 cluster_center
|
||||
output_lines = []
|
||||
with open('/data/zhaochsh01/buquan/12345/instag/123451wfilter_instag.jsonl', 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
key_id = str(data.get("工单编号")) # 转为字符串确保匹配
|
||||
if key_id in kmeans_data:
|
||||
data["cluster_center"] = kmeans_data[key_id] # 添加 cluster_center
|
||||
output_lines.append(json.dumps(data, ensure_ascii=False)) # 重新转为 JSON 字符串
|
||||
except Exception as e:
|
||||
print(f"Error processing JSONL line: {line}, error: {e}")
|
||||
|
||||
# 3. 将结果写入新文件
|
||||
with open('merged_result.jsonl', 'w', encoding='utf-8') as f:
|
||||
for line in output_lines:
|
||||
f.write(line + '\n')
|
||||
|
||||
print("数据处理完成,结果已保存到 merged_result.jsonl")
|
67
data_generate/zcs/fenbu/tool/prompt_change.py
Normal file
67
data_generate/zcs/fenbu/tool/prompt_change.py
Normal file
@ -0,0 +1,67 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
def select_and_transform_data(input_file, output_file, samples_per_category=4, total_samples=250):
|
||||
"""
|
||||
读取jsonl文件,按answer分类,每类选指定条数数据,转换字段并保存
|
||||
|
||||
:param input_file: 输入jsonl文件路径
|
||||
:param output_file: 输出jsonl文件路径
|
||||
:param samples_per_category: 每类选取的样本数
|
||||
:param total_samples: 总共需要选取的样本数
|
||||
"""
|
||||
# 按answer分类存储数据
|
||||
categorized_data = defaultdict(list)
|
||||
|
||||
# 1. 读取并分类数据
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
answer = data.get('answer')
|
||||
if answer:
|
||||
categorized_data[answer].append(data)
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
# 2. 从每类中随机选取指定数量的样本
|
||||
selected_data = []
|
||||
selected_count = 0
|
||||
|
||||
for category, items in categorized_data.items():
|
||||
# 计算当前类别最多能取多少样本(避免超过总数限制)
|
||||
remaining = total_samples - selected_count
|
||||
take = min(samples_per_category, len(items), remaining)
|
||||
|
||||
if take <= 0:
|
||||
break # 已达到总数要求
|
||||
|
||||
# 随机选取(这里简单取前take条,如需随机可改为random.sample)
|
||||
selected = items[:take]
|
||||
|
||||
# 3. 转换字段并添加到结果
|
||||
for item in selected:
|
||||
transformed = {
|
||||
"工单编号": item["uid"],
|
||||
"data": item["data"],
|
||||
"answer": item["answer"]
|
||||
}
|
||||
selected_data.append(transformed)
|
||||
selected_count += 1
|
||||
|
||||
if selected_count >= total_samples:
|
||||
break
|
||||
|
||||
# 4. 保存结果到jsonl文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for item in selected_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"处理完成!共选取{selected_count}条数据,保存到{output_file}")
|
||||
print(f"分类统计:{ {k: len([d for d in selected_data if d['answer'] == k]) for k in categorized_data} }")
|
||||
|
||||
# 使用示例
|
||||
input_file = "/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl" # 替换为你的输入文件路径
|
||||
output_file = "output_prompt.jsonl" # 输出文件路径
|
||||
select_and_transform_data(input_file, output_file, samples_per_category=4, total_samples=250)
|
85
data_generate/zcs/fenbu/tool/prompt_choose250.py
Normal file
85
data_generate/zcs/fenbu/tool/prompt_choose250.py
Normal file
@ -0,0 +1,85 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
def select_and_transform_data(input_file, output_file, small_category_max=2, normal_samples=4, total_samples=250):
|
||||
"""
|
||||
读取jsonl文件,按answer分类,小类别(<=2)全选,其他类别每类选指定条数数据,转换字段并保存
|
||||
|
||||
:param input_file: 输入jsonl文件路径
|
||||
:param output_file: 输出jsonl文件路径
|
||||
:param small_category_max: 小类别的最大数据量阈值(<=此值视为小类别)
|
||||
:param normal_samples: 普通类别每类选取的样本数
|
||||
:param total_samples: 总共需要选取的样本数
|
||||
"""
|
||||
# 按answer分类存储数据
|
||||
categorized_data = defaultdict(list)
|
||||
|
||||
# 1. 读取并分类数据
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
answer = data.get('answer')
|
||||
if answer:
|
||||
categorized_data[answer].append(data)
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告:跳过无法解析的行: {line}")
|
||||
continue
|
||||
|
||||
# 分离小类别和普通类别
|
||||
small_categories = {k: v for k, v in categorized_data.items() if len(v) <= small_category_max}
|
||||
normal_categories = {k: v for k, v in categorized_data.items() if len(v) > small_category_max}
|
||||
|
||||
# 2. 先选取小类别的所有数据
|
||||
selected_data = []
|
||||
selected_count = 0
|
||||
|
||||
for category, items in small_categories.items():
|
||||
# 全取小类别的数据
|
||||
for item in items:
|
||||
transformed = {
|
||||
"工单编号": item["uid"],
|
||||
"data": item["data"],
|
||||
"answer": item["answer"]
|
||||
}
|
||||
selected_data.append(transformed)
|
||||
selected_count += 1
|
||||
|
||||
# 3. 从普通类别中选取数据,直到达到总数
|
||||
for category, items in normal_categories.items():
|
||||
if selected_count >= total_samples:
|
||||
break
|
||||
|
||||
# 计算还能取多少数据
|
||||
remaining = total_samples - selected_count
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
# 取min(普通类别样本数, 当前类别数据量, 剩余需要的数据量)
|
||||
take = min(normal_samples, len(items), remaining)
|
||||
|
||||
# 选取数据(这里简单取前take条,如需随机可改为random.sample)
|
||||
for item in items[:take]:
|
||||
transformed = {
|
||||
"工单编号": item["uid"],
|
||||
"data": item["data"],
|
||||
"answer": item["answer"]
|
||||
}
|
||||
selected_data.append(transformed)
|
||||
selected_count += 1
|
||||
|
||||
# 4. 保存结果到jsonl文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for item in selected_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"处理完成!共选取{selected_count}条数据,保存到{output_file}")
|
||||
print("分类统计:")
|
||||
stat = {k: len([d for d in selected_data if d['answer'] == k]) for k in categorized_data}
|
||||
for k, v in stat.items():
|
||||
print(f"{k}: {v}条 (共{categorized_data[k]}条)")
|
||||
|
||||
# 使用示例
|
||||
input_file = "/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl" # 替换为你的输入文件路径
|
||||
output_file = "output_prompt.jsonl" # 输出文件路径
|
||||
select_and_transform_data(input_file, output_file, small_category_max=2, normal_samples=4, total_samples=250)
|
144
data_generate/zcs/fenbu/tool/prompt_label2.py
Normal file
144
data_generate/zcs/fenbu/tool/prompt_label2.py
Normal file
@ -0,0 +1,144 @@
|
||||
import json
|
||||
import requests
|
||||
import pandas as pd
|
||||
import re
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
def read_jsonl_lines_in_batches(file_path, batch_size=10000):
|
||||
"""按批次读取 JSONL 文件"""
|
||||
batch = []
|
||||
with open(file_path, mode="r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
batch.append(json.loads(line.strip()))
|
||||
if len(batch) == batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}")
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def process_data_concurrently(data_list, api_url, headers, max_workers=10):
|
||||
"""并发处理数据并调用 API"""
|
||||
result_data = []
|
||||
list_category = [
|
||||
"农业生产", "农村工作", "农村生活", "农民生活", "巩固脱贫攻坚成果", "精准扶贫", "人才交流",
|
||||
"企业工资福利", "企业离休退", "企业离休退休", "劳动保护", "劳动关系", "劳动纠纷",
|
||||
"就业创业", "招录辞退", "社会保险政策", "职务职称", "行政效能", "信息查询", "历史遗留",
|
||||
"表扬感谢", "人防工作", "住房与房地产", "园林绿化", "国有土地上房屋征收补偿", "城乡规划",
|
||||
"工程质量", "建筑市场管理", "道路修建", "交通运输", "城市管理", "城管执法", "居民生活",
|
||||
"环境保护", "旅游信息", "人口计生", "体育工作", "医疗保障", "卫生", "执法监督", "救灾救济",
|
||||
"教育", "文化工作", "文明建设", "残疾人服务管理", "民政", "民族宗教", "法律服务", "社会治安",
|
||||
"退役军人", "信息通信产业", "商业贸易", "国资监管", "安全生产", "市场秩序", "招商引资",
|
||||
"旅游管理", "综合审批", "财税金融", "质量监管", "自然资产:海洋渔业", "国土资源", "林业",
|
||||
"水利", "自然资源:海洋渔业"
|
||||
]
|
||||
def process_single_data(data):
|
||||
try:
|
||||
query = data.get('主要内容')
|
||||
if query:
|
||||
input_content = f'''
|
||||
您是一位文本分类专家,请依据专业的眼光判断下当前 query 属于以下65个领域类别中的哪一个并使用简体中文给出分析过程。
|
||||
领域类别:{list_category}
|
||||
需要判断的内容:{query}
|
||||
**注意:严格仅输出领域类别中的其中一个类别。
|
||||
|
||||
"输出格式如下:结果:<插入返回的结果>\n分析过程:"
|
||||
'''
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"model": "Qwen2.5-72B-Instruct",
|
||||
"stream": False,
|
||||
"temperature": 0.01,
|
||||
"messages": [{"role": "user", "content": input_content}]
|
||||
}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
answer = ""
|
||||
final_result = ""
|
||||
content = response.json()["choices"][0]["message"]["content"]
|
||||
print(f"返回结果为:{content}")
|
||||
# 提取 `返回结果为:结果:XXX` 的最后一个词
|
||||
match = re.search(r'返回结果为:结果[::]\s*(\S+)', content)
|
||||
if match:
|
||||
final_result = match.group(1).strip("。") # 去除可能结尾的标点
|
||||
print(f"截取的结果为:{final_result}") # 输出:交通运输
|
||||
else:
|
||||
print("未找到有效分类")
|
||||
# answer = json.loads(content)
|
||||
# answer = answer.get("返回结果")
|
||||
# print(f"解析的结果为:{answer}")
|
||||
except (KeyError, IndexError, json.JSONDecodeError):
|
||||
content = "无法解析返回内容"
|
||||
else:
|
||||
content = f"API请求失败,状态码:{response.status_code}"
|
||||
return {
|
||||
"uid": data.get('工单编号'),
|
||||
"data": query,
|
||||
"answer": final_result
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_single_data, data) for data in data_list]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
if result:
|
||||
result_data.append(result)
|
||||
logger.info(f"已完成 {data.get('工单编号')} ({i}/{len(data_list)})")
|
||||
|
||||
return result_data
|
||||
|
||||
|
||||
def save_to_excel_in_batches(data_list, output_file, batch_size=10000):
|
||||
"""按批次保存数据到 Excel 文件"""
|
||||
df = pd.DataFrame(data_list)
|
||||
writer = pd.ExcelWriter(output_file, engine='openpyxl')
|
||||
for i in range(0, len(df), batch_size):
|
||||
batch_df = df.iloc[i:i + batch_size]
|
||||
batch_df.to_excel(writer, index=False, startrow=i)
|
||||
writer.close()
|
||||
print(f"数据已成功保存到 {output_file}")
|
||||
|
||||
def save_to_jsonl_in_batches(data_list, output_file, batch_size=10000):
|
||||
"""按批次保存数据到 JSONL 文件"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
# 获取当前批次的数据
|
||||
batch_data = data_list[i:i + batch_size]
|
||||
# 将每个数据对象写入文件,每行一个 JSON 对象
|
||||
for item in batch_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
print(f"数据已成功保存到 {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
#output_excel_file = 'result-taoli-5.xlsx'
|
||||
# api_url = "http://100.105.149.39:8000/v1/chat/completions"
|
||||
api_url = "http://100.105.246.130:8000/v1/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c"
|
||||
}
|
||||
|
||||
#file_path = '/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_BELLE_Multiturn_Chat_filtered_2504232014.jsonl'
|
||||
file_path = '/data/zhaochsh01/buquan/12345/tool/output.jsonl'
|
||||
output_file = '/data/zhaochsh01/buquan/12345/count_result/12345_prompt.jsonl'
|
||||
output_excel_file = '/data/zhaochsh01/buquan/12345/count_result/12345_prompt.xlsx'
|
||||
|
||||
#file_path = '/dataset-pvc/suojiayi/new/train_prepare/20250423_020157/tmp_data/instruct_data_COIG_filtered_2504212014.jsonl'
|
||||
all_results = []
|
||||
for batch in read_jsonl_lines_in_batches(file_path, batch_size=10000):
|
||||
processed_batch = process_data_concurrently(batch, api_url, headers, max_workers=20)
|
||||
all_results.extend(processed_batch)
|
||||
save_to_excel_in_batches(all_results, output_excel_file, batch_size=10000)
|
||||
save_to_jsonl_in_batches(all_results, output_file, batch_size=10000)
|
67
data_generate/zcs/fenbu/tool/xslx2jsonl.py
Normal file
67
data_generate/zcs/fenbu/tool/xslx2jsonl.py
Normal file
@ -0,0 +1,67 @@
|
||||
import pandas as pd
|
||||
import json
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
def convert_value(obj):
|
||||
"""处理各种不可JSON序列化的类型"""
|
||||
# 处理空值
|
||||
if pd.isna(obj) or obj is None:
|
||||
return None
|
||||
|
||||
# 处理时间类型
|
||||
if isinstance(obj, (pd.Timestamp, datetime)):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 处理NaT类型
|
||||
if isinstance(obj, pd._libs.tslibs.nattype.NaTType):
|
||||
return None
|
||||
|
||||
# 处理numpy数值类型
|
||||
if isinstance(obj, (np.integer, np.floating)):
|
||||
return int(obj) if isinstance(obj, np.integer) else float(obj)
|
||||
|
||||
# 处理numpy数组和pandas Series
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
if isinstance(obj, pd.Series):
|
||||
return obj.to_dict()
|
||||
|
||||
# 其他类型直接返回
|
||||
return obj
|
||||
|
||||
def xlsx_to_jsonl(input_file, output_file):
|
||||
"""
|
||||
将XLSX文件转换为JSONL格式
|
||||
|
||||
参数:
|
||||
input_file (str): 输入的XLSX文件路径
|
||||
output_file (str): 输出的JSONL文件路径
|
||||
"""
|
||||
try:
|
||||
# 读取Excel文件
|
||||
df = pd.read_excel(input_file)
|
||||
|
||||
# 将数据写入JSONL文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for _, row in df.iterrows():
|
||||
# 将行转换为字典并处理所有值
|
||||
record = {k: convert_value(v) for k, v in row.items()}
|
||||
|
||||
# 写入JSON行
|
||||
json.dump(record, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
print(f"转换成功,结果已保存到 {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"转换过程中发生错误: {str(e)}")
|
||||
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
input_xlsx = "/data/zhaochsh01/buquan/12345/1w_fillter.xlsx" # 替换为你的输入文件路径
|
||||
output_jsonl = "output.jsonl" # 替换为你想要的输出文件路径
|
||||
|
||||
xlsx_to_jsonl(input_xlsx, output_jsonl)
|
109
data_generate/zcs/fenbu/vllm_infer_qw_instag_format_json.py
Normal file
109
data_generate/zcs/fenbu/vllm_infer_qw_instag_format_json.py
Normal file
@ -0,0 +1,109 @@
|
||||
import json
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from tqdm import tqdm, trange
|
||||
from copy import deepcopy
|
||||
|
||||
# need vllm==0.4.2
|
||||
# firefly_copy 使用
|
||||
|
||||
# step1 读入待跑 tag 的 jsonl 文件
|
||||
input_jsonl_file = "/data/zhaochsh01/buquan/12345/tool/output.jsonl"
|
||||
with open(input_jsonl_file, 'r', encoding='utf8') as f:
|
||||
raw_data_list = f.readlines()
|
||||
print("len_raw_data_list:", len(raw_data_list))
|
||||
|
||||
# 这里测试,先跑前100个数据
|
||||
# raw_data_list = raw_data_list[:100]
|
||||
|
||||
# step2 读取 instag 模型 配置 tokenizer
|
||||
model_path = '/model-pvc/instagger_qwen1_8B/'
|
||||
tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
||||
tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{ '### ' + message['role'].capitalize() + ':\\n' }}"
|
||||
"{{ message['content'] | string + '\\n' }}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '### Assistant:\\n' }}"
|
||||
"{% endif %}"
|
||||
|
||||
)
|
||||
# step3 这里内存够用,一次性读取所有文件进行拼接 prompt 处理
|
||||
prompt_text_list = []
|
||||
for index, line in enumerate(raw_data_list):
|
||||
# print("process:", index)
|
||||
json_obj = json.loads(line)
|
||||
query = json_obj["主要内容"]
|
||||
prompt = query
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an intelligent assistant, please identify the category based on the user’s query and return the corresponding category."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
prompt_text_list.append(text)
|
||||
|
||||
print("len_prompt_text_list:", len(prompt_text_list))
|
||||
assert len(raw_data_list) == len(prompt_text_list)
|
||||
|
||||
|
||||
def extract_raw_qwen_res(response):
|
||||
if "<|im_end|>" in response:
|
||||
response = response.split("<|im_end|>")[0]
|
||||
return response
|
||||
|
||||
|
||||
# step4: vllm 推理参数设置,官方推荐这里要求 temperature=0
|
||||
sampling_params = SamplingParams(max_tokens=64,
|
||||
temperature=0,
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.05)
|
||||
|
||||
# step5: 读入模型 这里读入模型往后放,数据有问题在load 模型之前暴露出来节省时间
|
||||
# set gpu_memory_utilization=0.95 for OOM error
|
||||
vllm_model = LLM(model=model_path,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.85,
|
||||
max_num_seqs=64, # 减少最大并发序列数(默认256)
|
||||
max_model_len=2048, # 减少最大上下文长度
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True)
|
||||
vllm_model.set_tokenizer(tokenizer)
|
||||
|
||||
# step6: 分批量推理并且得到结果的过程
|
||||
tag_result_list = []
|
||||
batch_size = 1024
|
||||
for start in trange(0, len(prompt_text_list), batch_size):
|
||||
print("index:", start)
|
||||
batch_list = prompt_text_list[start: start + batch_size]
|
||||
vllm_outputs = vllm_model.generate(batch_list, sampling_params)
|
||||
for output in vllm_outputs:
|
||||
prompt = output.prompt
|
||||
cur_tag_result = extract_raw_qwen_res(output.outputs[0].text)
|
||||
tag_result_list.append(cur_tag_result)
|
||||
|
||||
print("len_tag_result_list:", len(tag_result_list))
|
||||
|
||||
assert len(tag_result_list) == len(prompt_text_list)
|
||||
|
||||
# step7: 写入结果
|
||||
print("开始写 instag 结果 数据")
|
||||
single_turn_instag_path = "./12345/instag/123451wfilter_instag.jsonl"
|
||||
print("single_turn_instag_path:", single_turn_instag_path)
|
||||
single_f_sample_save = open(single_turn_instag_path, "w")
|
||||
write_single_sample_count = 0
|
||||
for line, tag in zip(raw_data_list, tag_result_list):
|
||||
json_line = json.loads(line)
|
||||
json_line["ins_tag_label"] = tag
|
||||
cur_final_line = json.dumps(deepcopy(json_line), ensure_ascii=False)
|
||||
write_single_sample_count = write_single_sample_count + 1
|
||||
single_f_sample_save.write(cur_final_line + "\n")
|
||||
if write_single_sample_count % 100000 == 0:
|
||||
print("write_single_count:", write_single_sample_count)
|
||||
|
||||
print("all finish success !")
|
164
data_generate/zcs/zaoshu/baogao_content_extract_zaoshu.py
Normal file
164
data_generate/zcs/zaoshu/baogao_content_extract_zaoshu.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""
|
||||
政务12345全国数据生成系统
|
||||
功能:
|
||||
1. 支持全国范围地理位置生成
|
||||
2. 多层级分类扩展
|
||||
3. 数据保存至Excel
|
||||
4. 真实业务场景模拟
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import random
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
import requests
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
class NationalDataGenerator:
|
||||
def __init__(self, excel_path: str, category_column: str):
|
||||
self.base_categories = self._load_excel_categories(excel_path, category_column)
|
||||
self.location_pool = self._generate_national_locations()
|
||||
self.expanded_categories = self._expand_categories_with_gpt()
|
||||
self.used_records = set()
|
||||
|
||||
|
||||
|
||||
def _chat(self, content: str) -> str:
|
||||
"""调用Qwen模型的统一接口"""
|
||||
payload = json.dumps({
|
||||
"model": "Qwen2.5-72B-Instruct",
|
||||
"stream": False,
|
||||
"temperature": 0.01,
|
||||
"top_p": 0.1,
|
||||
"repetition_penalty": 1.05,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
})
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"cache-control": "no-cache"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post("http://100.105.214.176:8000/v1/chat/completions", headers=headers, data=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
print(f"API调用失败: {str(e)}")
|
||||
return ""
|
||||
|
||||
def _load_excel_categories(self, path: str, column: str) -> List[str]:
|
||||
"""从Excel读取基础分类"""
|
||||
df = pd.read_excel(path)
|
||||
return df[column].dropna().unique().tolist()
|
||||
|
||||
def _generate_national_locations(self, num=200) -> List[str]:
|
||||
"""生成全国真实地理位置库"""
|
||||
prompt = f"生成{num}个中国各城市真实存在的地理位置,按省市区三级格式,示例:\n- 广东省广州市天河区珠江新城\n- 浙江省杭州市余杭区未来科技城"
|
||||
response = self._chat(prompt)
|
||||
print("生成的地理位置库为")
|
||||
print(response)
|
||||
print(type(response))
|
||||
locations = [
|
||||
parts[1] # 取第二部分(地址)
|
||||
for line in response.strip().split("\n")
|
||||
if line and (parts := line.split(maxsplit=1)) and len(parts) >= 2
|
||||
]
|
||||
print(locations)
|
||||
return locations
|
||||
|
||||
def _expand_categories_with_gpt(self) -> Dict[str, List[str]]:
|
||||
"""Qwen扩展分类体系"""
|
||||
category_map = {}
|
||||
for base_cat in self.base_categories:
|
||||
prompt = f"生成与【{base_cat}】相关但具有政务场景区分度的5个细分类型,示例:\n- 类型1:施工许可违规\n- 类型2:夜间施工超时"
|
||||
response = self._chat(prompt)
|
||||
print("扩展类型为")
|
||||
print(response)
|
||||
print(type(response))
|
||||
sub_cats = [
|
||||
re.sub(r"^.*类型\d+:|\s*$", "", line) # 移除 "类型X:" 和首尾空格
|
||||
for line in response.strip().split("\n")
|
||||
if "类型" in line and ":" in line # 只处理包含 "类型" 和 ":" 的行
|
||||
]
|
||||
category_map[base_cat] = sub_cats
|
||||
time.sleep(1)
|
||||
return category_map
|
||||
|
||||
def generate_dataset(self, num_records: int) -> pd.DataFrame:
|
||||
"""生成核心数据集"""
|
||||
data = []
|
||||
while len(data) < num_records:
|
||||
base_cat = random.choice(self.base_categories)
|
||||
sub_cat = random.choice(self.expanded_categories[base_cat])
|
||||
location = random.choice(self.location_pool)
|
||||
|
||||
content, keywords = self._generate_content(base_cat, sub_cat, location)
|
||||
if content and self._validate_record(content, keywords, base_cat):
|
||||
data.append({
|
||||
"ID": len(data)+1,
|
||||
"内容": content,
|
||||
"关键词": " ".join(keywords),
|
||||
"参考答案": base_cat,
|
||||
"细分类型": sub_cat,
|
||||
"地理位置": location
|
||||
})
|
||||
time.sleep(1.2)
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def _generate_content(self, base_cat: str, sub_cat: str, location: str) -> Tuple[str, List[str]]:
|
||||
"""生成政务工单内容"""
|
||||
prompt = f"""生成真实可信的12345政务工单,要求:
|
||||
1. 主分类:【{base_cat}】
|
||||
2. 细分类型:【{sub_cat}】
|
||||
3. 发生地点:【{location}】
|
||||
4. 包含要素:时间、具体问题、影响范围、市民诉求
|
||||
5. 生成5个关键词(必须包含{base_cat})
|
||||
6. 内容长度80-150字
|
||||
|
||||
示例格式:
|
||||
市民反映{location}某建筑工地违规夜间施工至凌晨,噪音严重干扰周边居民。已向环保部门投诉3次未解决,要求立即停工整顿。
|
||||
关键词:夜间施工 噪音污染 环保投诉 施工许可 居民维权"""
|
||||
|
||||
try:
|
||||
response = self._chat(prompt)
|
||||
raw_text = response.strip()
|
||||
return self._parse_generated_text(raw_text)
|
||||
except Exception as e:
|
||||
print(f"生成失败:{str(e)}")
|
||||
return None, []
|
||||
|
||||
def _parse_generated_text(self, text: str) -> Tuple[str, List[str]]:
|
||||
"""解析生成文本"""
|
||||
content = re.sub(r"关键词:.*", "", text).strip()
|
||||
keywords = re.findall(r"关键词:(.+)", text)[0].split()[:5]
|
||||
return content, keywords
|
||||
|
||||
def _validate_record(self, content: str, keywords: List[str], category: str) -> bool:
|
||||
"""五重数据校验"""
|
||||
return (
|
||||
len(content) >= 80 and
|
||||
len(keywords) == 5 and
|
||||
category in keywords and
|
||||
content not in self.used_records and
|
||||
any(c.isdigit() for c in content) # 包含数字要素
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化生成器
|
||||
generator = NationalDataGenerator(
|
||||
excel_path="/data/zhaochsh01/buquan/12345/zaoshu/12345政务服务大模型测试集.xlsx",
|
||||
category_column="answer"
|
||||
)
|
||||
|
||||
# 生成100条数据
|
||||
df = generator.generate_dataset(100)
|
||||
|
||||
# 保存到Excel
|
||||
with pd.ExcelWriter("./output/government_12345_data.xlsx") as writer:
|
||||
df.to_excel(writer, index=False)
|
||||
|
||||
print("生成数据示例:")
|
||||
print(df[["ID", "内容", "关键词", "参考答案"]].head(3).to_string(index=False))
|
640
data_generate/zcs/zaoshu/duihuazaoshu_piliang4.py
Normal file
640
data_generate/zcs/zaoshu/duihuazaoshu_piliang4.py
Normal file
@ -0,0 +1,640 @@
|
||||
import requests
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Font, Alignment
|
||||
import os
|
||||
from faker import Faker
|
||||
import json
|
||||
import random
|
||||
from typing import List, Dict, Tuple
|
||||
import pandas as pd
|
||||
from collections import defaultdict
|
||||
import concurrent.futures
|
||||
from functools import partial
|
||||
|
||||
def read_categories_config(file_path):
|
||||
try:
|
||||
# 读取Excel文件(假设前两列是二级和三级分类)
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
# 检查至少有两列数据
|
||||
if len(df.columns) < 2:
|
||||
raise ValueError("Excel文件必须至少包含两列:二级分类和三级分类")
|
||||
|
||||
categories_config = defaultdict(list)
|
||||
|
||||
# 遍历每一行数据
|
||||
for _, row in df.iterrows():
|
||||
level2 = str(row.iloc[0]).strip() # 二级分类(第一列)
|
||||
level3 = str(row.iloc[1]).strip() # 三级分类(第二列)
|
||||
|
||||
# 跳过空行
|
||||
if not level2 or not level3:
|
||||
continue
|
||||
|
||||
# 确保三级分类不重复
|
||||
if level3 not in categories_config[level2]:
|
||||
categories_config[level2].append(level3)
|
||||
|
||||
return dict(categories_config)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:文件 {file_path} 不存在", flush=True)
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"处理文件时出错: {str(e)}", flush=True)
|
||||
return {}
|
||||
|
||||
def chat(content: str, models_url):
|
||||
|
||||
payload = json.dumps(
|
||||
{
|
||||
"model": "Qwen2.5-72B-Instruct",
|
||||
"stream": False,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"repetition_penalty": 1.05,
|
||||
"messages": [{"role": "user", "content": f"{content}"}],
|
||||
}
|
||||
)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"cache-control": "no-cache",
|
||||
"Postman-Token": "4c70efd4-6448-4318-b2a9-e404f0181b80",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.request("POST", models_url, data=payload, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
content = response_data["choices"][0]["message"]["content"]
|
||||
else:
|
||||
logger.info(f"response is: {response.json()}")
|
||||
logger.info(f"Request failed with status code: {response.status_code}")
|
||||
logger.info(f"Response content: {response.content}")
|
||||
content = None
|
||||
except Exception as e:
|
||||
logger.error(f"resquest_exception: {e}", exc_info=True)
|
||||
return content
|
||||
|
||||
class FullyDynamicGenerator:
|
||||
def __init__(self):
|
||||
self.model_url = "http://100.105.61.165:8000/v1/chat/completions"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c"
|
||||
}
|
||||
self.model_name = "Qwen2.5-72B-Instruct"
|
||||
self.faker = Faker('zh_CN')
|
||||
self.dynamic_memory = {}
|
||||
self.special_cases = [
|
||||
"方言沟通", "老年人口齿不清", "情绪激动打断对话",
|
||||
"背景噪音干扰", "信号断续"
|
||||
]
|
||||
# 添加锁用于线程安全的Excel写入
|
||||
self._export_lock = threading.Lock()
|
||||
|
||||
def generate_dialog(self, category: str, subcategory: str, export_path: str = None) -> List[Dict]:
|
||||
"""全动态对话生成入口"""
|
||||
scene_knowledge = self.generate_scene_knowledge(category, subcategory)
|
||||
self.dynamic_memory[f"{category}_{subcategory}"] = scene_knowledge
|
||||
dialog = []
|
||||
dialog.extend(self.generate_complex_opening(category, subcategory))
|
||||
dialog.extend(self.generate_obstacle_base_phase(scene_knowledge, subcategory))
|
||||
dialog.extend(self.generate_verification_with_challenges(dialog))
|
||||
dialog.extend(self.generate_technical_extend_phase(scene_knowledge, subcategory))
|
||||
dialog.extend(self.generate_final_confirmation(scene_knowledge, subcategory))
|
||||
|
||||
formatted_dialog = self.format_output(dialog)
|
||||
|
||||
if export_path:
|
||||
with self._export_lock: # 使用锁保证线程安全
|
||||
self.export_to_excel(formatted_dialog, export_path, category, subcategory)
|
||||
|
||||
return formatted_dialog
|
||||
|
||||
def _generate_single_dialog(self, category, subcategory, export_path, num_per_subcategory, i, total_tasks, current_task_counter):
|
||||
"""生成单个对话的辅助函数,用于并发执行"""
|
||||
with current_task_counter.get_lock():
|
||||
current_task = current_task_counter.value + 1
|
||||
current_task_counter.value = current_task
|
||||
|
||||
print(f"\n进度: {current_task}/{total_tasks} "
|
||||
f"({(current_task/total_tasks)*100:.1f}%) - "
|
||||
f"分类: {category} - "
|
||||
f"子分类: {subcategory} - "
|
||||
f"第 {i+1}/{num_per_subcategory} 条", flush=True)
|
||||
|
||||
dialog = self.generate_dialog(
|
||||
category=category,
|
||||
subcategory=subcategory,
|
||||
export_path=export_path
|
||||
)
|
||||
return {
|
||||
"category": category,
|
||||
"subcategory": subcategory,
|
||||
"dialog": dialog
|
||||
}
|
||||
|
||||
def generate_dialogs_in_batch(self, categories: Dict[str, List[str]], num_per_subcategory: int, export_path: str):
|
||||
"""
|
||||
批量生成对话数据
|
||||
:param categories: 字典格式 {分类: [子分类1, 子分类2,...]}
|
||||
:param num_per_subcategory: 每个子分类生成的数量
|
||||
:param export_path: 输出文件路径
|
||||
"""
|
||||
all_dialogs = []
|
||||
|
||||
# 计算总任务量
|
||||
total_subcategories = sum(len(subcats) for subcats in categories.values())
|
||||
total_tasks = total_subcategories * num_per_subcategory
|
||||
print(f"\n总共需要生成 {total_subcategories} 个子分类的数据,每个子分类 {num_per_subcategory} 条,共计 {total_tasks} 条对话记录", flush=True)
|
||||
|
||||
# 使用ThreadPoolExecutor创建10个worker
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
# 创建共享计数器
|
||||
current_task_counter = multiprocessing.Value('i', 0)
|
||||
|
||||
# 准备任务列表
|
||||
futures = []
|
||||
for category, subcategories in categories.items():
|
||||
for subcategory in subcategories:
|
||||
for i in range(num_per_subcategory):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self._generate_single_dialog,
|
||||
category=category,
|
||||
subcategory=subcategory,
|
||||
export_path=export_path,
|
||||
num_per_subcategory=num_per_subcategory,
|
||||
i=i,
|
||||
total_tasks=total_tasks,
|
||||
current_task_counter=current_task_counter
|
||||
)
|
||||
)
|
||||
|
||||
# 获取结果
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
result = future.result()
|
||||
all_dialogs.append(result)
|
||||
except Exception as e:
|
||||
print(f"生成对话时出错: {str(e)}", flush=True)
|
||||
|
||||
print(f"\n已完成所有生成任务,共生成{len(all_dialogs)}条对话记录", flush=True)
|
||||
return all_dialogs
|
||||
|
||||
def export_to_excel(self, dialog: List[Dict], file_path: str, category: str, subcategory: str):
|
||||
"""将整个对话作为一条记录保存到Excel文件(追加模式)"""
|
||||
try:
|
||||
# 合并对话内容,格式为:1. [客服]内容
|
||||
dialog_text = "\n".join(
|
||||
[f"{turn['turn']}. {turn['speaker']} {turn['content']}"
|
||||
for turn in dialog]
|
||||
)
|
||||
|
||||
# 创建包含元数据的DataFrame
|
||||
record = {
|
||||
"分类": category,
|
||||
"子分类": subcategory,
|
||||
"对话轮数": len(dialog),
|
||||
"对话内容": dialog_text,
|
||||
}
|
||||
|
||||
df = pd.DataFrame([record])
|
||||
|
||||
# 如果文件存在则追加,否则创建新文件
|
||||
if os.path.exists(file_path):
|
||||
with pd.ExcelWriter(file_path, mode='a', engine='openpyxl', if_sheet_exists='overlay') as writer:
|
||||
# 读取现有数据
|
||||
existing_df = pd.read_excel(file_path)
|
||||
# 合并新旧数据
|
||||
combined_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
# 写入合并后的数据
|
||||
combined_df.to_excel(writer, index=False)
|
||||
else:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
df.to_excel(file_path, index=False)
|
||||
|
||||
print(f"对话已成功保存到: {file_path}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存Excel文件时出错: {str(e)}", flush=True)
|
||||
|
||||
def generate_complex_opening(self, category: str, subcategory: str) -> List[Tuple]:
|
||||
"""生成带复杂情形的开场对话"""
|
||||
phase = []
|
||||
special_case = random.choice(self.special_cases + [None]*3)
|
||||
|
||||
# 首先让客服说话
|
||||
response_text = "您好,我是政府热线服务,很高兴为您服务"
|
||||
if special_case == "老年人口齿不清":
|
||||
response_text += "(放慢语速)请您慢慢说"
|
||||
phase.append(("客服", "greeting", response_text))
|
||||
|
||||
# 然后市民反馈问题
|
||||
citizen_traits = {
|
||||
"方言": random.choice(["带浓重口音", "夹杂方言词汇", "语法不规范"]),
|
||||
"老年人": random.choice(["说话缓慢", "重复语句", "耳背听不清"]),
|
||||
"情绪化": random.choice(["不断打断", "提高音量", "带哭腔"])
|
||||
}
|
||||
opening_prompt = f"""生成市民反映{subcategory}问题的电话开场白,要求:
|
||||
1. 必须包含"您好"等礼貌用语
|
||||
2. 体现真实通话特征:{citizen_traits.get(special_case, "正常沟通")}
|
||||
3. 包含具体问题细节"""
|
||||
opening = self.safe_llm_call(
|
||||
prompt=opening_prompt,
|
||||
system="你擅长模拟各类人群的真实对话",
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
try:
|
||||
opening_data = json.loads(opening)
|
||||
opening_text = opening_data.get("text", f"您好,我要反映{subcategory}问题")
|
||||
if special_case == "方言沟通":
|
||||
opening_text = self.add_dialect_features(opening_text)
|
||||
except:
|
||||
opening_text = f"您好,我想投诉{subcategory}问题"
|
||||
phase.append(("市民", "open_call", opening_text))
|
||||
|
||||
# 如果需要确认问题
|
||||
if special_case in ["方言沟通", "老年人口齿不清", "信号断续"]:
|
||||
phase.append(("客服", "double_check", f"抱歉,刚才没有听清楚,您是说{subcategory}问题对吗?"))
|
||||
phase.append(("市民", "clarify", random.choice([
|
||||
"对,就是这个问题",
|
||||
f"不是,是{random.choice(['更严重','其他'])}的问题",
|
||||
"(声音断断续续)喂...听得到吗?"
|
||||
])))
|
||||
return phase
|
||||
|
||||
def generate_obstacle_base_phase(self, knowledge: Dict, scene: str) -> List[Tuple]:
|
||||
"""生成带沟通障碍的基础信息采集"""
|
||||
phase = []
|
||||
required_fields = ["时间", "地点", "事件描述", "联系方式", "姓氏"]
|
||||
for field in required_fields:
|
||||
if random.random() < 0.1:
|
||||
unclear_question = self.safe_llm_call(
|
||||
prompt=f"仅返回生成有歧义的{field}的询问话术,仅返回询问话术,不返回额外内容",
|
||||
system="故意制造1-2处不明确表述"
|
||||
) or f"那个...关于{field}的情况能不能说下?"
|
||||
phase.append(("客服", "unclear_question", unclear_question))
|
||||
phase.append(("市民", "confused", "您问的是什么?我没听明白"))
|
||||
question = self.safe_llm_call(
|
||||
prompt=f"仅返回重新生成清晰的{field}询问话术",
|
||||
system="使用最简明的表达"
|
||||
) or f"请提供{field}的具体信息"
|
||||
phase.append(("客服", "retry_question", question))
|
||||
else:
|
||||
question = self.safe_llm_call(
|
||||
prompt=f"仅返回生成政务热线询问{field}的标准话术,场景:{scene},仅返回询问话术,不返回额外内容",
|
||||
system="要求:1.使用敬语 2.明确信息要求"
|
||||
) or f"请问{scene}的{field}是?"
|
||||
phase.append(("客服", "info_request", question))
|
||||
answer, needs_clarify = self.generate_complex_answer(scene, field, question)
|
||||
phase.append(("市民", "info_response", answer))
|
||||
if needs_clarify:
|
||||
clarify_question = self.safe_llm_call(
|
||||
prompt=f"仅返回根据模糊回答'{answer}'生成澄清{field}的追问,仅返回追问内容,不返回额外内容",
|
||||
system="要求:1.在追问中指出不明确处 2.进行礼貌的追问"
|
||||
) or f"您提供的{field}不够具体,请补充(例:{self.get_field_example(field)})"
|
||||
phase.append(("客服", "clarify_request", clarify_question))
|
||||
if random.random() < 0.1:
|
||||
phase.append(("市民", "refuse", random.choice([
|
||||
"这么麻烦不说了!",
|
||||
"你们政府办事就是繁琐",
|
||||
f"{field}有什么好问的!"
|
||||
])))
|
||||
phase.append(("客服", "calm_down", random.choice([
|
||||
"理解您的心情,但详细信息能帮助我们更快解决问题",
|
||||
"抱歉给您带来不便,这是必要流程"
|
||||
])))
|
||||
phase.append(("市民", "clarified_response", f"哦,应该是{self.get_field_example(field)}"))
|
||||
return phase
|
||||
|
||||
def generate_complex_answer(self, scene: str, field: str, question) -> Tuple[str, bool]:
|
||||
"""生成带复杂特征的市民回答"""
|
||||
if random.random() < 0.15:
|
||||
special_answers = {
|
||||
"时间": [
|
||||
("就...就那个...前几天", True),
|
||||
("(背景嘈杂)喂?时间啊...上周?", True),
|
||||
("我不记得了!你们自己查!", False)
|
||||
],
|
||||
"地点": [
|
||||
("俺们村东头那个...那个啥来着", True),
|
||||
("(信号不好)在...哗哗...超市附近", True),
|
||||
("这么简单的问题都处理不了?", False)
|
||||
]
|
||||
}
|
||||
return random.choice(special_answers.get(field, [("这个我说不好", True)]))
|
||||
answers = {
|
||||
"时间": [
|
||||
(f"{random.choice(['今天','昨天'])}{random.randint(1,12)}点左右", False),
|
||||
(f"持续{random.randint(2,24)}小时了", False)
|
||||
],
|
||||
"地点": [
|
||||
(f"{self.faker.building_number()}号{random.choice(['东侧','南门'])}", False),
|
||||
(f"{self.faker.street_name()}附近", True)
|
||||
],
|
||||
"联系方式": [
|
||||
(f"{self.faker.phone_number()[:3]}****", True),
|
||||
(f"固话:{self.faker.phone_number()[:4]}-{self.faker.phone_number()[-4:]}", False)
|
||||
],
|
||||
"姓氏": [
|
||||
(f"免贵姓{self.faker.last_name()}", False),
|
||||
("叫我老李就行", True)
|
||||
]
|
||||
}
|
||||
common_answer = self.safe_llm_call(
|
||||
prompt = f"""仅返回模拟市民对'{question}'的真实回答,要求:1. 包含具体{field}的细节数据。 2. 反映真实诉求和情绪梯度。""",
|
||||
system="你是一个普通市民,回答要口语化并带生活细节"
|
||||
)
|
||||
|
||||
return random.choice(answers.get(field, [(common_answer, False)]))
|
||||
|
||||
def generate_verification_with_challenges(self, previous_dialog: List[Tuple]) -> List[Tuple]:
|
||||
"""生成带挑战的信息确认环节"""
|
||||
phase = []
|
||||
collected_info = {}
|
||||
for turn in previous_dialog:
|
||||
if turn[1] in ["info_response", "clarified_response"]:
|
||||
for field in ["时间", "地点", "姓氏"]:
|
||||
if field in turn[2]:
|
||||
collected_info[field] = turn[2]
|
||||
if random.random() < 0.1:
|
||||
collected_info[field] = self.get_wrong_info(field)
|
||||
if collected_info:
|
||||
if random.random() < 0.05:
|
||||
wrong_field = random.choice(list(collected_info.keys()))
|
||||
correct_value = collected_info[wrong_field]
|
||||
collected_info[wrong_field] = self.get_wrong_info(wrong_field)
|
||||
verification_text = self.safe_llm_call(
|
||||
prompt="仅返回根据以下信息生成确认话术:" + json.dumps(collected_info, ensure_ascii=False),
|
||||
system="要求:1.逐项确认 2.允许修正"
|
||||
) or f"我确认下:时间:{collected_info.get('时间','')},地点:{collected_info.get('地点','')}..."
|
||||
phase.append(("客服", "info_verification", verification_text))
|
||||
if random.random() < 0.3:
|
||||
correction_field = random.choice(list(collected_info.keys()))
|
||||
phase.append(("市民", "correction",
|
||||
f"{correction_field}不对!应该是{self.get_field_example(correction_field)}"))
|
||||
if random.random() < 0.1:
|
||||
phase.append(("市民", "angry", "你们连基本信息都记错!"))
|
||||
phase.append(("客服", "apology", "非常抱歉,这是我们的失误"))
|
||||
phase.append(("客服", "acknowledge_correction", f"已更正{correction_field}信息"))
|
||||
phase.append(("市民", "final_confirmation", "现在对了"))
|
||||
else:
|
||||
phase.append(("市民", "confirmation", "对,没错"))
|
||||
return phase
|
||||
|
||||
def generate_technical_extend_phase(self, knowledge: Dict, scene: str) -> List[Tuple]:
|
||||
"""生成带技术障碍的扩展追问"""
|
||||
phase = []
|
||||
for question_config in knowledge.get("extend_questions", []):
|
||||
# 确保question变量总是有值
|
||||
question = question_config.get('prompt','') # 默认值
|
||||
|
||||
if random.random() < 0.05:
|
||||
tech_question = self.safe_llm_call(
|
||||
prompt=f"仅返回生成包含专业术语的{scene}问题",
|
||||
system="使用3个以上专业词汇"
|
||||
) or f"请问{scene}的{random.choice(['频谱特征','声压级衰减曲线'])}是怎样的?"
|
||||
phase.append(("客服", "technical_question", tech_question))
|
||||
phase.append(("市民", "not_understand", "这些专业名词听不懂"))
|
||||
simplified = self.safe_llm_call(
|
||||
prompt=f"仅将'{tech_question}'转化为的通俗问题",
|
||||
system="用生活化比喻解释"
|
||||
) or f"就是问{scene}的具体表现是怎样的"
|
||||
question = simplified # 更新question变量
|
||||
phase.append(("客服", "simplified_question", simplified))
|
||||
else:
|
||||
generated_question = self.safe_llm_call(
|
||||
prompt=f"仅返回基于{scene}场景生成的追问:{question_config.get('prompt','')}",
|
||||
system="要求:1.分步骤询问 2.适度专业"
|
||||
)
|
||||
question = generated_question or question_config.get('prompt','') # 确保question有值
|
||||
phase.append(("客服", "extend_question", question))
|
||||
|
||||
# 现在question变量肯定有值
|
||||
if random.random() < 0.15:
|
||||
phase.append(("市民", "broken_response", "喂?...听得到吗?...我说到哪了?"))
|
||||
phase.append(("客服", "reassure", "电话不太稳定,请您继续"))
|
||||
|
||||
answer = self.generate_realistic_answer(
|
||||
question, scene, question_config.get("theme",""), "extend"
|
||||
)
|
||||
phase.append(("市民", "extend_answer", answer))
|
||||
|
||||
if random.random() < 0.1:
|
||||
phase.append(("客服", "request_material", "需要您提供现场照片或录音证据"))
|
||||
phase.append(("市民", "material_response", random.choice([
|
||||
"我手机里有,怎么发给你们?",
|
||||
"现在拍不了,你们自己来看!"
|
||||
])))
|
||||
phase.append(("客服", "guide", "可以通过微信公众号'市民服务'上传"))
|
||||
return phase
|
||||
|
||||
def generate_final_confirmation(self, knowledge: Dict, scene: str) -> List[Tuple]:
|
||||
"""生成最终确认"""
|
||||
phase = []
|
||||
confirmation = self.safe_llm_call(
|
||||
prompt=f"仅返回生成{scene}问题的最终确认话术",
|
||||
system="包含:1.处理时限 2.反馈方式 3.应急联系人"
|
||||
) or f"我们将在{random.choice(['24小时','3个工作日'])}内处理您的{scene}问题"
|
||||
phase.append(("客服", "final_confirmation", confirmation))
|
||||
if random.random() < 0.2:
|
||||
phase.append(("市民", "follow_up", random.choice([
|
||||
"如果超时没处理怎么办?",
|
||||
"我要找哪个部门跟进?"
|
||||
])))
|
||||
phase.append(("客服", "replay", random.choice([
|
||||
"可拨打监督电话12345查询进度",
|
||||
"我们会主动给您回复"
|
||||
])))
|
||||
return phase
|
||||
|
||||
def generate_scene_knowledge(self, category: str, subcategory: str) -> Dict:
|
||||
"""动态生成场景知识图谱"""
|
||||
prompt = f"""作为政务热线专家,请为【{category}->{subcategory}】场景生成知识配置,包含:
|
||||
1. 3-5个必问基础字段(如时间、地点)
|
||||
2. 3个专业追问方向及追问话术模板
|
||||
3. 该场景涉及的相关部门和处理时限参考
|
||||
仅返回JSON格式,结构示例:
|
||||
{{
|
||||
"base_fields": [
|
||||
{{"field": "时间", "prompt": "询问具体时间的标准话术"}},
|
||||
{{"field": "地点", "prompt": "询问详细位置的专业话术"}}
|
||||
],
|
||||
"extend_questions": [
|
||||
{{"theme": "历史记录", "prompt": "追问历史投诉情况的专业话术"}},
|
||||
{{"theme": "紧急程度", "prompt": "评估问题紧急程度的询问方式"}}
|
||||
],
|
||||
"departments": ["城管局", "环保局"],
|
||||
"time_ranges": ["24小时内", "3个工作日"]
|
||||
}}"""
|
||||
response = self.safe_llm_call(
|
||||
prompt=prompt,
|
||||
system="你是有10年经验的政务热线系统架构师",
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
try:
|
||||
knowledge = json.loads(response)
|
||||
knowledge["confirmation_template"] = self.generate_confirmation_template(
|
||||
category, subcategory, knowledge.get("departments", []), knowledge.get("time_ranges", [])
|
||||
)
|
||||
return knowledge
|
||||
except:
|
||||
return self.get_fallback_knowledge(category, subcategory)
|
||||
|
||||
def generate_confirmation_template(self, category: str, subcategory: str,
|
||||
departments: List[str], time_ranges: List[str]) -> str:
|
||||
"""生成确认话术模板"""
|
||||
prompt = f"""为【{category}->{subcategory}】创建确认话术模板,要求包含:
|
||||
1. 处理部门:{departments}
|
||||
2. 预计时限:{time_ranges}
|
||||
3. 至少2种后续跟进方式
|
||||
模板示例:\"我们将协调{{department}}在{{timeframe}}内处理,可通过{{phone}}或{{wechat}}查询进展\"
|
||||
"""
|
||||
return self.safe_llm_call(
|
||||
prompt=prompt,
|
||||
system="你需创建可参数化的文本模板,用{}标记变量位置"
|
||||
) or f"我们将尽快处理您的{subcategory}问题"
|
||||
|
||||
def generate_realistic_answer(self, question: str, scene: str,
|
||||
field: str, answer_type: str) -> str:
|
||||
"""生成高真实性回答"""
|
||||
prompt = f"""仅返回模拟市民对【{scene}】问题中'{question}'的真实回答,要求:
|
||||
1. 包含具体{field}的细节数据
|
||||
2. 反映真实诉求和情绪梯度
|
||||
3. 使用该场景典型市民的语言特征"""
|
||||
system = {
|
||||
"base": "你是一个普通市民,回答要口语化并带生活细节",
|
||||
"extend": "你是有相关专业知识的市民,回答要包含技术参数和量化描述"
|
||||
}[answer_type]
|
||||
answer = self.safe_llm_call(prompt=prompt, system=system)
|
||||
return answer or self.get_field_example(field)
|
||||
|
||||
def get_field_example(self, field: str) -> str:
|
||||
"""获取字段示例"""
|
||||
examples = {
|
||||
"时间": "2023年10月15日下午3点20分",
|
||||
"地点": "朝阳区建国路88号地下二层停车场",
|
||||
"联系方式": "13800138000或010-12345678",
|
||||
"姓氏": "张先生/李女士"
|
||||
}
|
||||
return examples.get(field, "具体情况是这样的...")
|
||||
|
||||
def get_fallback_knowledge(self, category: str, subcategory: str) -> Dict:
|
||||
"""应急知识库"""
|
||||
return {
|
||||
"base_fields": [
|
||||
{"field": "时间", "prompt": f"请问{subcategory}发生的具体时间?"},
|
||||
{"field": "地点", "prompt": f"请说明{category}问题的详细位置?"}
|
||||
],
|
||||
"extend_questions": [
|
||||
{"theme": "基本情况", "prompt": f"请描述{subcategory}的具体表现?"}
|
||||
],
|
||||
"confirmation_template": f"我们将处理您的{category}问题",
|
||||
"departments": ["相关部门"],
|
||||
"time_ranges": ["尽快"]
|
||||
}
|
||||
|
||||
def add_dialect_features(self, text: str) -> str:
|
||||
"""添加方言特征"""
|
||||
dialects = {
|
||||
"北方方言": [("我", "俺"), ("的", "滴"), ("这个", "这玩意儿")],
|
||||
"南方方言": [("是不是", "系唔系"), ("不知道", "母鸡"), ("说", "讲")]
|
||||
}
|
||||
dialect_type, replacements = random.choice(list(dialects.items()))
|
||||
for orig, rep in replacements:
|
||||
if orig in text:
|
||||
return text.replace(orig, rep)
|
||||
return text + random.choice(["晓得伐?", "中不中?", "得啵?"])
|
||||
|
||||
def get_wrong_info(self, field) -> str:
|
||||
"""生成错误信息"""
|
||||
wrong_examples = {
|
||||
"时间": random.choice(["昨天", "上周", "记不清了"]),
|
||||
"地点": random.choice(["东边", "路口", "大概位置"]),
|
||||
"姓氏": random.choice(["王", "李", "张"])
|
||||
}
|
||||
return wrong_examples.get(field, "信息有误")
|
||||
|
||||
def safe_llm_call(self, prompt: str, system: str = None,**kwargs) -> str:
|
||||
"""带熔断机制的API调用"""
|
||||
try:
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 400
|
||||
}
|
||||
|
||||
# 处理response_format参数
|
||||
if "response_format" in kwargs:
|
||||
data["response_format"] = kwargs["response_format"]
|
||||
|
||||
response = requests.post(
|
||||
self.model_url,
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
else:
|
||||
print(f"API调用失败: {response.status_code}, {response.text}", flush=True)
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"API异常: {str(e)}", flush=True)
|
||||
return ""
|
||||
|
||||
def format_output(self, dialog: List[Tuple]) -> List[Dict]:
|
||||
"""格式化输出,移除[xxx]类型标签"""
|
||||
formatted = []
|
||||
for idx, (speaker, dtype, content) in enumerate(dialog):
|
||||
# 移除类型标签,只保留说话人
|
||||
formatted.append({
|
||||
"turn": idx+1,
|
||||
"speaker": f"[{speaker}]",
|
||||
"content": content
|
||||
})
|
||||
return formatted
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import multiprocessing
|
||||
import threading
|
||||
|
||||
generator = FullyDynamicGenerator()
|
||||
|
||||
# 示例文件路径
|
||||
excel_path = "/data/zhaochsh01/buquan/12345/zaoshu/count_3level.xlsx"
|
||||
|
||||
# 读取并生成categories_config
|
||||
categories_config = read_categories_config(excel_path)
|
||||
|
||||
# 打印结果
|
||||
print("生成的categories_config:", flush=True)
|
||||
for level2, level3_list in categories_config.items():
|
||||
print(f"{level2}: {level3_list}", flush=True)
|
||||
|
||||
num_per_subcategory = 2 # 每个子分类生成3条数据
|
||||
output_file = "./output/政务热线对话记录更新.xlsx"
|
||||
|
||||
# 批量生成数据
|
||||
generator.generate_dialogs_in_batch(
|
||||
categories=categories_config,
|
||||
num_per_subcategory=num_per_subcategory,
|
||||
export_path=output_file
|
||||
)
|
||||
|
||||
# 示例:打印最后生成的5条记录
|
||||
sample_df = pd.read_excel(output_file)
|
||||
print("\n=== 最后5条记录示例 ===", flush=True)
|
||||
print(sample_df.tail(), flush=True)
|
205
智数员工/zhaopin_zaoshu.py
Normal file
205
智数员工/zhaopin_zaoshu.py
Normal file
@ -0,0 +1,205 @@
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook
|
||||
import requests
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# 配置参数
|
||||
class Config:
|
||||
OUTPUT_FILE = "recruitment_data.xlsx"
|
||||
FIXED_QUESTIONS = [
|
||||
"上个月面试了多少人",
|
||||
"本周安排了几个面试",
|
||||
"招聘进度如何",
|
||||
"有多少候选人进入二面",
|
||||
"销售岗位的招聘情况",
|
||||
"技术岗位的简历筛选数量",
|
||||
"最近一周的offer发放数量",
|
||||
"哪个部门的招聘完成率最高",
|
||||
"招聘成本是否超出预算",
|
||||
"候选人平均面试周期是多长"
|
||||
]
|
||||
LOCATIONS = ["北京", "上海", "广州", "深圳", "杭州", "", "成都"]
|
||||
INTENTS = ["招聘数据", "招聘进度", "其他", "成本分析", "效率统计"]
|
||||
COMMISSIONER_TYPES = ["yxz", "hrbp", "recruiter", "manager"]
|
||||
USER_NAMES = ["张招聘", "李HR", "王人事", "赵经理", "刘专员"]
|
||||
|
||||
|
||||
|
||||
|
||||
async def chat(input_content):
|
||||
response = requests.post(
|
||||
api_url = "http://100.105.1.227:8000/v1/chat/completions",
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "7c3eafb5-2d6e-100d-ab0f-7b2c1cdafb3c"
|
||||
},
|
||||
json={
|
||||
"model": "Qwen3-72B",
|
||||
"stream": False,
|
||||
"temperature": 0.6,
|
||||
"TopP": 0.95,
|
||||
"TopK": 20,
|
||||
"MinP": 0,
|
||||
"messages": [{"role": "user", "content": input_content}]
|
||||
},
|
||||
timeout=180
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
result = response.json()["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing API response: {e}")
|
||||
else:
|
||||
logging.error(f"API request failed with status code: {response.status_code}")
|
||||
await asyncio.sleep(0.1)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
# 模拟模型生成多样化问题
|
||||
async def generate_diverse_questions() -> List[str]:
|
||||
# 这里应该是实际调用模型生成多样化问题的代码
|
||||
# 模拟生成几个变体问题
|
||||
|
||||
input_content = """你是一个资深HR分析师。请生成一个招聘数据分析的查询请求,要求:
|
||||
- 聚焦在以下至少一个方面:面试、offer、入职、渠道效果、成本、周期时间
|
||||
- 包含具体的时间范围(如最近一周/上月/本季度)
|
||||
- 可选项包含部门/岗位/地域等维度
|
||||
- 直接返回问题,不要任何解释
|
||||
|
||||
例如:
|
||||
对比北京和上海地区过去两个月销售岗位的offer接受率"""
|
||||
gen_question = chat(input_content)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return gen_question
|
||||
|
||||
|
||||
|
||||
# 生成招聘相关的输入数据
|
||||
async def generate_input_data(use_fixed: bool = True) -> Dict[str, Any]:
|
||||
if random.random() > 0.3:
|
||||
base_question = random.choice(Config.FIXED_QUESTIONS)
|
||||
else:
|
||||
|
||||
base_question = await generate_diverse_questions()
|
||||
|
||||
|
||||
return {
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": base_question
|
||||
}],
|
||||
"location": random.choice(Config.LOCATIONS),
|
||||
"uuid": str(random.randint(1e18, 1e19-1)),
|
||||
"intent": random.choice(Config.INTENTS),
|
||||
"loginUserName": random.choice(Config.USER_NAMES),
|
||||
"loginUserId": "hr_" + str(random.randint(1000, 9999)),
|
||||
"commissioner_type": random.choice(Config.COMMISSIONER_TYPES)
|
||||
}
|
||||
|
||||
# 处理单个请求
|
||||
async def process_request(input_data: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
try:
|
||||
input_content = f"""
|
||||
你是一个专业招聘数据分析助手。请按以下规则处理问题:
|
||||
1. 如果问题已包含明确且可清晰回答,直接返回原问题
|
||||
2. 如果问题模糊或不完整,按标准改写:
|
||||
- 补充时间范围(最近/上月/本季度等)
|
||||
- 明确量化指标(数量/比率/趋势等)
|
||||
- 指定具体对象(岗位/部门/渠道等)
|
||||
3. 直接返回最终问题,不要任何解释
|
||||
|
||||
待处理问题:{input_data}
|
||||
"""
|
||||
user_content = input_data["messages"][0]["content"]
|
||||
rewritten_question = await chat(input_content)
|
||||
|
||||
output_data = {
|
||||
"code": "0",
|
||||
"message": "",
|
||||
"result": rewritten_question
|
||||
}
|
||||
return input_data, output_data
|
||||
except Exception as e:
|
||||
output_data = {
|
||||
"code": "1",
|
||||
"message": str(e),
|
||||
"result": ""
|
||||
}
|
||||
return input_data, output_data
|
||||
|
||||
# 保存数据到Excel
|
||||
def save_to_excel(data: List[Dict[str, Any]], filename: str):
|
||||
rows = []
|
||||
for item in data:
|
||||
input_data = item["input"]
|
||||
output_data = item["output"]
|
||||
|
||||
row = {
|
||||
"输入问题": input_data["messages"][0]["content"],
|
||||
"输出问题": output_data["result"],
|
||||
"地点": input_data["location"],
|
||||
"UUID": input_data["uuid"],
|
||||
"意图": input_data["intent"],
|
||||
"用户名": input_data["loginUserName"],
|
||||
"用户ID": input_data["loginUserId"],
|
||||
"专员类型": input_data["commissioner_type"],
|
||||
"状态码": output_data["code"],
|
||||
"消息": output_data["message"]
|
||||
}
|
||||
rows.append(row)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
df.to_excel(filename, index=False, engine='openpyxl')
|
||||
print(f"数据已保存到 {filename}")
|
||||
|
||||
# 并发生成数据
|
||||
async def generate_data(num_samples: int) -> List[Dict[str, Any]]:
|
||||
# 首先生成所有输入数据
|
||||
input_tasks = [generate_input_data() for _ in range(num_samples)]
|
||||
input_data_list = await asyncio.gather(*input_tasks)
|
||||
|
||||
# 然后并发处理所有请求
|
||||
process_tasks = [process_request(input_data) for input_data in input_data_list]
|
||||
results = await asyncio.gather(*process_tasks)
|
||||
|
||||
# 组合结果
|
||||
output = []
|
||||
for input_data, output_data in results:
|
||||
output.append({
|
||||
"input": input_data,
|
||||
"output": output_data
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
# 主函数
|
||||
async def main():
|
||||
try:
|
||||
|
||||
num_samples = 2000
|
||||
print(f"开始生成 {num_samples} 条招聘数据...")
|
||||
data_pairs = await generate_data(num_samples)
|
||||
|
||||
save_to_excel(data_pairs, Config.OUTPUT_FILE)
|
||||
|
||||
# 打印前3条样本
|
||||
print("\n样本示例:")
|
||||
for i, pair in enumerate(data_pairs[:3], 1):
|
||||
print(f"样本 {i}:")
|
||||
print("输入问题:", pair["input"]["messages"][0]["content"])
|
||||
print("输出问题:", pair["output"]["result"])
|
||||
print("-" * 50)
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Reference in New Issue
Block a user