feat: AI中间章节改写功能 + 滚动优化

This commit is contained in:
wangwuww111
2026-03-06 13:16:54 +08:00
parent 66d4bd60c1
commit bbdccfa843
9 changed files with 602 additions and 21 deletions

View File

@@ -2,8 +2,10 @@
AI服务封装模块
支持多种AI提供商DeepSeek, OpenAI, Claude, 通义千问
"""
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List
import httpx
import json
import re
class AIService:
def __init__(self):
@@ -13,11 +15,14 @@ class AIService:
self.enabled = settings.ai_service_enabled
self.provider = settings.ai_provider
print(f"[AI服务初始化] enabled={self.enabled}, provider={self.provider}")
# 根据提供商初始化配置
if self.provider == "deepseek":
self.api_key = settings.deepseek_api_key
self.base_url = settings.deepseek_base_url
self.model = settings.deepseek_model
print(f"[AI服务初始化] DeepSeek配置: api_key={self.api_key[:20] + '...' if self.api_key else 'None'}, base_url={self.base_url}, model={self.model}")
elif self.provider == "openai":
self.api_key = settings.openai_api_key
self.base_url = settings.openai_base_url
@@ -86,6 +91,351 @@ class AIService:
return None
async def rewrite_branch(
self,
story_title: str,
story_category: str,
path_history: List[Dict[str, str]],
current_content: str,
user_prompt: str
) -> Optional[Dict[str, Any]]:
"""
AI改写中间章节生成新的剧情分支
"""
print(f"\n[rewrite_branch] ========== 开始调用 ==========")
print(f"[rewrite_branch] story_title={story_title}, category={story_category}")
print(f"[rewrite_branch] user_prompt={user_prompt}")
print(f"[rewrite_branch] path_history长度={len(path_history)}")
print(f"[rewrite_branch] current_content长度={len(current_content)}")
print(f"[rewrite_branch] enabled={self.enabled}, api_key存在={bool(self.api_key)}")
if not self.enabled or not self.api_key:
print(f"[rewrite_branch] 服务未启用或API Key为空返回None")
return None
# 构建路径历史文本
path_text = ""
for i, item in enumerate(path_history, 1):
path_text += f"{i}段:{item.get('content', '')}\n"
if item.get('choice'):
path_text += f" → 用户选择:{item['choice']}\n"
# 构建系统提示词
system_prompt = """你是一个专业的互动故事续写专家。用户正在玩一个互动故事,想要在当前位置改变剧情走向。
【任务】
请从当前节点开始,创作新的剧情分支。新剧情的第一句话必须紧接当前节点最后的情节,仿佛是同一段故事的自然延续,不能有任何跳跃感。
【写作要求】
1. 第一个节点必须紧密衔接当前剧情,像是同一段话的下一句
2. 生成 4-6 个新节点,形成有层次的剧情发展(起承转合)
3. 每个节点内容 150-300 字,要分 2-3 个自然段(用\n\n分隔),包含:场景描写、人物对话、心理活动
4. 每个非结局节点有 2 个选项,选项要有明显的剧情差异和后果
5. 必须以结局收尾,结局内容要 200-400 字,分 2-3 段,有情感冲击力
6. 严格符合用户的改写意图,围绕用户指令展开剧情
7. 保持原故事的人物性格、语言风格和世界观
8. 对话要自然生动,描写要有画面感
【重要】内容分段示例:
"content": "他的声音在耳边响起,像是一阵温柔的风。\n\n\"我喜欢你。\"他说,目光坚定地看着你。\n\n你的心跳漏了一拍,一时间不知该如何回应。"
【输出格式】严格JSON不要有任何额外文字
{
"nodes": {
"branch_1": {
"content": "新剧情第一段150-300字...",
"speaker": "旁白",
"choices": [
{"text": "选项A5-15字", "nextNodeKey": "branch_2a"},
{"text": "选项B5-15字", "nextNodeKey": "branch_2b"}
]
},
"branch_2a": {
"content": "...",
"speaker": "旁白",
"choices": [...]
},
"branch_ending_good": {
"content": "好结局内容200-400字...",
"speaker": "旁白",
"is_ending": true,
"ending_name": "结局名称4-8字",
"ending_type": "good"
}
},
"entryNodeKey": "branch_1"
}"""
# 构建用户提示词
user_prompt_text = f"""【原故事信息】
故事标题:{story_title}
故事分类:{story_category}
【用户已走过的剧情】
{path_text}
【当前节点】
{current_content}
【用户改写指令】
{user_prompt}
请创作新的剧情分支输出JSON格式"""
print(f"[rewrite_branch] 提示词构建完成开始调用AI...")
print(f"[rewrite_branch] provider={self.provider}")
try:
result = None
if self.provider == "openai":
print(f"[rewrite_branch] 调用 OpenAI...")
result = await self._call_openai_long(system_prompt, user_prompt_text)
elif self.provider == "claude":
print(f"[rewrite_branch] 调用 Claude...")
result = await self._call_claude(f"{system_prompt}\n\n{user_prompt_text}")
elif self.provider == "qwen":
print(f"[rewrite_branch] 调用 Qwen...")
result = await self._call_qwen_long(system_prompt, user_prompt_text)
elif self.provider == "deepseek":
print(f"[rewrite_branch] 调用 DeepSeek...")
result = await self._call_deepseek_long(system_prompt, user_prompt_text)
print(f"[rewrite_branch] AI调用完成result存在={result is not None}")
if result and result.get("content"):
print(f"[rewrite_branch] AI返回内容长度={len(result.get('content', ''))}")
print(f"[rewrite_branch] AI返回内容前500字: {result.get('content', '')[:500]}")
# 解析JSON响应
parsed = self._parse_branch_json(result["content"])
print(f"[rewrite_branch] JSON解析结果: parsed存在={parsed is not None}")
if parsed:
parsed["tokens_used"] = result.get("tokens_used", 0)
print(f"[rewrite_branch] 成功! nodes数量={len(parsed.get('nodes', {}))}, tokens={parsed.get('tokens_used')}")
return parsed
else:
print(f"[rewrite_branch] JSON解析失败!")
else:
print(f"[rewrite_branch] AI返回为空或无content")
return None
except Exception as e:
print(f"[rewrite_branch] 异常: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return None
def _parse_branch_json(self, content: str) -> Optional[Dict]:
"""解析AI返回的分支JSON"""
print(f"[_parse_branch_json] 开始解析,内容长度={len(content)}")
# 移除 markdown 代码块标记
clean_content = content.strip()
if clean_content.startswith('```'):
# 移除开头的 ```json 或 ```
clean_content = re.sub(r'^```(?:json)?\s*', '', clean_content)
# 移除结尾的 ```
clean_content = re.sub(r'\s*```$', '', clean_content)
try:
# 尝试直接解析
result = json.loads(clean_content)
print(f"[_parse_branch_json] 直接解析成功!")
return result
except json.JSONDecodeError as e:
print(f"[_parse_branch_json] 直接解析失败: {e}")
# 尝试提取JSON块
try:
# 匹配 { ... } 结构
brace_match = re.search(r'\{[\s\S]*\}', clean_content)
if brace_match:
json_str = brace_match.group(0)
print(f"[_parse_branch_json] 找到花括号块,尝试解析...")
try:
result = json.loads(json_str)
print(f"[_parse_branch_json] 花括号块解析成功!")
return result
except json.JSONDecodeError as e:
print(f"[_parse_branch_json] 花括号块解析失败: {e}")
# 打印错误位置附近的内容
error_pos = e.pos if hasattr(e, 'pos') else 0
start = max(0, error_pos - 100)
end = min(len(json_str), error_pos + 100)
print(f"[_parse_branch_json] 错误位置附近内容: ...{json_str[start:end]}...")
# 尝试修复不完整的 JSON
print(f"[_parse_branch_json] 尝试修复不完整的JSON...")
fixed_json = self._try_fix_incomplete_json(json_str)
if fixed_json:
print(f"[_parse_branch_json] JSON修复成功!")
return fixed_json
except Exception as e:
print(f"[_parse_branch_json] 提取解析异常: {e}")
print(f"[_parse_branch_json] 所有解析方法都失败了")
return None
def _try_fix_incomplete_json(self, json_str: str) -> Optional[Dict]:
"""尝试修复不完整的JSON被截断的情况"""
try:
# 找到已完成的节点,截断不完整的部分
# 查找最后一个完整的节点(以 } 结尾,后面跟着逗号或闭括号)
# 先找到 "nodes": { 的位置
nodes_match = re.search(r'"nodes"\s*:\s*\{', json_str)
if not nodes_match:
return None
nodes_start = nodes_match.end()
# 找所有完整的 branch 节点
branch_pattern = r'"branch_\w+"\s*:\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
branches = list(re.finditer(branch_pattern, json_str[nodes_start:]))
if not branches:
return None
# 取最后一个完整的节点的结束位置
last_complete_end = nodes_start + branches[-1].end()
# 构建修复后的 JSON
# 截取到最后一个完整节点,然后补全结构
truncated = json_str[:last_complete_end]
# 补全 JSON 结构
fixed = truncated + '\n },\n "entryNodeKey": "branch_1"\n}'
print(f"[_try_fix_incomplete_json] 修复后的JSON长度: {len(fixed)}")
result = json.loads(fixed)
# 验证结果结构
if "nodes" in result and len(result["nodes"]) > 0:
print(f"[_try_fix_incomplete_json] 修复后节点数: {len(result['nodes'])}")
return result
except Exception as e:
print(f"[_try_fix_incomplete_json] 修复失败: {e}")
return None
async def _call_deepseek_long(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
"""调用 DeepSeek API (长文本版本)"""
print(f"[_call_deepseek_long] 开始调用...")
print(f"[_call_deepseek_long] base_url={self.base_url}")
print(f"[_call_deepseek_long] model={self.model}")
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": 0.85,
"max_tokens": 6000 # 增加输出长度确保JSON完整
}
print(f"[_call_deepseek_long] system_prompt长度={len(system_prompt)}")
print(f"[_call_deepseek_long] user_prompt长度={len(user_prompt)}")
async with httpx.AsyncClient(timeout=300.0) as client:
try:
print(f"[_call_deepseek_long] 发送请求到 {url}...")
response = await client.post(url, headers=headers, json=data)
print(f"[_call_deepseek_long] 响应状态码: {response.status_code}")
response.raise_for_status()
result = response.json()
print(f"[_call_deepseek_long] 响应JSON keys: {result.keys()}")
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
tokens = result.get("usage", {}).get("total_tokens", 0)
print(f"[_call_deepseek_long] 成功! content长度={len(content)}, tokens={tokens}")
return {"content": content.strip(), "tokens_used": tokens}
else:
print(f"[_call_deepseek_long] 响应异常无choices: {result}")
return None
except httpx.HTTPStatusError as e:
print(f"[_call_deepseek_long] HTTP错误: {e.response.status_code} - {e.response.text}")
return None
except httpx.TimeoutException as e:
print(f"[_call_deepseek_long] 请求超时: {e}")
return None
except Exception as e:
print(f"[_call_deepseek_long] 其他错误: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return None
async def _call_openai_long(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
"""调用OpenAI API (长文本版本)"""
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": 0.8,
"max_tokens": 2000
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
tokens = result["usage"]["total_tokens"]
return {"content": content.strip(), "tokens_used": tokens}
async def _call_qwen_long(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
"""调用通义千问API (长文本版本)"""
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"input": {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
},
"parameters": {
"result_format": "message",
"temperature": 0.8,
"max_tokens": 2000
}
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
result = response.json()
content = result["output"]["choices"][0]["message"]["content"]
tokens = result.get("usage", {}).get("total_tokens", 0)
return {"content": content.strip(), "tokens_used": tokens}
async def _call_openai(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
"""调用OpenAI API"""
url = f"{self.base_url}/chat/completions"