1283 lines
44 KiB
Python
1283 lines
44 KiB
Python
"""
|
||
故事相关API路由
|
||
"""
|
||
import random
|
||
import asyncio
|
||
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, update, func, distinct
|
||
from typing import Optional, List
|
||
from pydantic import BaseModel
|
||
|
||
from app.database import get_db
|
||
from app.models.story import Story, StoryNode, StoryChoice, StoryCharacter, StoryDraft, DraftStatus
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# ========== 请求/响应模型 ==========
|
||
class LikeRequest(BaseModel):
|
||
like: bool
|
||
|
||
|
||
class RewriteRequest(BaseModel):
|
||
ending_name: str
|
||
ending_content: str
|
||
prompt: str
|
||
|
||
|
||
class PathHistoryItem(BaseModel):
|
||
nodeKey: str
|
||
content: str = ""
|
||
choice: str = ""
|
||
|
||
|
||
class RewriteBranchRequest(BaseModel):
|
||
userId: int
|
||
currentNodeKey: str
|
||
pathHistory: List[PathHistoryItem]
|
||
currentContent: str
|
||
prompt: str
|
||
|
||
|
||
class NodeImageUpdate(BaseModel):
|
||
nodeKey: str
|
||
backgroundImage: str = ""
|
||
characterImage: str = ""
|
||
|
||
|
||
class CharacterImageUpdate(BaseModel):
|
||
characterId: int
|
||
avatarUrl: str = ""
|
||
|
||
|
||
class ImageConfigRequest(BaseModel):
|
||
coverUrl: str = ""
|
||
nodes: List[NodeImageUpdate] = []
|
||
characters: List[CharacterImageUpdate] = []
|
||
|
||
|
||
class GenerateImageRequest(BaseModel):
|
||
prompt: str
|
||
style: str = "anime" # anime/realistic/illustration
|
||
category: str = "character" # character/background/cover
|
||
storyId: Optional[int] = None
|
||
targetField: Optional[str] = None # coverUrl/backgroundImage/characterImage/avatarUrl
|
||
targetKey: Optional[str] = None # nodeKey 或 characterId
|
||
|
||
|
||
class AICreateStoryRequest(BaseModel):
|
||
"""AI创作全新故事请求"""
|
||
userId: int
|
||
genre: str # 题材
|
||
keywords: str # 关键词
|
||
protagonist: Optional[str] = None # 主角设定
|
||
conflict: Optional[str] = None # 核心冲突
|
||
|
||
|
||
# ========== API接口 ==========
|
||
|
||
@router.get("")
|
||
async def get_stories(
|
||
category: Optional[str] = Query(None),
|
||
featured: bool = Query(False),
|
||
limit: int = Query(20, ge=1, le=100),
|
||
offset: int = Query(0, ge=0),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""获取故事列表"""
|
||
query = select(Story).where(Story.status == 1)
|
||
|
||
if category:
|
||
query = query.where(Story.category == category)
|
||
if featured:
|
||
query = query.where(Story.is_featured == True)
|
||
|
||
query = query.order_by(Story.is_featured.desc(), Story.play_count.desc())
|
||
query = query.limit(limit).offset(offset)
|
||
|
||
result = await db.execute(query)
|
||
stories = result.scalars().all()
|
||
|
||
data = [{
|
||
"id": s.id,
|
||
"title": s.title,
|
||
"cover_url": s.cover_url,
|
||
"description": s.description,
|
||
"category": s.category,
|
||
"play_count": s.play_count,
|
||
"like_count": s.like_count,
|
||
"is_featured": s.is_featured
|
||
} for s in stories]
|
||
|
||
return {"code": 0, "data": data}
|
||
|
||
|
||
@router.get("/hot")
|
||
async def get_hot_stories(
|
||
limit: int = Query(10, ge=1, le=50),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""获取热门故事"""
|
||
query = select(Story).where(Story.status == 1).order_by(Story.play_count.desc()).limit(limit)
|
||
result = await db.execute(query)
|
||
stories = result.scalars().all()
|
||
|
||
data = [{
|
||
"id": s.id,
|
||
"title": s.title,
|
||
"cover_url": s.cover_url,
|
||
"description": s.description,
|
||
"category": s.category,
|
||
"play_count": s.play_count,
|
||
"like_count": s.like_count
|
||
} for s in stories]
|
||
|
||
return {"code": 0, "data": data}
|
||
|
||
|
||
@router.get("/categories")
|
||
async def get_categories(db: AsyncSession = Depends(get_db)):
|
||
"""获取分类列表"""
|
||
query = select(distinct(Story.category)).where(Story.status == 1)
|
||
result = await db.execute(query)
|
||
categories = [row[0] for row in result.all()]
|
||
|
||
return {"code": 0, "data": categories}
|
||
|
||
|
||
@router.get("/test-image-gen")
|
||
async def test_image_generation():
|
||
"""测试图片生成服务是否正常"""
|
||
from app.config import get_settings
|
||
from app.services.image_gen import get_image_gen_service
|
||
import httpx
|
||
|
||
settings = get_settings()
|
||
results = {
|
||
"api_key_configured": bool(settings.gemini_api_key),
|
||
"api_key_preview": settings.gemini_api_key[:8] + "..." if settings.gemini_api_key else "未配置",
|
||
"base_url": "https://work.poloapi.com/v1beta",
|
||
"network_test": None,
|
||
"generate_test": None
|
||
}
|
||
|
||
# 测试网络连接
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
response = await client.get("https://work.poloapi.com")
|
||
results["network_test"] = {
|
||
"success": response.status_code < 500,
|
||
"status_code": response.status_code,
|
||
"message": "网络连接正常" if response.status_code < 500 else "服务端错误"
|
||
}
|
||
except Exception as e:
|
||
results["network_test"] = {
|
||
"success": False,
|
||
"error": str(e),
|
||
"message": "网络连接失败"
|
||
}
|
||
|
||
# 测试实际生图(简单测试)
|
||
if settings.gemini_api_key:
|
||
try:
|
||
service = get_image_gen_service()
|
||
gen_result = await service.generate_image(
|
||
prompt="a simple red circle on white background",
|
||
image_type="avatar",
|
||
style="illustration"
|
||
)
|
||
results["generate_test"] = {
|
||
"success": gen_result.get("success", False),
|
||
"error": gen_result.get("error") if not gen_result.get("success") else None,
|
||
"has_image_data": bool(gen_result.get("image_data"))
|
||
}
|
||
except Exception as e:
|
||
results["generate_test"] = {
|
||
"success": False,
|
||
"error": str(e)
|
||
}
|
||
else:
|
||
results["generate_test"] = {
|
||
"success": False,
|
||
"error": "API Key 未配置"
|
||
}
|
||
|
||
return {
|
||
"code": 0,
|
||
"message": "测试完成",
|
||
"data": results
|
||
}
|
||
|
||
|
||
@router.get("/test-deepseek")
|
||
async def test_deepseek():
|
||
"""测试 DeepSeek AI 服务是否正常"""
|
||
from app.config import get_settings
|
||
import httpx
|
||
|
||
settings = get_settings()
|
||
results = {
|
||
"ai_service_enabled": settings.ai_service_enabled,
|
||
"provider": settings.ai_provider,
|
||
"api_key_configured": bool(settings.deepseek_api_key),
|
||
"api_key_preview": settings.deepseek_api_key[:8] + "..." + settings.deepseek_api_key[-4:] if settings.deepseek_api_key and len(settings.deepseek_api_key) > 12 else "未配置或太短",
|
||
"base_url": settings.deepseek_base_url,
|
||
"model": settings.deepseek_model,
|
||
"network_test": None,
|
||
"api_test": None
|
||
}
|
||
|
||
# 测试网络连接
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
response = await client.get("https://api.deepseek.com")
|
||
results["network_test"] = {
|
||
"success": response.status_code < 500,
|
||
"status_code": response.status_code,
|
||
"message": "网络连接正常"
|
||
}
|
||
except Exception as e:
|
||
results["network_test"] = {
|
||
"success": False,
|
||
"error": str(e),
|
||
"message": "网络连接失败"
|
||
}
|
||
|
||
# 测试 API 调用
|
||
if settings.deepseek_api_key:
|
||
try:
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
response = await client.post(
|
||
f"{settings.deepseek_base_url}/chat/completions",
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {settings.deepseek_api_key}"
|
||
},
|
||
json={
|
||
"model": settings.deepseek_model,
|
||
"messages": [{"role": "user", "content": "说'测试成功'两个字"}],
|
||
"max_tokens": 10
|
||
}
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||
results["api_test"] = {
|
||
"success": True,
|
||
"status_code": 200,
|
||
"response": content[:50]
|
||
}
|
||
else:
|
||
results["api_test"] = {
|
||
"success": False,
|
||
"status_code": response.status_code,
|
||
"error": response.text[:200]
|
||
}
|
||
except Exception as e:
|
||
results["api_test"] = {
|
||
"success": False,
|
||
"error": str(e)
|
||
}
|
||
else:
|
||
results["api_test"] = {
|
||
"success": False,
|
||
"error": "API Key 未配置"
|
||
}
|
||
|
||
return {
|
||
"code": 0,
|
||
"message": "测试完成",
|
||
"data": results
|
||
}
|
||
|
||
|
||
@router.get("/test-cloud-upload")
|
||
async def test_cloud_upload():
|
||
"""测试云存储上传是否正常"""
|
||
import os
|
||
import httpx
|
||
|
||
tcb_env = os.environ.get("TCB_ENV") or os.environ.get("CBR_ENV_ID")
|
||
|
||
results = {
|
||
"env_id": tcb_env,
|
||
"is_cloud": bool(tcb_env)
|
||
}
|
||
|
||
if not tcb_env:
|
||
return {
|
||
"code": 1,
|
||
"message": "非云托管环境,无法测试云存储上传",
|
||
"data": results
|
||
}
|
||
|
||
try:
|
||
# 测试获取上传链接
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(
|
||
"http://api.weixin.qq.com/tcb/uploadfile",
|
||
json={
|
||
"env": tcb_env,
|
||
"path": "test/cloud_upload_test.txt"
|
||
},
|
||
headers={"Content-Type": "application/json"}
|
||
)
|
||
|
||
results["status_code"] = resp.status_code
|
||
results["response"] = resp.text[:500] if resp.text else ""
|
||
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
if data.get("errcode", 0) == 0:
|
||
results["upload_url"] = data.get("url", "")[:100]
|
||
results["file_id"] = data.get("file_id", "")
|
||
return {
|
||
"code": 0,
|
||
"message": "云存储上传链接获取成功",
|
||
"data": results
|
||
}
|
||
else:
|
||
results["errcode"] = data.get("errcode")
|
||
results["errmsg"] = data.get("errmsg")
|
||
return {
|
||
"code": 1,
|
||
"message": f"获取上传链接失败: {data.get('errmsg')}",
|
||
"data": results
|
||
}
|
||
else:
|
||
return {
|
||
"code": 1,
|
||
"message": f"请求失败: HTTP {resp.status_code}",
|
||
"data": results
|
||
}
|
||
|
||
except Exception as e:
|
||
results["error"] = str(e)
|
||
return {
|
||
"code": 1,
|
||
"message": f"测试异常: {str(e)}",
|
||
"data": results
|
||
}
|
||
|
||
|
||
@router.get("/{story_id}")
|
||
async def get_story_detail(story_id: int, db: AsyncSession = Depends(get_db)):
|
||
"""获取故事详情(含节点和选项)"""
|
||
# 获取故事
|
||
result = await db.execute(select(Story).where(Story.id == story_id, Story.status == 1))
|
||
story = result.scalar_one_or_none()
|
||
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
# 获取节点
|
||
nodes_result = await db.execute(
|
||
select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order)
|
||
)
|
||
nodes = nodes_result.scalars().all()
|
||
|
||
# 获取选项
|
||
choices_result = await db.execute(
|
||
select(StoryChoice).where(StoryChoice.story_id == story_id).order_by(StoryChoice.sort_order)
|
||
)
|
||
choices = choices_result.scalars().all()
|
||
|
||
# 组装节点和选项
|
||
nodes_map = {}
|
||
for node in nodes:
|
||
nodes_map[node.node_key] = {
|
||
"id": node.id,
|
||
"node_key": node.node_key,
|
||
"content": node.content,
|
||
"speaker": node.speaker,
|
||
"background_image": node.background_image,
|
||
"character_image": node.character_image,
|
||
"bgm": node.bgm,
|
||
"is_ending": node.is_ending,
|
||
"ending_name": node.ending_name,
|
||
"ending_score": node.ending_score,
|
||
"ending_type": node.ending_type,
|
||
"choices": []
|
||
}
|
||
|
||
for choice in choices:
|
||
# 找到对应的节点
|
||
for node in nodes:
|
||
if node.id == choice.node_id and node.node_key in nodes_map:
|
||
nodes_map[node.node_key]["choices"].append({
|
||
"text": choice.text,
|
||
"nextNodeKey": choice.next_node_key,
|
||
"isLocked": choice.is_locked
|
||
})
|
||
break
|
||
|
||
data = {
|
||
"id": story.id,
|
||
"title": story.title,
|
||
"cover_url": story.cover_url,
|
||
"description": story.description,
|
||
"category": story.category,
|
||
"author_id": story.author_id,
|
||
"play_count": story.play_count,
|
||
"like_count": story.like_count,
|
||
"is_featured": story.is_featured,
|
||
"nodes": nodes_map
|
||
}
|
||
|
||
return {"code": 0, "data": data}
|
||
|
||
|
||
@router.post("/{story_id}/play")
|
||
async def record_play(story_id: int, db: AsyncSession = Depends(get_db)):
|
||
"""记录游玩"""
|
||
await db.execute(
|
||
update(Story).where(Story.id == story_id).values(play_count=Story.play_count + 1)
|
||
)
|
||
await db.commit()
|
||
return {"code": 0, "message": "记录成功"}
|
||
|
||
|
||
@router.post("/{story_id}/like")
|
||
async def toggle_like(story_id: int, request: LikeRequest, db: AsyncSession = Depends(get_db)):
|
||
"""点赞/取消点赞"""
|
||
delta = 1 if request.like else -1
|
||
await db.execute(
|
||
update(Story).where(Story.id == story_id).values(like_count=Story.like_count + delta)
|
||
)
|
||
await db.commit()
|
||
return {"code": 0, "message": "点赞成功" if request.like else "取消点赞成功"}
|
||
|
||
|
||
@router.post("/{story_id}/rewrite")
|
||
async def ai_rewrite_ending(story_id: int, request: RewriteRequest, db: AsyncSession = Depends(get_db)):
|
||
"""AI改写结局"""
|
||
import json
|
||
import re
|
||
|
||
if not request.prompt:
|
||
raise HTTPException(status_code=400, detail="请输入改写指令")
|
||
|
||
# 获取故事信息
|
||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||
story = result.scalar_one_or_none()
|
||
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
# 获取故事角色
|
||
char_result = await db.execute(
|
||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||
)
|
||
characters = [
|
||
{
|
||
"name": c.name,
|
||
"role_type": c.role_type,
|
||
"gender": c.gender,
|
||
"age_range": c.age_range,
|
||
"appearance": c.appearance,
|
||
"personality": c.personality
|
||
}
|
||
for c in char_result.scalars().all()
|
||
]
|
||
|
||
# 调用 AI 服务
|
||
from app.services.ai import ai_service
|
||
|
||
ai_result = await ai_service.rewrite_ending(
|
||
story_title=story.title,
|
||
story_category=story.category or "未知",
|
||
ending_name=request.ending_name or "未知结局",
|
||
ending_content=request.ending_content or "",
|
||
user_prompt=request.prompt,
|
||
characters=characters
|
||
)
|
||
|
||
if ai_result and ai_result.get("content"):
|
||
content = ai_result["content"]
|
||
ending_name = f"{request.ending_name}(AI改写)"
|
||
|
||
# 尝试解析 JSON 格式的返回
|
||
try:
|
||
# 提取 JSON 部分
|
||
json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL)
|
||
if json_match:
|
||
parsed = json.loads(json_match.group())
|
||
ending_name = parsed.get("ending_name", ending_name)
|
||
content = parsed.get("content", content)
|
||
else:
|
||
# 尝试直接解析整个内容
|
||
parsed = json.loads(content)
|
||
ending_name = parsed.get("ending_name", ending_name)
|
||
content = parsed.get("content", content)
|
||
except (json.JSONDecodeError, AttributeError):
|
||
# 解析失败,使用原始内容
|
||
pass
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"content": content,
|
||
"speaker": "旁白",
|
||
"is_ending": True,
|
||
"ending_name": ending_name,
|
||
"ending_type": "rewrite",
|
||
"tokens_used": ai_result.get("tokens_used", 0)
|
||
}
|
||
}
|
||
|
||
# AI 服务不可用时的降级处理
|
||
templates = [
|
||
f"根据你的愿望「{request.prompt}」,故事有了新的发展...\n\n",
|
||
f"命运的齿轮开始转动,{request.prompt}...\n\n",
|
||
f"在另一个平行世界里,{request.prompt}成为了现实...\n\n"
|
||
]
|
||
|
||
template = random.choice(templates)
|
||
new_content = (
|
||
template +
|
||
"原本的结局被改写,新的故事在这里展开。\n\n" +
|
||
f"【提示】AI服务暂时不可用,这是模板内容。"
|
||
)
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"content": new_content,
|
||
"speaker": "旁白",
|
||
"is_ending": True,
|
||
"ending_name": f"{request.ending_name}(改写版)",
|
||
"ending_type": "rewrite"
|
||
}
|
||
}
|
||
|
||
|
||
@router.post("/{story_id}/rewrite-branch")
|
||
async def ai_rewrite_branch(
|
||
story_id: int,
|
||
request: RewriteBranchRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""AI改写中间章节,生成新的剧情分支"""
|
||
if not request.prompt:
|
||
raise HTTPException(status_code=400, detail="请输入改写指令")
|
||
|
||
# 获取故事信息
|
||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||
story = result.scalar_one_or_none()
|
||
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
# 获取故事角色
|
||
char_result = await db.execute(
|
||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||
)
|
||
characters = [
|
||
{
|
||
"name": c.name,
|
||
"role_type": c.role_type,
|
||
"gender": c.gender,
|
||
"age_range": c.age_range,
|
||
"appearance": c.appearance,
|
||
"personality": c.personality
|
||
}
|
||
for c in char_result.scalars().all()
|
||
]
|
||
|
||
# 将 Pydantic 模型转换为字典列表
|
||
path_history = [
|
||
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
|
||
for item in request.pathHistory
|
||
]
|
||
|
||
# 调用 AI 服务
|
||
from app.services.ai import ai_service
|
||
|
||
ai_result = await ai_service.rewrite_branch(
|
||
story_title=story.title,
|
||
story_category=story.category or "未知",
|
||
path_history=path_history,
|
||
current_content=request.currentContent,
|
||
user_prompt=request.prompt,
|
||
characters=characters
|
||
)
|
||
|
||
if ai_result and ai_result.get("nodes"):
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"nodes": ai_result["nodes"],
|
||
"entryNodeKey": ai_result.get("entryNodeKey", "branch_1"),
|
||
"tokensUsed": ai_result.get("tokens_used", 0)
|
||
}
|
||
}
|
||
|
||
# AI 服务不可用时,返回空结果(不使用兜底模板)
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"nodes": None,
|
||
"entryNodeKey": None,
|
||
"tokensUsed": 0,
|
||
"error": "AI服务暂时不可用"
|
||
}
|
||
}
|
||
|
||
|
||
@router.get("/{story_id}/images")
|
||
async def get_story_images(story_id: int, db: AsyncSession = Depends(get_db)):
|
||
"""获取故事的所有图片配置"""
|
||
# 获取故事封面
|
||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||
story = result.scalar_one_or_none()
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
# 获取所有节点的图片
|
||
nodes_result = await db.execute(
|
||
select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order)
|
||
)
|
||
nodes = nodes_result.scalars().all()
|
||
|
||
# 获取所有角色的头像
|
||
chars_result = await db.execute(
|
||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||
)
|
||
characters = chars_result.scalars().all()
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"storyId": story_id,
|
||
"title": story.title,
|
||
"coverUrl": story.cover_url or "",
|
||
"nodes": [
|
||
{
|
||
"nodeKey": n.node_key,
|
||
"content": n.content[:50] + "..." if len(n.content) > 50 else n.content,
|
||
"backgroundImage": n.background_image or "",
|
||
"characterImage": n.character_image or "",
|
||
"isEnding": n.is_ending,
|
||
"endingName": n.ending_name or ""
|
||
}
|
||
for n in nodes
|
||
],
|
||
"characters": [
|
||
{
|
||
"characterId": c.id,
|
||
"name": c.name,
|
||
"roleType": c.role_type,
|
||
"avatarUrl": c.avatar_url or "",
|
||
"avatarPrompt": c.avatar_prompt or ""
|
||
}
|
||
for c in characters
|
||
]
|
||
}
|
||
}
|
||
|
||
|
||
@router.put("/{story_id}/images")
|
||
async def update_story_images(
|
||
story_id: int,
|
||
request: ImageConfigRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""批量更新故事的图片配置"""
|
||
# 验证故事存在
|
||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||
story = result.scalar_one_or_none()
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
updated = {"cover": False, "nodes": 0, "characters": 0}
|
||
|
||
# 更新封面
|
||
if request.coverUrl:
|
||
await db.execute(
|
||
update(Story).where(Story.id == story_id).values(cover_url=request.coverUrl)
|
||
)
|
||
updated["cover"] = True
|
||
|
||
# 更新节点图片
|
||
for node_img in request.nodes:
|
||
values = {}
|
||
if node_img.backgroundImage:
|
||
values["background_image"] = node_img.backgroundImage
|
||
if node_img.characterImage:
|
||
values["character_image"] = node_img.characterImage
|
||
if values:
|
||
await db.execute(
|
||
update(StoryNode)
|
||
.where(StoryNode.story_id == story_id, StoryNode.node_key == node_img.nodeKey)
|
||
.values(**values)
|
||
)
|
||
updated["nodes"] += 1
|
||
|
||
# 更新角色头像
|
||
for char_img in request.characters:
|
||
if char_img.avatarUrl:
|
||
await db.execute(
|
||
update(StoryCharacter)
|
||
.where(StoryCharacter.id == char_img.characterId, StoryCharacter.story_id == story_id)
|
||
.values(avatar_url=char_img.avatarUrl)
|
||
)
|
||
updated["characters"] += 1
|
||
|
||
await db.commit()
|
||
|
||
return {
|
||
"code": 0,
|
||
"message": "更新成功",
|
||
"data": updated
|
||
}
|
||
|
||
|
||
@router.post("/generate-image")
|
||
async def generate_story_image(
|
||
request: GenerateImageRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""使用AI生成图片并可选保存到故事"""
|
||
from app.services.image_gen import get_image_gen_service
|
||
|
||
# 生成图片
|
||
result = await get_image_gen_service().generate_and_save(
|
||
prompt=request.prompt,
|
||
category=request.category,
|
||
style=request.style
|
||
)
|
||
|
||
if not result.get("success"):
|
||
return {
|
||
"code": 1,
|
||
"message": result.get("error", "生成失败"),
|
||
"data": None
|
||
}
|
||
|
||
image_url = result["url"]
|
||
|
||
# 如果指定了故事和目标字段,自动更新
|
||
if request.storyId and request.targetField:
|
||
if request.targetField == "coverUrl":
|
||
await db.execute(
|
||
update(Story).where(Story.id == request.storyId).values(cover_url=image_url)
|
||
)
|
||
elif request.targetField == "backgroundImage" and request.targetKey:
|
||
await db.execute(
|
||
update(StoryNode)
|
||
.where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey)
|
||
.values(background_image=image_url)
|
||
)
|
||
elif request.targetField == "characterImage" and request.targetKey:
|
||
await db.execute(
|
||
update(StoryNode)
|
||
.where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey)
|
||
.values(character_image=image_url)
|
||
)
|
||
elif request.targetField == "avatarUrl" and request.targetKey:
|
||
await db.execute(
|
||
update(StoryCharacter)
|
||
.where(StoryCharacter.story_id == request.storyId, StoryCharacter.id == int(request.targetKey))
|
||
.values(avatar_url=image_url)
|
||
)
|
||
await db.commit()
|
||
|
||
return {
|
||
"code": 0,
|
||
"message": "生成成功",
|
||
"data": {
|
||
"url": image_url,
|
||
"filename": result.get("filename"),
|
||
"saved": bool(request.storyId and request.targetField)
|
||
}
|
||
}
|
||
|
||
|
||
@router.post("/{story_id}/generate-all-images")
|
||
async def generate_all_story_images(
|
||
story_id: int,
|
||
style: str = "anime",
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""为故事批量生成所有角色头像"""
|
||
from app.services.image_gen import get_image_gen_service
|
||
image_service = get_image_gen_service()
|
||
|
||
# 获取所有角色
|
||
result = await db.execute(
|
||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||
)
|
||
characters = result.scalars().all()
|
||
|
||
if not characters:
|
||
return {"code": 1, "message": "故事没有角色数据", "data": None}
|
||
|
||
generated = []
|
||
failed = []
|
||
|
||
for char in characters:
|
||
# 使用avatar_prompt或自动构建
|
||
prompt = char.avatar_prompt or f"{char.name}, {char.gender}, {char.appearance or ''}"
|
||
|
||
gen_result = await image_service.generate_and_save(
|
||
prompt=prompt,
|
||
category="character",
|
||
style=style
|
||
)
|
||
|
||
if gen_result.get("success"):
|
||
# 更新数据库
|
||
await db.execute(
|
||
update(StoryCharacter)
|
||
.where(StoryCharacter.id == char.id)
|
||
.values(avatar_url=gen_result["url"])
|
||
)
|
||
generated.append({"id": char.id, "name": char.name, "url": gen_result["url"]})
|
||
else:
|
||
failed.append({"id": char.id, "name": char.name, "error": gen_result.get("error")})
|
||
|
||
await db.commit()
|
||
|
||
return {
|
||
"code": 0,
|
||
"message": f"生成完成: {len(generated)}成功, {len(failed)}失败",
|
||
"data": {
|
||
"generated": generated,
|
||
"failed": failed
|
||
}
|
||
}
|
||
|
||
|
||
# ========== AI创作全新故事 ==========
|
||
|
||
@router.post("/ai-create")
|
||
async def ai_create_story(
|
||
request: AICreateStoryRequest,
|
||
background_tasks: BackgroundTasks,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""AI创作全新故事(异步处理)- 只存储到 story_drafts,不创建 Story"""
|
||
from app.models.user import User
|
||
|
||
# 验证用户
|
||
user_result = await db.execute(select(User).where(User.id == request.userId))
|
||
user = user_result.scalar_one_or_none()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
# 获取或创建虚拟故事(用于满足 story_id 外键约束)
|
||
virtual_story = await get_or_create_virtual_story(db)
|
||
|
||
# 创建草稿记录(完整故事内容将存储在 ai_nodes 中)
|
||
draft = StoryDraft(
|
||
user_id=request.userId,
|
||
story_id=virtual_story.id, # 使用虚拟故事ID
|
||
title="AI创作中...",
|
||
user_prompt=f"题材:{request.genre}, 关键词:{request.keywords}, 主角:{request.protagonist or '无'}, 冲突:{request.conflict or '无'}",
|
||
draft_type="create",
|
||
status=DraftStatus.pending
|
||
)
|
||
db.add(draft)
|
||
await db.commit()
|
||
await db.refresh(draft)
|
||
|
||
# 添加后台任务
|
||
background_tasks.add_task(
|
||
process_ai_create_story,
|
||
draft.id,
|
||
request.userId,
|
||
request.genre,
|
||
request.keywords,
|
||
request.protagonist,
|
||
request.conflict
|
||
)
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"draftId": draft.id,
|
||
"message": "故事创作已开始,完成后将保存到草稿箱"
|
||
}
|
||
}
|
||
|
||
|
||
async def get_or_create_virtual_story(db: AsyncSession) -> Story:
|
||
"""获取或创建用于AI创作的虚拟故事(满足外键约束)"""
|
||
# 查找已存在的虚拟故事
|
||
result = await db.execute(
|
||
select(Story).where(Story.title == "[系统] AI创作占位故事")
|
||
)
|
||
virtual_story = result.scalar_one_or_none()
|
||
|
||
if not virtual_story:
|
||
# 创建虚拟故事
|
||
virtual_story = Story(
|
||
title="[系统] AI创作占位故事",
|
||
description="此故事仅用于AI创作功能的外键占位,不可游玩",
|
||
category="系统",
|
||
status=-99, # 特殊状态,不会出现在任何列表中
|
||
cover_url="",
|
||
author_id=1 # 系统用户
|
||
)
|
||
db.add(virtual_story)
|
||
await db.commit()
|
||
await db.refresh(virtual_story)
|
||
|
||
return virtual_story
|
||
|
||
|
||
@router.get("/ai-create/{draft_id}/status")
|
||
async def get_ai_create_status(
|
||
draft_id: int,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""获取AI创作状态(通过 draft_id 查询)"""
|
||
draft_result = await db.execute(
|
||
select(StoryDraft).where(StoryDraft.id == draft_id)
|
||
)
|
||
draft = draft_result.scalar_one_or_none()
|
||
|
||
if not draft:
|
||
raise HTTPException(status_code=404, detail="草稿不存在")
|
||
|
||
is_completed = draft.status == DraftStatus.completed
|
||
is_failed = draft.status == DraftStatus.failed
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"draftId": draft.id,
|
||
"status": -1 if is_failed else (1 if is_completed else 0),
|
||
"title": draft.title,
|
||
"isCompleted": is_completed,
|
||
"isFailed": is_failed,
|
||
"errorMessage": draft.error_message if is_failed else None
|
||
}
|
||
}
|
||
|
||
|
||
@router.post("/ai-create/{draft_id}/publish")
|
||
async def publish_ai_created_story(
|
||
draft_id: int,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""发布AI创作的草稿到'我的作品'"""
|
||
draft_result = await db.execute(
|
||
select(StoryDraft).where(StoryDraft.id == draft_id)
|
||
)
|
||
draft = draft_result.scalar_one_or_none()
|
||
|
||
if not draft:
|
||
raise HTTPException(status_code=404, detail="草稿不存在")
|
||
|
||
if draft.status != DraftStatus.completed:
|
||
raise HTTPException(status_code=400, detail="草稿尚未完成或已失败")
|
||
|
||
if draft.published_to_center:
|
||
raise HTTPException(status_code=400, detail="草稿已发布")
|
||
|
||
# 标记为已发布
|
||
draft.published_to_center = True
|
||
await db.commit()
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"draftId": draft.id,
|
||
"title": draft.title,
|
||
"message": "发布成功!可在'我的作品'中查看"
|
||
}
|
||
}
|
||
|
||
|
||
async def generate_draft_cover(
|
||
story_id: int,
|
||
draft_id: int,
|
||
title: str,
|
||
description: str,
|
||
category: str
|
||
) -> str:
|
||
"""
|
||
为AI创作的草稿生成封面图片
|
||
返回封面图片的URL路径
|
||
"""
|
||
from app.services.image_gen import ImageGenService
|
||
from app.config import get_settings
|
||
import os
|
||
import base64
|
||
|
||
settings = get_settings()
|
||
service = ImageGenService()
|
||
|
||
# 检测是否是云端环境
|
||
is_cloud = os.environ.get('TCB_ENV') or os.environ.get('CBR_ENV_ID')
|
||
|
||
# 生成封面图
|
||
cover_prompt = f"Book cover for {category} story titled '{title}'. {description[:100] if description else ''}. Vertical cover image, anime style, vibrant colors, eye-catching design, high quality illustration."
|
||
|
||
print(f"[generate_draft_cover] 生成封面图: {title}")
|
||
result = await service.generate_image(cover_prompt, "cover", "anime")
|
||
|
||
if not result or not result.get("success"):
|
||
print(f"[generate_draft_cover] 封面图生成失败: {result.get('error') if result else 'Unknown'}")
|
||
return None
|
||
|
||
image_bytes = base64.b64decode(result["image_data"])
|
||
cover_path = f"uploads/stories/{story_id}/drafts/{draft_id}/cover.jpg"
|
||
|
||
if is_cloud:
|
||
# 云端环境:上传到云存储
|
||
try:
|
||
from app.routers.drafts import upload_to_cloud_storage
|
||
await upload_to_cloud_storage(image_bytes, cover_path)
|
||
print(f"[generate_draft_cover] ✓ 云端封面图上传成功")
|
||
return f"/{cover_path}"
|
||
except Exception as e:
|
||
print(f"[generate_draft_cover] 云端上传失败: {e}")
|
||
return None
|
||
else:
|
||
# 本地环境:保存到文件系统
|
||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
|
||
full_path = os.path.join(base_dir, "stories", str(story_id), "drafts", str(draft_id), "cover.jpg")
|
||
|
||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||
|
||
with open(full_path, "wb") as f:
|
||
f.write(image_bytes)
|
||
|
||
print(f"[generate_draft_cover] ✓ 本地封面图保存成功")
|
||
return f"/{cover_path}"
|
||
|
||
|
||
async def process_ai_create_story(
|
||
draft_id: int,
|
||
user_id: int,
|
||
genre: str,
|
||
keywords: str,
|
||
protagonist: str = None,
|
||
conflict: str = None
|
||
):
|
||
"""后台异步处理AI创作故事 - 将完整故事内容存入 ai_nodes"""
|
||
from app.database import async_session_factory
|
||
from app.services.ai import ai_service
|
||
|
||
print(f"\n[process_ai_create_story] ========== 开始创作 ==========")
|
||
print(f"[process_ai_create_story] draft_id={draft_id}, user_id={user_id}")
|
||
print(f"[process_ai_create_story] genre={genre}, keywords={keywords}")
|
||
|
||
async with async_session_factory() as db:
|
||
try:
|
||
# 获取草稿记录
|
||
draft_result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
|
||
draft = draft_result.scalar_one_or_none()
|
||
|
||
if not draft:
|
||
print(f"[process_ai_create_story] 草稿不存在")
|
||
return
|
||
|
||
# 调用AI服务创作故事
|
||
print(f"[process_ai_create_story] 开始调用AI服务...")
|
||
ai_result = await ai_service.create_story(
|
||
genre=genre,
|
||
keywords=keywords,
|
||
protagonist=protagonist,
|
||
conflict=conflict,
|
||
user_id=user_id
|
||
)
|
||
|
||
if not ai_result:
|
||
print(f"[process_ai_create_story] AI创作失败")
|
||
draft.status = DraftStatus.failed
|
||
draft.error_message = "AI创作失败"
|
||
await db.commit()
|
||
return
|
||
|
||
print(f"[process_ai_create_story] AI创作成功,开始生成配图...")
|
||
|
||
# 获取故事节点并生成背景图(失败不影响创作结果)
|
||
story_nodes = ai_result.get("nodes", {})
|
||
story_category = ai_result.get("category", genre)
|
||
story_title = ai_result.get("title", "未命名故事")
|
||
story_description = ai_result.get("description", "")
|
||
|
||
# 生成封面图
|
||
try:
|
||
cover_url = await generate_draft_cover(
|
||
story_id=draft.story_id,
|
||
draft_id=draft_id,
|
||
title=story_title,
|
||
description=story_description,
|
||
category=story_category
|
||
)
|
||
if cover_url:
|
||
ai_result["coverUrl"] = cover_url
|
||
print(f"[process_ai_create_story] 封面图生成成功: {cover_url}")
|
||
except Exception as cover_e:
|
||
print(f"[process_ai_create_story] 封面图生成失败: {cover_e}")
|
||
|
||
# 生成节点背景图
|
||
if story_nodes:
|
||
try:
|
||
from app.routers.drafts import generate_draft_images
|
||
await generate_draft_images(
|
||
story_id=draft.story_id,
|
||
draft_id=draft_id,
|
||
ai_nodes=story_nodes,
|
||
story_category=story_category
|
||
)
|
||
print(f"[process_ai_create_story] 配图生成完成")
|
||
except Exception as img_e:
|
||
print(f"[process_ai_create_story] 配图生成失败(不影响创作结果): {img_e}")
|
||
|
||
print(f"[process_ai_create_story] 保存到草稿...")
|
||
|
||
# 将完整故事内容存入 ai_nodes(包含已生成的 background_url)
|
||
draft.title = ai_result.get("title", "未命名故事")
|
||
draft.ai_nodes = ai_result # 存储完整的AI结果(包含 title, description, characters, nodes, startNodeKey)
|
||
draft.entry_node_key = ai_result.get("startNodeKey", "start")
|
||
draft.status = DraftStatus.completed
|
||
|
||
await db.commit()
|
||
|
||
print(f"[process_ai_create_story] ========== 创作完成(已保存到草稿箱) ==========")
|
||
print(f"[process_ai_create_story] 故事标题: {draft.title}")
|
||
print(f"[process_ai_create_story] 节点数量: {len(story_nodes)}")
|
||
|
||
except Exception as e:
|
||
print(f"[process_ai_create_story] 异常: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
try:
|
||
draft.status = DraftStatus.failed
|
||
draft.error_message = str(e)[:200]
|
||
await db.commit()
|
||
except:
|
||
pass
|
||
|
||
|
||
async def generate_story_images(story_id: int, ai_result: dict, genre: str):
|
||
"""为AI创作的故事生成图片"""
|
||
from app.database import async_session_factory
|
||
from app.services.image_gen import ImageGenService
|
||
from app.config import get_settings
|
||
import os
|
||
import base64
|
||
|
||
print(f"\n[generate_story_images] 开始为故事 {story_id} 生成图片")
|
||
|
||
settings = get_settings()
|
||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
|
||
story_dir = os.path.join(base_dir, "stories", str(story_id))
|
||
|
||
service = ImageGenService()
|
||
|
||
async with async_session_factory() as db:
|
||
try:
|
||
# 1. 生成封面图
|
||
print(f"[generate_story_images] 生成封面图...")
|
||
title = ai_result.get("title", "")
|
||
description = ai_result.get("description", "")
|
||
|
||
cover_prompt = f"Book cover for {genre} story titled '{title}'. {description[:100]}. Vertical cover image, anime style, vibrant colors, eye-catching design."
|
||
cover_result = await service.generate_image(cover_prompt, "cover", "anime")
|
||
|
||
if cover_result and cover_result.get("success"):
|
||
cover_dir = os.path.join(story_dir, "cover")
|
||
os.makedirs(cover_dir, exist_ok=True)
|
||
cover_path = os.path.join(cover_dir, "cover.jpg")
|
||
|
||
with open(cover_path, "wb") as f:
|
||
f.write(base64.b64decode(cover_result["image_data"]))
|
||
|
||
# 更新数据库
|
||
await db.execute(
|
||
update(Story)
|
||
.where(Story.id == story_id)
|
||
.values(cover_url=f"/uploads/stories/{story_id}/cover/cover.jpg")
|
||
)
|
||
print(f" ✓ 封面图生成成功")
|
||
else:
|
||
print(f" ✗ 封面图生成失败")
|
||
|
||
await asyncio.sleep(1)
|
||
|
||
# 2. 生成角色头像
|
||
print(f"[generate_story_images] 生成角色头像...")
|
||
characters = ai_result.get("characters", [])
|
||
|
||
char_result = await db.execute(
|
||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||
)
|
||
db_characters = char_result.scalars().all()
|
||
|
||
char_dir = os.path.join(story_dir, "characters")
|
||
os.makedirs(char_dir, exist_ok=True)
|
||
|
||
for db_char in db_characters:
|
||
# 找到对应的AI生成数据
|
||
char_data = next((c for c in characters if c.get("name") == db_char.name), None)
|
||
|
||
appearance = db_char.appearance or ""
|
||
avatar_prompt = f"Character portrait: {db_char.name}, {db_char.gender}, {appearance}. Anime style avatar, head and shoulders, clear face, high quality."
|
||
|
||
avatar_result = await service.generate_image(avatar_prompt, "avatar", "anime")
|
||
|
||
if avatar_result and avatar_result.get("success"):
|
||
avatar_path = os.path.join(char_dir, f"{db_char.id}.jpg")
|
||
|
||
with open(avatar_path, "wb") as f:
|
||
f.write(base64.b64decode(avatar_result["image_data"]))
|
||
|
||
await db.execute(
|
||
update(StoryCharacter)
|
||
.where(StoryCharacter.id == db_char.id)
|
||
.values(avatar_url=f"/uploads/stories/{story_id}/characters/{db_char.id}.jpg")
|
||
)
|
||
print(f" ✓ 角色 {db_char.name} 头像生成成功")
|
||
else:
|
||
print(f" ✗ 角色 {db_char.name} 头像生成失败")
|
||
|
||
await asyncio.sleep(1)
|
||
|
||
# 3. 生成节点背景图
|
||
print(f"[generate_story_images] 生成节点背景图...")
|
||
nodes_data = ai_result.get("nodes", {})
|
||
|
||
nodes_dir = os.path.join(story_dir, "nodes")
|
||
|
||
for node_key, node_data in nodes_data.items():
|
||
content = node_data.get("content", "")[:150]
|
||
|
||
bg_prompt = f"Background scene for {genre} story. Scene: {content}. Wide shot, atmospheric, no characters, anime style, vivid colors."
|
||
bg_result = await service.generate_image(bg_prompt, "background", "anime")
|
||
|
||
if bg_result and bg_result.get("success"):
|
||
node_dir = os.path.join(nodes_dir, node_key)
|
||
os.makedirs(node_dir, exist_ok=True)
|
||
bg_path = os.path.join(node_dir, "background.jpg")
|
||
|
||
with open(bg_path, "wb") as f:
|
||
f.write(base64.b64decode(bg_result["image_data"]))
|
||
|
||
await db.execute(
|
||
update(StoryNode)
|
||
.where(StoryNode.story_id == story_id)
|
||
.where(StoryNode.node_key == node_key)
|
||
.values(background_image=f"/uploads/stories/{story_id}/nodes/{node_key}/background.jpg")
|
||
)
|
||
print(f" ✓ 节点 {node_key} 背景图生成成功")
|
||
else:
|
||
print(f" ✗ 节点 {node_key} 背景图生成失败")
|
||
|
||
await asyncio.sleep(1)
|
||
|
||
await db.commit()
|
||
print(f"[generate_story_images] 图片生成完成")
|
||
|
||
except Exception as e:
|
||
print(f"[generate_story_images] 异常: {e}")
|
||
import traceback
|
||
traceback.print_exc() |