562 lines
23 KiB
Python
562 lines
23 KiB
Python
"""
|
||
AI服务封装模块
|
||
支持多种AI提供商:DeepSeek, OpenAI, Claude, 通义千问
|
||
"""
|
||
from typing import Optional, Dict, Any, List
|
||
import httpx
|
||
import json
|
||
import re
|
||
|
||
class AIService:
|
||
def __init__(self):
|
||
from app.config import get_settings
|
||
settings = get_settings()
|
||
|
||
self.enabled = settings.ai_service_enabled
|
||
self.provider = settings.ai_provider
|
||
|
||
print(f"[AI服务初始化] enabled={self.enabled}, provider={self.provider}")
|
||
|
||
# 根据提供商初始化配置
|
||
if self.provider == "deepseek":
|
||
self.api_key = settings.deepseek_api_key
|
||
self.base_url = settings.deepseek_base_url
|
||
self.model = settings.deepseek_model
|
||
print(f"[AI服务初始化] DeepSeek配置: api_key={self.api_key[:20] + '...' if self.api_key else 'None'}, base_url={self.base_url}, model={self.model}")
|
||
elif self.provider == "openai":
|
||
self.api_key = settings.openai_api_key
|
||
self.base_url = settings.openai_base_url
|
||
self.model = "gpt-3.5-turbo"
|
||
elif self.provider == "claude":
|
||
# Claude 需要从环境变量读取
|
||
import os
|
||
self.api_key = os.getenv("CLAUDE_API_KEY", "")
|
||
self.base_url = "https://api.anthropic.com"
|
||
self.model = "claude-3-haiku-20240307"
|
||
elif self.provider == "qwen":
|
||
import os
|
||
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")
|
||
self.base_url = "https://dashscope.aliyuncs.com/api/v1"
|
||
self.model = "qwen-plus"
|
||
else:
|
||
self.api_key = None
|
||
self.base_url = None
|
||
self.model = None
|
||
|
||
async def rewrite_ending(
|
||
self,
|
||
story_title: str,
|
||
story_category: str,
|
||
ending_name: str,
|
||
ending_content: str,
|
||
user_prompt: str
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
AI改写结局
|
||
:return: {"content": str, "tokens_used": int} 或 None
|
||
"""
|
||
if not self.enabled or not self.api_key:
|
||
return None
|
||
|
||
# 构建Prompt
|
||
system_prompt = """你是一个专业的互动故事创作专家。根据用户的改写指令,重新创作故事结局。
|
||
要求:
|
||
1. 保持原故事的世界观和人物性格
|
||
2. 结局要有张力和情感冲击
|
||
3. 结局内容字数控制在200-400字
|
||
4. 为新结局取一个4-8字的新名字,体现改写后的剧情走向
|
||
5. 输出格式必须是JSON:{"ending_name": "新结局名称", "content": "结局内容"}"""
|
||
|
||
user_prompt_text = f"""故事标题:{story_title}
|
||
故事分类:{story_category}
|
||
原结局名称:{ending_name}
|
||
原结局内容:{ending_content[:500]}
|
||
---
|
||
用户改写指令:{user_prompt}
|
||
---
|
||
请创作新的结局(输出JSON格式):"""
|
||
|
||
try:
|
||
if self.provider == "openai":
|
||
return await self._call_openai(system_prompt, user_prompt_text)
|
||
elif self.provider == "claude":
|
||
return await self._call_claude(user_prompt_text)
|
||
elif self.provider == "qwen":
|
||
return await self._call_qwen(system_prompt, user_prompt_text)
|
||
elif self.provider == "deepseek":
|
||
return await self._call_deepseek(system_prompt, user_prompt_text)
|
||
except Exception as e:
|
||
print(f"AI调用失败:{e}")
|
||
return None
|
||
|
||
return None
|
||
|
||
async def rewrite_branch(
|
||
self,
|
||
story_title: str,
|
||
story_category: str,
|
||
path_history: List[Dict[str, str]],
|
||
current_content: str,
|
||
user_prompt: str
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
AI改写中间章节,生成新的剧情分支
|
||
"""
|
||
print(f"\n[rewrite_branch] ========== 开始调用 ==========")
|
||
print(f"[rewrite_branch] story_title={story_title}, category={story_category}")
|
||
print(f"[rewrite_branch] user_prompt={user_prompt}")
|
||
print(f"[rewrite_branch] path_history长度={len(path_history)}")
|
||
print(f"[rewrite_branch] current_content长度={len(current_content)}")
|
||
print(f"[rewrite_branch] enabled={self.enabled}, api_key存在={bool(self.api_key)}")
|
||
|
||
if not self.enabled or not self.api_key:
|
||
print(f"[rewrite_branch] 服务未启用或API Key为空,返回None")
|
||
return None
|
||
|
||
# 构建路径历史文本
|
||
path_text = ""
|
||
for i, item in enumerate(path_history, 1):
|
||
path_text += f"第{i}段:{item.get('content', '')}\n"
|
||
if item.get('choice'):
|
||
path_text += f" → 用户选择:{item['choice']}\n"
|
||
|
||
# 构建系统提示词
|
||
system_prompt = """你是一个专业的互动故事续写专家。用户正在玩一个互动故事,想要在当前位置改变剧情走向。
|
||
|
||
【任务】
|
||
请从当前节点开始,创作新的剧情分支。新剧情的第一句话必须紧接当前节点最后的情节,仿佛是同一段故事的自然延续,不能有任何跳跃感。
|
||
|
||
【写作要求】
|
||
1. 第一个节点必须紧密衔接当前剧情,像是同一段话的下一句
|
||
2. 生成 4-6 个新节点,形成有层次的剧情发展(起承转合)
|
||
3. 每个节点内容 150-300 字,要分 2-3 个自然段(用\n\n分隔),包含:场景描写、人物对话、心理活动
|
||
4. 每个非结局节点有 2 个选项,选项要有明显的剧情差异和后果
|
||
5. 必须以结局收尾,结局内容要 200-400 字,分 2-3 段,有情感冲击力
|
||
6. 严格符合用户的改写意图,围绕用户指令展开剧情
|
||
7. 保持原故事的人物性格、语言风格和世界观
|
||
8. 对话要自然生动,描写要有画面感
|
||
|
||
【重要】内容分段示例:
|
||
"content": "他的声音在耳边响起,像是一阵温柔的风。\n\n\"我喜欢你。\"他说,目光坚定地看着你。\n\n你的心跳漏了一拍,一时间不知该如何回应。"
|
||
|
||
【输出格式】(严格JSON,不要有任何额外文字)
|
||
{
|
||
"nodes": {
|
||
"branch_1": {
|
||
"content": "新剧情第一段(150-300字)...",
|
||
"speaker": "旁白",
|
||
"choices": [
|
||
{"text": "选项A(5-15字)", "nextNodeKey": "branch_2a"},
|
||
{"text": "选项B(5-15字)", "nextNodeKey": "branch_2b"}
|
||
]
|
||
},
|
||
"branch_2a": {
|
||
"content": "...",
|
||
"speaker": "旁白",
|
||
"choices": [...]
|
||
},
|
||
"branch_ending_good": {
|
||
"content": "好结局内容(200-400字)...",
|
||
"speaker": "旁白",
|
||
"is_ending": true,
|
||
"ending_name": "结局名称(4-8字)",
|
||
"ending_type": "good"
|
||
}
|
||
},
|
||
"entryNodeKey": "branch_1"
|
||
}"""
|
||
|
||
# 构建用户提示词
|
||
user_prompt_text = f"""【原故事信息】
|
||
故事标题:{story_title}
|
||
故事分类:{story_category}
|
||
|
||
【用户已走过的剧情】
|
||
{path_text}
|
||
|
||
【当前节点】
|
||
{current_content}
|
||
|
||
【用户改写指令】
|
||
{user_prompt}
|
||
|
||
请创作新的剧情分支(输出JSON格式):"""
|
||
|
||
print(f"[rewrite_branch] 提示词构建完成,开始调用AI...")
|
||
print(f"[rewrite_branch] provider={self.provider}")
|
||
|
||
try:
|
||
result = None
|
||
if self.provider == "openai":
|
||
print(f"[rewrite_branch] 调用 OpenAI...")
|
||
result = await self._call_openai_long(system_prompt, user_prompt_text)
|
||
elif self.provider == "claude":
|
||
print(f"[rewrite_branch] 调用 Claude...")
|
||
result = await self._call_claude(f"{system_prompt}\n\n{user_prompt_text}")
|
||
elif self.provider == "qwen":
|
||
print(f"[rewrite_branch] 调用 Qwen...")
|
||
result = await self._call_qwen_long(system_prompt, user_prompt_text)
|
||
elif self.provider == "deepseek":
|
||
print(f"[rewrite_branch] 调用 DeepSeek...")
|
||
result = await self._call_deepseek_long(system_prompt, user_prompt_text)
|
||
|
||
print(f"[rewrite_branch] AI调用完成,result存在={result is not None}")
|
||
|
||
if result and result.get("content"):
|
||
print(f"[rewrite_branch] AI返回内容长度={len(result.get('content', ''))}")
|
||
print(f"[rewrite_branch] AI返回内容前500字: {result.get('content', '')[:500]}")
|
||
|
||
# 解析JSON响应
|
||
parsed = self._parse_branch_json(result["content"])
|
||
print(f"[rewrite_branch] JSON解析结果: parsed存在={parsed is not None}")
|
||
|
||
if parsed:
|
||
parsed["tokens_used"] = result.get("tokens_used", 0)
|
||
print(f"[rewrite_branch] 成功! nodes数量={len(parsed.get('nodes', {}))}, tokens={parsed.get('tokens_used')}")
|
||
return parsed
|
||
else:
|
||
print(f"[rewrite_branch] JSON解析失败!")
|
||
else:
|
||
print(f"[rewrite_branch] AI返回为空或无content")
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"[rewrite_branch] 异常: {type(e).__name__}: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
def _parse_branch_json(self, content: str) -> Optional[Dict]:
|
||
"""解析AI返回的分支JSON"""
|
||
print(f"[_parse_branch_json] 开始解析,内容长度={len(content)}")
|
||
|
||
# 移除 markdown 代码块标记
|
||
clean_content = content.strip()
|
||
if clean_content.startswith('```'):
|
||
# 移除开头的 ```json 或 ```
|
||
clean_content = re.sub(r'^```(?:json)?\s*', '', clean_content)
|
||
# 移除结尾的 ```
|
||
clean_content = re.sub(r'\s*```$', '', clean_content)
|
||
|
||
try:
|
||
# 尝试直接解析
|
||
result = json.loads(clean_content)
|
||
print(f"[_parse_branch_json] 直接解析成功!")
|
||
return result
|
||
except json.JSONDecodeError as e:
|
||
print(f"[_parse_branch_json] 直接解析失败: {e}")
|
||
|
||
# 尝试提取JSON块
|
||
try:
|
||
# 匹配 { ... } 结构
|
||
brace_match = re.search(r'\{[\s\S]*\}', clean_content)
|
||
if brace_match:
|
||
json_str = brace_match.group(0)
|
||
print(f"[_parse_branch_json] 找到花括号块,尝试解析...")
|
||
|
||
try:
|
||
result = json.loads(json_str)
|
||
print(f"[_parse_branch_json] 花括号块解析成功!")
|
||
return result
|
||
except json.JSONDecodeError as e:
|
||
print(f"[_parse_branch_json] 花括号块解析失败: {e}")
|
||
# 打印错误位置附近的内容
|
||
error_pos = e.pos if hasattr(e, 'pos') else 0
|
||
start = max(0, error_pos - 100)
|
||
end = min(len(json_str), error_pos + 100)
|
||
print(f"[_parse_branch_json] 错误位置附近内容: ...{json_str[start:end]}...")
|
||
|
||
# 尝试修复不完整的 JSON
|
||
print(f"[_parse_branch_json] 尝试修复不完整的JSON...")
|
||
fixed_json = self._try_fix_incomplete_json(json_str)
|
||
if fixed_json:
|
||
print(f"[_parse_branch_json] JSON修复成功!")
|
||
return fixed_json
|
||
|
||
except Exception as e:
|
||
print(f"[_parse_branch_json] 提取解析异常: {e}")
|
||
|
||
print(f"[_parse_branch_json] 所有解析方法都失败了")
|
||
return None
|
||
|
||
def _try_fix_incomplete_json(self, json_str: str) -> Optional[Dict]:
|
||
"""尝试修复不完整的JSON(被截断的情况)"""
|
||
try:
|
||
# 找到已完成的节点,截断不完整的部分
|
||
# 查找最后一个完整的节点(以 } 结尾,后面跟着逗号或闭括号)
|
||
|
||
# 先找到 "nodes": { 的位置
|
||
nodes_match = re.search(r'"nodes"\s*:\s*\{', json_str)
|
||
if not nodes_match:
|
||
return None
|
||
|
||
nodes_start = nodes_match.end()
|
||
|
||
# 找所有完整的 branch 节点
|
||
branch_pattern = r'"branch_\w+"\s*:\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
||
branches = list(re.finditer(branch_pattern, json_str[nodes_start:]))
|
||
|
||
if not branches:
|
||
return None
|
||
|
||
# 取最后一个完整的节点的结束位置
|
||
last_complete_end = nodes_start + branches[-1].end()
|
||
|
||
# 构建修复后的 JSON
|
||
# 截取到最后一个完整节点,然后补全结构
|
||
truncated = json_str[:last_complete_end]
|
||
|
||
# 补全 JSON 结构
|
||
fixed = truncated + '\n },\n "entryNodeKey": "branch_1"\n}'
|
||
|
||
print(f"[_try_fix_incomplete_json] 修复后的JSON长度: {len(fixed)}")
|
||
result = json.loads(fixed)
|
||
|
||
# 验证结果结构
|
||
if "nodes" in result and len(result["nodes"]) > 0:
|
||
print(f"[_try_fix_incomplete_json] 修复后节点数: {len(result['nodes'])}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"[_try_fix_incomplete_json] 修复失败: {e}")
|
||
|
||
return None
|
||
|
||
async def _call_deepseek_long(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||
"""调用 DeepSeek API (长文本版本)"""
|
||
print(f"[_call_deepseek_long] 开始调用...")
|
||
print(f"[_call_deepseek_long] base_url={self.base_url}")
|
||
print(f"[_call_deepseek_long] model={self.model}")
|
||
|
||
url = f"{self.base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
"temperature": 0.85,
|
||
"max_tokens": 6000 # 增加输出长度,确保JSON完整
|
||
}
|
||
|
||
print(f"[_call_deepseek_long] system_prompt长度={len(system_prompt)}")
|
||
print(f"[_call_deepseek_long] user_prompt长度={len(user_prompt)}")
|
||
|
||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||
try:
|
||
print(f"[_call_deepseek_long] 发送请求到 {url}...")
|
||
response = await client.post(url, headers=headers, json=data)
|
||
print(f"[_call_deepseek_long] 响应状态码: {response.status_code}")
|
||
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
print(f"[_call_deepseek_long] 响应JSON keys: {result.keys()}")
|
||
|
||
if "choices" in result and len(result["choices"]) > 0:
|
||
content = result["choices"][0]["message"]["content"]
|
||
tokens = result.get("usage", {}).get("total_tokens", 0)
|
||
print(f"[_call_deepseek_long] 成功! content长度={len(content)}, tokens={tokens}")
|
||
return {"content": content.strip(), "tokens_used": tokens}
|
||
else:
|
||
print(f"[_call_deepseek_long] 响应异常,无choices: {result}")
|
||
return None
|
||
except httpx.HTTPStatusError as e:
|
||
print(f"[_call_deepseek_long] HTTP错误: {e.response.status_code} - {e.response.text}")
|
||
return None
|
||
except httpx.TimeoutException as e:
|
||
print(f"[_call_deepseek_long] 请求超时: {e}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"[_call_deepseek_long] 其他错误: {type(e).__name__}: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
async def _call_openai_long(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||
"""调用OpenAI API (长文本版本)"""
|
||
url = f"{self.base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
"temperature": 0.8,
|
||
"max_tokens": 2000
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.post(url, headers=headers, json=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
content = result["choices"][0]["message"]["content"]
|
||
tokens = result["usage"]["total_tokens"]
|
||
|
||
return {"content": content.strip(), "tokens_used": tokens}
|
||
|
||
async def _call_qwen_long(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||
"""调用通义千问API (长文本版本)"""
|
||
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"model": self.model,
|
||
"input": {
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
]
|
||
},
|
||
"parameters": {
|
||
"result_format": "message",
|
||
"temperature": 0.8,
|
||
"max_tokens": 2000
|
||
}
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.post(url, headers=headers, json=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
content = result["output"]["choices"][0]["message"]["content"]
|
||
tokens = result.get("usage", {}).get("total_tokens", 0)
|
||
|
||
return {"content": content.strip(), "tokens_used": tokens}
|
||
|
||
async def _call_openai(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||
"""调用OpenAI API"""
|
||
url = f"{self.base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
"temperature": 0.8,
|
||
"max_tokens": 500
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
response = await client.post(url, headers=headers, json=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
content = result["choices"][0]["message"]["content"]
|
||
tokens = result["usage"]["total_tokens"]
|
||
|
||
return {"content": content.strip(), "tokens_used": tokens}
|
||
|
||
async def _call_claude(self, prompt: str) -> Optional[Dict]:
|
||
"""调用Claude API"""
|
||
url = "https://api.anthropic.com/v1/messages"
|
||
headers = {
|
||
"x-api-key": self.api_key,
|
||
"anthropic-version": "2023-06-01",
|
||
"content-type": "application/json"
|
||
}
|
||
data = {
|
||
"model": self.model,
|
||
"max_tokens": 1024,
|
||
"messages": [{"role": "user", "content": prompt}]
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
response = await client.post(url, headers=headers, json=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
content = result["content"][0]["text"]
|
||
tokens = result.get("usage", {}).get("output_tokens", 0)
|
||
|
||
return {"content": content.strip(), "tokens_used": tokens}
|
||
|
||
async def _call_qwen(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||
"""调用通义千问API"""
|
||
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"model": self.model,
|
||
"input": {
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
]
|
||
},
|
||
"parameters": {
|
||
"result_format": "message",
|
||
"temperature": 0.8
|
||
}
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
response = await client.post(url, headers=headers, json=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
content = result["output"]["choices"][0]["message"]["content"]
|
||
tokens = result.get("usage", {}).get("total_tokens", 0)
|
||
|
||
return {"content": content.strip(), "tokens_used": tokens}
|
||
|
||
async def _call_deepseek(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||
"""调用 DeepSeek API"""
|
||
url = f"{self.base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
"temperature": 0.8,
|
||
"max_tokens": 500
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
try:
|
||
response = await client.post(url, headers=headers, json=data)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
if "choices" in result and len(result["choices"]) > 0:
|
||
content = result["choices"][0]["message"]["content"]
|
||
tokens = result.get("usage", {}).get("total_tokens", 0)
|
||
|
||
return {"content": content.strip(), "tokens_used": tokens}
|
||
else:
|
||
print(f"DeepSeek API 返回异常:{result}")
|
||
return None
|
||
except httpx.HTTPStatusError as e:
|
||
print(f"DeepSeek HTTP 错误:{e.response.status_code} - {e.response.text}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"DeepSeek 调用失败:{e}")
|
||
return None
|
||
|
||
|
||
# 单例模式
|
||
ai_service = AIService()
|