340 lines
10 KiB
Python
340 lines
10 KiB
Python
"""
|
||
故事相关API路由
|
||
"""
|
||
import random
|
||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, update, func, distinct
|
||
from typing import Optional, List
|
||
from pydantic import BaseModel
|
||
|
||
from app.database import get_db
|
||
from app.models.story import Story, StoryNode, StoryChoice
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# ========== 请求/响应模型 ==========
|
||
class LikeRequest(BaseModel):
|
||
like: bool
|
||
|
||
|
||
class RewriteRequest(BaseModel):
|
||
ending_name: str
|
||
ending_content: str
|
||
prompt: str
|
||
|
||
|
||
class PathHistoryItem(BaseModel):
|
||
nodeKey: str
|
||
content: str = ""
|
||
choice: str = ""
|
||
|
||
|
||
class RewriteBranchRequest(BaseModel):
|
||
userId: int
|
||
currentNodeKey: str
|
||
pathHistory: List[PathHistoryItem]
|
||
currentContent: str
|
||
prompt: str
|
||
|
||
|
||
# ========== API接口 ==========
|
||
|
||
@router.get("")
|
||
async def get_stories(
|
||
category: Optional[str] = Query(None),
|
||
featured: bool = Query(False),
|
||
limit: int = Query(20, ge=1, le=100),
|
||
offset: int = Query(0, ge=0),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""获取故事列表"""
|
||
query = select(Story).where(Story.status == 1)
|
||
|
||
if category:
|
||
query = query.where(Story.category == category)
|
||
if featured:
|
||
query = query.where(Story.is_featured == True)
|
||
|
||
query = query.order_by(Story.is_featured.desc(), Story.play_count.desc())
|
||
query = query.limit(limit).offset(offset)
|
||
|
||
result = await db.execute(query)
|
||
stories = result.scalars().all()
|
||
|
||
data = [{
|
||
"id": s.id,
|
||
"title": s.title,
|
||
"cover_url": s.cover_url,
|
||
"description": s.description,
|
||
"category": s.category,
|
||
"play_count": s.play_count,
|
||
"like_count": s.like_count,
|
||
"is_featured": s.is_featured
|
||
} for s in stories]
|
||
|
||
return {"code": 0, "data": data}
|
||
|
||
|
||
@router.get("/hot")
|
||
async def get_hot_stories(
|
||
limit: int = Query(10, ge=1, le=50),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""获取热门故事"""
|
||
query = select(Story).where(Story.status == 1).order_by(Story.play_count.desc()).limit(limit)
|
||
result = await db.execute(query)
|
||
stories = result.scalars().all()
|
||
|
||
data = [{
|
||
"id": s.id,
|
||
"title": s.title,
|
||
"cover_url": s.cover_url,
|
||
"description": s.description,
|
||
"category": s.category,
|
||
"play_count": s.play_count,
|
||
"like_count": s.like_count
|
||
} for s in stories]
|
||
|
||
return {"code": 0, "data": data}
|
||
|
||
|
||
@router.get("/categories")
|
||
async def get_categories(db: AsyncSession = Depends(get_db)):
|
||
"""获取分类列表"""
|
||
query = select(distinct(Story.category)).where(Story.status == 1)
|
||
result = await db.execute(query)
|
||
categories = [row[0] for row in result.all()]
|
||
|
||
return {"code": 0, "data": categories}
|
||
|
||
|
||
@router.get("/{story_id}")
|
||
async def get_story_detail(story_id: int, db: AsyncSession = Depends(get_db)):
|
||
"""获取故事详情(含节点和选项)"""
|
||
# 获取故事
|
||
result = await db.execute(select(Story).where(Story.id == story_id, Story.status == 1))
|
||
story = result.scalar_one_or_none()
|
||
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
# 获取节点
|
||
nodes_result = await db.execute(
|
||
select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order)
|
||
)
|
||
nodes = nodes_result.scalars().all()
|
||
|
||
# 获取选项
|
||
choices_result = await db.execute(
|
||
select(StoryChoice).where(StoryChoice.story_id == story_id).order_by(StoryChoice.sort_order)
|
||
)
|
||
choices = choices_result.scalars().all()
|
||
|
||
# 组装节点和选项
|
||
nodes_map = {}
|
||
for node in nodes:
|
||
nodes_map[node.node_key] = {
|
||
"id": node.id,
|
||
"node_key": node.node_key,
|
||
"content": node.content,
|
||
"speaker": node.speaker,
|
||
"background_image": node.background_image,
|
||
"character_image": node.character_image,
|
||
"bgm": node.bgm,
|
||
"is_ending": node.is_ending,
|
||
"ending_name": node.ending_name,
|
||
"ending_score": node.ending_score,
|
||
"ending_type": node.ending_type,
|
||
"choices": []
|
||
}
|
||
|
||
for choice in choices:
|
||
# 找到对应的节点
|
||
for node in nodes:
|
||
if node.id == choice.node_id and node.node_key in nodes_map:
|
||
nodes_map[node.node_key]["choices"].append({
|
||
"text": choice.text,
|
||
"nextNodeKey": choice.next_node_key,
|
||
"isLocked": choice.is_locked
|
||
})
|
||
break
|
||
|
||
data = {
|
||
"id": story.id,
|
||
"title": story.title,
|
||
"cover_url": story.cover_url,
|
||
"description": story.description,
|
||
"category": story.category,
|
||
"author_id": story.author_id,
|
||
"play_count": story.play_count,
|
||
"like_count": story.like_count,
|
||
"is_featured": story.is_featured,
|
||
"nodes": nodes_map
|
||
}
|
||
|
||
return {"code": 0, "data": data}
|
||
|
||
|
||
@router.post("/{story_id}/play")
|
||
async def record_play(story_id: int, db: AsyncSession = Depends(get_db)):
|
||
"""记录游玩"""
|
||
await db.execute(
|
||
update(Story).where(Story.id == story_id).values(play_count=Story.play_count + 1)
|
||
)
|
||
await db.commit()
|
||
return {"code": 0, "message": "记录成功"}
|
||
|
||
|
||
@router.post("/{story_id}/like")
|
||
async def toggle_like(story_id: int, request: LikeRequest, db: AsyncSession = Depends(get_db)):
|
||
"""点赞/取消点赞"""
|
||
delta = 1 if request.like else -1
|
||
await db.execute(
|
||
update(Story).where(Story.id == story_id).values(like_count=Story.like_count + delta)
|
||
)
|
||
await db.commit()
|
||
return {"code": 0, "message": "点赞成功" if request.like else "取消点赞成功"}
|
||
|
||
|
||
@router.post("/{story_id}/rewrite")
|
||
async def ai_rewrite_ending(story_id: int, request: RewriteRequest, db: AsyncSession = Depends(get_db)):
|
||
"""AI改写结局"""
|
||
import json
|
||
import re
|
||
|
||
if not request.prompt:
|
||
raise HTTPException(status_code=400, detail="请输入改写指令")
|
||
|
||
# 获取故事信息
|
||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||
story = result.scalar_one_or_none()
|
||
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
# 调用 AI 服务
|
||
from app.services.ai import ai_service
|
||
|
||
ai_result = await ai_service.rewrite_ending(
|
||
story_title=story.title,
|
||
story_category=story.category or "未知",
|
||
ending_name=request.ending_name or "未知结局",
|
||
ending_content=request.ending_content or "",
|
||
user_prompt=request.prompt
|
||
)
|
||
|
||
if ai_result and ai_result.get("content"):
|
||
content = ai_result["content"]
|
||
ending_name = f"{request.ending_name}(AI改写)"
|
||
|
||
# 尝试解析 JSON 格式的返回
|
||
try:
|
||
# 提取 JSON 部分
|
||
json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL)
|
||
if json_match:
|
||
parsed = json.loads(json_match.group())
|
||
ending_name = parsed.get("ending_name", ending_name)
|
||
content = parsed.get("content", content)
|
||
else:
|
||
# 尝试直接解析整个内容
|
||
parsed = json.loads(content)
|
||
ending_name = parsed.get("ending_name", ending_name)
|
||
content = parsed.get("content", content)
|
||
except (json.JSONDecodeError, AttributeError):
|
||
# 解析失败,使用原始内容
|
||
pass
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"content": content,
|
||
"speaker": "旁白",
|
||
"is_ending": True,
|
||
"ending_name": ending_name,
|
||
"ending_type": "rewrite",
|
||
"tokens_used": ai_result.get("tokens_used", 0)
|
||
}
|
||
}
|
||
|
||
# AI 服务不可用时的降级处理
|
||
templates = [
|
||
f"根据你的愿望「{request.prompt}」,故事有了新的发展...\n\n",
|
||
f"命运的齿轮开始转动,{request.prompt}...\n\n",
|
||
f"在另一个平行世界里,{request.prompt}成为了现实...\n\n"
|
||
]
|
||
|
||
template = random.choice(templates)
|
||
new_content = (
|
||
template +
|
||
"原本的结局被改写,新的故事在这里展开。\n\n" +
|
||
f"【提示】AI服务暂时不可用,这是模板内容。"
|
||
)
|
||
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"content": new_content,
|
||
"speaker": "旁白",
|
||
"is_ending": True,
|
||
"ending_name": f"{request.ending_name}(改写版)",
|
||
"ending_type": "rewrite"
|
||
}
|
||
}
|
||
|
||
|
||
@router.post("/{story_id}/rewrite-branch")
|
||
async def ai_rewrite_branch(
|
||
story_id: int,
|
||
request: RewriteBranchRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""AI改写中间章节,生成新的剧情分支"""
|
||
if not request.prompt:
|
||
raise HTTPException(status_code=400, detail="请输入改写指令")
|
||
|
||
# 获取故事信息
|
||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||
story = result.scalar_one_or_none()
|
||
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="故事不存在")
|
||
|
||
# 将 Pydantic 模型转换为字典列表
|
||
path_history = [
|
||
{"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice}
|
||
for item in request.pathHistory
|
||
]
|
||
|
||
# 调用 AI 服务
|
||
from app.services.ai import ai_service
|
||
|
||
ai_result = await ai_service.rewrite_branch(
|
||
story_title=story.title,
|
||
story_category=story.category or "未知",
|
||
path_history=path_history,
|
||
current_content=request.currentContent,
|
||
user_prompt=request.prompt
|
||
)
|
||
|
||
if ai_result and ai_result.get("nodes"):
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"nodes": ai_result["nodes"],
|
||
"entryNodeKey": ai_result.get("entryNodeKey", "branch_1"),
|
||
"tokensUsed": ai_result.get("tokens_used", 0)
|
||
}
|
||
}
|
||
|
||
# AI 服务不可用时,返回空结果(不使用兜底模板)
|
||
return {
|
||
"code": 0,
|
||
"data": {
|
||
"nodes": None,
|
||
"entryNodeKey": None,
|
||
"tokensUsed": 0,
|
||
"error": "AI服务暂时不可用"
|
||
}
|
||
} |