206 lines
7.8 KiB
Python
206 lines
7.8 KiB
Python
|
from openai import OpenAI, APIError, RateLimitError, AuthenticationError
|
|||
|
import csv
|
|||
|
import json
|
|||
|
import re
|
|||
|
import ast
|
|||
|
from typing import List, Dict, Tuple
|
|||
|
|
|||
|
|
|||
|
class GovDataGenerator:
|
|||
|
def __init__(self, api_key: str):
|
|||
|
api_key = "sk-af3d1a3ed4d64df09249383a76fa12f4"
|
|||
|
base_url = "https://api.deepseek.com"
|
|||
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
|||
|
self.config_cache = {}
|
|||
|
self._init_prompts()
|
|||
|
|
|||
|
def _init_prompts(self):
|
|||
|
"""Initialize all prompt templates"""
|
|||
|
self.base_prompt = """请生成与【{title}】相关的5个真实政务12345业务案例,每个案例包含:
|
|||
|
1. 当事人(如张先生)+ 问题场景(如XX街道XX小区)
|
|||
|
2. 业务类型(如社保缴纳)+ 具体问题(如未回复)
|
|||
|
3. 涉及单位(如医保局)+ 时间要素(如8月17日)
|
|||
|
4. 证件编号(如370181XXXXXXXXXX)+ 政策条件(如连续缴费12个月)
|
|||
|
|
|||
|
每个案例生成6-8个关键词,格式:
|
|||
|
[案例1] 关键词:关键词1,关键词2,...
|
|||
|
[案例2] 关键词:关键词A,关键词B,...
|
|||
|
请用中文逗号分隔,不要编号"""
|
|||
|
|
|||
|
self.config_prompt = """请根据政务领域【{title}】生成:
|
|||
|
1. 3-5个核心业务分类(categories)
|
|||
|
2. 1条特别生成要求(requirements)
|
|||
|
|
|||
|
示例(保险领域):
|
|||
|
categories: ["医疗保险", "失业保险", "养老保险", "生育保险"]
|
|||
|
requirements: "需包含医保报销和生育津贴案例各1个"
|
|||
|
|
|||
|
请用JSON格式返回:{{"categories": [], "requirements": ""}}"""
|
|||
|
|
|||
|
def _call_gpt(self, prompt: str, **kwargs) -> str:
|
|||
|
"""统一GPT调用入口(适配v0.28+)"""
|
|||
|
try:
|
|||
|
response = self.client.chat.completions.create(
|
|||
|
model="deepseek-chat",
|
|||
|
temperature=kwargs.get('temperature', 0.7),
|
|||
|
max_tokens=kwargs.get('max_tokens', 800),
|
|||
|
messages=[
|
|||
|
{"role": "system", "content": "You are a helpful assistant"},
|
|||
|
{"role": "user", "content": prompt},
|
|||
|
],
|
|||
|
stream=False
|
|||
|
)
|
|||
|
return response.choices[0].message.content
|
|||
|
|
|||
|
except RateLimitError:
|
|||
|
raise Exception(f"请求超频,请等待后重试(错误码429)")
|
|||
|
except AuthenticationError:
|
|||
|
raise Exception("API密钥无效,请检查密钥是否正确")
|
|||
|
except APIError as e:
|
|||
|
raise Exception(f"API错误: {e.code} - {e.message}")
|
|||
|
except Exception as e:
|
|||
|
raise Exception(f"请求失败: {str(e)}")
|
|||
|
|
|||
|
def generate_dynamic_config(self, title: str) -> Tuple[List[str], str]:
|
|||
|
"""动态生成领域配置"""
|
|||
|
if title in self.config_cache:
|
|||
|
return self.config_cache[title]
|
|||
|
|
|||
|
try:
|
|||
|
prompt = self.config_prompt.format(title=title)
|
|||
|
raw_text = self._call_gpt(prompt, temperature=0.5, max_tokens=300)
|
|||
|
parsed = self.safe_parse_config(raw_text)
|
|||
|
|
|||
|
# 验证数据结构
|
|||
|
if not isinstance(parsed.get("categories", []), list) or \
|
|||
|
not isinstance(parsed.get("requirements", ""), str):
|
|||
|
raise ValueError("配置格式错误")
|
|||
|
|
|||
|
self.config_cache[title] = (parsed["categories"], parsed["requirements"])
|
|||
|
return self.config_cache[title]
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"配置生成失败: {str(e)},使用默认配置")
|
|||
|
return [], ""
|
|||
|
|
|||
|
def safe_parse_config(self, text: str) -> Dict:
|
|||
|
"""安全解析配置响应"""
|
|||
|
try:
|
|||
|
# 尝试标准JSON解析
|
|||
|
json_str = re.search(r'\{.*\}', text, re.DOTALL).group()
|
|||
|
return json.loads(json_str)
|
|||
|
except json.JSONDecodeError:
|
|||
|
# 尝试容错解析
|
|||
|
try:
|
|||
|
return ast.literal_eval(json_str.replace('"', "'"))
|
|||
|
except:
|
|||
|
# 最终容错方案
|
|||
|
return {
|
|||
|
"categories": list(set(re.findall(r'"([^"]+)"', text)))[:3],
|
|||
|
"requirements": re.split(r'[::]', text.split("requirements")[-1])[-1].strip('"\'')
|
|||
|
}
|
|||
|
except Exception:
|
|||
|
return {"categories": [], "requirements": ""}
|
|||
|
|
|||
|
def get_prompt(self, title: str) -> str:
|
|||
|
"""构建动态提示"""
|
|||
|
categories, requirements = self.generate_dynamic_config(title)
|
|||
|
prompt = self.base_prompt.format(title=title)
|
|||
|
|
|||
|
if requirements:
|
|||
|
prompt += f"\n特别要求:{requirements}"
|
|||
|
if categories:
|
|||
|
prompt += f"\n参考分类:{', '.join(categories[:3])}..."
|
|||
|
|
|||
|
return prompt
|
|||
|
|
|||
|
def generate_keywords(self, title: str) -> List[str]:
|
|||
|
"""生成关键词主流程"""
|
|||
|
try:
|
|||
|
# 获取配置和生成内容
|
|||
|
categories, _ = self.generate_dynamic_config(title)
|
|||
|
prompt = self.get_prompt(title)
|
|||
|
raw_text = self._call_gpt(prompt)
|
|||
|
|
|||
|
# 处理响应
|
|||
|
return self.process_response(raw_text, categories)
|
|||
|
except Exception as e:
|
|||
|
print(f"关键词生成失败: {str(e)}")
|
|||
|
return []
|
|||
|
|
|||
|
def process_response(self, raw_text: str, categories: List[str]) -> List[str]:
|
|||
|
"""处理生成的响应内容"""
|
|||
|
# 解析案例
|
|||
|
cases = re.findall(r'关键词:(.*?)(?=\n\[案例|\n\n|$)', raw_text, re.DOTALL)
|
|||
|
keywords = []
|
|||
|
for case in cases:
|
|||
|
keywords.extend([k.strip() for k in case.split(',') if k.strip()])
|
|||
|
|
|||
|
# 分级抽样
|
|||
|
sampled = []
|
|||
|
if categories:
|
|||
|
for cat in categories[:3]: # 取前3个分类
|
|||
|
matches = [k for k in keywords if cat in k][:5] # 每个分类最多取5个
|
|||
|
sampled.extend(matches)
|
|||
|
|
|||
|
# 合并去重
|
|||
|
seen = set()
|
|||
|
final = []
|
|||
|
for k in sampled + keywords:
|
|||
|
if k not in seen:
|
|||
|
seen.add(k)
|
|||
|
final.append(k)
|
|||
|
|
|||
|
return self.post_process(final[:60])[:50] # 最终保留50个
|
|||
|
|
|||
|
def post_process(self, keywords: List[str]) -> List[str]:
|
|||
|
"""后处理管道"""
|
|||
|
processed = []
|
|||
|
for k in keywords:
|
|||
|
# 标准化处理
|
|||
|
k = re.sub(r'\s+', ' ', k).strip()
|
|||
|
|
|||
|
# 过滤无效条目
|
|||
|
if len(k) < 4 or '...' in k:
|
|||
|
continue
|
|||
|
|
|||
|
# 增强关键条目
|
|||
|
if any(c.isdigit() for c in k) or re.search(r'[市区街道]', k):
|
|||
|
processed.append(k)
|
|||
|
|
|||
|
# 优先级排序
|
|||
|
priority_terms = ['社保', '医保', '房产证', '施工许可']
|
|||
|
return sorted(processed,
|
|||
|
key=lambda x: any(t in x for t in priority_terms),
|
|||
|
reverse=True)
|
|||
|
|
|||
|
def save_to_tsv(self, data: List[dict], filename: str):
|
|||
|
"""保存结果到TSV文件"""
|
|||
|
with open(filename, 'w', newline='', encoding='utf-8') as f:
|
|||
|
writer = csv.writer(f, delimiter='\t')
|
|||
|
writer.writerow(['title', 'contents'])
|
|||
|
for item in data:
|
|||
|
writer.writerow([
|
|||
|
item['title'],
|
|||
|
json.dumps(item['keywords'], ensure_ascii=False)
|
|||
|
])
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 使用示例
|
|||
|
generator = GovDataGenerator("sk-your-api-key-here")
|
|||
|
|
|||
|
# 生成数据
|
|||
|
results = []
|
|||
|
for domain in ["保险", "城市管理", "房产"]:
|
|||
|
print(f"正在处理:{domain}")
|
|||
|
keywords = generator.generate_keywords(domain)
|
|||
|
results.append({
|
|||
|
"title": domain,
|
|||
|
"keywords": keywords
|
|||
|
})
|
|||
|
print(f"生成完成,获得{len(keywords)}个关键词")
|
|||
|
|
|||
|
# 保存结果
|
|||
|
generator.save_to_tsv(results, "government_data.tsv")
|
|||
|
print("数据已保存至 government_data.tsv")
|