422 lines
12 KiB
Python
422 lines
12 KiB
Python
|
|
"""
|
|||
|
|
用户相关API路由
|
|||
|
|
"""
|
|||
|
|
from fastapi import APIRouter, Depends, Query, HTTPException
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from sqlalchemy import select, update, func, text
|
|||
|
|
from typing import Optional
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
from app.database import get_db
|
|||
|
|
from app.models.user import User, UserProgress, UserEnding
|
|||
|
|
from app.models.story import Story
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 请求/响应模型 ==========
|
|||
|
|
class LoginRequest(BaseModel):
|
|||
|
|
code: str
|
|||
|
|
userInfo: Optional[dict] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProfileRequest(BaseModel):
|
|||
|
|
nickname: str
|
|||
|
|
avatarUrl: str
|
|||
|
|
gender: int = 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProgressRequest(BaseModel):
|
|||
|
|
userId: int
|
|||
|
|
storyId: int
|
|||
|
|
currentNodeKey: str
|
|||
|
|
isCompleted: bool = False
|
|||
|
|
endingReached: str = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LikeRequest(BaseModel):
|
|||
|
|
userId: int
|
|||
|
|
storyId: int
|
|||
|
|
isLiked: bool
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CollectRequest(BaseModel):
|
|||
|
|
userId: int
|
|||
|
|
storyId: int
|
|||
|
|
isCollected: bool
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== API接口 ==========
|
|||
|
|
|
|||
|
|
@router.post("/login")
|
|||
|
|
async def login(request: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""微信登录"""
|
|||
|
|
# 实际部署时需要调用微信API获取openid
|
|||
|
|
# 这里简化处理:用code作为openid
|
|||
|
|
openid = request.code
|
|||
|
|
|
|||
|
|
# 查找或创建用户
|
|||
|
|
result = await db.execute(select(User).where(User.openid == openid))
|
|||
|
|
user = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not user:
|
|||
|
|
user_info = request.userInfo or {}
|
|||
|
|
user = User(
|
|||
|
|
openid=openid,
|
|||
|
|
nickname=user_info.get("nickname", ""),
|
|||
|
|
avatar_url=user_info.get("avatarUrl", ""),
|
|||
|
|
gender=user_info.get("gender", 0)
|
|||
|
|
)
|
|||
|
|
db.add(user)
|
|||
|
|
await db.commit()
|
|||
|
|
await db.refresh(user)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"code": 0,
|
|||
|
|
"data": {
|
|||
|
|
"userId": user.id,
|
|||
|
|
"openid": user.openid,
|
|||
|
|
"nickname": user.nickname,
|
|||
|
|
"avatarUrl": user.avatar_url,
|
|||
|
|
"gender": user.gender,
|
|||
|
|
"total_play_count": user.total_play_count,
|
|||
|
|
"total_endings": user.total_endings
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/profile")
|
|||
|
|
async def update_profile(request: ProfileRequest, user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""更新用户信息"""
|
|||
|
|
await db.execute(
|
|||
|
|
update(User).where(User.id == user_id).values(
|
|||
|
|
nickname=request.nickname,
|
|||
|
|
avatar_url=request.avatarUrl,
|
|||
|
|
gender=request.gender
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
await db.commit()
|
|||
|
|
return {"code": 0, "message": "更新成功"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/progress")
|
|||
|
|
async def get_progress(
|
|||
|
|
user_id: int = Query(..., alias="userId"),
|
|||
|
|
story_id: Optional[str] = Query(None, alias="storyId"),
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""获取用户进度"""
|
|||
|
|
# 处理 storyId 为 "null" 字符串的情况
|
|||
|
|
story_id_int = None
|
|||
|
|
if story_id and story_id != "null":
|
|||
|
|
try:
|
|||
|
|
story_id_int = int(story_id)
|
|||
|
|
except ValueError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
query = (
|
|||
|
|
select(UserProgress, Story.title.label("story_title"), Story.cover_url)
|
|||
|
|
.join(Story, UserProgress.story_id == Story.id)
|
|||
|
|
.where(UserProgress.user_id == user_id)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if story_id_int:
|
|||
|
|
query = query.where(UserProgress.story_id == story_id_int)
|
|||
|
|
|
|||
|
|
query = query.order_by(UserProgress.updated_at.desc())
|
|||
|
|
result = await db.execute(query)
|
|||
|
|
rows = result.all()
|
|||
|
|
|
|||
|
|
data = [{
|
|||
|
|
"id": row.UserProgress.id,
|
|||
|
|
"user_id": row.UserProgress.user_id,
|
|||
|
|
"story_id": row.UserProgress.story_id,
|
|||
|
|
"story_title": row.story_title,
|
|||
|
|
"cover_url": row.cover_url,
|
|||
|
|
"current_node_key": row.UserProgress.current_node_key,
|
|||
|
|
"is_completed": row.UserProgress.is_completed,
|
|||
|
|
"ending_reached": row.UserProgress.ending_reached,
|
|||
|
|
"is_liked": row.UserProgress.is_liked,
|
|||
|
|
"is_collected": row.UserProgress.is_collected,
|
|||
|
|
"play_count": row.UserProgress.play_count
|
|||
|
|
} for row in rows]
|
|||
|
|
|
|||
|
|
if story_id_int:
|
|||
|
|
return {"code": 0, "data": data[0] if data else None}
|
|||
|
|
return {"code": 0, "data": data}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/progress")
|
|||
|
|
async def save_progress(request: ProgressRequest, db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""保存用户进度"""
|
|||
|
|
user_id = request.userId
|
|||
|
|
story_id = request.storyId
|
|||
|
|
|
|||
|
|
# 查找是否存在
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(UserProgress).where(
|
|||
|
|
UserProgress.user_id == user_id,
|
|||
|
|
UserProgress.story_id == story_id
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
progress = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if progress:
|
|||
|
|
# 更新
|
|||
|
|
await db.execute(
|
|||
|
|
update(UserProgress).where(UserProgress.id == progress.id).values(
|
|||
|
|
current_node_key=request.currentNodeKey,
|
|||
|
|
is_completed=request.isCompleted,
|
|||
|
|
ending_reached=request.endingReached,
|
|||
|
|
play_count=UserProgress.play_count + 1
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# 新建
|
|||
|
|
progress = UserProgress(
|
|||
|
|
user_id=user_id,
|
|||
|
|
story_id=story_id,
|
|||
|
|
current_node_key=request.currentNodeKey,
|
|||
|
|
is_completed=request.isCompleted,
|
|||
|
|
ending_reached=request.endingReached
|
|||
|
|
)
|
|||
|
|
db.add(progress)
|
|||
|
|
|
|||
|
|
# 如果完成,记录结局
|
|||
|
|
if request.isCompleted and request.endingReached:
|
|||
|
|
# 检查是否已存在
|
|||
|
|
ending_result = await db.execute(
|
|||
|
|
select(UserEnding).where(
|
|||
|
|
UserEnding.user_id == user_id,
|
|||
|
|
UserEnding.story_id == story_id,
|
|||
|
|
UserEnding.ending_name == request.endingReached
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
if not ending_result.scalar_one_or_none():
|
|||
|
|
ending = UserEnding(
|
|||
|
|
user_id=user_id,
|
|||
|
|
story_id=story_id,
|
|||
|
|
ending_name=request.endingReached
|
|||
|
|
)
|
|||
|
|
db.add(ending)
|
|||
|
|
|
|||
|
|
# 更新用户统计
|
|||
|
|
await db.execute(
|
|||
|
|
update(User).where(User.id == user_id).values(
|
|||
|
|
total_play_count=User.total_play_count + 1
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
# 更新结局总数
|
|||
|
|
count_result = await db.execute(
|
|||
|
|
select(func.count()).select_from(UserEnding).where(UserEnding.user_id == user_id)
|
|||
|
|
)
|
|||
|
|
count = count_result.scalar()
|
|||
|
|
await db.execute(
|
|||
|
|
update(User).where(User.id == user_id).values(total_endings=count)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
await db.commit()
|
|||
|
|
return {"code": 0, "message": "保存成功"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/like")
|
|||
|
|
async def toggle_like(request: LikeRequest, db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""点赞/取消点赞"""
|
|||
|
|
user_id = request.userId
|
|||
|
|
story_id = request.storyId
|
|||
|
|
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(UserProgress).where(
|
|||
|
|
UserProgress.user_id == user_id,
|
|||
|
|
UserProgress.story_id == story_id
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
progress = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if progress:
|
|||
|
|
await db.execute(
|
|||
|
|
update(UserProgress).where(UserProgress.id == progress.id).values(is_liked=request.isLiked)
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
progress = UserProgress(
|
|||
|
|
user_id=user_id,
|
|||
|
|
story_id=story_id,
|
|||
|
|
is_liked=request.isLiked
|
|||
|
|
)
|
|||
|
|
db.add(progress)
|
|||
|
|
|
|||
|
|
await db.commit()
|
|||
|
|
return {"code": 0, "message": "点赞成功" if request.isLiked else "取消点赞成功"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/collect")
|
|||
|
|
async def toggle_collect(request: CollectRequest, db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""收藏/取消收藏"""
|
|||
|
|
user_id = request.userId
|
|||
|
|
story_id = request.storyId
|
|||
|
|
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(UserProgress).where(
|
|||
|
|
UserProgress.user_id == user_id,
|
|||
|
|
UserProgress.story_id == story_id
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
progress = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if progress:
|
|||
|
|
await db.execute(
|
|||
|
|
update(UserProgress).where(UserProgress.id == progress.id).values(is_collected=request.isCollected)
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
progress = UserProgress(
|
|||
|
|
user_id=user_id,
|
|||
|
|
story_id=story_id,
|
|||
|
|
is_collected=request.isCollected
|
|||
|
|
)
|
|||
|
|
db.add(progress)
|
|||
|
|
|
|||
|
|
await db.commit()
|
|||
|
|
return {"code": 0, "message": "收藏成功" if request.isCollected else "取消收藏成功"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/collections")
|
|||
|
|
async def get_collections(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""获取收藏列表"""
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Story)
|
|||
|
|
.join(UserProgress, Story.id == UserProgress.story_id)
|
|||
|
|
.where(UserProgress.user_id == user_id, UserProgress.is_collected == True)
|
|||
|
|
.order_by(UserProgress.updated_at.desc())
|
|||
|
|
)
|
|||
|
|
stories = result.scalars().all()
|
|||
|
|
|
|||
|
|
data = [{
|
|||
|
|
"id": s.id,
|
|||
|
|
"title": s.title,
|
|||
|
|
"cover_url": s.cover_url,
|
|||
|
|
"description": s.description,
|
|||
|
|
"category": s.category,
|
|||
|
|
"play_count": s.play_count,
|
|||
|
|
"like_count": s.like_count
|
|||
|
|
} for s in stories]
|
|||
|
|
|
|||
|
|
return {"code": 0, "data": data}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/endings")
|
|||
|
|
async def get_unlocked_endings(
|
|||
|
|
user_id: int = Query(..., alias="userId"),
|
|||
|
|
story_id: Optional[str] = Query(None, alias="storyId"),
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""获取已解锁结局"""
|
|||
|
|
# 处理 storyId 为 "null" 字符串的情况
|
|||
|
|
story_id_int = None
|
|||
|
|
if story_id and story_id != "null":
|
|||
|
|
try:
|
|||
|
|
story_id_int = int(story_id)
|
|||
|
|
except ValueError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
query = select(UserEnding).where(UserEnding.user_id == user_id)
|
|||
|
|
if story_id_int:
|
|||
|
|
query = query.where(UserEnding.story_id == story_id_int)
|
|||
|
|
|
|||
|
|
result = await db.execute(query)
|
|||
|
|
endings = result.scalars().all()
|
|||
|
|
|
|||
|
|
data = [{
|
|||
|
|
"id": e.id,
|
|||
|
|
"story_id": e.story_id,
|
|||
|
|
"ending_name": e.ending_name,
|
|||
|
|
"ending_score": e.ending_score,
|
|||
|
|
"unlocked_at": str(e.unlocked_at) if e.unlocked_at else None
|
|||
|
|
} for e in endings]
|
|||
|
|
|
|||
|
|
return {"code": 0, "data": data}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/my-works")
|
|||
|
|
async def get_my_works(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""获取我的作品"""
|
|||
|
|
try:
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Story).where(Story.author_id == user_id).order_by(Story.created_at.desc())
|
|||
|
|
)
|
|||
|
|
works = result.scalars().all()
|
|||
|
|
|
|||
|
|
data = [{
|
|||
|
|
"id": w.id,
|
|||
|
|
"title": w.title,
|
|||
|
|
"description": w.description,
|
|||
|
|
"category": w.category,
|
|||
|
|
"cover_url": w.cover_url,
|
|||
|
|
"play_count": w.play_count,
|
|||
|
|
"like_count": w.like_count,
|
|||
|
|
"status": w.status,
|
|||
|
|
"created_at": str(w.created_at) if w.created_at else None,
|
|||
|
|
"updated_at": str(w.updated_at) if w.updated_at else None
|
|||
|
|
} for w in works]
|
|||
|
|
|
|||
|
|
return {"code": 0, "data": data}
|
|||
|
|
except Exception:
|
|||
|
|
return {"code": 0, "data": []}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/drafts")
|
|||
|
|
async def get_drafts(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""获取草稿箱(预留)"""
|
|||
|
|
# story_drafts表可能不存在,返回空
|
|||
|
|
return {"code": 0, "data": []}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/recent-played")
|
|||
|
|
async def get_recent_played(
|
|||
|
|
user_id: int = Query(..., alias="userId"),
|
|||
|
|
limit: int = Query(10, ge=1, le=50),
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
):
|
|||
|
|
"""获取最近游玩"""
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Story, UserProgress)
|
|||
|
|
.join(UserProgress, Story.id == UserProgress.story_id)
|
|||
|
|
.where(UserProgress.user_id == user_id)
|
|||
|
|
.order_by(UserProgress.updated_at.desc())
|
|||
|
|
.limit(limit)
|
|||
|
|
)
|
|||
|
|
rows = result.all()
|
|||
|
|
|
|||
|
|
data = [{
|
|||
|
|
"id": row.Story.id,
|
|||
|
|
"title": row.Story.title,
|
|||
|
|
"category": row.Story.category,
|
|||
|
|
"description": row.Story.description,
|
|||
|
|
"cover_url": row.Story.cover_url,
|
|||
|
|
"current_node_key": row.UserProgress.current_node_key,
|
|||
|
|
"is_completed": row.UserProgress.is_completed,
|
|||
|
|
"progress": "已完成" if row.UserProgress.is_completed else "进行中"
|
|||
|
|
} for row in rows]
|
|||
|
|
|
|||
|
|
return {"code": 0, "data": data}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/ai-history")
|
|||
|
|
async def get_ai_history(user_id: int = Query(..., alias="userId"), limit: int = Query(20), db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""获取AI创作历史(预留)"""
|
|||
|
|
return {"code": 0, "data": []}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/ai-quota")
|
|||
|
|
async def get_ai_quota(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
|
|||
|
|
"""获取AI配额"""
|
|||
|
|
# 返回默认值
|
|||
|
|
return {
|
|||
|
|
"code": 0,
|
|||
|
|
"data": {
|
|||
|
|
"daily": 3,
|
|||
|
|
"used": 0,
|
|||
|
|
"purchased": 0,
|
|||
|
|
"gift": 0
|
|||
|
|
}
|
|||
|
|
}
|