offline_data_model_pipline/data_generate/zw12345/yanpanbaogaozongjie_demo.py
2025-05-12 14:18:19 +08:00

206 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from openai import OpenAI, APIError, RateLimitError, AuthenticationError
import csv
import json
import re
import ast
from typing import List, Dict, Tuple
class GovDataGenerator:
def __init__(self, api_key: str):
api_key = "sk-af3d1a3ed4d64df09249383a76fa12f4"
base_url = "https://api.deepseek.com"
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.config_cache = {}
self._init_prompts()
def _init_prompts(self):
"""Initialize all prompt templates"""
self.base_prompt = """请生成与【{title}】相关的5个真实政务12345业务案例每个案例包含
1. 当事人(如张先生)+ 问题场景如XX街道XX小区
2. 业务类型(如社保缴纳)+ 具体问题(如未回复)
3. 涉及单位(如医保局)+ 时间要素如8月17日
4. 证件编号如370181XXXXXXXXXX+ 政策条件如连续缴费12个月
每个案例生成6-8个关键词格式
[案例1] 关键词关键词1关键词2...
[案例2] 关键词关键词A关键词B...
请用中文逗号分隔,不要编号"""
self.config_prompt = """请根据政务领域【{title}】生成:
1. 3-5个核心业务分类categories
2. 1条特别生成要求requirements
示例(保险领域):
categories: ["医疗保险", "失业保险", "养老保险", "生育保险"]
requirements: "需包含医保报销和生育津贴案例各1个"
请用JSON格式返回{{"categories": [], "requirements": ""}}"""
def _call_gpt(self, prompt: str, **kwargs) -> str:
"""统一GPT调用入口适配v0.28+"""
try:
response = self.client.chat.completions.create(
model="deepseek-chat",
temperature=kwargs.get('temperature', 0.7),
max_tokens=kwargs.get('max_tokens', 800),
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": prompt},
],
stream=False
)
return response.choices[0].message.content
except RateLimitError:
raise Exception(f"请求超频请等待后重试错误码429")
except AuthenticationError:
raise Exception("API密钥无效请检查密钥是否正确")
except APIError as e:
raise Exception(f"API错误: {e.code} - {e.message}")
except Exception as e:
raise Exception(f"请求失败: {str(e)}")
def generate_dynamic_config(self, title: str) -> Tuple[List[str], str]:
"""动态生成领域配置"""
if title in self.config_cache:
return self.config_cache[title]
try:
prompt = self.config_prompt.format(title=title)
raw_text = self._call_gpt(prompt, temperature=0.5, max_tokens=300)
parsed = self.safe_parse_config(raw_text)
# 验证数据结构
if not isinstance(parsed.get("categories", []), list) or \
not isinstance(parsed.get("requirements", ""), str):
raise ValueError("配置格式错误")
self.config_cache[title] = (parsed["categories"], parsed["requirements"])
return self.config_cache[title]
except Exception as e:
print(f"配置生成失败: {str(e)},使用默认配置")
return [], ""
def safe_parse_config(self, text: str) -> Dict:
"""安全解析配置响应"""
try:
# 尝试标准JSON解析
json_str = re.search(r'\{.*\}', text, re.DOTALL).group()
return json.loads(json_str)
except json.JSONDecodeError:
# 尝试容错解析
try:
return ast.literal_eval(json_str.replace('"', "'"))
except:
# 最终容错方案
return {
"categories": list(set(re.findall(r'"([^"]+)"', text)))[:3],
"requirements": re.split(r'[:]', text.split("requirements")[-1])[-1].strip('"\'')
}
except Exception:
return {"categories": [], "requirements": ""}
def get_prompt(self, title: str) -> str:
"""构建动态提示"""
categories, requirements = self.generate_dynamic_config(title)
prompt = self.base_prompt.format(title=title)
if requirements:
prompt += f"\n特别要求:{requirements}"
if categories:
prompt += f"\n参考分类:{', '.join(categories[:3])}..."
return prompt
def generate_keywords(self, title: str) -> List[str]:
"""生成关键词主流程"""
try:
# 获取配置和生成内容
categories, _ = self.generate_dynamic_config(title)
prompt = self.get_prompt(title)
raw_text = self._call_gpt(prompt)
# 处理响应
return self.process_response(raw_text, categories)
except Exception as e:
print(f"关键词生成失败: {str(e)}")
return []
def process_response(self, raw_text: str, categories: List[str]) -> List[str]:
"""处理生成的响应内容"""
# 解析案例
cases = re.findall(r'关键词:(.*?)(?=\n\[案例|\n\n|$)', raw_text, re.DOTALL)
keywords = []
for case in cases:
keywords.extend([k.strip() for k in case.split('') if k.strip()])
# 分级抽样
sampled = []
if categories:
for cat in categories[:3]: # 取前3个分类
matches = [k for k in keywords if cat in k][:5] # 每个分类最多取5个
sampled.extend(matches)
# 合并去重
seen = set()
final = []
for k in sampled + keywords:
if k not in seen:
seen.add(k)
final.append(k)
return self.post_process(final[:60])[:50] # 最终保留50个
def post_process(self, keywords: List[str]) -> List[str]:
"""后处理管道"""
processed = []
for k in keywords:
# 标准化处理
k = re.sub(r'\s+', ' ', k).strip()
# 过滤无效条目
if len(k) < 4 or '...' in k:
continue
# 增强关键条目
if any(c.isdigit() for c in k) or re.search(r'[市区街道]', k):
processed.append(k)
# 优先级排序
priority_terms = ['社保', '医保', '房产证', '施工许可']
return sorted(processed,
key=lambda x: any(t in x for t in priority_terms),
reverse=True)
def save_to_tsv(self, data: List[dict], filename: str):
"""保存结果到TSV文件"""
with open(filename, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(['title', 'contents'])
for item in data:
writer.writerow([
item['title'],
json.dumps(item['keywords'], ensure_ascii=False)
])
if __name__ == "__main__":
# 使用示例
generator = GovDataGenerator("sk-your-api-key-here")
# 生成数据
results = []
for domain in ["保险", "城市管理", "房产"]:
print(f"正在处理:{domain}")
keywords = generator.generate_keywords(domain)
results.append({
"title": domain,
"keywords": keywords
})
print(f"生成完成,获得{len(keywords)}个关键词")
# 保存结果
generator.save_to_tsv(results, "government_data.tsv")
print("数据已保存至 government_data.tsv")