206 lines
7.8 KiB
Python
206 lines
7.8 KiB
Python
from openai import OpenAI, APIError, RateLimitError, AuthenticationError
|
||
import csv
|
||
import json
|
||
import re
|
||
import ast
|
||
from typing import List, Dict, Tuple
|
||
|
||
|
||
class GovDataGenerator:
|
||
def __init__(self, api_key: str):
|
||
api_key = "sk-af3d1a3ed4d64df09249383a76fa12f4"
|
||
base_url = "https://api.deepseek.com"
|
||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||
self.config_cache = {}
|
||
self._init_prompts()
|
||
|
||
def _init_prompts(self):
|
||
"""Initialize all prompt templates"""
|
||
self.base_prompt = """请生成与【{title}】相关的5个真实政务12345业务案例,每个案例包含:
|
||
1. 当事人(如张先生)+ 问题场景(如XX街道XX小区)
|
||
2. 业务类型(如社保缴纳)+ 具体问题(如未回复)
|
||
3. 涉及单位(如医保局)+ 时间要素(如8月17日)
|
||
4. 证件编号(如370181XXXXXXXXXX)+ 政策条件(如连续缴费12个月)
|
||
|
||
每个案例生成6-8个关键词,格式:
|
||
[案例1] 关键词:关键词1,关键词2,...
|
||
[案例2] 关键词:关键词A,关键词B,...
|
||
请用中文逗号分隔,不要编号"""
|
||
|
||
self.config_prompt = """请根据政务领域【{title}】生成:
|
||
1. 3-5个核心业务分类(categories)
|
||
2. 1条特别生成要求(requirements)
|
||
|
||
示例(保险领域):
|
||
categories: ["医疗保险", "失业保险", "养老保险", "生育保险"]
|
||
requirements: "需包含医保报销和生育津贴案例各1个"
|
||
|
||
请用JSON格式返回:{{"categories": [], "requirements": ""}}"""
|
||
|
||
def _call_gpt(self, prompt: str, **kwargs) -> str:
|
||
"""统一GPT调用入口(适配v0.28+)"""
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model="deepseek-chat",
|
||
temperature=kwargs.get('temperature', 0.7),
|
||
max_tokens=kwargs.get('max_tokens', 800),
|
||
messages=[
|
||
{"role": "system", "content": "You are a helpful assistant"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
stream=False
|
||
)
|
||
return response.choices[0].message.content
|
||
|
||
except RateLimitError:
|
||
raise Exception(f"请求超频,请等待后重试(错误码429)")
|
||
except AuthenticationError:
|
||
raise Exception("API密钥无效,请检查密钥是否正确")
|
||
except APIError as e:
|
||
raise Exception(f"API错误: {e.code} - {e.message}")
|
||
except Exception as e:
|
||
raise Exception(f"请求失败: {str(e)}")
|
||
|
||
def generate_dynamic_config(self, title: str) -> Tuple[List[str], str]:
|
||
"""动态生成领域配置"""
|
||
if title in self.config_cache:
|
||
return self.config_cache[title]
|
||
|
||
try:
|
||
prompt = self.config_prompt.format(title=title)
|
||
raw_text = self._call_gpt(prompt, temperature=0.5, max_tokens=300)
|
||
parsed = self.safe_parse_config(raw_text)
|
||
|
||
# 验证数据结构
|
||
if not isinstance(parsed.get("categories", []), list) or \
|
||
not isinstance(parsed.get("requirements", ""), str):
|
||
raise ValueError("配置格式错误")
|
||
|
||
self.config_cache[title] = (parsed["categories"], parsed["requirements"])
|
||
return self.config_cache[title]
|
||
|
||
except Exception as e:
|
||
print(f"配置生成失败: {str(e)},使用默认配置")
|
||
return [], ""
|
||
|
||
def safe_parse_config(self, text: str) -> Dict:
|
||
"""安全解析配置响应"""
|
||
try:
|
||
# 尝试标准JSON解析
|
||
json_str = re.search(r'\{.*\}', text, re.DOTALL).group()
|
||
return json.loads(json_str)
|
||
except json.JSONDecodeError:
|
||
# 尝试容错解析
|
||
try:
|
||
return ast.literal_eval(json_str.replace('"', "'"))
|
||
except:
|
||
# 最终容错方案
|
||
return {
|
||
"categories": list(set(re.findall(r'"([^"]+)"', text)))[:3],
|
||
"requirements": re.split(r'[::]', text.split("requirements")[-1])[-1].strip('"\'')
|
||
}
|
||
except Exception:
|
||
return {"categories": [], "requirements": ""}
|
||
|
||
def get_prompt(self, title: str) -> str:
|
||
"""构建动态提示"""
|
||
categories, requirements = self.generate_dynamic_config(title)
|
||
prompt = self.base_prompt.format(title=title)
|
||
|
||
if requirements:
|
||
prompt += f"\n特别要求:{requirements}"
|
||
if categories:
|
||
prompt += f"\n参考分类:{', '.join(categories[:3])}..."
|
||
|
||
return prompt
|
||
|
||
def generate_keywords(self, title: str) -> List[str]:
|
||
"""生成关键词主流程"""
|
||
try:
|
||
# 获取配置和生成内容
|
||
categories, _ = self.generate_dynamic_config(title)
|
||
prompt = self.get_prompt(title)
|
||
raw_text = self._call_gpt(prompt)
|
||
|
||
# 处理响应
|
||
return self.process_response(raw_text, categories)
|
||
except Exception as e:
|
||
print(f"关键词生成失败: {str(e)}")
|
||
return []
|
||
|
||
def process_response(self, raw_text: str, categories: List[str]) -> List[str]:
|
||
"""处理生成的响应内容"""
|
||
# 解析案例
|
||
cases = re.findall(r'关键词:(.*?)(?=\n\[案例|\n\n|$)', raw_text, re.DOTALL)
|
||
keywords = []
|
||
for case in cases:
|
||
keywords.extend([k.strip() for k in case.split(',') if k.strip()])
|
||
|
||
# 分级抽样
|
||
sampled = []
|
||
if categories:
|
||
for cat in categories[:3]: # 取前3个分类
|
||
matches = [k for k in keywords if cat in k][:5] # 每个分类最多取5个
|
||
sampled.extend(matches)
|
||
|
||
# 合并去重
|
||
seen = set()
|
||
final = []
|
||
for k in sampled + keywords:
|
||
if k not in seen:
|
||
seen.add(k)
|
||
final.append(k)
|
||
|
||
return self.post_process(final[:60])[:50] # 最终保留50个
|
||
|
||
def post_process(self, keywords: List[str]) -> List[str]:
|
||
"""后处理管道"""
|
||
processed = []
|
||
for k in keywords:
|
||
# 标准化处理
|
||
k = re.sub(r'\s+', ' ', k).strip()
|
||
|
||
# 过滤无效条目
|
||
if len(k) < 4 or '...' in k:
|
||
continue
|
||
|
||
# 增强关键条目
|
||
if any(c.isdigit() for c in k) or re.search(r'[市区街道]', k):
|
||
processed.append(k)
|
||
|
||
# 优先级排序
|
||
priority_terms = ['社保', '医保', '房产证', '施工许可']
|
||
return sorted(processed,
|
||
key=lambda x: any(t in x for t in priority_terms),
|
||
reverse=True)
|
||
|
||
def save_to_tsv(self, data: List[dict], filename: str):
|
||
"""保存结果到TSV文件"""
|
||
with open(filename, 'w', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f, delimiter='\t')
|
||
writer.writerow(['title', 'contents'])
|
||
for item in data:
|
||
writer.writerow([
|
||
item['title'],
|
||
json.dumps(item['keywords'], ensure_ascii=False)
|
||
])
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 使用示例
|
||
generator = GovDataGenerator("sk-your-api-key-here")
|
||
|
||
# 生成数据
|
||
results = []
|
||
for domain in ["保险", "城市管理", "房产"]:
|
||
print(f"正在处理:{domain}")
|
||
keywords = generator.generate_keywords(domain)
|
||
results.append({
|
||
"title": domain,
|
||
"keywords": keywords
|
||
})
|
||
print(f"生成完成,获得{len(keywords)}个关键词")
|
||
|
||
# 保存结果
|
||
generator.save_to_tsv(results, "government_data.tsv")
|
||
print("数据已保存至 government_data.tsv") |