offline_data_model_pipline/data_generate/zw12345/yanpanbaogaozongjie_demo.py

206 lines
7.8 KiB
Python
Raw Normal View History

2025-05-12 14:18:19 +08:00
from openai import OpenAI, APIError, RateLimitError, AuthenticationError
import csv
import json
import re
import ast
from typing import List, Dict, Tuple
class GovDataGenerator:
def __init__(self, api_key: str):
api_key = "sk-af3d1a3ed4d64df09249383a76fa12f4"
base_url = "https://api.deepseek.com"
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.config_cache = {}
self._init_prompts()
def _init_prompts(self):
"""Initialize all prompt templates"""
self.base_prompt = """请生成与【{title}】相关的5个真实政务12345业务案例每个案例包含
1. 当事人如张先生+ 问题场景如XX街道XX小区
2. 业务类型如社保缴纳+ 具体问题如未回复
3. 涉及单位如医保局+ 时间要素如8月17日
4. 证件编号如370181XXXXXXXXXX+ 政策条件如连续缴费12个月
每个案例生成6-8个关键词格式
[案例1] 关键词关键词1关键词2...
[案例2] 关键词关键词A关键词B...
请用中文逗号分隔不要编号"""
self.config_prompt = """请根据政务领域【{title}】生成:
1. 3-5个核心业务分类categories
2. 1条特别生成要求requirements
示例保险领域
categories: ["医疗保险", "失业保险", "养老保险", "生育保险"]
requirements: "需包含医保报销和生育津贴案例各1个"
请用JSON格式返回{{"categories": [], "requirements": ""}}"""
def _call_gpt(self, prompt: str, **kwargs) -> str:
"""统一GPT调用入口适配v0.28+"""
try:
response = self.client.chat.completions.create(
model="deepseek-chat",
temperature=kwargs.get('temperature', 0.7),
max_tokens=kwargs.get('max_tokens', 800),
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": prompt},
],
stream=False
)
return response.choices[0].message.content
except RateLimitError:
raise Exception(f"请求超频请等待后重试错误码429")
except AuthenticationError:
raise Exception("API密钥无效请检查密钥是否正确")
except APIError as e:
raise Exception(f"API错误: {e.code} - {e.message}")
except Exception as e:
raise Exception(f"请求失败: {str(e)}")
def generate_dynamic_config(self, title: str) -> Tuple[List[str], str]:
"""动态生成领域配置"""
if title in self.config_cache:
return self.config_cache[title]
try:
prompt = self.config_prompt.format(title=title)
raw_text = self._call_gpt(prompt, temperature=0.5, max_tokens=300)
parsed = self.safe_parse_config(raw_text)
# 验证数据结构
if not isinstance(parsed.get("categories", []), list) or \
not isinstance(parsed.get("requirements", ""), str):
raise ValueError("配置格式错误")
self.config_cache[title] = (parsed["categories"], parsed["requirements"])
return self.config_cache[title]
except Exception as e:
print(f"配置生成失败: {str(e)},使用默认配置")
return [], ""
def safe_parse_config(self, text: str) -> Dict:
"""安全解析配置响应"""
try:
# 尝试标准JSON解析
json_str = re.search(r'\{.*\}', text, re.DOTALL).group()
return json.loads(json_str)
except json.JSONDecodeError:
# 尝试容错解析
try:
return ast.literal_eval(json_str.replace('"', "'"))
except:
# 最终容错方案
return {
"categories": list(set(re.findall(r'"([^"]+)"', text)))[:3],
"requirements": re.split(r'[:]', text.split("requirements")[-1])[-1].strip('"\'')
}
except Exception:
return {"categories": [], "requirements": ""}
def get_prompt(self, title: str) -> str:
"""构建动态提示"""
categories, requirements = self.generate_dynamic_config(title)
prompt = self.base_prompt.format(title=title)
if requirements:
prompt += f"\n特别要求:{requirements}"
if categories:
prompt += f"\n参考分类:{', '.join(categories[:3])}..."
return prompt
def generate_keywords(self, title: str) -> List[str]:
"""生成关键词主流程"""
try:
# 获取配置和生成内容
categories, _ = self.generate_dynamic_config(title)
prompt = self.get_prompt(title)
raw_text = self._call_gpt(prompt)
# 处理响应
return self.process_response(raw_text, categories)
except Exception as e:
print(f"关键词生成失败: {str(e)}")
return []
def process_response(self, raw_text: str, categories: List[str]) -> List[str]:
"""处理生成的响应内容"""
# 解析案例
cases = re.findall(r'关键词:(.*?)(?=\n\[案例|\n\n|$)', raw_text, re.DOTALL)
keywords = []
for case in cases:
keywords.extend([k.strip() for k in case.split('') if k.strip()])
# 分级抽样
sampled = []
if categories:
for cat in categories[:3]: # 取前3个分类
matches = [k for k in keywords if cat in k][:5] # 每个分类最多取5个
sampled.extend(matches)
# 合并去重
seen = set()
final = []
for k in sampled + keywords:
if k not in seen:
seen.add(k)
final.append(k)
return self.post_process(final[:60])[:50] # 最终保留50个
def post_process(self, keywords: List[str]) -> List[str]:
"""后处理管道"""
processed = []
for k in keywords:
# 标准化处理
k = re.sub(r'\s+', ' ', k).strip()
# 过滤无效条目
if len(k) < 4 or '...' in k:
continue
# 增强关键条目
if any(c.isdigit() for c in k) or re.search(r'[市区街道]', k):
processed.append(k)
# 优先级排序
priority_terms = ['社保', '医保', '房产证', '施工许可']
return sorted(processed,
key=lambda x: any(t in x for t in priority_terms),
reverse=True)
def save_to_tsv(self, data: List[dict], filename: str):
"""保存结果到TSV文件"""
with open(filename, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(['title', 'contents'])
for item in data:
writer.writerow([
item['title'],
json.dumps(item['keywords'], ensure_ascii=False)
])
if __name__ == "__main__":
# 使用示例
generator = GovDataGenerator("sk-your-api-key-here")
# 生成数据
results = []
for domain in ["保险", "城市管理", "房产"]:
print(f"正在处理:{domain}")
keywords = generator.generate_keywords(domain)
results.append({
"title": domain,
"keywords": keywords
})
print(f"生成完成,获得{len(keywords)}个关键词")
# 保存结果
generator.save_to_tsv(results, "government_data.tsv")
print("数据已保存至 government_data.tsv")