Files
ai_game/server/app/routers/drafts.py

345 lines
9.7 KiB
Python
Raw Normal View History

"""
草稿箱路由 - AI异步改写功能
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.sql import func
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime
from app.database import get_db
from app.models.story import Story, StoryDraft, DraftStatus
router = APIRouter(prefix="/drafts", tags=["草稿箱"])
# ============ 请求/响应模型 ============
class PathHistoryItem(BaseModel):
nodeKey: str
content: str
choice: str
class CreateDraftRequest(BaseModel):
userId: int
storyId: int
currentNodeKey: str
pathHistory: List[PathHistoryItem]
currentContent: str
prompt: str
class DraftResponse(BaseModel):
id: int
storyId: int
storyTitle: str
title: str
userPrompt: str
status: str
isRead: bool
createdAt: str
completedAt: Optional[str] = None
class Config:
from_attributes = True
# ============ 后台任务 ============
async def process_ai_rewrite(draft_id: int):
"""后台异步处理AI改写"""
from app.database import async_session_factory
from app.services.ai import ai_service
async with async_session_factory() as db:
try:
# 获取草稿
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
draft = result.scalar_one_or_none()
if not draft:
return
# 更新状态为处理中
draft.status = DraftStatus.processing
await db.commit()
# 获取故事信息
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
story = story_result.scalar_one_or_none()
if not story:
draft.status = DraftStatus.failed
draft.error_message = "故事不存在"
draft.completed_at = datetime.now()
await db.commit()
return
# 转换路径历史格式
path_history = draft.path_history or []
# 调用AI服务
ai_result = await ai_service.rewrite_branch(
story_title=story.title,
story_category=story.category or "未知",
path_history=path_history,
current_content=draft.current_content or "",
user_prompt=draft.user_prompt
)
if ai_result and ai_result.get("nodes"):
# 成功
draft.status = DraftStatus.completed
draft.ai_nodes = ai_result["nodes"]
draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1")
draft.tokens_used = ai_result.get("tokens_used", 0)
draft.title = f"{story.title}-改写"
else:
# 失败
draft.status = DraftStatus.failed
draft.error_message = "AI服务暂时不可用"
draft.completed_at = datetime.now()
await db.commit()
except Exception as e:
print(f"[process_ai_rewrite] 异常: {e}")
import traceback
traceback.print_exc()
# 更新失败状态
try:
draft.status = DraftStatus.failed
draft.error_message = str(e)[:500]
draft.completed_at = datetime.now()
await db.commit()
except:
pass
# ============ API 路由 ============
@router.post("")
async def create_draft(
request: CreateDraftRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""提交AI改写任务异步处理"""
if not request.prompt:
raise HTTPException(status_code=400, detail="请输入改写指令")
# 获取故事信息
result = await db.execute(select(Story).where(Story.id == request.storyId))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 转换路径历史
path_history = [
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
for item in request.pathHistory
]
# 创建草稿记录
draft = StoryDraft(
user_id=request.userId,
story_id=request.storyId,
title=f"{story.title}-改写",
path_history=path_history,
current_node_key=request.currentNodeKey,
current_content=request.currentContent,
user_prompt=request.prompt,
status=DraftStatus.pending
)
db.add(draft)
await db.commit()
await db.refresh(draft)
# 添加后台任务
background_tasks.add_task(process_ai_rewrite, draft.id)
return {
"code": 0,
"data": {
"draftId": draft.id,
"message": "已提交AI正在生成中..."
}
}
@router.get("")
async def get_drafts(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""获取用户的草稿列表"""
result = await db.execute(
select(StoryDraft, Story.title.label("story_title"))
.join(Story, StoryDraft.story_id == Story.id)
.where(StoryDraft.user_id == userId)
.order_by(StoryDraft.created_at.desc())
)
drafts = []
for row in result:
draft = row[0]
story_title = row[1]
drafts.append({
"id": draft.id,
"storyId": draft.story_id,
"storyTitle": story_title,
"title": draft.title,
"userPrompt": draft.user_prompt,
"status": draft.status.value if draft.status else "pending",
"isRead": draft.is_read,
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
})
return {"code": 0, "data": drafts}
@router.get("/check-new")
async def check_new_drafts(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""检查是否有新完成的草稿(用于弹窗通知)"""
result = await db.execute(
select(StoryDraft)
.where(
StoryDraft.user_id == userId,
StoryDraft.status == DraftStatus.completed,
StoryDraft.is_read == False
)
)
unread_drafts = result.scalars().all()
return {
"code": 0,
"data": {
"hasNew": len(unread_drafts) > 0,
"count": len(unread_drafts),
"drafts": [
{
"id": d.id,
"title": d.title,
"userPrompt": d.user_prompt
}
for d in unread_drafts[:3] # 最多返回3个
]
}
}
@router.get("/{draft_id}")
async def get_draft_detail(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取草稿详情"""
result = await db.execute(
select(StoryDraft, Story)
.join(Story, StoryDraft.story_id == Story.id)
.where(StoryDraft.id == draft_id)
)
row = result.first()
if not row:
raise HTTPException(status_code=404, detail="草稿不存在")
draft, story = row
# 标记为已读
if not draft.is_read:
draft.is_read = True
await db.commit()
return {
"code": 0,
"data": {
"id": draft.id,
"storyId": draft.story_id,
"storyTitle": story.title,
"storyCategory": story.category,
"title": draft.title,
"pathHistory": draft.path_history,
"currentNodeKey": draft.current_node_key,
"currentContent": draft.current_content,
"userPrompt": draft.user_prompt,
"aiNodes": draft.ai_nodes,
"entryNodeKey": draft.entry_node_key,
"tokensUsed": draft.tokens_used,
"status": draft.status.value if draft.status else "pending",
"errorMessage": draft.error_message,
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
}
}
@router.delete("/{draft_id}")
async def delete_draft(
draft_id: int,
userId: int,
db: AsyncSession = Depends(get_db)
):
"""删除草稿"""
result = await db.execute(
select(StoryDraft).where(
StoryDraft.id == draft_id,
StoryDraft.user_id == userId
)
)
draft = result.scalar_one_or_none()
if not draft:
raise HTTPException(status_code=404, detail="草稿不存在")
await db.delete(draft)
await db.commit()
return {"code": 0, "message": "删除成功"}
@router.put("/{draft_id}/read")
async def mark_draft_read(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""标记草稿为已读"""
await db.execute(
update(StoryDraft)
.where(StoryDraft.id == draft_id)
.values(is_read=True)
)
await db.commit()
return {"code": 0, "message": "已标记为已读"}
@router.put("/batch-read")
async def mark_all_drafts_read(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""批量标记所有未读草稿为已读"""
await db.execute(
update(StoryDraft)
.where(
StoryDraft.user_id == userId,
StoryDraft.is_read == False
)
.values(is_read=True)
)
await db.commit()
return {"code": 0, "message": "已全部标记为已读"}