Files
ai_game/server/app/routers/story.py

1283 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
故事相关API路由
"""
import random
import asyncio
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, func, distinct
from typing import Optional, List
from pydantic import BaseModel
from app.database import get_db
from app.models.story import Story, StoryNode, StoryChoice, StoryCharacter, StoryDraft, DraftStatus
router = APIRouter()
# ========== 请求/响应模型 ==========
class LikeRequest(BaseModel):
like: bool
class RewriteRequest(BaseModel):
ending_name: str
ending_content: str
prompt: str
class PathHistoryItem(BaseModel):
nodeKey: str
content: str = ""
choice: str = ""
class RewriteBranchRequest(BaseModel):
userId: int
currentNodeKey: str
pathHistory: List[PathHistoryItem]
currentContent: str
prompt: str
class NodeImageUpdate(BaseModel):
nodeKey: str
backgroundImage: str = ""
characterImage: str = ""
class CharacterImageUpdate(BaseModel):
characterId: int
avatarUrl: str = ""
class ImageConfigRequest(BaseModel):
coverUrl: str = ""
nodes: List[NodeImageUpdate] = []
characters: List[CharacterImageUpdate] = []
class GenerateImageRequest(BaseModel):
prompt: str
style: str = "anime" # anime/realistic/illustration
category: str = "character" # character/background/cover
storyId: Optional[int] = None
targetField: Optional[str] = None # coverUrl/backgroundImage/characterImage/avatarUrl
targetKey: Optional[str] = None # nodeKey 或 characterId
class AICreateStoryRequest(BaseModel):
"""AI创作全新故事请求"""
userId: int
genre: str # 题材
keywords: str # 关键词
protagonist: Optional[str] = None # 主角设定
conflict: Optional[str] = None # 核心冲突
# ========== API接口 ==========
@router.get("")
async def get_stories(
category: Optional[str] = Query(None),
featured: bool = Query(False),
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db)
):
"""获取故事列表"""
query = select(Story).where(Story.status == 1)
if category:
query = query.where(Story.category == category)
if featured:
query = query.where(Story.is_featured == True)
query = query.order_by(Story.is_featured.desc(), Story.play_count.desc())
query = query.limit(limit).offset(offset)
result = await db.execute(query)
stories = result.scalars().all()
data = [{
"id": s.id,
"title": s.title,
"cover_url": s.cover_url,
"description": s.description,
"category": s.category,
"play_count": s.play_count,
"like_count": s.like_count,
"is_featured": s.is_featured
} for s in stories]
return {"code": 0, "data": data}
@router.get("/hot")
async def get_hot_stories(
limit: int = Query(10, ge=1, le=50),
db: AsyncSession = Depends(get_db)
):
"""获取热门故事"""
query = select(Story).where(Story.status == 1).order_by(Story.play_count.desc()).limit(limit)
result = await db.execute(query)
stories = result.scalars().all()
data = [{
"id": s.id,
"title": s.title,
"cover_url": s.cover_url,
"description": s.description,
"category": s.category,
"play_count": s.play_count,
"like_count": s.like_count
} for s in stories]
return {"code": 0, "data": data}
@router.get("/categories")
async def get_categories(db: AsyncSession = Depends(get_db)):
"""获取分类列表"""
query = select(distinct(Story.category)).where(Story.status == 1)
result = await db.execute(query)
categories = [row[0] for row in result.all()]
return {"code": 0, "data": categories}
@router.get("/test-image-gen")
async def test_image_generation():
"""测试图片生成服务是否正常"""
from app.config import get_settings
from app.services.image_gen import get_image_gen_service
import httpx
settings = get_settings()
results = {
"api_key_configured": bool(settings.gemini_api_key),
"api_key_preview": settings.gemini_api_key[:8] + "..." if settings.gemini_api_key else "未配置",
"base_url": "https://work.poloapi.com/v1beta",
"network_test": None,
"generate_test": None
}
# 测试网络连接
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get("https://work.poloapi.com")
results["network_test"] = {
"success": response.status_code < 500,
"status_code": response.status_code,
"message": "网络连接正常" if response.status_code < 500 else "服务端错误"
}
except Exception as e:
results["network_test"] = {
"success": False,
"error": str(e),
"message": "网络连接失败"
}
# 测试实际生图(简单测试)
if settings.gemini_api_key:
try:
service = get_image_gen_service()
gen_result = await service.generate_image(
prompt="a simple red circle on white background",
image_type="avatar",
style="illustration"
)
results["generate_test"] = {
"success": gen_result.get("success", False),
"error": gen_result.get("error") if not gen_result.get("success") else None,
"has_image_data": bool(gen_result.get("image_data"))
}
except Exception as e:
results["generate_test"] = {
"success": False,
"error": str(e)
}
else:
results["generate_test"] = {
"success": False,
"error": "API Key 未配置"
}
return {
"code": 0,
"message": "测试完成",
"data": results
}
@router.get("/test-deepseek")
async def test_deepseek():
"""测试 DeepSeek AI 服务是否正常"""
from app.config import get_settings
import httpx
settings = get_settings()
results = {
"ai_service_enabled": settings.ai_service_enabled,
"provider": settings.ai_provider,
"api_key_configured": bool(settings.deepseek_api_key),
"api_key_preview": settings.deepseek_api_key[:8] + "..." + settings.deepseek_api_key[-4:] if settings.deepseek_api_key and len(settings.deepseek_api_key) > 12 else "未配置或太短",
"base_url": settings.deepseek_base_url,
"model": settings.deepseek_model,
"network_test": None,
"api_test": None
}
# 测试网络连接
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get("https://api.deepseek.com")
results["network_test"] = {
"success": response.status_code < 500,
"status_code": response.status_code,
"message": "网络连接正常"
}
except Exception as e:
results["network_test"] = {
"success": False,
"error": str(e),
"message": "网络连接失败"
}
# 测试 API 调用
if settings.deepseek_api_key:
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{settings.deepseek_base_url}/chat/completions",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.deepseek_api_key}"
},
json={
"model": settings.deepseek_model,
"messages": [{"role": "user", "content": "'测试成功'两个字"}],
"max_tokens": 10
}
)
if response.status_code == 200:
data = response.json()
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
results["api_test"] = {
"success": True,
"status_code": 200,
"response": content[:50]
}
else:
results["api_test"] = {
"success": False,
"status_code": response.status_code,
"error": response.text[:200]
}
except Exception as e:
results["api_test"] = {
"success": False,
"error": str(e)
}
else:
results["api_test"] = {
"success": False,
"error": "API Key 未配置"
}
return {
"code": 0,
"message": "测试完成",
"data": results
}
@router.get("/test-cloud-upload")
async def test_cloud_upload():
"""测试云存储上传是否正常"""
import os
import httpx
tcb_env = os.environ.get("TCB_ENV") or os.environ.get("CBR_ENV_ID")
results = {
"env_id": tcb_env,
"is_cloud": bool(tcb_env)
}
if not tcb_env:
return {
"code": 1,
"message": "非云托管环境,无法测试云存储上传",
"data": results
}
try:
# 测试获取上传链接
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
"http://api.weixin.qq.com/tcb/uploadfile",
json={
"env": tcb_env,
"path": "test/cloud_upload_test.txt"
},
headers={"Content-Type": "application/json"}
)
results["status_code"] = resp.status_code
results["response"] = resp.text[:500] if resp.text else ""
if resp.status_code == 200:
data = resp.json()
if data.get("errcode", 0) == 0:
results["upload_url"] = data.get("url", "")[:100]
results["file_id"] = data.get("file_id", "")
return {
"code": 0,
"message": "云存储上传链接获取成功",
"data": results
}
else:
results["errcode"] = data.get("errcode")
results["errmsg"] = data.get("errmsg")
return {
"code": 1,
"message": f"获取上传链接失败: {data.get('errmsg')}",
"data": results
}
else:
return {
"code": 1,
"message": f"请求失败: HTTP {resp.status_code}",
"data": results
}
except Exception as e:
results["error"] = str(e)
return {
"code": 1,
"message": f"测试异常: {str(e)}",
"data": results
}
@router.get("/{story_id}")
async def get_story_detail(story_id: int, db: AsyncSession = Depends(get_db)):
"""获取故事详情(含节点和选项)"""
# 获取故事
result = await db.execute(select(Story).where(Story.id == story_id, Story.status == 1))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 获取节点
nodes_result = await db.execute(
select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order)
)
nodes = nodes_result.scalars().all()
# 获取选项
choices_result = await db.execute(
select(StoryChoice).where(StoryChoice.story_id == story_id).order_by(StoryChoice.sort_order)
)
choices = choices_result.scalars().all()
# 组装节点和选项
nodes_map = {}
for node in nodes:
nodes_map[node.node_key] = {
"id": node.id,
"node_key": node.node_key,
"content": node.content,
"speaker": node.speaker,
"background_image": node.background_image,
"character_image": node.character_image,
"bgm": node.bgm,
"is_ending": node.is_ending,
"ending_name": node.ending_name,
"ending_score": node.ending_score,
"ending_type": node.ending_type,
"choices": []
}
for choice in choices:
# 找到对应的节点
for node in nodes:
if node.id == choice.node_id and node.node_key in nodes_map:
nodes_map[node.node_key]["choices"].append({
"text": choice.text,
"nextNodeKey": choice.next_node_key,
"isLocked": choice.is_locked
})
break
data = {
"id": story.id,
"title": story.title,
"cover_url": story.cover_url,
"description": story.description,
"category": story.category,
"author_id": story.author_id,
"play_count": story.play_count,
"like_count": story.like_count,
"is_featured": story.is_featured,
"nodes": nodes_map
}
return {"code": 0, "data": data}
@router.post("/{story_id}/play")
async def record_play(story_id: int, db: AsyncSession = Depends(get_db)):
"""记录游玩"""
await db.execute(
update(Story).where(Story.id == story_id).values(play_count=Story.play_count + 1)
)
await db.commit()
return {"code": 0, "message": "记录成功"}
@router.post("/{story_id}/like")
async def toggle_like(story_id: int, request: LikeRequest, db: AsyncSession = Depends(get_db)):
"""点赞/取消点赞"""
delta = 1 if request.like else -1
await db.execute(
update(Story).where(Story.id == story_id).values(like_count=Story.like_count + delta)
)
await db.commit()
return {"code": 0, "message": "点赞成功" if request.like else "取消点赞成功"}
@router.post("/{story_id}/rewrite")
async def ai_rewrite_ending(story_id: int, request: RewriteRequest, db: AsyncSession = Depends(get_db)):
"""AI改写结局"""
import json
import re
if not request.prompt:
raise HTTPException(status_code=400, detail="请输入改写指令")
# 获取故事信息
result = await db.execute(select(Story).where(Story.id == story_id))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 获取故事角色
char_result = await db.execute(
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
)
characters = [
{
"name": c.name,
"role_type": c.role_type,
"gender": c.gender,
"age_range": c.age_range,
"appearance": c.appearance,
"personality": c.personality
}
for c in char_result.scalars().all()
]
# 调用 AI 服务
from app.services.ai import ai_service
ai_result = await ai_service.rewrite_ending(
story_title=story.title,
story_category=story.category or "未知",
ending_name=request.ending_name or "未知结局",
ending_content=request.ending_content or "",
user_prompt=request.prompt,
characters=characters
)
if ai_result and ai_result.get("content"):
content = ai_result["content"]
ending_name = f"{request.ending_name}AI改写"
# 尝试解析 JSON 格式的返回
try:
# 提取 JSON 部分
json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL)
if json_match:
parsed = json.loads(json_match.group())
ending_name = parsed.get("ending_name", ending_name)
content = parsed.get("content", content)
else:
# 尝试直接解析整个内容
parsed = json.loads(content)
ending_name = parsed.get("ending_name", ending_name)
content = parsed.get("content", content)
except (json.JSONDecodeError, AttributeError):
# 解析失败,使用原始内容
pass
return {
"code": 0,
"data": {
"content": content,
"speaker": "旁白",
"is_ending": True,
"ending_name": ending_name,
"ending_type": "rewrite",
"tokens_used": ai_result.get("tokens_used", 0)
}
}
# AI 服务不可用时的降级处理
templates = [
f"根据你的愿望「{request.prompt}」,故事有了新的发展...\n\n",
f"命运的齿轮开始转动,{request.prompt}...\n\n",
f"在另一个平行世界里,{request.prompt}成为了现实...\n\n"
]
template = random.choice(templates)
new_content = (
template +
"原本的结局被改写,新的故事在这里展开。\n\n" +
f"【提示】AI服务暂时不可用这是模板内容。"
)
return {
"code": 0,
"data": {
"content": new_content,
"speaker": "旁白",
"is_ending": True,
"ending_name": f"{request.ending_name}(改写版)",
"ending_type": "rewrite"
}
}
@router.post("/{story_id}/rewrite-branch")
async def ai_rewrite_branch(
story_id: int,
request: RewriteBranchRequest,
db: AsyncSession = Depends(get_db)
):
"""AI改写中间章节生成新的剧情分支"""
if not request.prompt:
raise HTTPException(status_code=400, detail="请输入改写指令")
# 获取故事信息
result = await db.execute(select(Story).where(Story.id == story_id))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 获取故事角色
char_result = await db.execute(
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
)
characters = [
{
"name": c.name,
"role_type": c.role_type,
"gender": c.gender,
"age_range": c.age_range,
"appearance": c.appearance,
"personality": c.personality
}
for c in char_result.scalars().all()
]
# 将 Pydantic 模型转换为字典列表
path_history = [
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
for item in request.pathHistory
]
# 调用 AI 服务
from app.services.ai import ai_service
ai_result = await ai_service.rewrite_branch(
story_title=story.title,
story_category=story.category or "未知",
path_history=path_history,
current_content=request.currentContent,
user_prompt=request.prompt,
characters=characters
)
if ai_result and ai_result.get("nodes"):
return {
"code": 0,
"data": {
"nodes": ai_result["nodes"],
"entryNodeKey": ai_result.get("entryNodeKey", "branch_1"),
"tokensUsed": ai_result.get("tokens_used", 0)
}
}
# AI 服务不可用时,返回空结果(不使用兜底模板)
return {
"code": 0,
"data": {
"nodes": None,
"entryNodeKey": None,
"tokensUsed": 0,
"error": "AI服务暂时不可用"
}
}
@router.get("/{story_id}/images")
async def get_story_images(story_id: int, db: AsyncSession = Depends(get_db)):
"""获取故事的所有图片配置"""
# 获取故事封面
result = await db.execute(select(Story).where(Story.id == story_id))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 获取所有节点的图片
nodes_result = await db.execute(
select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order)
)
nodes = nodes_result.scalars().all()
# 获取所有角色的头像
chars_result = await db.execute(
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
)
characters = chars_result.scalars().all()
return {
"code": 0,
"data": {
"storyId": story_id,
"title": story.title,
"coverUrl": story.cover_url or "",
"nodes": [
{
"nodeKey": n.node_key,
"content": n.content[:50] + "..." if len(n.content) > 50 else n.content,
"backgroundImage": n.background_image or "",
"characterImage": n.character_image or "",
"isEnding": n.is_ending,
"endingName": n.ending_name or ""
}
for n in nodes
],
"characters": [
{
"characterId": c.id,
"name": c.name,
"roleType": c.role_type,
"avatarUrl": c.avatar_url or "",
"avatarPrompt": c.avatar_prompt or ""
}
for c in characters
]
}
}
@router.put("/{story_id}/images")
async def update_story_images(
story_id: int,
request: ImageConfigRequest,
db: AsyncSession = Depends(get_db)
):
"""批量更新故事的图片配置"""
# 验证故事存在
result = await db.execute(select(Story).where(Story.id == story_id))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
updated = {"cover": False, "nodes": 0, "characters": 0}
# 更新封面
if request.coverUrl:
await db.execute(
update(Story).where(Story.id == story_id).values(cover_url=request.coverUrl)
)
updated["cover"] = True
# 更新节点图片
for node_img in request.nodes:
values = {}
if node_img.backgroundImage:
values["background_image"] = node_img.backgroundImage
if node_img.characterImage:
values["character_image"] = node_img.characterImage
if values:
await db.execute(
update(StoryNode)
.where(StoryNode.story_id == story_id, StoryNode.node_key == node_img.nodeKey)
.values(**values)
)
updated["nodes"] += 1
# 更新角色头像
for char_img in request.characters:
if char_img.avatarUrl:
await db.execute(
update(StoryCharacter)
.where(StoryCharacter.id == char_img.characterId, StoryCharacter.story_id == story_id)
.values(avatar_url=char_img.avatarUrl)
)
updated["characters"] += 1
await db.commit()
return {
"code": 0,
"message": "更新成功",
"data": updated
}
@router.post("/generate-image")
async def generate_story_image(
request: GenerateImageRequest,
db: AsyncSession = Depends(get_db)
):
"""使用AI生成图片并可选保存到故事"""
from app.services.image_gen import get_image_gen_service
# 生成图片
result = await get_image_gen_service().generate_and_save(
prompt=request.prompt,
category=request.category,
style=request.style
)
if not result.get("success"):
return {
"code": 1,
"message": result.get("error", "生成失败"),
"data": None
}
image_url = result["url"]
# 如果指定了故事和目标字段,自动更新
if request.storyId and request.targetField:
if request.targetField == "coverUrl":
await db.execute(
update(Story).where(Story.id == request.storyId).values(cover_url=image_url)
)
elif request.targetField == "backgroundImage" and request.targetKey:
await db.execute(
update(StoryNode)
.where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey)
.values(background_image=image_url)
)
elif request.targetField == "characterImage" and request.targetKey:
await db.execute(
update(StoryNode)
.where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey)
.values(character_image=image_url)
)
elif request.targetField == "avatarUrl" and request.targetKey:
await db.execute(
update(StoryCharacter)
.where(StoryCharacter.story_id == request.storyId, StoryCharacter.id == int(request.targetKey))
.values(avatar_url=image_url)
)
await db.commit()
return {
"code": 0,
"message": "生成成功",
"data": {
"url": image_url,
"filename": result.get("filename"),
"saved": bool(request.storyId and request.targetField)
}
}
@router.post("/{story_id}/generate-all-images")
async def generate_all_story_images(
story_id: int,
style: str = "anime",
db: AsyncSession = Depends(get_db)
):
"""为故事批量生成所有角色头像"""
from app.services.image_gen import get_image_gen_service
image_service = get_image_gen_service()
# 获取所有角色
result = await db.execute(
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
)
characters = result.scalars().all()
if not characters:
return {"code": 1, "message": "故事没有角色数据", "data": None}
generated = []
failed = []
for char in characters:
# 使用avatar_prompt或自动构建
prompt = char.avatar_prompt or f"{char.name}, {char.gender}, {char.appearance or ''}"
gen_result = await image_service.generate_and_save(
prompt=prompt,
category="character",
style=style
)
if gen_result.get("success"):
# 更新数据库
await db.execute(
update(StoryCharacter)
.where(StoryCharacter.id == char.id)
.values(avatar_url=gen_result["url"])
)
generated.append({"id": char.id, "name": char.name, "url": gen_result["url"]})
else:
failed.append({"id": char.id, "name": char.name, "error": gen_result.get("error")})
await db.commit()
return {
"code": 0,
"message": f"生成完成: {len(generated)}成功, {len(failed)}失败",
"data": {
"generated": generated,
"failed": failed
}
}
# ========== AI创作全新故事 ==========
@router.post("/ai-create")
async def ai_create_story(
request: AICreateStoryRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""AI创作全新故事异步处理- 只存储到 story_drafts不创建 Story"""
from app.models.user import User
# 验证用户
user_result = await db.execute(select(User).where(User.id == request.userId))
user = user_result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
# 获取或创建虚拟故事(用于满足 story_id 外键约束)
virtual_story = await get_or_create_virtual_story(db)
# 创建草稿记录(完整故事内容将存储在 ai_nodes 中)
draft = StoryDraft(
user_id=request.userId,
story_id=virtual_story.id, # 使用虚拟故事ID
title="AI创作中...",
user_prompt=f"题材:{request.genre}, 关键词:{request.keywords}, 主角:{request.protagonist or ''}, 冲突:{request.conflict or ''}",
draft_type="create",
status=DraftStatus.pending
)
db.add(draft)
await db.commit()
await db.refresh(draft)
# 添加后台任务
background_tasks.add_task(
process_ai_create_story,
draft.id,
request.userId,
request.genre,
request.keywords,
request.protagonist,
request.conflict
)
return {
"code": 0,
"data": {
"draftId": draft.id,
"message": "故事创作已开始,完成后将保存到草稿箱"
}
}
async def get_or_create_virtual_story(db: AsyncSession) -> Story:
"""获取或创建用于AI创作的虚拟故事满足外键约束"""
# 查找已存在的虚拟故事
result = await db.execute(
select(Story).where(Story.title == "[系统] AI创作占位故事")
)
virtual_story = result.scalar_one_or_none()
if not virtual_story:
# 创建虚拟故事
virtual_story = Story(
title="[系统] AI创作占位故事",
description="此故事仅用于AI创作功能的外键占位不可游玩",
category="系统",
status=-99, # 特殊状态,不会出现在任何列表中
cover_url="",
author_id=1 # 系统用户
)
db.add(virtual_story)
await db.commit()
await db.refresh(virtual_story)
return virtual_story
@router.get("/ai-create/{draft_id}/status")
async def get_ai_create_status(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取AI创作状态通过 draft_id 查询)"""
draft_result = await db.execute(
select(StoryDraft).where(StoryDraft.id == draft_id)
)
draft = draft_result.scalar_one_or_none()
if not draft:
raise HTTPException(status_code=404, detail="草稿不存在")
is_completed = draft.status == DraftStatus.completed
is_failed = draft.status == DraftStatus.failed
return {
"code": 0,
"data": {
"draftId": draft.id,
"status": -1 if is_failed else (1 if is_completed else 0),
"title": draft.title,
"isCompleted": is_completed,
"isFailed": is_failed,
"errorMessage": draft.error_message if is_failed else None
}
}
@router.post("/ai-create/{draft_id}/publish")
async def publish_ai_created_story(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""发布AI创作的草稿到'我的作品'"""
draft_result = await db.execute(
select(StoryDraft).where(StoryDraft.id == draft_id)
)
draft = draft_result.scalar_one_or_none()
if not draft:
raise HTTPException(status_code=404, detail="草稿不存在")
if draft.status != DraftStatus.completed:
raise HTTPException(status_code=400, detail="草稿尚未完成或已失败")
if draft.published_to_center:
raise HTTPException(status_code=400, detail="草稿已发布")
# 标记为已发布
draft.published_to_center = True
await db.commit()
return {
"code": 0,
"data": {
"draftId": draft.id,
"title": draft.title,
"message": "发布成功!可在'我的作品'中查看"
}
}
async def generate_draft_cover(
story_id: int,
draft_id: int,
title: str,
description: str,
category: str
) -> str:
"""
为AI创作的草稿生成封面图片
返回封面图片的URL路径
"""
from app.services.image_gen import ImageGenService
from app.config import get_settings
import os
import base64
settings = get_settings()
service = ImageGenService()
# 检测是否是云端环境
is_cloud = os.environ.get('TCB_ENV') or os.environ.get('CBR_ENV_ID')
# 生成封面图
cover_prompt = f"Book cover for {category} story titled '{title}'. {description[:100] if description else ''}. Vertical cover image, anime style, vibrant colors, eye-catching design, high quality illustration."
print(f"[generate_draft_cover] 生成封面图: {title}")
result = await service.generate_image(cover_prompt, "cover", "anime")
if not result or not result.get("success"):
print(f"[generate_draft_cover] 封面图生成失败: {result.get('error') if result else 'Unknown'}")
return None
image_bytes = base64.b64decode(result["image_data"])
cover_path = f"uploads/stories/{story_id}/drafts/{draft_id}/cover.jpg"
if is_cloud:
# 云端环境:上传到云存储
try:
from app.routers.drafts import upload_to_cloud_storage
await upload_to_cloud_storage(image_bytes, cover_path)
print(f"[generate_draft_cover] ✓ 云端封面图上传成功")
return f"/{cover_path}"
except Exception as e:
print(f"[generate_draft_cover] 云端上传失败: {e}")
return None
else:
# 本地环境:保存到文件系统
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
full_path = os.path.join(base_dir, "stories", str(story_id), "drafts", str(draft_id), "cover.jpg")
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, "wb") as f:
f.write(image_bytes)
print(f"[generate_draft_cover] ✓ 本地封面图保存成功")
return f"/{cover_path}"
async def process_ai_create_story(
draft_id: int,
user_id: int,
genre: str,
keywords: str,
protagonist: str = None,
conflict: str = None
):
"""后台异步处理AI创作故事 - 将完整故事内容存入 ai_nodes"""
from app.database import async_session_factory
from app.services.ai import ai_service
print(f"\n[process_ai_create_story] ========== 开始创作 ==========")
print(f"[process_ai_create_story] draft_id={draft_id}, user_id={user_id}")
print(f"[process_ai_create_story] genre={genre}, keywords={keywords}")
async with async_session_factory() as db:
try:
# 获取草稿记录
draft_result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
draft = draft_result.scalar_one_or_none()
if not draft:
print(f"[process_ai_create_story] 草稿不存在")
return
# 调用AI服务创作故事
print(f"[process_ai_create_story] 开始调用AI服务...")
ai_result = await ai_service.create_story(
genre=genre,
keywords=keywords,
protagonist=protagonist,
conflict=conflict,
user_id=user_id
)
if not ai_result:
print(f"[process_ai_create_story] AI创作失败")
draft.status = DraftStatus.failed
draft.error_message = "AI创作失败"
await db.commit()
return
print(f"[process_ai_create_story] AI创作成功开始生成配图...")
# 获取故事节点并生成背景图(失败不影响创作结果)
story_nodes = ai_result.get("nodes", {})
story_category = ai_result.get("category", genre)
story_title = ai_result.get("title", "未命名故事")
story_description = ai_result.get("description", "")
# 生成封面图
try:
cover_url = await generate_draft_cover(
story_id=draft.story_id,
draft_id=draft_id,
title=story_title,
description=story_description,
category=story_category
)
if cover_url:
ai_result["coverUrl"] = cover_url
print(f"[process_ai_create_story] 封面图生成成功: {cover_url}")
except Exception as cover_e:
print(f"[process_ai_create_story] 封面图生成失败: {cover_e}")
# 生成节点背景图
if story_nodes:
try:
from app.routers.drafts import generate_draft_images
await generate_draft_images(
story_id=draft.story_id,
draft_id=draft_id,
ai_nodes=story_nodes,
story_category=story_category
)
print(f"[process_ai_create_story] 配图生成完成")
except Exception as img_e:
print(f"[process_ai_create_story] 配图生成失败(不影响创作结果): {img_e}")
print(f"[process_ai_create_story] 保存到草稿...")
# 将完整故事内容存入 ai_nodes包含已生成的 background_url
draft.title = ai_result.get("title", "未命名故事")
draft.ai_nodes = ai_result # 存储完整的AI结果包含 title, description, characters, nodes, startNodeKey
draft.entry_node_key = ai_result.get("startNodeKey", "start")
draft.status = DraftStatus.completed
await db.commit()
print(f"[process_ai_create_story] ========== 创作完成(已保存到草稿箱) ==========")
print(f"[process_ai_create_story] 故事标题: {draft.title}")
print(f"[process_ai_create_story] 节点数量: {len(story_nodes)}")
except Exception as e:
print(f"[process_ai_create_story] 异常: {e}")
import traceback
traceback.print_exc()
try:
draft.status = DraftStatus.failed
draft.error_message = str(e)[:200]
await db.commit()
except:
pass
async def generate_story_images(story_id: int, ai_result: dict, genre: str):
"""为AI创作的故事生成图片"""
from app.database import async_session_factory
from app.services.image_gen import ImageGenService
from app.config import get_settings
import os
import base64
print(f"\n[generate_story_images] 开始为故事 {story_id} 生成图片")
settings = get_settings()
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
story_dir = os.path.join(base_dir, "stories", str(story_id))
service = ImageGenService()
async with async_session_factory() as db:
try:
# 1. 生成封面图
print(f"[generate_story_images] 生成封面图...")
title = ai_result.get("title", "")
description = ai_result.get("description", "")
cover_prompt = f"Book cover for {genre} story titled '{title}'. {description[:100]}. Vertical cover image, anime style, vibrant colors, eye-catching design."
cover_result = await service.generate_image(cover_prompt, "cover", "anime")
if cover_result and cover_result.get("success"):
cover_dir = os.path.join(story_dir, "cover")
os.makedirs(cover_dir, exist_ok=True)
cover_path = os.path.join(cover_dir, "cover.jpg")
with open(cover_path, "wb") as f:
f.write(base64.b64decode(cover_result["image_data"]))
# 更新数据库
await db.execute(
update(Story)
.where(Story.id == story_id)
.values(cover_url=f"/uploads/stories/{story_id}/cover/cover.jpg")
)
print(f" ✓ 封面图生成成功")
else:
print(f" ✗ 封面图生成失败")
await asyncio.sleep(1)
# 2. 生成角色头像
print(f"[generate_story_images] 生成角色头像...")
characters = ai_result.get("characters", [])
char_result = await db.execute(
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
)
db_characters = char_result.scalars().all()
char_dir = os.path.join(story_dir, "characters")
os.makedirs(char_dir, exist_ok=True)
for db_char in db_characters:
# 找到对应的AI生成数据
char_data = next((c for c in characters if c.get("name") == db_char.name), None)
appearance = db_char.appearance or ""
avatar_prompt = f"Character portrait: {db_char.name}, {db_char.gender}, {appearance}. Anime style avatar, head and shoulders, clear face, high quality."
avatar_result = await service.generate_image(avatar_prompt, "avatar", "anime")
if avatar_result and avatar_result.get("success"):
avatar_path = os.path.join(char_dir, f"{db_char.id}.jpg")
with open(avatar_path, "wb") as f:
f.write(base64.b64decode(avatar_result["image_data"]))
await db.execute(
update(StoryCharacter)
.where(StoryCharacter.id == db_char.id)
.values(avatar_url=f"/uploads/stories/{story_id}/characters/{db_char.id}.jpg")
)
print(f" ✓ 角色 {db_char.name} 头像生成成功")
else:
print(f" ✗ 角色 {db_char.name} 头像生成失败")
await asyncio.sleep(1)
# 3. 生成节点背景图
print(f"[generate_story_images] 生成节点背景图...")
nodes_data = ai_result.get("nodes", {})
nodes_dir = os.path.join(story_dir, "nodes")
for node_key, node_data in nodes_data.items():
content = node_data.get("content", "")[:150]
bg_prompt = f"Background scene for {genre} story. Scene: {content}. Wide shot, atmospheric, no characters, anime style, vivid colors."
bg_result = await service.generate_image(bg_prompt, "background", "anime")
if bg_result and bg_result.get("success"):
node_dir = os.path.join(nodes_dir, node_key)
os.makedirs(node_dir, exist_ok=True)
bg_path = os.path.join(node_dir, "background.jpg")
with open(bg_path, "wb") as f:
f.write(base64.b64decode(bg_result["image_data"]))
await db.execute(
update(StoryNode)
.where(StoryNode.story_id == story_id)
.where(StoryNode.node_key == node_key)
.values(background_image=f"/uploads/stories/{story_id}/nodes/{node_key}/background.jpg")
)
print(f" ✓ 节点 {node_key} 背景图生成成功")
else:
print(f" ✗ 节点 {node_key} 背景图生成失败")
await asyncio.sleep(1)
await db.commit()
print(f"[generate_story_images] 图片生成完成")
except Exception as e:
print(f"[generate_story_images] 异常: {e}")
import traceback
traceback.print_exc()