feat: 实装AI改写结局功能 - 接入DeepSeek API - AI动态生成新结局名称 - 新增rewrite类型结局样式 - 修复请求超时问题
This commit is contained in:
@@ -19,10 +19,23 @@ class Settings(BaseSettings):
|
||||
server_port: int = 3000
|
||||
debug: bool = True
|
||||
|
||||
# AI服务配置(预留)
|
||||
# AI 服务配置
|
||||
ai_service_enabled: bool = True
|
||||
ai_provider: str = "deepseek"
|
||||
|
||||
# DeepSeek 配置
|
||||
deepseek_api_key: str = ""
|
||||
deepseek_base_url: str = "https://api.deepseek.com/v1"
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
|
||||
# OpenAI 配置(备用)
|
||||
openai_api_key: str = ""
|
||||
openai_base_url: str = "https://api.openai.com/v1"
|
||||
|
||||
# 微信小游戏配置(预留)
|
||||
wx_appid: str = ""
|
||||
wx_secret: str = ""
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return f"mysql+aiomysql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}"
|
||||
|
||||
@@ -187,6 +187,9 @@ async def toggle_like(story_id: int, request: LikeRequest, db: AsyncSession = De
|
||||
@router.post("/{story_id}/rewrite")
|
||||
async def ai_rewrite_ending(story_id: int, request: RewriteRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""AI改写结局"""
|
||||
import json
|
||||
import re
|
||||
|
||||
if not request.prompt:
|
||||
raise HTTPException(status_code=400, detail="请输入改写指令")
|
||||
|
||||
@@ -194,7 +197,54 @@ async def ai_rewrite_ending(story_id: int, request: RewriteRequest, db: AsyncSes
|
||||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
# 模拟AI生成(后续替换为真实API调用)
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
# 调用 AI 服务
|
||||
from app.services.ai import ai_service
|
||||
|
||||
ai_result = await ai_service.rewrite_ending(
|
||||
story_title=story.title,
|
||||
story_category=story.category or "未知",
|
||||
ending_name=request.ending_name or "未知结局",
|
||||
ending_content=request.ending_content or "",
|
||||
user_prompt=request.prompt
|
||||
)
|
||||
|
||||
if ai_result and ai_result.get("content"):
|
||||
content = ai_result["content"]
|
||||
ending_name = f"{request.ending_name}(AI改写)"
|
||||
|
||||
# 尝试解析 JSON 格式的返回
|
||||
try:
|
||||
# 提取 JSON 部分
|
||||
json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
parsed = json.loads(json_match.group())
|
||||
ending_name = parsed.get("ending_name", ending_name)
|
||||
content = parsed.get("content", content)
|
||||
else:
|
||||
# 尝试直接解析整个内容
|
||||
parsed = json.loads(content)
|
||||
ending_name = parsed.get("ending_name", ending_name)
|
||||
content = parsed.get("content", content)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# 解析失败,使用原始内容
|
||||
pass
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"content": content,
|
||||
"speaker": "旁白",
|
||||
"is_ending": True,
|
||||
"ending_name": ending_name,
|
||||
"ending_type": "rewrite",
|
||||
"tokens_used": ai_result.get("tokens_used", 0)
|
||||
}
|
||||
}
|
||||
|
||||
# AI 服务不可用时的降级处理
|
||||
templates = [
|
||||
f"根据你的愿望「{request.prompt}」,故事有了新的发展...\n\n",
|
||||
f"命运的齿轮开始转动,{request.prompt}...\n\n",
|
||||
@@ -205,8 +255,7 @@ async def ai_rewrite_ending(story_id: int, request: RewriteRequest, db: AsyncSes
|
||||
new_content = (
|
||||
template +
|
||||
"原本的结局被改写,新的故事在这里展开。\n\n" +
|
||||
f"【AI改写提示】这是基于「{request.prompt}」生成的新结局。\n" +
|
||||
"实际部署时,这里将由AI大模型根据上下文生成更精彩的内容。"
|
||||
f"【提示】AI服务暂时不可用,这是模板内容。"
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
211
server/app/services/ai.py
Normal file
211
server/app/services/ai.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
AI服务封装模块
|
||||
支持多种AI提供商:DeepSeek, OpenAI, Claude, 通义千问
|
||||
"""
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
|
||||
class AIService:
|
||||
def __init__(self):
|
||||
from app.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
self.enabled = settings.ai_service_enabled
|
||||
self.provider = settings.ai_provider
|
||||
|
||||
# 根据提供商初始化配置
|
||||
if self.provider == "deepseek":
|
||||
self.api_key = settings.deepseek_api_key
|
||||
self.base_url = settings.deepseek_base_url
|
||||
self.model = settings.deepseek_model
|
||||
elif self.provider == "openai":
|
||||
self.api_key = settings.openai_api_key
|
||||
self.base_url = settings.openai_base_url
|
||||
self.model = "gpt-3.5-turbo"
|
||||
elif self.provider == "claude":
|
||||
# Claude 需要从环境变量读取
|
||||
import os
|
||||
self.api_key = os.getenv("CLAUDE_API_KEY", "")
|
||||
self.base_url = "https://api.anthropic.com"
|
||||
self.model = "claude-3-haiku-20240307"
|
||||
elif self.provider == "qwen":
|
||||
import os
|
||||
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
self.base_url = "https://dashscope.aliyuncs.com/api/v1"
|
||||
self.model = "qwen-plus"
|
||||
else:
|
||||
self.api_key = None
|
||||
self.base_url = None
|
||||
self.model = None
|
||||
|
||||
async def rewrite_ending(
|
||||
self,
|
||||
story_title: str,
|
||||
story_category: str,
|
||||
ending_name: str,
|
||||
ending_content: str,
|
||||
user_prompt: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
AI改写结局
|
||||
:return: {"content": str, "tokens_used": int} 或 None
|
||||
"""
|
||||
if not self.enabled or not self.api_key:
|
||||
return None
|
||||
|
||||
# 构建Prompt
|
||||
system_prompt = """你是一个专业的互动故事创作专家。根据用户的改写指令,重新创作故事结局。
|
||||
要求:
|
||||
1. 保持原故事的世界观和人物性格
|
||||
2. 结局要有张力和情感冲击
|
||||
3. 结局内容字数控制在200-400字
|
||||
4. 为新结局取一个4-8字的新名字,体现改写后的剧情走向
|
||||
5. 输出格式必须是JSON:{"ending_name": "新结局名称", "content": "结局内容"}"""
|
||||
|
||||
user_prompt_text = f"""故事标题:{story_title}
|
||||
故事分类:{story_category}
|
||||
原结局名称:{ending_name}
|
||||
原结局内容:{ending_content[:500]}
|
||||
---
|
||||
用户改写指令:{user_prompt}
|
||||
---
|
||||
请创作新的结局(输出JSON格式):"""
|
||||
|
||||
try:
|
||||
if self.provider == "openai":
|
||||
return await self._call_openai(system_prompt, user_prompt_text)
|
||||
elif self.provider == "claude":
|
||||
return await self._call_claude(user_prompt_text)
|
||||
elif self.provider == "qwen":
|
||||
return await self._call_qwen(system_prompt, user_prompt_text)
|
||||
elif self.provider == "deepseek":
|
||||
return await self._call_deepseek(system_prompt, user_prompt_text)
|
||||
except Exception as e:
|
||||
print(f"AI调用失败:{e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _call_openai(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||||
"""调用OpenAI API"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 500
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
tokens = result["usage"]["total_tokens"]
|
||||
|
||||
return {"content": content.strip(), "tokens_used": tokens}
|
||||
|
||||
async def _call_claude(self, prompt: str) -> Optional[Dict]:
|
||||
"""调用Claude API"""
|
||||
url = "https://api.anthropic.com/v1/messages"
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
content = result["content"][0]["text"]
|
||||
tokens = result.get("usage", {}).get("output_tokens", 0)
|
||||
|
||||
return {"content": content.strip(), "tokens_used": tokens}
|
||||
|
||||
async def _call_qwen(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||||
"""调用通义千问API"""
|
||||
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"input": {
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
},
|
||||
"parameters": {
|
||||
"result_format": "message",
|
||||
"temperature": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
content = result["output"]["choices"][0]["message"]["content"]
|
||||
tokens = result.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return {"content": content.strip(), "tokens_used": tokens}
|
||||
|
||||
async def _call_deepseek(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
||||
"""调用 DeepSeek API"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 500
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
tokens = result.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
return {"content": content.strip(), "tokens_used": tokens}
|
||||
else:
|
||||
print(f"DeepSeek API 返回异常:{result}")
|
||||
return None
|
||||
except httpx.HTTPStatusError as e:
|
||||
print(f"DeepSeek HTTP 错误:{e.response.status_code} - {e.response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"DeepSeek 调用失败:{e}")
|
||||
return None
|
||||
|
||||
|
||||
# 单例模式
|
||||
ai_service = AIService()
|
||||
Reference in New Issue
Block a user