""" 草稿箱路由 - AI异步改写功能 """ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete from sqlalchemy.sql import func from pydantic import BaseModel from typing import List, Optional from datetime import datetime from app.database import get_db from app.models.story import Story, StoryDraft, DraftStatus router = APIRouter(prefix="/drafts", tags=["草稿箱"]) # ============ 请求/响应模型 ============ class PathHistoryItem(BaseModel): nodeKey: str content: str choice: str class CreateDraftRequest(BaseModel): userId: int storyId: int currentNodeKey: str pathHistory: List[PathHistoryItem] currentContent: str prompt: str class CreateEndingDraftRequest(BaseModel): """结局改写请求""" userId: int storyId: int endingName: str endingContent: str prompt: str pathHistory: list = [] # 游玩路径历史(可选) class ContinueEndingDraftRequest(BaseModel): """结局续写请求""" userId: int storyId: int endingName: str endingContent: str prompt: str pathHistory: list = [] # 游玩路径历史(可选) class DraftResponse(BaseModel): id: int storyId: int storyTitle: str title: str userPrompt: str status: str isRead: bool createdAt: str completedAt: Optional[str] = None class Config: from_attributes = True # ============ 后台任务 ============ async def process_ai_rewrite(draft_id: int): """后台异步处理AI改写""" from app.database import async_session_factory from app.services.ai import ai_service async with async_session_factory() as db: try: # 获取草稿 result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id)) draft = result.scalar_one_or_none() if not draft: return # 更新状态为处理中 draft.status = DraftStatus.processing await db.commit() # 获取故事信息 story_result = await db.execute(select(Story).where(Story.id == draft.story_id)) story = story_result.scalar_one_or_none() if not story: draft.status = DraftStatus.failed draft.error_message = "故事不存在" draft.completed_at = datetime.now() await db.commit() return # 转换路径历史格式 path_history = draft.path_history or [] # 调用AI服务 ai_result = await ai_service.rewrite_branch( story_title=story.title, story_category=story.category or "未知", path_history=path_history, current_content=draft.current_content or "", user_prompt=draft.user_prompt ) if ai_result and ai_result.get("nodes"): # 成功 draft.status = DraftStatus.completed draft.ai_nodes = ai_result["nodes"] draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1") draft.tokens_used = ai_result.get("tokens_used", 0) draft.title = f"{story.title}-改写" else: # 失败 draft.status = DraftStatus.failed draft.error_message = "AI服务暂时不可用" draft.completed_at = datetime.now() await db.commit() except Exception as e: print(f"[process_ai_rewrite] 异常: {e}") import traceback traceback.print_exc() # 更新失败状态 try: draft.status = DraftStatus.failed draft.error_message = str(e)[:500] draft.completed_at = datetime.now() await db.commit() except: pass async def process_ai_rewrite_ending(draft_id: int): """后台异步处理AI改写结局""" from app.database import async_session_factory from app.services.ai import ai_service import json import re async with async_session_factory() as db: try: # 获取草稿 result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id)) draft = result.scalar_one_or_none() if not draft: return # 更新状态为处理中 draft.status = DraftStatus.processing await db.commit() # 获取故事信息 story_result = await db.execute(select(Story).where(Story.id == draft.story_id)) story = story_result.scalar_one_or_none() if not story: draft.status = DraftStatus.failed draft.error_message = "故事不存在" draft.completed_at = datetime.now() await db.commit() return # 从草稿字段获取结局信息 ending_name = draft.current_node_key or "未知结局" ending_content = draft.current_content or "" # 调用AI服务改写结局 ai_result = await ai_service.rewrite_ending( story_title=story.title, story_category=story.category or "未知", ending_name=ending_name, ending_content=ending_content, user_prompt=draft.user_prompt ) if ai_result and ai_result.get("content"): content = ai_result["content"] new_ending_name = f"{ending_name}(AI改写)" # 尝试解析 JSON 格式的返回 try: json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL) if json_match: parsed = json.loads(json_match.group()) new_ending_name = parsed.get("ending_name", new_ending_name) content = parsed.get("content", content) else: parsed = json.loads(content) new_ending_name = parsed.get("ending_name", new_ending_name) content = parsed.get("content", content) except (json.JSONDecodeError, AttributeError): pass # 成功 - 存储为对象格式(与故事节点格式一致) draft.status = DraftStatus.completed draft.ai_nodes = { "ending_rewrite": { "content": content, "speaker": "旁白", "is_ending": True, "ending_name": new_ending_name, "ending_type": "rewrite" } } draft.entry_node_key = "ending_rewrite" draft.tokens_used = ai_result.get("tokens_used", 0) draft.title = f"{story.title}-{new_ending_name}" else: draft.status = DraftStatus.failed draft.error_message = "AI服务暂时不可用" draft.completed_at = datetime.now() await db.commit() except Exception as e: print(f"[process_ai_rewrite_ending] 异常: {e}") import traceback traceback.print_exc() try: draft.status = DraftStatus.failed draft.error_message = str(e)[:500] draft.completed_at = datetime.now() await db.commit() except: pass async def process_ai_continue_ending(draft_id: int): """后台异步处理AI续写结局""" from app.database import async_session_factory from app.services.ai import ai_service async with async_session_factory() as db: try: # 获取草稿 result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id)) draft = result.scalar_one_or_none() if not draft: return # 更新状态为处理中 draft.status = DraftStatus.processing await db.commit() # 获取故事信息 story_result = await db.execute(select(Story).where(Story.id == draft.story_id)) story = story_result.scalar_one_or_none() if not story: draft.status = DraftStatus.failed draft.error_message = "故事不存在" draft.completed_at = datetime.now() await db.commit() return # 从草稿字段获取结局信息 ending_name = draft.current_node_key or "未知结局" ending_content = draft.current_content or "" # 调用AI服务续写结局 ai_result = await ai_service.continue_ending( story_title=story.title, story_category=story.category or "未知", ending_name=ending_name, ending_content=ending_content, user_prompt=draft.user_prompt ) if ai_result and ai_result.get("nodes"): # 成功 - 存储多节点分支格式 draft.status = DraftStatus.completed draft.ai_nodes = ai_result["nodes"] draft.entry_node_key = ai_result.get("entryNodeKey", "continue_1") draft.tokens_used = ai_result.get("tokens_used", 0) draft.title = f"{story.title}-{ending_name}续写" else: draft.status = DraftStatus.failed draft.error_message = "AI服务暂时不可用" draft.completed_at = datetime.now() await db.commit() except Exception as e: print(f"[process_ai_continue_ending] 异常: {e}") import traceback traceback.print_exc() try: draft.status = DraftStatus.failed draft.error_message = str(e)[:500] draft.completed_at = datetime.now() await db.commit() except: pass # ============ API 路由 ============ @router.post("") async def create_draft( request: CreateDraftRequest, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db) ): """提交AI改写任务(异步处理)""" if not request.prompt: raise HTTPException(status_code=400, detail="请输入改写指令") # 获取故事信息 result = await db.execute(select(Story).where(Story.id == request.storyId)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 转换路径历史 path_history = [ {"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice} for item in request.pathHistory ] # 创建草稿记录 draft = StoryDraft( user_id=request.userId, story_id=request.storyId, title=f"{story.title}-改写", path_history=path_history, current_node_key=request.currentNodeKey, current_content=request.currentContent, user_prompt=request.prompt, status=DraftStatus.pending ) db.add(draft) await db.commit() await db.refresh(draft) # 添加后台任务 background_tasks.add_task(process_ai_rewrite, draft.id) return { "code": 0, "data": { "draftId": draft.id, "message": "已提交,AI正在生成中..." } } @router.post("/ending") async def create_ending_draft( request: CreateEndingDraftRequest, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db) ): """提交AI改写结局任务(异步处理)""" if not request.prompt: raise HTTPException(status_code=400, detail="请输入改写指令") # 获取故事信息 result = await db.execute(select(Story).where(Story.id == request.storyId)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 创建草稿记录,保存游玩路径和结局信息 draft = StoryDraft( user_id=request.userId, story_id=request.storyId, title=f"{story.title}-结局改写", path_history=request.pathHistory, # 保存游玩路径 current_node_key=request.endingName, # 保存结局名称 current_content=request.endingContent, # 保存结局内容 user_prompt=request.prompt, status=DraftStatus.pending ) db.add(draft) await db.commit() await db.refresh(draft) # 添加后台任务 background_tasks.add_task(process_ai_rewrite_ending, draft.id) return { "code": 0, "data": { "draftId": draft.id, "message": "已提交,AI正在生成新结局..." } } @router.post("/continue-ending") async def create_continue_ending_draft( request: ContinueEndingDraftRequest, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db) ): """提交AI续写结局任务(异步处理)""" if not request.prompt: raise HTTPException(status_code=400, detail="请输入续写指令") # 获取故事信息 result = await db.execute(select(Story).where(Story.id == request.storyId)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 创建草稿记录,保存游玩路径和结局信息 draft = StoryDraft( user_id=request.userId, story_id=request.storyId, title=f"{story.title}-结局续写", path_history=request.pathHistory, # 保存游玩路径 current_node_key=request.endingName, # 保存结局名称 current_content=request.endingContent, # 保存结局内容 user_prompt=request.prompt, status=DraftStatus.pending ) db.add(draft) await db.commit() await db.refresh(draft) # 添加后台任务 background_tasks.add_task(process_ai_continue_ending, draft.id) return { "code": 0, "data": { "draftId": draft.id, "message": "已提交,AI正在续写故事..." } } @router.get("") async def get_drafts( userId: int, db: AsyncSession = Depends(get_db) ): """获取用户的草稿列表""" result = await db.execute( select(StoryDraft, Story.title.label("story_title")) .join(Story, StoryDraft.story_id == Story.id) .where(StoryDraft.user_id == userId) .order_by(StoryDraft.created_at.desc()) ) drafts = [] for row in result: draft = row[0] story_title = row[1] drafts.append({ "id": draft.id, "storyId": draft.story_id, "storyTitle": story_title, "title": draft.title, "userPrompt": draft.user_prompt, "status": draft.status.value if draft.status else "pending", "isRead": draft.is_read, "createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "", "completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None }) return {"code": 0, "data": drafts} @router.get("/check-new") async def check_new_drafts( userId: int, db: AsyncSession = Depends(get_db) ): """检查是否有新完成的草稿(用于弹窗通知)""" result = await db.execute( select(StoryDraft) .where( StoryDraft.user_id == userId, StoryDraft.status == DraftStatus.completed, StoryDraft.is_read == False ) ) unread_drafts = result.scalars().all() return { "code": 0, "data": { "hasNew": len(unread_drafts) > 0, "count": len(unread_drafts), "drafts": [ { "id": d.id, "title": d.title, "userPrompt": d.user_prompt } for d in unread_drafts[:3] # 最多返回3个 ] } } @router.get("/{draft_id}") async def get_draft_detail( draft_id: int, db: AsyncSession = Depends(get_db) ): """获取草稿详情""" result = await db.execute( select(StoryDraft, Story) .join(Story, StoryDraft.story_id == Story.id) .where(StoryDraft.id == draft_id) ) row = result.first() if not row: raise HTTPException(status_code=404, detail="草稿不存在") draft, story = row # 标记为已读 if not draft.is_read: draft.is_read = True await db.commit() return { "code": 0, "data": { "id": draft.id, "storyId": draft.story_id, "storyTitle": story.title, "storyCategory": story.category, "title": draft.title, "pathHistory": draft.path_history, "currentNodeKey": draft.current_node_key, "currentContent": draft.current_content, "userPrompt": draft.user_prompt, "aiNodes": draft.ai_nodes, "entryNodeKey": draft.entry_node_key, "tokensUsed": draft.tokens_used, "status": draft.status.value if draft.status else "pending", "errorMessage": draft.error_message, "createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "", "completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None } } @router.delete("/{draft_id}") async def delete_draft( draft_id: int, userId: int, db: AsyncSession = Depends(get_db) ): """删除草稿""" result = await db.execute( select(StoryDraft).where( StoryDraft.id == draft_id, StoryDraft.user_id == userId ) ) draft = result.scalar_one_or_none() if not draft: raise HTTPException(status_code=404, detail="草稿不存在") await db.delete(draft) await db.commit() return {"code": 0, "message": "删除成功"} @router.put("/{draft_id}/read") async def mark_draft_read( draft_id: int, db: AsyncSession = Depends(get_db) ): """标记草稿为已读""" await db.execute( update(StoryDraft) .where(StoryDraft.id == draft_id) .values(is_read=True) ) await db.commit() return {"code": 0, "message": "已标记为已读"} @router.put("/batch-read") async def mark_all_drafts_read( userId: int, db: AsyncSession = Depends(get_db) ): """批量标记所有未读草稿为已读""" await db.execute( update(StoryDraft) .where( StoryDraft.user_id == userId, StoryDraft.is_read == False ) .values(is_read=True) ) await db.commit() return {"code": 0, "message": "已全部标记为已读"}