feat: 完善AI改写草稿箱功能 - 修复重头游玩、评分、数据刷新等问题

This commit is contained in:
wangwuww111
2026-03-09 14:15:00 +08:00
parent bbdccfa843
commit 18db6a8cc6
17 changed files with 1385 additions and 99 deletions

View File

@@ -23,6 +23,9 @@ AsyncSessionLocal = sessionmaker(
expire_on_commit=False
)
# 后台任务使用的会话工厂
async_session_factory = AsyncSessionLocal
# 基类
Base = declarative_base()

View File

@@ -6,7 +6,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import get_settings
from app.routers import story, user
from app.routers import story, user, drafts
settings = get_settings()
@@ -29,6 +29,7 @@ app.add_middleware(
# 注册路由
app.include_router(story.router, prefix="/api/stories", tags=["故事"])
app.include_router(user.router, prefix="/api/user", tags=["用户"])
app.include_router(drafts.router, prefix="/api", tags=["草稿箱"])
@app.get("/")

View File

@@ -1,10 +1,11 @@
"""
故事相关ORM模型
"""
from sqlalchemy import Column, Integer, String, Text, Boolean, TIMESTAMP, ForeignKey
from sqlalchemy import Column, Integer, String, Text, Boolean, TIMESTAMP, ForeignKey, Enum, JSON
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.database import Base
import enum
class Story(Base):
@@ -64,3 +65,43 @@ class StoryChoice(Base):
created_at = Column(TIMESTAMP, server_default=func.now())
node = relationship("StoryNode", back_populates="choices")
class DraftStatus(enum.Enum):
"""草稿状态枚举"""
pending = "pending"
processing = "processing"
completed = "completed"
failed = "failed"
class StoryDraft(Base):
"""AI改写草稿表"""
__tablename__ = "story_drafts"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
story_id = Column(Integer, ForeignKey("stories.id", ondelete="CASCADE"), nullable=False)
title = Column(String(100), default="")
# 用户输入
path_history = Column(JSON, default=None) # 用户之前的选择路径
current_node_key = Column(String(50), default="")
current_content = Column(Text, default="")
user_prompt = Column(String(500), nullable=False)
# AI生成结果
ai_nodes = Column(JSON, default=None) # AI生成的新节点
entry_node_key = Column(String(50), default="")
tokens_used = Column(Integer, default=0)
# 状态
status = Column(Enum(DraftStatus), default=DraftStatus.pending)
error_message = Column(String(500), default="")
is_read = Column(Boolean, default=False) # 用户是否已查看
created_at = Column(TIMESTAMP, server_default=func.now())
completed_at = Column(TIMESTAMP, default=None)
# 关联
story = relationship("Story")

Binary file not shown.

View File

@@ -0,0 +1,344 @@
"""
草稿箱路由 - AI异步改写功能
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.sql import func
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime
from app.database import get_db
from app.models.story import Story, StoryDraft, DraftStatus
router = APIRouter(prefix="/drafts", tags=["草稿箱"])
# ============ 请求/响应模型 ============
class PathHistoryItem(BaseModel):
nodeKey: str
content: str
choice: str
class CreateDraftRequest(BaseModel):
userId: int
storyId: int
currentNodeKey: str
pathHistory: List[PathHistoryItem]
currentContent: str
prompt: str
class DraftResponse(BaseModel):
id: int
storyId: int
storyTitle: str
title: str
userPrompt: str
status: str
isRead: bool
createdAt: str
completedAt: Optional[str] = None
class Config:
from_attributes = True
# ============ 后台任务 ============
async def process_ai_rewrite(draft_id: int):
"""后台异步处理AI改写"""
from app.database import async_session_factory
from app.services.ai import ai_service
async with async_session_factory() as db:
try:
# 获取草稿
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
draft = result.scalar_one_or_none()
if not draft:
return
# 更新状态为处理中
draft.status = DraftStatus.processing
await db.commit()
# 获取故事信息
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
story = story_result.scalar_one_or_none()
if not story:
draft.status = DraftStatus.failed
draft.error_message = "故事不存在"
draft.completed_at = datetime.now()
await db.commit()
return
# 转换路径历史格式
path_history = draft.path_history or []
# 调用AI服务
ai_result = await ai_service.rewrite_branch(
story_title=story.title,
story_category=story.category or "未知",
path_history=path_history,
current_content=draft.current_content or "",
user_prompt=draft.user_prompt
)
if ai_result and ai_result.get("nodes"):
# 成功
draft.status = DraftStatus.completed
draft.ai_nodes = ai_result["nodes"]
draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1")
draft.tokens_used = ai_result.get("tokens_used", 0)
draft.title = f"{story.title}-改写"
else:
# 失败
draft.status = DraftStatus.failed
draft.error_message = "AI服务暂时不可用"
draft.completed_at = datetime.now()
await db.commit()
except Exception as e:
print(f"[process_ai_rewrite] 异常: {e}")
import traceback
traceback.print_exc()
# 更新失败状态
try:
draft.status = DraftStatus.failed
draft.error_message = str(e)[:500]
draft.completed_at = datetime.now()
await db.commit()
except:
pass
# ============ API 路由 ============
@router.post("")
async def create_draft(
request: CreateDraftRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""提交AI改写任务异步处理"""
if not request.prompt:
raise HTTPException(status_code=400, detail="请输入改写指令")
# 获取故事信息
result = await db.execute(select(Story).where(Story.id == request.storyId))
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="故事不存在")
# 转换路径历史
path_history = [
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
for item in request.pathHistory
]
# 创建草稿记录
draft = StoryDraft(
user_id=request.userId,
story_id=request.storyId,
title=f"{story.title}-改写",
path_history=path_history,
current_node_key=request.currentNodeKey,
current_content=request.currentContent,
user_prompt=request.prompt,
status=DraftStatus.pending
)
db.add(draft)
await db.commit()
await db.refresh(draft)
# 添加后台任务
background_tasks.add_task(process_ai_rewrite, draft.id)
return {
"code": 0,
"data": {
"draftId": draft.id,
"message": "已提交AI正在生成中..."
}
}
@router.get("")
async def get_drafts(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""获取用户的草稿列表"""
result = await db.execute(
select(StoryDraft, Story.title.label("story_title"))
.join(Story, StoryDraft.story_id == Story.id)
.where(StoryDraft.user_id == userId)
.order_by(StoryDraft.created_at.desc())
)
drafts = []
for row in result:
draft = row[0]
story_title = row[1]
drafts.append({
"id": draft.id,
"storyId": draft.story_id,
"storyTitle": story_title,
"title": draft.title,
"userPrompt": draft.user_prompt,
"status": draft.status.value if draft.status else "pending",
"isRead": draft.is_read,
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
})
return {"code": 0, "data": drafts}
@router.get("/check-new")
async def check_new_drafts(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""检查是否有新完成的草稿(用于弹窗通知)"""
result = await db.execute(
select(StoryDraft)
.where(
StoryDraft.user_id == userId,
StoryDraft.status == DraftStatus.completed,
StoryDraft.is_read == False
)
)
unread_drafts = result.scalars().all()
return {
"code": 0,
"data": {
"hasNew": len(unread_drafts) > 0,
"count": len(unread_drafts),
"drafts": [
{
"id": d.id,
"title": d.title,
"userPrompt": d.user_prompt
}
for d in unread_drafts[:3] # 最多返回3个
]
}
}
@router.get("/{draft_id}")
async def get_draft_detail(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取草稿详情"""
result = await db.execute(
select(StoryDraft, Story)
.join(Story, StoryDraft.story_id == Story.id)
.where(StoryDraft.id == draft_id)
)
row = result.first()
if not row:
raise HTTPException(status_code=404, detail="草稿不存在")
draft, story = row
# 标记为已读
if not draft.is_read:
draft.is_read = True
await db.commit()
return {
"code": 0,
"data": {
"id": draft.id,
"storyId": draft.story_id,
"storyTitle": story.title,
"storyCategory": story.category,
"title": draft.title,
"pathHistory": draft.path_history,
"currentNodeKey": draft.current_node_key,
"currentContent": draft.current_content,
"userPrompt": draft.user_prompt,
"aiNodes": draft.ai_nodes,
"entryNodeKey": draft.entry_node_key,
"tokensUsed": draft.tokens_used,
"status": draft.status.value if draft.status else "pending",
"errorMessage": draft.error_message,
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
}
}
@router.delete("/{draft_id}")
async def delete_draft(
draft_id: int,
userId: int,
db: AsyncSession = Depends(get_db)
):
"""删除草稿"""
result = await db.execute(
select(StoryDraft).where(
StoryDraft.id == draft_id,
StoryDraft.user_id == userId
)
)
draft = result.scalar_one_or_none()
if not draft:
raise HTTPException(status_code=404, detail="草稿不存在")
await db.delete(draft)
await db.commit()
return {"code": 0, "message": "删除成功"}
@router.put("/{draft_id}/read")
async def mark_draft_read(
draft_id: int,
db: AsyncSession = Depends(get_db)
):
"""标记草稿为已读"""
await db.execute(
update(StoryDraft)
.where(StoryDraft.id == draft_id)
.values(is_read=True)
)
await db.commit()
return {"code": 0, "message": "已标记为已读"}
@router.put("/batch-read")
async def mark_all_drafts_read(
userId: int,
db: AsyncSession = Depends(get_db)
):
"""批量标记所有未读草稿为已读"""
await db.execute(
update(StoryDraft)
.where(
StoryDraft.user_id == userId,
StoryDraft.is_read == False
)
.values(is_read=True)
)
await db.commit()
return {"code": 0, "message": "已全部标记为已读"}

View File

@@ -131,12 +131,24 @@ class AIService:
2. 生成 4-6 个新节点,形成有层次的剧情发展(起承转合)
3. 每个节点内容 150-300 字,要分 2-3 个自然段(用\n\n分隔),包含:场景描写、人物对话、心理活动
4. 每个非结局节点有 2 个选项,选项要有明显的剧情差异和后果
5. 必须以结局收尾,结局内容要 200-400 字,分 2-3 段,有情感冲击力
6. 严格符合用户的改写意图,围绕用户指令展开剧情
7. 保持原故事的人物性格、语言风格和世界观
8. 对话要自然生动,描写要有画面感
5. 严格符合用户的改写意图,围绕用户指令展开剧情
6. 保持原故事的人物性格、语言风格和世界观
7. 对话要自然生动,描写要有画面感
重要】内容分段示例:
关于结局 - 极其重要!】
★★★ 每一条分支路径的尽头必须是结局节点 ★★★
- 结局节点必须设置 "is_ending": true
- 结局内容要 200-400 字,分 2-3 段,有情感冲击力
- 结局名称 4-8 字,体现剧情走向
- 如果有2个选项分支最终必须有2个不同的结局
- 不允许出现没有结局的"死胡同"节点
- 每个结局必须有 "ending_score" 评分0-100
- good 好结局80-100分
- bad 坏结局20-50分
- neutral 中立结局50-70分
- special 特殊结局70-90分
【内容分段示例】
"content": "他的声音在耳边响起,像是一阵温柔的风。\n\n\"我喜欢你。\"他说,目光坚定地看着你。\n\n你的心跳漏了一拍,一时间不知该如何回应。"
【输出格式】严格JSON不要有任何额外文字
@@ -153,14 +165,50 @@ class AIService:
"branch_2a": {
"content": "...",
"speaker": "旁白",
"choices": [...]
"choices": [
{"text": "选项C", "nextNodeKey": "branch_ending_good"},
{"text": "选项D", "nextNodeKey": "branch_ending_bad"}
]
},
"branch_2b": {
"content": "...",
"speaker": "旁白",
"choices": [
{"text": "选项E", "nextNodeKey": "branch_ending_neutral"},
{"text": "选项F", "nextNodeKey": "branch_ending_special"}
]
},
"branch_ending_good": {
"content": "好结局内容200-400字...",
"content": "好结局内容200-400字...\n\n【达成结局xxx】",
"speaker": "旁白",
"is_ending": true,
"ending_name": "结局名称4-8字",
"ending_type": "good"
"ending_name": "结局名称",
"ending_type": "good",
"ending_score": 90
},
"branch_ending_bad": {
"content": "坏结局内容...\n\n【达成结局xxx】",
"speaker": "旁白",
"is_ending": true,
"ending_name": "结局名称",
"ending_type": "bad",
"ending_score": 40
},
"branch_ending_neutral": {
"content": "中立结局...\n\n【达成结局xxx】",
"speaker": "旁白",
"is_ending": true,
"ending_name": "结局名称",
"ending_type": "neutral",
"ending_score": 60
},
"branch_ending_special": {
"content": "特殊结局...\n\n【达成结局xxx】",
"speaker": "旁白",
"is_ending": true,
"ending_name": "结局名称",
"ending_type": "special",
"ending_score": 80
}
},
"entryNodeKey": "branch_1"

View File

@@ -564,3 +564,39 @@ CREATE TABLE sensitive_words (
UNIQUE KEY uk_word (word),
INDEX idx_category (category)
) ENGINE=InnoDB COMMENT='敏感词表';
-- ============================================
-- 九、AI改写草稿箱
-- ============================================
-- AI改写草稿表
CREATE TABLE story_drafts (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
user_id BIGINT NOT NULL COMMENT '用户ID',
story_id BIGINT NOT NULL COMMENT '原故事ID',
title VARCHAR(100) DEFAULT '' COMMENT '草稿标题',
-- 用户输入
path_history JSON COMMENT '用户之前的选择路径',
current_node_key VARCHAR(50) DEFAULT '' COMMENT '改写起始节点',
current_content TEXT COMMENT '当前节点内容',
user_prompt VARCHAR(500) NOT NULL COMMENT '用户改写指令',
-- AI生成结果
ai_nodes JSON COMMENT 'AI生成的新节点',
entry_node_key VARCHAR(50) DEFAULT '' COMMENT '入口节点',
tokens_used INT DEFAULT 0 COMMENT '消耗token数',
-- 状态
status ENUM('pending', 'processing', 'completed', 'failed') DEFAULT 'pending' COMMENT '状态',
error_message VARCHAR(500) DEFAULT '' COMMENT '失败原因',
is_read BOOLEAN DEFAULT FALSE COMMENT '用户是否已查看',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP NULL COMMENT '完成时间',
INDEX idx_user (user_id),
INDEX idx_story (story_id),
INDEX idx_status (status),
INDEX idx_user_unread (user_id, is_read)
) ENGINE=InnoDB COMMENT='AI改写草稿表';