345 lines
9.7 KiB
Python
345 lines
9.7 KiB
Python
|
|
"""
|
|||
|
|
草稿箱路由 - AI异步改写功能
|
|||
|
|
"""
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from sqlalchemy import select, update, delete
|
|||
|
|
from sqlalchemy.sql import func
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from typing import List, Optional
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
from app.database import get_db
|
|||
|
|
from app.models.story import Story, StoryDraft, DraftStatus
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/drafts", tags=["草稿箱"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ 请求/响应模型 ============
|
|||
|
|
|
|||
|
|
class PathHistoryItem(BaseModel):
|
|||
|
|
nodeKey: str
|
|||
|
|
content: str
|
|||
|
|
choice: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CreateDraftRequest(BaseModel):
|
|||
|
|
userId: int
|
|||
|
|
storyId: int
|
|||
|
|
currentNodeKey: str
|
|||
|
|
pathHistory: List[PathHistoryItem]
|
|||
|
|
currentContent: str
|
|||
|
|
prompt: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DraftResponse(BaseModel):
|
|||
|
|
id: int
|
|||
|
|
storyId: int
|
|||
|
|
storyTitle: str
|
|||
|
|
title: str
|
|||
|
|
userPrompt: str
|
|||
|
|
status: str
|
|||
|
|
isRead: bool
|
|||
|
|
createdAt: str
|
|||
|
|
completedAt: Optional[str] = None
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
from_attributes = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ 后台任务 ============
|
|||
|
|
|
|||
|
|
async def process_ai_rewrite(draft_id: int):
|
|||
|
|
"""后台异步处理AI改写"""
|
|||
|
|
from app.database import async_session_factory
|
|||
|
|
from app.services.ai import ai_service
|
|||
|
|
|
|||
|
|
async with async_session_factory() as db:
|
|||
|
|
try:
|
|||
|
|
# 获取草稿
|
|||
|
|
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
|
|||
|
|
draft = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not draft:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 更新状态为处理中
|
|||
|
|
draft.status = DraftStatus.processing
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
# 获取故事信息
|
|||
|
|
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
|
|||
|
|
story = story_result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not story:
|
|||
|
|
draft.status = DraftStatus.failed
|
|||
|
|
draft.error_message = "故事不存在"
|
|||
|
|
draft.completed_at = datetime.now()
|
|||
|
|
await db.commit()
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 转换路径历史格式
|
|||
|
|
path_history = draft.path_history or []
|
|||
|
|
|
|||
|
|
# 调用AI服务
|
|||
|
|
ai_result = await ai_service.rewrite_branch(
|
|||
|
|
story_title=story.title,
|
|||
|
|
story_category=story.category or "未知",
|
|||
|
|
path_history=path_history,
|
|||
|
|
current_content=draft.current_content or "",
|
|||
|
|
user_prompt=draft.user_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if ai_result and ai_result.get("nodes"):
|
|||
|
|
# 成功
|
|||
|
|
draft.status = DraftStatus.completed
|
|||
|
|
draft.ai_nodes = ai_result["nodes"]
|
|||
|
|
draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1")
|
|||
|
|
draft.tokens_used = ai_result.get("tokens_used", 0)
|
|||
|
|
draft.title = f"{story.title}-改写"
|
|||
|
|
else:
|
|||
|
|
# 失败
|
|||
|
|
draft.status = DraftStatus.failed
|
|||
|
|
draft.error_message = "AI服务暂时不可用"
|
|||
|
|
|
|||
|
|
draft.completed_at = datetime.now()
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[process_ai_rewrite] 异常: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
# 更新失败状态
|
|||
|
|
try:
|
|||
|
|
draft.status = DraftStatus.failed
|
|||
|
|
draft.error_message = str(e)[:500]
|
|||
|
|
draft.completed_at = datetime.now()
|
|||
|
|
await db.commit()
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ API 路由 ============
|
|||
|
|
|
|||
|
|
@router.post("")
|
|||
|
|
async def create_draft(
|
|||
|
|
request: CreateDraftRequest,
|
|||
|
|
background_tasks: BackgroundTasks,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""提交AI改写任务(异步处理)"""
|
|||
|
|
if not request.prompt:
|
|||
|
|
raise HTTPException(status_code=400, detail="请输入改写指令")
|
|||
|
|
|
|||
|
|
# 获取故事信息
|
|||
|
|
result = await db.execute(select(Story).where(Story.id == request.storyId))
|
|||
|
|
story = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not story:
|
|||
|
|
raise HTTPException(status_code=404, detail="故事不存在")
|
|||
|
|
|
|||
|
|
# 转换路径历史
|
|||
|
|
path_history = [
|
|||
|
|
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
|
|||
|
|
for item in request.pathHistory
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 创建草稿记录
|
|||
|
|
draft = StoryDraft(
|
|||
|
|
user_id=request.userId,
|
|||
|
|
story_id=request.storyId,
|
|||
|
|
title=f"{story.title}-改写",
|
|||
|
|
path_history=path_history,
|
|||
|
|
current_node_key=request.currentNodeKey,
|
|||
|
|
current_content=request.currentContent,
|
|||
|
|
user_prompt=request.prompt,
|
|||
|
|
status=DraftStatus.pending
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
db.add(draft)
|
|||
|
|
await db.commit()
|
|||
|
|
await db.refresh(draft)
|
|||
|
|
|
|||
|
|
# 添加后台任务
|
|||
|
|
background_tasks.add_task(process_ai_rewrite, draft.id)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"code": 0,
|
|||
|
|
"data": {
|
|||
|
|
"draftId": draft.id,
|
|||
|
|
"message": "已提交,AI正在生成中..."
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("")
|
|||
|
|
async def get_drafts(
|
|||
|
|
userId: int,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""获取用户的草稿列表"""
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(StoryDraft, Story.title.label("story_title"))
|
|||
|
|
.join(Story, StoryDraft.story_id == Story.id)
|
|||
|
|
.where(StoryDraft.user_id == userId)
|
|||
|
|
.order_by(StoryDraft.created_at.desc())
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
drafts = []
|
|||
|
|
for row in result:
|
|||
|
|
draft = row[0]
|
|||
|
|
story_title = row[1]
|
|||
|
|
drafts.append({
|
|||
|
|
"id": draft.id,
|
|||
|
|
"storyId": draft.story_id,
|
|||
|
|
"storyTitle": story_title,
|
|||
|
|
"title": draft.title,
|
|||
|
|
"userPrompt": draft.user_prompt,
|
|||
|
|
"status": draft.status.value if draft.status else "pending",
|
|||
|
|
"isRead": draft.is_read,
|
|||
|
|
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
|
|||
|
|
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return {"code": 0, "data": drafts}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/check-new")
|
|||
|
|
async def check_new_drafts(
|
|||
|
|
userId: int,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""检查是否有新完成的草稿(用于弹窗通知)"""
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(StoryDraft)
|
|||
|
|
.where(
|
|||
|
|
StoryDraft.user_id == userId,
|
|||
|
|
StoryDraft.status == DraftStatus.completed,
|
|||
|
|
StoryDraft.is_read == False
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
unread_drafts = result.scalars().all()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"code": 0,
|
|||
|
|
"data": {
|
|||
|
|
"hasNew": len(unread_drafts) > 0,
|
|||
|
|
"count": len(unread_drafts),
|
|||
|
|
"drafts": [
|
|||
|
|
{
|
|||
|
|
"id": d.id,
|
|||
|
|
"title": d.title,
|
|||
|
|
"userPrompt": d.user_prompt
|
|||
|
|
}
|
|||
|
|
for d in unread_drafts[:3] # 最多返回3个
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{draft_id}")
|
|||
|
|
async def get_draft_detail(
|
|||
|
|
draft_id: int,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""获取草稿详情"""
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(StoryDraft, Story)
|
|||
|
|
.join(Story, StoryDraft.story_id == Story.id)
|
|||
|
|
.where(StoryDraft.id == draft_id)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
row = result.first()
|
|||
|
|
if not row:
|
|||
|
|
raise HTTPException(status_code=404, detail="草稿不存在")
|
|||
|
|
|
|||
|
|
draft, story = row
|
|||
|
|
|
|||
|
|
# 标记为已读
|
|||
|
|
if not draft.is_read:
|
|||
|
|
draft.is_read = True
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"code": 0,
|
|||
|
|
"data": {
|
|||
|
|
"id": draft.id,
|
|||
|
|
"storyId": draft.story_id,
|
|||
|
|
"storyTitle": story.title,
|
|||
|
|
"storyCategory": story.category,
|
|||
|
|
"title": draft.title,
|
|||
|
|
"pathHistory": draft.path_history,
|
|||
|
|
"currentNodeKey": draft.current_node_key,
|
|||
|
|
"currentContent": draft.current_content,
|
|||
|
|
"userPrompt": draft.user_prompt,
|
|||
|
|
"aiNodes": draft.ai_nodes,
|
|||
|
|
"entryNodeKey": draft.entry_node_key,
|
|||
|
|
"tokensUsed": draft.tokens_used,
|
|||
|
|
"status": draft.status.value if draft.status else "pending",
|
|||
|
|
"errorMessage": draft.error_message,
|
|||
|
|
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
|
|||
|
|
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/{draft_id}")
|
|||
|
|
async def delete_draft(
|
|||
|
|
draft_id: int,
|
|||
|
|
userId: int,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""删除草稿"""
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(StoryDraft).where(
|
|||
|
|
StoryDraft.id == draft_id,
|
|||
|
|
StoryDraft.user_id == userId
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
draft = result.scalar_one_or_none()
|
|||
|
|
if not draft:
|
|||
|
|
raise HTTPException(status_code=404, detail="草稿不存在")
|
|||
|
|
|
|||
|
|
await db.delete(draft)
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
return {"code": 0, "message": "删除成功"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.put("/{draft_id}/read")
|
|||
|
|
async def mark_draft_read(
|
|||
|
|
draft_id: int,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""标记草稿为已读"""
|
|||
|
|
await db.execute(
|
|||
|
|
update(StoryDraft)
|
|||
|
|
.where(StoryDraft.id == draft_id)
|
|||
|
|
.values(is_read=True)
|
|||
|
|
)
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
return {"code": 0, "message": "已标记为已读"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.put("/batch-read")
|
|||
|
|
async def mark_all_drafts_read(
|
|||
|
|
userId: int,
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""批量标记所有未读草稿为已读"""
|
|||
|
|
await db.execute(
|
|||
|
|
update(StoryDraft)
|
|||
|
|
.where(
|
|||
|
|
StoryDraft.user_id == userId,
|
|||
|
|
StoryDraft.is_read == False
|
|||
|
|
)
|
|||
|
|
.values(is_read=True)
|
|||
|
|
)
|
|||
|
|
await db.commit()
|
|||
|
|
|
|||
|
|
return {"code": 0, "message": "已全部标记为已读"}
|