""" 草稿箱路由 - AI异步改写功能 """ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete from sqlalchemy.sql import func from pydantic import BaseModel from typing import List, Optional from datetime import datetime from app.database import get_db from app.models.story import Story, StoryDraft, DraftStatus router = APIRouter(prefix="/drafts", tags=["草稿箱"]) # ============ 请求/响应模型 ============ class PathHistoryItem(BaseModel): nodeKey: str content: str choice: str class CreateDraftRequest(BaseModel): userId: int storyId: int currentNodeKey: str pathHistory: List[PathHistoryItem] currentContent: str prompt: str class DraftResponse(BaseModel): id: int storyId: int storyTitle: str title: str userPrompt: str status: str isRead: bool createdAt: str completedAt: Optional[str] = None class Config: from_attributes = True # ============ 后台任务 ============ async def process_ai_rewrite(draft_id: int): """后台异步处理AI改写""" from app.database import async_session_factory from app.services.ai import ai_service async with async_session_factory() as db: try: # 获取草稿 result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id)) draft = result.scalar_one_or_none() if not draft: return # 更新状态为处理中 draft.status = DraftStatus.processing await db.commit() # 获取故事信息 story_result = await db.execute(select(Story).where(Story.id == draft.story_id)) story = story_result.scalar_one_or_none() if not story: draft.status = DraftStatus.failed draft.error_message = "故事不存在" draft.completed_at = datetime.now() await db.commit() return # 转换路径历史格式 path_history = draft.path_history or [] # 调用AI服务 ai_result = await ai_service.rewrite_branch( story_title=story.title, story_category=story.category or "未知", path_history=path_history, current_content=draft.current_content or "", user_prompt=draft.user_prompt ) if ai_result and ai_result.get("nodes"): # 成功 draft.status = DraftStatus.completed draft.ai_nodes = ai_result["nodes"] draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1") draft.tokens_used = ai_result.get("tokens_used", 0) draft.title = f"{story.title}-改写" else: # 失败 draft.status = DraftStatus.failed draft.error_message = "AI服务暂时不可用" draft.completed_at = datetime.now() await db.commit() except Exception as e: print(f"[process_ai_rewrite] 异常: {e}") import traceback traceback.print_exc() # 更新失败状态 try: draft.status = DraftStatus.failed draft.error_message = str(e)[:500] draft.completed_at = datetime.now() await db.commit() except: pass # ============ API 路由 ============ @router.post("") async def create_draft( request: CreateDraftRequest, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db) ): """提交AI改写任务(异步处理)""" if not request.prompt: raise HTTPException(status_code=400, detail="请输入改写指令") # 获取故事信息 result = await db.execute(select(Story).where(Story.id == request.storyId)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 转换路径历史 path_history = [ {"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice} for item in request.pathHistory ] # 创建草稿记录 draft = StoryDraft( user_id=request.userId, story_id=request.storyId, title=f"{story.title}-改写", path_history=path_history, current_node_key=request.currentNodeKey, current_content=request.currentContent, user_prompt=request.prompt, status=DraftStatus.pending ) db.add(draft) await db.commit() await db.refresh(draft) # 添加后台任务 background_tasks.add_task(process_ai_rewrite, draft.id) return { "code": 0, "data": { "draftId": draft.id, "message": "已提交,AI正在生成中..." } } @router.get("") async def get_drafts( userId: int, db: AsyncSession = Depends(get_db) ): """获取用户的草稿列表""" result = await db.execute( select(StoryDraft, Story.title.label("story_title")) .join(Story, StoryDraft.story_id == Story.id) .where(StoryDraft.user_id == userId) .order_by(StoryDraft.created_at.desc()) ) drafts = [] for row in result: draft = row[0] story_title = row[1] drafts.append({ "id": draft.id, "storyId": draft.story_id, "storyTitle": story_title, "title": draft.title, "userPrompt": draft.user_prompt, "status": draft.status.value if draft.status else "pending", "isRead": draft.is_read, "createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "", "completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None }) return {"code": 0, "data": drafts} @router.get("/check-new") async def check_new_drafts( userId: int, db: AsyncSession = Depends(get_db) ): """检查是否有新完成的草稿(用于弹窗通知)""" result = await db.execute( select(StoryDraft) .where( StoryDraft.user_id == userId, StoryDraft.status == DraftStatus.completed, StoryDraft.is_read == False ) ) unread_drafts = result.scalars().all() return { "code": 0, "data": { "hasNew": len(unread_drafts) > 0, "count": len(unread_drafts), "drafts": [ { "id": d.id, "title": d.title, "userPrompt": d.user_prompt } for d in unread_drafts[:3] # 最多返回3个 ] } } @router.get("/{draft_id}") async def get_draft_detail( draft_id: int, db: AsyncSession = Depends(get_db) ): """获取草稿详情""" result = await db.execute( select(StoryDraft, Story) .join(Story, StoryDraft.story_id == Story.id) .where(StoryDraft.id == draft_id) ) row = result.first() if not row: raise HTTPException(status_code=404, detail="草稿不存在") draft, story = row # 标记为已读 if not draft.is_read: draft.is_read = True await db.commit() return { "code": 0, "data": { "id": draft.id, "storyId": draft.story_id, "storyTitle": story.title, "storyCategory": story.category, "title": draft.title, "pathHistory": draft.path_history, "currentNodeKey": draft.current_node_key, "currentContent": draft.current_content, "userPrompt": draft.user_prompt, "aiNodes": draft.ai_nodes, "entryNodeKey": draft.entry_node_key, "tokensUsed": draft.tokens_used, "status": draft.status.value if draft.status else "pending", "errorMessage": draft.error_message, "createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "", "completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None } } @router.delete("/{draft_id}") async def delete_draft( draft_id: int, userId: int, db: AsyncSession = Depends(get_db) ): """删除草稿""" result = await db.execute( select(StoryDraft).where( StoryDraft.id == draft_id, StoryDraft.user_id == userId ) ) draft = result.scalar_one_or_none() if not draft: raise HTTPException(status_code=404, detail="草稿不存在") await db.delete(draft) await db.commit() return {"code": 0, "message": "删除成功"} @router.put("/{draft_id}/read") async def mark_draft_read( draft_id: int, db: AsyncSession = Depends(get_db) ): """标记草稿为已读""" await db.execute( update(StoryDraft) .where(StoryDraft.id == draft_id) .values(is_read=True) ) await db.commit() return {"code": 0, "message": "已标记为已读"} @router.put("/batch-read") async def mark_all_drafts_read( userId: int, db: AsyncSession = Depends(get_db) ): """批量标记所有未读草稿为已读""" await db.execute( update(StoryDraft) .where( StoryDraft.user_id == userId, StoryDraft.is_read == False ) .values(is_read=True) ) await db.commit() return {"code": 0, "message": "已全部标记为已读"}