feat: 完善AI改写草稿箱功能 - 修复重头游玩、评分、数据刷新等问题
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -23,6 +23,9 @@ AsyncSessionLocal = sessionmaker(
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
# 后台任务使用的会话工厂
|
||||
async_session_factory = AsyncSessionLocal
|
||||
|
||||
# 基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import get_settings
|
||||
from app.routers import story, user
|
||||
from app.routers import story, user, drafts
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
@@ -29,6 +29,7 @@ app.add_middleware(
|
||||
# 注册路由
|
||||
app.include_router(story.router, prefix="/api/stories", tags=["故事"])
|
||||
app.include_router(user.router, prefix="/api/user", tags=["用户"])
|
||||
app.include_router(drafts.router, prefix="/api", tags=["草稿箱"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
Binary file not shown.
@@ -1,10 +1,11 @@
|
||||
"""
|
||||
故事相关ORM模型
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, TIMESTAMP, ForeignKey
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, TIMESTAMP, ForeignKey, Enum, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import enum
|
||||
|
||||
|
||||
class Story(Base):
|
||||
@@ -64,3 +65,43 @@ class StoryChoice(Base):
|
||||
created_at = Column(TIMESTAMP, server_default=func.now())
|
||||
|
||||
node = relationship("StoryNode", back_populates="choices")
|
||||
|
||||
|
||||
class DraftStatus(enum.Enum):
|
||||
"""草稿状态枚举"""
|
||||
pending = "pending"
|
||||
processing = "processing"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class StoryDraft(Base):
|
||||
"""AI改写草稿表"""
|
||||
__tablename__ = "story_drafts"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
story_id = Column(Integer, ForeignKey("stories.id", ondelete="CASCADE"), nullable=False)
|
||||
title = Column(String(100), default="")
|
||||
|
||||
# 用户输入
|
||||
path_history = Column(JSON, default=None) # 用户之前的选择路径
|
||||
current_node_key = Column(String(50), default="")
|
||||
current_content = Column(Text, default="")
|
||||
user_prompt = Column(String(500), nullable=False)
|
||||
|
||||
# AI生成结果
|
||||
ai_nodes = Column(JSON, default=None) # AI生成的新节点
|
||||
entry_node_key = Column(String(50), default="")
|
||||
tokens_used = Column(Integer, default=0)
|
||||
|
||||
# 状态
|
||||
status = Column(Enum(DraftStatus), default=DraftStatus.pending)
|
||||
error_message = Column(String(500), default="")
|
||||
is_read = Column(Boolean, default=False) # 用户是否已查看
|
||||
|
||||
created_at = Column(TIMESTAMP, server_default=func.now())
|
||||
completed_at = Column(TIMESTAMP, default=None)
|
||||
|
||||
# 关联
|
||||
story = relationship("Story")
|
||||
|
||||
BIN
server/app/routers/__pycache__/drafts.cpython-310.pyc
Normal file
BIN
server/app/routers/__pycache__/drafts.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
344
server/app/routers/drafts.py
Normal file
344
server/app/routers/drafts.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
草稿箱路由 - AI异步改写功能
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.sql import func
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.story import Story, StoryDraft, DraftStatus
|
||||
|
||||
router = APIRouter(prefix="/drafts", tags=["草稿箱"])
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
class PathHistoryItem(BaseModel):
|
||||
nodeKey: str
|
||||
content: str
|
||||
choice: str
|
||||
|
||||
|
||||
class CreateDraftRequest(BaseModel):
|
||||
userId: int
|
||||
storyId: int
|
||||
currentNodeKey: str
|
||||
pathHistory: List[PathHistoryItem]
|
||||
currentContent: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class DraftResponse(BaseModel):
|
||||
id: int
|
||||
storyId: int
|
||||
storyTitle: str
|
||||
title: str
|
||||
userPrompt: str
|
||||
status: str
|
||||
isRead: bool
|
||||
createdAt: str
|
||||
completedAt: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ 后台任务 ============
|
||||
|
||||
async def process_ai_rewrite(draft_id: int):
|
||||
"""后台异步处理AI改写"""
|
||||
from app.database import async_session_factory
|
||||
from app.services.ai import ai_service
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
# 获取草稿
|
||||
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
|
||||
draft = result.scalar_one_or_none()
|
||||
|
||||
if not draft:
|
||||
return
|
||||
|
||||
# 更新状态为处理中
|
||||
draft.status = DraftStatus.processing
|
||||
await db.commit()
|
||||
|
||||
# 获取故事信息
|
||||
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
|
||||
story = story_result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = "故事不存在"
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
return
|
||||
|
||||
# 转换路径历史格式
|
||||
path_history = draft.path_history or []
|
||||
|
||||
# 调用AI服务
|
||||
ai_result = await ai_service.rewrite_branch(
|
||||
story_title=story.title,
|
||||
story_category=story.category or "未知",
|
||||
path_history=path_history,
|
||||
current_content=draft.current_content or "",
|
||||
user_prompt=draft.user_prompt
|
||||
)
|
||||
|
||||
if ai_result and ai_result.get("nodes"):
|
||||
# 成功
|
||||
draft.status = DraftStatus.completed
|
||||
draft.ai_nodes = ai_result["nodes"]
|
||||
draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1")
|
||||
draft.tokens_used = ai_result.get("tokens_used", 0)
|
||||
draft.title = f"{story.title}-改写"
|
||||
else:
|
||||
# 失败
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = "AI服务暂时不可用"
|
||||
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[process_ai_rewrite] 异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# 更新失败状态
|
||||
try:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = str(e)[:500]
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# ============ API 路由 ============
|
||||
|
||||
@router.post("")
|
||||
async def create_draft(
|
||||
request: CreateDraftRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""提交AI改写任务(异步处理)"""
|
||||
if not request.prompt:
|
||||
raise HTTPException(status_code=400, detail="请输入改写指令")
|
||||
|
||||
# 获取故事信息
|
||||
result = await db.execute(select(Story).where(Story.id == request.storyId))
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
# 转换路径历史
|
||||
path_history = [
|
||||
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
|
||||
for item in request.pathHistory
|
||||
]
|
||||
|
||||
# 创建草稿记录
|
||||
draft = StoryDraft(
|
||||
user_id=request.userId,
|
||||
story_id=request.storyId,
|
||||
title=f"{story.title}-改写",
|
||||
path_history=path_history,
|
||||
current_node_key=request.currentNodeKey,
|
||||
current_content=request.currentContent,
|
||||
user_prompt=request.prompt,
|
||||
status=DraftStatus.pending
|
||||
)
|
||||
|
||||
db.add(draft)
|
||||
await db.commit()
|
||||
await db.refresh(draft)
|
||||
|
||||
# 添加后台任务
|
||||
background_tasks.add_task(process_ai_rewrite, draft.id)
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"draftId": draft.id,
|
||||
"message": "已提交,AI正在生成中..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_drafts(
|
||||
userId: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取用户的草稿列表"""
|
||||
result = await db.execute(
|
||||
select(StoryDraft, Story.title.label("story_title"))
|
||||
.join(Story, StoryDraft.story_id == Story.id)
|
||||
.where(StoryDraft.user_id == userId)
|
||||
.order_by(StoryDraft.created_at.desc())
|
||||
)
|
||||
|
||||
drafts = []
|
||||
for row in result:
|
||||
draft = row[0]
|
||||
story_title = row[1]
|
||||
drafts.append({
|
||||
"id": draft.id,
|
||||
"storyId": draft.story_id,
|
||||
"storyTitle": story_title,
|
||||
"title": draft.title,
|
||||
"userPrompt": draft.user_prompt,
|
||||
"status": draft.status.value if draft.status else "pending",
|
||||
"isRead": draft.is_read,
|
||||
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
|
||||
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
|
||||
})
|
||||
|
||||
return {"code": 0, "data": drafts}
|
||||
|
||||
|
||||
@router.get("/check-new")
|
||||
async def check_new_drafts(
|
||||
userId: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""检查是否有新完成的草稿(用于弹窗通知)"""
|
||||
result = await db.execute(
|
||||
select(StoryDraft)
|
||||
.where(
|
||||
StoryDraft.user_id == userId,
|
||||
StoryDraft.status == DraftStatus.completed,
|
||||
StoryDraft.is_read == False
|
||||
)
|
||||
)
|
||||
|
||||
unread_drafts = result.scalars().all()
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"hasNew": len(unread_drafts) > 0,
|
||||
"count": len(unread_drafts),
|
||||
"drafts": [
|
||||
{
|
||||
"id": d.id,
|
||||
"title": d.title,
|
||||
"userPrompt": d.user_prompt
|
||||
}
|
||||
for d in unread_drafts[:3] # 最多返回3个
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{draft_id}")
|
||||
async def get_draft_detail(
|
||||
draft_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取草稿详情"""
|
||||
result = await db.execute(
|
||||
select(StoryDraft, Story)
|
||||
.join(Story, StoryDraft.story_id == Story.id)
|
||||
.where(StoryDraft.id == draft_id)
|
||||
)
|
||||
|
||||
row = result.first()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="草稿不存在")
|
||||
|
||||
draft, story = row
|
||||
|
||||
# 标记为已读
|
||||
if not draft.is_read:
|
||||
draft.is_read = True
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"id": draft.id,
|
||||
"storyId": draft.story_id,
|
||||
"storyTitle": story.title,
|
||||
"storyCategory": story.category,
|
||||
"title": draft.title,
|
||||
"pathHistory": draft.path_history,
|
||||
"currentNodeKey": draft.current_node_key,
|
||||
"currentContent": draft.current_content,
|
||||
"userPrompt": draft.user_prompt,
|
||||
"aiNodes": draft.ai_nodes,
|
||||
"entryNodeKey": draft.entry_node_key,
|
||||
"tokensUsed": draft.tokens_used,
|
||||
"status": draft.status.value if draft.status else "pending",
|
||||
"errorMessage": draft.error_message,
|
||||
"createdAt": draft.created_at.strftime("%Y-%m-%d %H:%M") if draft.created_at else "",
|
||||
"completedAt": draft.completed_at.strftime("%Y-%m-%d %H:%M") if draft.completed_at else None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{draft_id}")
|
||||
async def delete_draft(
|
||||
draft_id: int,
|
||||
userId: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除草稿"""
|
||||
result = await db.execute(
|
||||
select(StoryDraft).where(
|
||||
StoryDraft.id == draft_id,
|
||||
StoryDraft.user_id == userId
|
||||
)
|
||||
)
|
||||
|
||||
draft = result.scalar_one_or_none()
|
||||
if not draft:
|
||||
raise HTTPException(status_code=404, detail="草稿不存在")
|
||||
|
||||
await db.delete(draft)
|
||||
await db.commit()
|
||||
|
||||
return {"code": 0, "message": "删除成功"}
|
||||
|
||||
|
||||
@router.put("/{draft_id}/read")
|
||||
async def mark_draft_read(
|
||||
draft_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""标记草稿为已读"""
|
||||
await db.execute(
|
||||
update(StoryDraft)
|
||||
.where(StoryDraft.id == draft_id)
|
||||
.values(is_read=True)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"code": 0, "message": "已标记为已读"}
|
||||
|
||||
|
||||
@router.put("/batch-read")
|
||||
async def mark_all_drafts_read(
|
||||
userId: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""批量标记所有未读草稿为已读"""
|
||||
await db.execute(
|
||||
update(StoryDraft)
|
||||
.where(
|
||||
StoryDraft.user_id == userId,
|
||||
StoryDraft.is_read == False
|
||||
)
|
||||
.values(is_read=True)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"code": 0, "message": "已全部标记为已读"}
|
||||
Binary file not shown.
@@ -131,12 +131,24 @@ class AIService:
|
||||
2. 生成 4-6 个新节点,形成有层次的剧情发展(起承转合)
|
||||
3. 每个节点内容 150-300 字,要分 2-3 个自然段(用\n\n分隔),包含:场景描写、人物对话、心理活动
|
||||
4. 每个非结局节点有 2 个选项,选项要有明显的剧情差异和后果
|
||||
5. 必须以结局收尾,结局内容要 200-400 字,分 2-3 段,有情感冲击力
|
||||
6. 严格符合用户的改写意图,围绕用户指令展开剧情
|
||||
7. 保持原故事的人物性格、语言风格和世界观
|
||||
8. 对话要自然生动,描写要有画面感
|
||||
5. 严格符合用户的改写意图,围绕用户指令展开剧情
|
||||
6. 保持原故事的人物性格、语言风格和世界观
|
||||
7. 对话要自然生动,描写要有画面感
|
||||
|
||||
【重要】内容分段示例:
|
||||
【关于结局 - 极其重要!】
|
||||
★★★ 每一条分支路径的尽头必须是结局节点 ★★★
|
||||
- 结局节点必须设置 "is_ending": true
|
||||
- 结局内容要 200-400 字,分 2-3 段,有情感冲击力
|
||||
- 结局名称 4-8 字,体现剧情走向
|
||||
- 如果有2个选项分支,最终必须有2个不同的结局
|
||||
- 不允许出现没有结局的"死胡同"节点
|
||||
- 每个结局必须有 "ending_score" 评分(0-100):
|
||||
- good 好结局:80-100分
|
||||
- bad 坏结局:20-50分
|
||||
- neutral 中立结局:50-70分
|
||||
- special 特殊结局:70-90分
|
||||
|
||||
【内容分段示例】
|
||||
"content": "他的声音在耳边响起,像是一阵温柔的风。\n\n\"我喜欢你。\"他说,目光坚定地看着你。\n\n你的心跳漏了一拍,一时间不知该如何回应。"
|
||||
|
||||
【输出格式】(严格JSON,不要有任何额外文字)
|
||||
@@ -153,14 +165,50 @@ class AIService:
|
||||
"branch_2a": {
|
||||
"content": "...",
|
||||
"speaker": "旁白",
|
||||
"choices": [...]
|
||||
"choices": [
|
||||
{"text": "选项C", "nextNodeKey": "branch_ending_good"},
|
||||
{"text": "选项D", "nextNodeKey": "branch_ending_bad"}
|
||||
]
|
||||
},
|
||||
"branch_2b": {
|
||||
"content": "...",
|
||||
"speaker": "旁白",
|
||||
"choices": [
|
||||
{"text": "选项E", "nextNodeKey": "branch_ending_neutral"},
|
||||
{"text": "选项F", "nextNodeKey": "branch_ending_special"}
|
||||
]
|
||||
},
|
||||
"branch_ending_good": {
|
||||
"content": "好结局内容(200-400字)...",
|
||||
"content": "好结局内容(200-400字)...\n\n【达成结局:xxx】",
|
||||
"speaker": "旁白",
|
||||
"is_ending": true,
|
||||
"ending_name": "结局名称(4-8字)",
|
||||
"ending_type": "good"
|
||||
"ending_name": "结局名称",
|
||||
"ending_type": "good",
|
||||
"ending_score": 90
|
||||
},
|
||||
"branch_ending_bad": {
|
||||
"content": "坏结局内容...\n\n【达成结局:xxx】",
|
||||
"speaker": "旁白",
|
||||
"is_ending": true,
|
||||
"ending_name": "结局名称",
|
||||
"ending_type": "bad",
|
||||
"ending_score": 40
|
||||
},
|
||||
"branch_ending_neutral": {
|
||||
"content": "中立结局...\n\n【达成结局:xxx】",
|
||||
"speaker": "旁白",
|
||||
"is_ending": true,
|
||||
"ending_name": "结局名称",
|
||||
"ending_type": "neutral",
|
||||
"ending_score": 60
|
||||
},
|
||||
"branch_ending_special": {
|
||||
"content": "特殊结局...\n\n【达成结局:xxx】",
|
||||
"speaker": "旁白",
|
||||
"is_ending": true,
|
||||
"ending_name": "结局名称",
|
||||
"ending_type": "special",
|
||||
"ending_score": 80
|
||||
}
|
||||
},
|
||||
"entryNodeKey": "branch_1"
|
||||
|
||||
@@ -564,3 +564,39 @@ CREATE TABLE sensitive_words (
|
||||
UNIQUE KEY uk_word (word),
|
||||
INDEX idx_category (category)
|
||||
) ENGINE=InnoDB COMMENT='敏感词表';
|
||||
|
||||
-- ============================================
|
||||
-- 九、AI改写草稿箱
|
||||
-- ============================================
|
||||
|
||||
-- AI改写草稿表
|
||||
CREATE TABLE story_drafts (
|
||||
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||
user_id BIGINT NOT NULL COMMENT '用户ID',
|
||||
story_id BIGINT NOT NULL COMMENT '原故事ID',
|
||||
title VARCHAR(100) DEFAULT '' COMMENT '草稿标题',
|
||||
|
||||
-- 用户输入
|
||||
path_history JSON COMMENT '用户之前的选择路径',
|
||||
current_node_key VARCHAR(50) DEFAULT '' COMMENT '改写起始节点',
|
||||
current_content TEXT COMMENT '当前节点内容',
|
||||
user_prompt VARCHAR(500) NOT NULL COMMENT '用户改写指令',
|
||||
|
||||
-- AI生成结果
|
||||
ai_nodes JSON COMMENT 'AI生成的新节点',
|
||||
entry_node_key VARCHAR(50) DEFAULT '' COMMENT '入口节点',
|
||||
tokens_used INT DEFAULT 0 COMMENT '消耗token数',
|
||||
|
||||
-- 状态
|
||||
status ENUM('pending', 'processing', 'completed', 'failed') DEFAULT 'pending' COMMENT '状态',
|
||||
error_message VARCHAR(500) DEFAULT '' COMMENT '失败原因',
|
||||
is_read BOOLEAN DEFAULT FALSE COMMENT '用户是否已查看',
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
completed_at TIMESTAMP NULL COMMENT '完成时间',
|
||||
|
||||
INDEX idx_user (user_id),
|
||||
INDEX idx_story (story_id),
|
||||
INDEX idx_status (status),
|
||||
INDEX idx_user_unread (user_id, is_read)
|
||||
) ENGINE=InnoDB COMMENT='AI改写草稿表';
|
||||
|
||||
Reference in New Issue
Block a user