feat: AI改写面板样式优化、封面图片显示、max_tokens调整至8192
This commit is contained in:
@@ -2,14 +2,15 @@
|
||||
故事相关API路由
|
||||
"""
|
||||
import random
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, func, distinct
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.story import Story, StoryNode, StoryChoice, StoryCharacter
|
||||
from app.models.story import Story, StoryNode, StoryChoice, StoryCharacter, StoryDraft, DraftStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -65,6 +66,15 @@ class GenerateImageRequest(BaseModel):
|
||||
targetKey: Optional[str] = None # nodeKey 或 characterId
|
||||
|
||||
|
||||
class AICreateStoryRequest(BaseModel):
|
||||
"""AI创作全新故事请求"""
|
||||
userId: int
|
||||
genre: str # 题材
|
||||
keywords: str # 关键词
|
||||
protagonist: Optional[str] = None # 主角设定
|
||||
conflict: Optional[str] = None # 核心冲突
|
||||
|
||||
|
||||
# ========== API接口 ==========
|
||||
|
||||
@router.get("")
|
||||
@@ -837,4 +847,437 @@ async def generate_all_story_images(
|
||||
"generated": generated,
|
||||
"failed": failed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ========== AI创作全新故事 ==========
|
||||
|
||||
@router.post("/ai-create")
|
||||
async def ai_create_story(
|
||||
request: AICreateStoryRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""AI创作全新故事(异步处理)- 只存储到 story_drafts,不创建 Story"""
|
||||
from app.models.user import User
|
||||
|
||||
# 验证用户
|
||||
user_result = await db.execute(select(User).where(User.id == request.userId))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 获取或创建虚拟故事(用于满足 story_id 外键约束)
|
||||
virtual_story = await get_or_create_virtual_story(db)
|
||||
|
||||
# 创建草稿记录(完整故事内容将存储在 ai_nodes 中)
|
||||
draft = StoryDraft(
|
||||
user_id=request.userId,
|
||||
story_id=virtual_story.id, # 使用虚拟故事ID
|
||||
title="AI创作中...",
|
||||
user_prompt=f"题材:{request.genre}, 关键词:{request.keywords}, 主角:{request.protagonist or '无'}, 冲突:{request.conflict or '无'}",
|
||||
draft_type="create",
|
||||
status=DraftStatus.pending
|
||||
)
|
||||
db.add(draft)
|
||||
await db.commit()
|
||||
await db.refresh(draft)
|
||||
|
||||
# 添加后台任务
|
||||
background_tasks.add_task(
|
||||
process_ai_create_story,
|
||||
draft.id,
|
||||
request.userId,
|
||||
request.genre,
|
||||
request.keywords,
|
||||
request.protagonist,
|
||||
request.conflict
|
||||
)
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"draftId": draft.id,
|
||||
"message": "故事创作已开始,完成后将保存到草稿箱"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def get_or_create_virtual_story(db: AsyncSession) -> Story:
|
||||
"""获取或创建用于AI创作的虚拟故事(满足外键约束)"""
|
||||
# 查找已存在的虚拟故事
|
||||
result = await db.execute(
|
||||
select(Story).where(Story.title == "[系统] AI创作占位故事")
|
||||
)
|
||||
virtual_story = result.scalar_one_or_none()
|
||||
|
||||
if not virtual_story:
|
||||
# 创建虚拟故事
|
||||
virtual_story = Story(
|
||||
title="[系统] AI创作占位故事",
|
||||
description="此故事仅用于AI创作功能的外键占位,不可游玩",
|
||||
category="系统",
|
||||
status=-99, # 特殊状态,不会出现在任何列表中
|
||||
cover_url="",
|
||||
author_id=1 # 系统用户
|
||||
)
|
||||
db.add(virtual_story)
|
||||
await db.commit()
|
||||
await db.refresh(virtual_story)
|
||||
|
||||
return virtual_story
|
||||
|
||||
|
||||
@router.get("/ai-create/{draft_id}/status")
|
||||
async def get_ai_create_status(
|
||||
draft_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取AI创作状态(通过 draft_id 查询)"""
|
||||
draft_result = await db.execute(
|
||||
select(StoryDraft).where(StoryDraft.id == draft_id)
|
||||
)
|
||||
draft = draft_result.scalar_one_or_none()
|
||||
|
||||
if not draft:
|
||||
raise HTTPException(status_code=404, detail="草稿不存在")
|
||||
|
||||
is_completed = draft.status == DraftStatus.completed
|
||||
is_failed = draft.status == DraftStatus.failed
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"draftId": draft.id,
|
||||
"status": -1 if is_failed else (1 if is_completed else 0),
|
||||
"title": draft.title,
|
||||
"isCompleted": is_completed,
|
||||
"isFailed": is_failed,
|
||||
"errorMessage": draft.error_message if is_failed else None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/ai-create/{draft_id}/publish")
|
||||
async def publish_ai_created_story(
|
||||
draft_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""发布AI创作的草稿到'我的作品'"""
|
||||
draft_result = await db.execute(
|
||||
select(StoryDraft).where(StoryDraft.id == draft_id)
|
||||
)
|
||||
draft = draft_result.scalar_one_or_none()
|
||||
|
||||
if not draft:
|
||||
raise HTTPException(status_code=404, detail="草稿不存在")
|
||||
|
||||
if draft.status != DraftStatus.completed:
|
||||
raise HTTPException(status_code=400, detail="草稿尚未完成或已失败")
|
||||
|
||||
if draft.published_to_center:
|
||||
raise HTTPException(status_code=400, detail="草稿已发布")
|
||||
|
||||
# 标记为已发布
|
||||
draft.published_to_center = True
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"draftId": draft.id,
|
||||
"title": draft.title,
|
||||
"message": "发布成功!可在'我的作品'中查看"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def generate_draft_cover(
|
||||
story_id: int,
|
||||
draft_id: int,
|
||||
title: str,
|
||||
description: str,
|
||||
category: str
|
||||
) -> str:
|
||||
"""
|
||||
为AI创作的草稿生成封面图片
|
||||
返回封面图片的URL路径
|
||||
"""
|
||||
from app.services.image_gen import ImageGenService
|
||||
from app.config import get_settings
|
||||
import os
|
||||
import base64
|
||||
|
||||
settings = get_settings()
|
||||
service = ImageGenService()
|
||||
|
||||
# 检测是否是云端环境
|
||||
is_cloud = os.environ.get('TCB_ENV') or os.environ.get('CBR_ENV_ID')
|
||||
|
||||
# 生成封面图
|
||||
cover_prompt = f"Book cover for {category} story titled '{title}'. {description[:100] if description else ''}. Vertical cover image, anime style, vibrant colors, eye-catching design, high quality illustration."
|
||||
|
||||
print(f"[generate_draft_cover] 生成封面图: {title}")
|
||||
result = await service.generate_image(cover_prompt, "cover", "anime")
|
||||
|
||||
if not result or not result.get("success"):
|
||||
print(f"[generate_draft_cover] 封面图生成失败: {result.get('error') if result else 'Unknown'}")
|
||||
return None
|
||||
|
||||
image_bytes = base64.b64decode(result["image_data"])
|
||||
cover_path = f"uploads/stories/{story_id}/drafts/{draft_id}/cover.jpg"
|
||||
|
||||
if is_cloud:
|
||||
# 云端环境:上传到云存储
|
||||
try:
|
||||
from app.routers.drafts import upload_to_cloud_storage
|
||||
await upload_to_cloud_storage(image_bytes, cover_path)
|
||||
print(f"[generate_draft_cover] ✓ 云端封面图上传成功")
|
||||
return f"/{cover_path}"
|
||||
except Exception as e:
|
||||
print(f"[generate_draft_cover] 云端上传失败: {e}")
|
||||
return None
|
||||
else:
|
||||
# 本地环境:保存到文件系统
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
|
||||
full_path = os.path.join(base_dir, "stories", str(story_id), "drafts", str(draft_id), "cover.jpg")
|
||||
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
print(f"[generate_draft_cover] ✓ 本地封面图保存成功")
|
||||
return f"/{cover_path}"
|
||||
|
||||
|
||||
async def process_ai_create_story(
|
||||
draft_id: int,
|
||||
user_id: int,
|
||||
genre: str,
|
||||
keywords: str,
|
||||
protagonist: str = None,
|
||||
conflict: str = None
|
||||
):
|
||||
"""后台异步处理AI创作故事 - 将完整故事内容存入 ai_nodes"""
|
||||
from app.database import async_session_factory
|
||||
from app.services.ai import ai_service
|
||||
|
||||
print(f"\n[process_ai_create_story] ========== 开始创作 ==========")
|
||||
print(f"[process_ai_create_story] draft_id={draft_id}, user_id={user_id}")
|
||||
print(f"[process_ai_create_story] genre={genre}, keywords={keywords}")
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
# 获取草稿记录
|
||||
draft_result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
|
||||
draft = draft_result.scalar_one_or_none()
|
||||
|
||||
if not draft:
|
||||
print(f"[process_ai_create_story] 草稿不存在")
|
||||
return
|
||||
|
||||
# 调用AI服务创作故事
|
||||
print(f"[process_ai_create_story] 开始调用AI服务...")
|
||||
ai_result = await ai_service.create_story(
|
||||
genre=genre,
|
||||
keywords=keywords,
|
||||
protagonist=protagonist,
|
||||
conflict=conflict,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if not ai_result:
|
||||
print(f"[process_ai_create_story] AI创作失败")
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = "AI创作失败"
|
||||
await db.commit()
|
||||
return
|
||||
|
||||
print(f"[process_ai_create_story] AI创作成功,开始生成配图...")
|
||||
|
||||
# 获取故事节点并生成背景图(失败不影响创作结果)
|
||||
story_nodes = ai_result.get("nodes", {})
|
||||
story_category = ai_result.get("category", genre)
|
||||
story_title = ai_result.get("title", "未命名故事")
|
||||
story_description = ai_result.get("description", "")
|
||||
|
||||
# 生成封面图
|
||||
try:
|
||||
cover_url = await generate_draft_cover(
|
||||
story_id=draft.story_id,
|
||||
draft_id=draft_id,
|
||||
title=story_title,
|
||||
description=story_description,
|
||||
category=story_category
|
||||
)
|
||||
if cover_url:
|
||||
ai_result["coverUrl"] = cover_url
|
||||
print(f"[process_ai_create_story] 封面图生成成功: {cover_url}")
|
||||
except Exception as cover_e:
|
||||
print(f"[process_ai_create_story] 封面图生成失败: {cover_e}")
|
||||
|
||||
# 生成节点背景图
|
||||
if story_nodes:
|
||||
try:
|
||||
from app.routers.drafts import generate_draft_images
|
||||
await generate_draft_images(
|
||||
story_id=draft.story_id,
|
||||
draft_id=draft_id,
|
||||
ai_nodes=story_nodes,
|
||||
story_category=story_category
|
||||
)
|
||||
print(f"[process_ai_create_story] 配图生成完成")
|
||||
except Exception as img_e:
|
||||
print(f"[process_ai_create_story] 配图生成失败(不影响创作结果): {img_e}")
|
||||
|
||||
print(f"[process_ai_create_story] 保存到草稿...")
|
||||
|
||||
# 将完整故事内容存入 ai_nodes(包含已生成的 background_url)
|
||||
draft.title = ai_result.get("title", "未命名故事")
|
||||
draft.ai_nodes = ai_result # 存储完整的AI结果(包含 title, description, characters, nodes, startNodeKey)
|
||||
draft.entry_node_key = ai_result.get("startNodeKey", "start")
|
||||
draft.status = DraftStatus.completed
|
||||
|
||||
await db.commit()
|
||||
|
||||
print(f"[process_ai_create_story] ========== 创作完成(已保存到草稿箱) ==========")
|
||||
print(f"[process_ai_create_story] 故事标题: {draft.title}")
|
||||
print(f"[process_ai_create_story] 节点数量: {len(story_nodes)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[process_ai_create_story] 异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
try:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = str(e)[:200]
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
async def generate_story_images(story_id: int, ai_result: dict, genre: str):
|
||||
"""为AI创作的故事生成图片"""
|
||||
from app.database import async_session_factory
|
||||
from app.services.image_gen import ImageGenService
|
||||
from app.config import get_settings
|
||||
import os
|
||||
import base64
|
||||
|
||||
print(f"\n[generate_story_images] 开始为故事 {story_id} 生成图片")
|
||||
|
||||
settings = get_settings()
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
|
||||
story_dir = os.path.join(base_dir, "stories", str(story_id))
|
||||
|
||||
service = ImageGenService()
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
# 1. 生成封面图
|
||||
print(f"[generate_story_images] 生成封面图...")
|
||||
title = ai_result.get("title", "")
|
||||
description = ai_result.get("description", "")
|
||||
|
||||
cover_prompt = f"Book cover for {genre} story titled '{title}'. {description[:100]}. Vertical cover image, anime style, vibrant colors, eye-catching design."
|
||||
cover_result = await service.generate_image(cover_prompt, "cover", "anime")
|
||||
|
||||
if cover_result and cover_result.get("success"):
|
||||
cover_dir = os.path.join(story_dir, "cover")
|
||||
os.makedirs(cover_dir, exist_ok=True)
|
||||
cover_path = os.path.join(cover_dir, "cover.jpg")
|
||||
|
||||
with open(cover_path, "wb") as f:
|
||||
f.write(base64.b64decode(cover_result["image_data"]))
|
||||
|
||||
# 更新数据库
|
||||
await db.execute(
|
||||
update(Story)
|
||||
.where(Story.id == story_id)
|
||||
.values(cover_url=f"/uploads/stories/{story_id}/cover/cover.jpg")
|
||||
)
|
||||
print(f" ✓ 封面图生成成功")
|
||||
else:
|
||||
print(f" ✗ 封面图生成失败")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 2. 生成角色头像
|
||||
print(f"[generate_story_images] 生成角色头像...")
|
||||
characters = ai_result.get("characters", [])
|
||||
|
||||
char_result = await db.execute(
|
||||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||||
)
|
||||
db_characters = char_result.scalars().all()
|
||||
|
||||
char_dir = os.path.join(story_dir, "characters")
|
||||
os.makedirs(char_dir, exist_ok=True)
|
||||
|
||||
for db_char in db_characters:
|
||||
# 找到对应的AI生成数据
|
||||
char_data = next((c for c in characters if c.get("name") == db_char.name), None)
|
||||
|
||||
appearance = db_char.appearance or ""
|
||||
avatar_prompt = f"Character portrait: {db_char.name}, {db_char.gender}, {appearance}. Anime style avatar, head and shoulders, clear face, high quality."
|
||||
|
||||
avatar_result = await service.generate_image(avatar_prompt, "avatar", "anime")
|
||||
|
||||
if avatar_result and avatar_result.get("success"):
|
||||
avatar_path = os.path.join(char_dir, f"{db_char.id}.jpg")
|
||||
|
||||
with open(avatar_path, "wb") as f:
|
||||
f.write(base64.b64decode(avatar_result["image_data"]))
|
||||
|
||||
await db.execute(
|
||||
update(StoryCharacter)
|
||||
.where(StoryCharacter.id == db_char.id)
|
||||
.values(avatar_url=f"/uploads/stories/{story_id}/characters/{db_char.id}.jpg")
|
||||
)
|
||||
print(f" ✓ 角色 {db_char.name} 头像生成成功")
|
||||
else:
|
||||
print(f" ✗ 角色 {db_char.name} 头像生成失败")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 3. 生成节点背景图
|
||||
print(f"[generate_story_images] 生成节点背景图...")
|
||||
nodes_data = ai_result.get("nodes", {})
|
||||
|
||||
nodes_dir = os.path.join(story_dir, "nodes")
|
||||
|
||||
for node_key, node_data in nodes_data.items():
|
||||
content = node_data.get("content", "")[:150]
|
||||
|
||||
bg_prompt = f"Background scene for {genre} story. Scene: {content}. Wide shot, atmospheric, no characters, anime style, vivid colors."
|
||||
bg_result = await service.generate_image(bg_prompt, "background", "anime")
|
||||
|
||||
if bg_result and bg_result.get("success"):
|
||||
node_dir = os.path.join(nodes_dir, node_key)
|
||||
os.makedirs(node_dir, exist_ok=True)
|
||||
bg_path = os.path.join(node_dir, "background.jpg")
|
||||
|
||||
with open(bg_path, "wb") as f:
|
||||
f.write(base64.b64decode(bg_result["image_data"]))
|
||||
|
||||
await db.execute(
|
||||
update(StoryNode)
|
||||
.where(StoryNode.story_id == story_id)
|
||||
.where(StoryNode.node_key == node_key)
|
||||
.values(background_image=f"/uploads/stories/{story_id}/nodes/{node_key}/background.jpg")
|
||||
)
|
||||
print(f" ✓ 节点 {node_key} 背景图生成成功")
|
||||
else:
|
||||
print(f" ✗ 节点 {node_key} 背景图生成失败")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await db.commit()
|
||||
print(f"[generate_story_images] 图片生成完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[generate_story_images] 异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Reference in New Issue
Block a user