Files
ai_game/server/app/routers/drafts.py

987 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
草稿箱路由 - AI异步改写功能
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.sql import func
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime
import os
import base64
from app.database import get_db
from app.models.story import Story, StoryDraft, DraftStatus, StoryCharacter
from app.config import get_settings
router = APIRouter(prefix="/drafts", tags=["草稿箱"])
# ============ 辅助函数 ============
async def get_story_characters(db: AsyncSession, story_id: int) -> List[dict]:
"""获取故事的所有角色并转为字典列表"""
result = await db.execute(
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
)
characters = result.scalars().all()
return [
{
"name": c.name,
"role_type": c.role_type,
"gender": c.gender,
"age_range": c.age_range,
"appearance": c.appearance,
"personality": c.personality
}
for c in characters
]
async def upload_to_cloud_storage(image_bytes: bytes, cloud_path: str) -> str:
"""
上传图片到微信云存储(云托管容器内调用)
cloud_path: 云存储路径,如 stories/1/drafts/10/branch_1/background.jpg
返回: 文件访问路径
"""
import httpx
env_id = os.environ.get('TCB_ENV') or os.environ.get('CBR_ENV_ID')
if not env_id:
# 尝试从配置获取
settings = get_settings()
env_id = getattr(settings, 'wx_cloud_env', None)
if not env_id:
raise Exception("未检测到云环境ID")
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# 云托管内网调用云开发 API不需要 access_token
# 参考: https://developers.weixin.qq.com/miniprogram/dev/wxcloudrun/src/development/storage/service/upload.html
# 1. 获取上传链接
resp = await client.post(
"http://api.weixin.qq.com/tcb/uploadfile",
json={
"env": env_id,
"path": cloud_path
},
headers={"Content-Type": "application/json"}
)
if resp.status_code != 200:
raise Exception(f"获取上传链接失败: {resp.status_code} - {resp.text[:200]}")
data = resp.json()
if data.get("errcode", 0) != 0:
raise Exception(f"获取上传链接失败: {data.get('errmsg')}")
upload_url = data.get("url")
authorization = data.get("authorization")
token = data.get("token")
cos_file_id = data.get("cos_file_id")
file_id = data.get("file_id")
# 2. 上传文件到 COS
form_data = {
"key": cloud_path,
"Signature": authorization,
"x-cos-security-token": token,
"x-cos-meta-fileid": cos_file_id,
}
files = {"file": ("background.jpg", image_bytes, "image/jpeg")}
upload_resp = await client.post(upload_url, data=form_data, files=files)
if upload_resp.status_code not in [200, 204]:
raise Exception(f"上传文件失败: {upload_resp.status_code} - {upload_resp.text[:200]}")
print(f" [CloudStorage] 文件上传成功: {file_id}")
return file_id
except Exception as e:
print(f"[upload_to_cloud_storage] 上传失败: {e}")
raise
async def generate_draft_images(story_id: int, draft_id: int, ai_nodes: dict, story_category: str):
"""
为草稿的 AI 生成节点生成背景图
本地环境:保存到文件系统 /uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg
云端环境:上传到云存储
"""
from app.services.image_gen import ImageGenService
if not ai_nodes:
return
settings = get_settings()
# 检测是否是云端环境TCB_ENV 或 CBR_ENV_ID 是云托管容器自动注入的)
is_cloud = os.environ.get('TCB_ENV') or os.environ.get('CBR_ENV_ID')
# 本地环境使用文件系统
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
draft_dir = os.path.join(base_dir, "stories", str(story_id), "drafts", str(draft_id))
service = ImageGenService()
for node_key, node_data in ai_nodes.items():
if not isinstance(node_data, dict):
continue
content = node_data.get('content', '')[:150]
if not content:
continue
try:
# 生成背景图 - 强调情绪表达
bg_prompt = f"Background scene for {story_category} story. Scene: {content}. Wide shot, atmospheric, no characters, anime style. Strong emotional expression, dramatic mood, vivid colors reflecting the scene's emotion."
result = await service.generate_image(bg_prompt, "background", "anime")
if result and result.get("success"):
image_bytes = base64.b64decode(result["image_data"])
# 路径格式和本地一致uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg
cloud_path = f"uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg"
# 云端环境:上传到云存储
if is_cloud:
try:
file_id = await upload_to_cloud_storage(image_bytes, cloud_path)
# 云存储返回的 file_id 格式: cloud://env-id.xxx/path
# 前端通过 CDN 地址访问: https://7072-prod-xxx.tcb.qcloud.la/uploads/...
node_data['background_url'] = f"/{cloud_path}"
print(f" ✓ 云端草稿节点 {node_key} 背景图上传成功")
except Exception as cloud_e:
print(f" ✗ 云端上传失败: {cloud_e}")
continue
# 本地环境:保存到文件系统
node_dir = os.path.join(draft_dir, node_key)
os.makedirs(node_dir, exist_ok=True)
bg_path = os.path.join(node_dir, "background.jpg")
with open(bg_path, "wb") as f:
f.write(image_bytes)
# 更新节点数据,添加图片路径
node_data['background_url'] = f"/uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg"
print(f" ✓ 草稿节点 {node_key} 背景图生成成功")
else:
print(f" ✗ 草稿节点 {node_key} 背景图生成失败: {result.get('error') if result else 'Unknown'}")
except Exception as e:
print(f" ✗ 草稿节点 {node_key} 图片生成异常: {e}")
# 避免请求过快
import asyncio
await asyncio.sleep(1)
# ============ 请求/响应模型 ============
class PathHistoryItem(BaseModel):
nodeKey: str
content: str
choice: str
class CreateDraftRequest(BaseModel):
userId: int
storyId: int
currentNodeKey: str
pathHistory: List[PathHistoryItem]
currentContent: str
prompt: str
class CreateEndingDraftRequest(BaseModel):
"""结局改写请求"""
userId: int
storyId: int
endingName: str
endingContent: str
prompt: str
pathHistory: list = [] # 游玩路径历史(可选)
class ContinueEndingDraftRequest(BaseModel):
"""结局续写请求"""
userId: int
storyId: int
endingName: str
endingContent: str
prompt: str
pathHistory: list = [] # 游玩路径历史(可选)
class DraftResponse(BaseModel):
id: int
storyId: int
storyTitle: str
title: str
userPrompt: str
status: str
isRead: bool
createdAt: str
completedAt: Optional[str] = None
class Config:
from_attributes = True
# ============ 后台任务 ============
async def process_ai_rewrite(draft_id: int):
"""后台异步处理AI改写"""
from app.database import async_session_factory
from app.services.ai import ai_service
async with async_session_factory() as db:
try:
# 获取草稿
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
draft = result.scalar_one_or_none()
if not draft:
return
# 更新状态为处理中
draft.status = DraftStatus.processing
await db.commit()
# 获取故事信息
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
story = story_result.scalar_one_or_none()
if not story:
draft.status = DraftStatus.failed
draft.error_message = "故事不存在"
draft.completed_at = datetime.now()
await db.commit()
return
# 获取故事角色
characters = await get_story_characters(db, story.id)
print(f"[process_ai_rewrite] 获取到角色数: {len(characters)}")
# 转换路径历史格式
path_history = draft.path_history or []
print(f"[process_ai_rewrite] 路径历史长度: {len(path_history)}")
# 调用AI服务
print(f"[process_ai_rewrite] 开始调用 AI 服务...")
ai_result = await ai_service.rewrite_branch(
story_title=story.title,
story_category=story.category or "未知",
path_history=path_history,
current_content=draft.current_content or "",
user_prompt=draft.user_prompt,
characters=characters
)
print(f"[process_ai_rewrite] AI 服务返回: {bool(ai_result)}")
if ai_result and ai_result.get("nodes"):
# 成功 - 尝试生成配图(失败不影响改写结果)
try:
print(f"[process_ai_rewrite] AI生成成功开始生成配图...")
await generate_draft_images(
story_id=draft.story_id,
draft_id=draft.id,
ai_nodes=ai_result["nodes"],
story_category=story.category or "都市言情"
)
except Exception as img_e:
print(f"[process_ai_rewrite] 配图生成失败(不影响改写结果): {img_e}")
draft.status = DraftStatus.completed
draft.ai_nodes = ai_result["nodes"]
draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1")
draft.tokens_used = ai_result.get("tokens_used", 0)
draft.title = f"{story.title}-改写"
else:
# 失败
draft.status = DraftStatus.failed
draft.error_message = "AI服务暂时不可用"
draft.completed_at = datetime.now()
await db.commit()
except Exception as e:
print(f"[process_ai_rewrite] 异常: {e}")
import traceback
traceback.print_exc()
# 更新失败状态
try:
draft.status = DraftStatus.failed
draft.error_message = str(e)[:500]
draft.completed_at = datetime.now()
await db.commit()
except:
pass
async def process_ai_rewrite_ending(draft_id: int):
"""后台异步处理AI改写结局"""
from app.database import async_session_factory
from app.services.ai import ai_service
import json
import re
async with async_session_factory() as db:
try:
# 获取草稿
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
draft = result.scalar_one_or_none()
if not draft:
return
# 更新状态为处理中
draft.status = DraftStatus.processing
await db.commit()
# 获取故事信息
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
story = story_result.scalar_one_or_none()
if not story:
draft.status = DraftStatus.failed
draft.error_message = "故事不存在"
draft.completed_at = datetime.now()
await db.commit()
return
# 获取故事角色
characters = await get_story_characters(db, story.id)
# 从草稿字段获取结局信息
ending_name = draft.current_node_key or "未知结局"
ending_content = draft.current_content or ""
# 调用AI服务改写结局
ai_result = await ai_service.rewrite_ending(
story_title=story.title,
story_category=story.category or "未知",
ending_name=ending_name,
ending_content=ending_content,
user_prompt=draft.user_prompt,
characters=characters
)
if ai_result and ai_result.get("content"):
content = ai_result["content"]
new_ending_name = f"{ending_name}AI改写"
# 尝试解析 JSON 格式的返回
try:
json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL)
if json_match:
parsed = json.loads(json_match.group())
new_ending_name = parsed.get("ending_name", new_ending_name)
content = parsed.get("content", content)
else:
parsed = json.loads(content)
new_ending_name = parsed.get("ending_name", new_ending_name)
content = parsed.get("content", content)
except (json.JSONDecodeError, AttributeError):
pass
# 成功 - 存储为对象格式(与故事节点格式一致)
ai_nodes = {
"ending_rewrite": {
"content": content,
"speaker": "旁白",
"is_ending": True,
"ending_name": new_ending_name,
"ending_type": "rewrite"
}
}
# 生成配图(失败不影响改写结果)
try:
print(f"[process_ai_rewrite_ending] AI生成成功开始生成配图...")
await generate_draft_images(
story_id=draft.story_id,
draft_id=draft.id,
ai_nodes=ai_nodes,
story_category=story.category or "都市言情"
)
except Exception as img_e:
print(f"[process_ai_rewrite_ending] 配图生成失败(不影响改写结果): {img_e}")
draft.status = DraftStatus.completed
draft.ai_nodes = ai_nodes
draft.entry_node_key = "ending_rewrite"
draft.tokens_used = ai_result.get("tokens_used", 0)
draft.title = f"{story.title}-{new_ending_name}"
else:
draft.status = DraftStatus.failed
draft.error_message = "AI服务暂时不可用"
draft.completed_at = datetime.now()
await db.commit()
except Exception as e:
print(f"[process_ai_rewrite_ending] 异常: {e}")
import traceback
traceback.print_exc()
try:
draft.status = DraftStatus.failed
draft.error_message = str(e)[:500]
draft.completed_at = datetime.now()
await db.commit()
except:
pass
async def process_ai_continue_ending(draft_id: int):
"""后台异步处理AI续写结局"""
from app.database import async_session_factory
from app.services.ai import ai_service
async with async_session_factory() as db:
try:
# 获取草稿
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
draft = result.scalar_one_or_none()
if not draft:
return
# 更新状态为处理中
draft.status = DraftStatus.processing
await db.commit()
# 获取故事信息
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
story = story_result.scalar_one_or_none()
if not story:
draft.status = DraftStatus.failed
draft.error_message = "故事不存在"
draft.completed_at = datetime.now()
await db.commit()
return
# 获取故事角色
characters = await get_story_characters(db, story.id)
# 从草稿字段获取结局信息
ending_name = draft.current_node_key or "未知结局"
ending_content = draft.current_content or ""
# 调用AI服务续写结局
ai_result = await ai_service.continue_ending(
story_title=story.title,
story_category=story.category or "未知",
ending_name=ending_name,
ending_content=ending_content,
user_prompt=draft.user_prompt,
characters=characters
)
if ai_result and ai_result.get("nodes"):
# 成功 - 尝试生成配图(失败不影响续写结果)
try:
print(f"[process_ai_continue_ending] AI生成成功开始生成配图...")
await generate_draft_images(
story_id=draft.story_id,
draft_id=draft.id,
ai_nodes=ai_result["nodes"],
story_category=story.category or "都市言情"
)
except Exception as img_e:
print(f"[process_ai_continue_ending] 配图生成失败(不影响续写结果): {img_e}")
draft.status = DraftStatus.completed
draft.ai_nodes = ai_result["nodes"]
draft.entry_node_key = ai_result.get("entryNodeKey", "continue_1")
draft.tokens_used = ai_result.get("tokens_used", 0)
draft.title = f"{story.title}-{ending_name}续写"
else:
draft.status = DraftStatus.failed
draft.error_message = "AI服务暂时不可用"
draft.completed_at = datetime.now()
await db.commit()
except Exception as e:
print(f"[process_ai_continue_ending] 异常: {e}")
import traceback
traceback.print_exc()
try:
draft.status = DraftStatus.failed
draft.error_message = str(e)[:500]
draft.completed_at = datetime.now()
await db.commit()
except:
pass
# ============ API 路由 ============
@router.post("")
async def create_draft(
request: CreateDraftRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""提交AI改写任务异步处理"""
if not request.prompt:
raise HTTPException(status_code=400, detail="请输入改写指令")
# 获取故事信息
result = await db.execute(select(Story).where(Story.id == request.storyId))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 转换路径历史
path_history = [
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
for item in request.pathHistory
]
# 创建草稿记录
draft = StoryDraft(
user_id=request.userId,
story_id=request.storyId,
title=f"{story.title}-改写",
path_history=path_history,
current_node_key=request.currentNodeKey,
current_content=request.currentContent,
user_prompt=request.prompt,
status=DraftStatus.pending,
draft_type='rewrite'
)
db.add(draft)
await db.commit()
await db.refresh(draft)
# 添加后台任务
background_tasks.add_task(process_ai_rewrite, draft.id)
return {
"code": 0,
"data": {
"draftId": draft.id,
"message": "已提交AI正在生成中..."
}
}
@router.post("/ending")
async def create_ending_draft(
request: CreateEndingDraftRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""提交AI改写结局任务异步处理"""
if not request.prompt:
raise HTTPException(status_code=400, detail="请输入改写指令")
# 获取故事信息
result = await db.execute(select(Story).where(Story.id == request.storyId))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 创建草稿记录,保存游玩路径和结局信息
draft = StoryDraft(
user_id=request.userId,
story_id=request.storyId,
title=f"{story.title}-结局改写",
path_history=request.pathHistory, # 保存游玩路径
current_node_key=request.endingName, # 保存结局名称
current_content=request.endingContent, # 保存结局内容
user_prompt=request.prompt,
status=DraftStatus.pending,
draft_type='rewrite'
)
db.add(draft)
await db.commit()
await db.refresh(draft)
# 添加后台任务
background_tasks.add_task(process_ai_rewrite_ending, draft.id)
return {
"code": 0,
"data": {
"draftId": draft.id,
"message": "已提交AI正在生成新结局..."
}
}
@router.post("/continue-ending")
async def create_continue_ending_draft(
request: ContinueEndingDraftRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""提交AI续写结局任务异步处理"""
if not request.prompt:
raise HTTPException(status_code=400, detail="请输入续写指令")
# 获取故事信息
result = await db.execute(select(Story).where(Story.id == request.storyId))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 创建草稿记录,保存游玩路径和结局信息
draft = StoryDraft(
user_id=request.userId,
story_id=request.storyId,
title=f"{story.title}-结局续写",
path_history=request.pathHistory, # 保存游玩路径
current_node_key=request.endingName, # 保存结局名称
current_content=request.endingContent, # 保存结局内容
user_prompt=request.prompt,
status=DraftStatus.pending,
draft_type='continue'
)
db.add(draft)
await db.commit()
await db.refresh(draft)
# 添加后台任务
background_tasks.add_task(process_ai_continue_ending, draft.id)
return {
"code": 0,
"data": {
"draftId": draft.id,
"message": "已提交AI正在续写故事..."
}
}
@router.get("")
async def get_drafts(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""获取用户的草稿列表"""
result = await db.execute(
select(StoryDraft, Story.title.label("story_title"))
.join(Story, StoryDraft.story_id == Story.id)
.where(StoryDraft.user_id == userId)
.order_by(StoryDraft.created_at.desc())
)
drafts = []
for row in result:
draft = row[0]
story_title = row[1]
drafts.append({
"id": draft.id,
"storyId": draft.story_id,
"storyTitle": story_title,
"title": draft.title,
"userPrompt": draft.user_prompt,
"status": draft.status.value if draft.status else "pending",
"isRead": draft.is_read,
"publishedToCenter": draft.published_to_center,
"draftType": draft.draft_type or "rewrite",
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
})
return {"code": 0, "data": drafts}
@router.get("/check-new")
async def check_new_drafts(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""检查是否有新完成的草稿(用于弹窗通知)"""
result = await db.execute(
select(StoryDraft)
.where(
StoryDraft.user_id == userId,
StoryDraft.status == DraftStatus.completed,
StoryDraft.is_read == False
)
)
unread_drafts = result.scalars().all()
return {
"code": 0,
"data": {
"hasNew": len(unread_drafts) > 0,
"count": len(unread_drafts),
"drafts": [
{
"id": d.id,
"title": d.title,
"userPrompt": d.user_prompt
}
for d in unread_drafts[:3] # 最多返回3个
]
}
}
@router.get("/published")
async def get_published_drafts(
userId: int,
draftType: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
"""获取已发布到创作中心的草稿列表"""
query = select(StoryDraft, Story.title.label('story_title')).join(
Story, StoryDraft.story_id == Story.id
).where(
StoryDraft.user_id == userId,
StoryDraft.published_to_center == True,
StoryDraft.status == DraftStatus.completed
)
# 按类型筛选
if draftType:
query = query.where(StoryDraft.draft_type == draftType)
query = query.order_by(StoryDraft.created_at.desc())
result = await db.execute(query)
rows = result.all()
drafts = []
for draft, story_title in rows:
drafts.append({
"id": draft.id,
"storyId": draft.story_id,
"storyTitle": story_title or "未知故事",
"title": draft.title or "",
"userPrompt": draft.user_prompt,
"draftType": draft.draft_type or "rewrite",
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else ""
})
return {
"code": 0,
"data": drafts
}
@router.get("/{draft_id}")
async def get_draft_detail(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取草稿详情"""
result = await db.execute(
select(StoryDraft, Story)
.join(Story, StoryDraft.story_id == Story.id)
.where(StoryDraft.id == draft_id)
)
row = result.first()
if not row:
raise HTTPException(status_code=404, detail="草稿不存在")
draft, story = row
# 标记为已读
if not draft.is_read:
draft.is_read = True
await db.commit()
return {
"code": 0,
"data": {
"id": draft.id,
"storyId": draft.story_id,
"storyTitle": story.title,
"storyCategory": story.category,
"title": draft.title,
"pathHistory": draft.path_history,
"currentNodeKey": draft.current_node_key,
"currentContent": draft.current_content,
"userPrompt": draft.user_prompt,
"aiNodes": draft.ai_nodes,
"entryNodeKey": draft.entry_node_key,
"tokensUsed": draft.tokens_used,
"status": draft.status.value if draft.status else "pending",
"errorMessage": draft.error_message,
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
}
}
@router.delete("/{draft_id}")
async def delete_draft(
draft_id: int,
userId: int,
db: AsyncSession = Depends(get_db)
):
"""删除草稿(同时清理图片文件)"""
result = await db.execute(
select(StoryDraft).where(
StoryDraft.id == draft_id,
StoryDraft.user_id == userId
)
)
draft = result.scalar_one_or_none()
if not draft:
raise HTTPException(status_code=404, detail="草稿不存在")
# 删除草稿对应的图片文件夹
try:
import shutil
settings = get_settings()
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
draft_dir = os.path.join(base_dir, "stories", str(draft.story_id), "drafts", str(draft_id))
if os.path.exists(draft_dir):
shutil.rmtree(draft_dir)
print(f"[delete_draft] 已清理图片目录: {draft_dir}")
except Exception as e:
print(f"[delete_draft] 清理图片失败: {e}")
# 图片清理失败不影响草稿删除
await db.delete(draft)
await db.commit()
return {"code": 0, "message": "删除成功"}
@router.put("/{draft_id}/read")
async def mark_draft_read(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""标记草稿为已读"""
await db.execute(
update(StoryDraft)
.where(StoryDraft.id == draft_id)
.values(is_read=True)
)
await db.commit()
return {"code": 0, "message": "已标记为已读"}
@router.put("/batch-read")
async def mark_all_drafts_read(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""批量标记所有未读草稿为已读"""
await db.execute(
update(StoryDraft)
.where(
StoryDraft.user_id == userId,
StoryDraft.is_read == False
)
.values(is_read=True)
)
await db.commit()
return {"code": 0, "message": "已全部标记为已读"}
@router.put("/{draft_id}/publish")
async def publish_draft_to_center(
draft_id: int,
userId: int,
db: AsyncSession = Depends(get_db)
):
"""发布草稿到创作中心"""
# 验证草稿存在且属于该用户
result = await db.execute(
select(StoryDraft).where(
StoryDraft.id == draft_id,
StoryDraft.user_id == userId,
StoryDraft.status == DraftStatus.completed
)
)
draft = result.scalar_one_or_none()
if not draft:
raise HTTPException(status_code=404, detail="草稿不存在或未完成")
# 更新发布状态
draft.published_to_center = True
await db.commit()
return {"code": 0, "message": "已发布到创作中心"}
@router.put("/{draft_id}/unpublish")
async def unpublish_draft_from_center(
draft_id: int,
userId: int,
db: AsyncSession = Depends(get_db)
):
"""从创作中心取消发布"""
await db.execute(
update(StoryDraft)
.where(
StoryDraft.id == draft_id,
StoryDraft.user_id == userId
)
.values(published_to_center=False)
)
await db.commit()
return {"code": 0, "message": "已从创作中心移除"}
@router.put("/{draft_id}/collect")
async def collect_draft(
draft_id: int,
userId: int,
isCollected: bool = True,
db: AsyncSession = Depends(get_db)
):
"""收藏/取消收藏草稿"""
await db.execute(
update(StoryDraft)
.where(
StoryDraft.id == draft_id,
StoryDraft.user_id == userId
)
.values(is_collected=isCollected)
)
await db.commit()
return {"code": 0, "message": "收藏成功" if isCollected else "取消收藏成功"}
@router.get("/{draft_id}/collect-status")
async def get_draft_collect_status(
draft_id: int,
userId: int,
db: AsyncSession = Depends(get_db)
):
"""获取草稿收藏状态"""
result = await db.execute(
select(StoryDraft.is_collected)
.where(
StoryDraft.id == draft_id,
StoryDraft.user_id == userId
)
)
is_collected = result.scalar_one_or_none()
return {"code": 0, "data": {"isCollected": is_collected or False}}