Files
ai_game/server/app/routers/user.py

574 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用户相关API路由
"""
from fastapi import APIRouter, Depends, Query, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, func, text, delete
from typing import Optional
from pydantic import BaseModel
from app.database import get_db
from app.models.user import User, UserProgress, UserEnding, PlayRecord
from app.models.story import Story
router = APIRouter()
# ========== 请求/响应模型 ==========
class LoginRequest(BaseModel):
code: str
userInfo: Optional[dict] = None
class ProfileRequest(BaseModel):
nickname: str
avatarUrl: str
gender: int = 0
class ProgressRequest(BaseModel):
userId: int
storyId: int
currentNodeKey: str
isCompleted: bool = False
endingReached: str = ""
class LikeRequest(BaseModel):
userId: int
storyId: int
isLiked: bool
class CollectRequest(BaseModel):
userId: int
storyId: int
isCollected: bool
class PlayRecordRequest(BaseModel):
userId: int
storyId: int
endingName: str
endingType: str = ""
pathHistory: list
# ========== API接口 ==========
@router.post("/login")
async def login(request: LoginRequest, db: AsyncSession = Depends(get_db)):
"""微信登录"""
# 实际部署时需要调用微信API获取openid
# 这里简化处理用code作为openid
openid = request.code
# 查找或创建用户
result = await db.execute(select(User).where(User.openid == openid))
user = result.scalar_one_or_none()
if not user:
user_info = request.userInfo or {}
user = User(
openid=openid,
nickname=user_info.get("nickname", ""),
avatar_url=user_info.get("avatarUrl", ""),
gender=user_info.get("gender", 0)
)
db.add(user)
await db.commit()
await db.refresh(user)
return {
"code": 0,
"data": {
"userId": user.id,
"openid": user.openid,
"nickname": user.nickname,
"avatarUrl": user.avatar_url,
"gender": user.gender,
"total_play_count": user.total_play_count,
"total_endings": user.total_endings
}
}
@router.post("/profile")
async def update_profile(request: ProfileRequest, user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
"""更新用户信息"""
await db.execute(
update(User).where(User.id == user_id).values(
nickname=request.nickname,
avatar_url=request.avatarUrl,
gender=request.gender
)
)
await db.commit()
return {"code": 0, "message": "更新成功"}
@router.get("/progress")
async def get_progress(
user_id: int = Query(..., alias="userId"),
story_id: Optional[str] = Query(None, alias="storyId"),
db: AsyncSession = Depends(get_db)
):
"""获取用户进度"""
# 处理 storyId 为 "null" 字符串的情况
story_id_int = None
if story_id and story_id != "null":
try:
story_id_int = int(story_id)
except ValueError:
pass
query = (
select(UserProgress, Story.title.label("story_title"), Story.cover_url)
.join(Story, UserProgress.story_id == Story.id)
.where(UserProgress.user_id == user_id)
)
if story_id_int:
query = query.where(UserProgress.story_id == story_id_int)
query = query.order_by(UserProgress.updated_at.desc())
result = await db.execute(query)
rows = result.all()
data = [{
"id": row.UserProgress.id,
"user_id": row.UserProgress.user_id,
"story_id": row.UserProgress.story_id,
"story_title": row.story_title,
"cover_url": row.cover_url,
"current_node_key": row.UserProgress.current_node_key,
"is_completed": row.UserProgress.is_completed,
"ending_reached": row.UserProgress.ending_reached,
"is_liked": row.UserProgress.is_liked,
"is_collected": row.UserProgress.is_collected,
"play_count": row.UserProgress.play_count
} for row in rows]
if story_id_int:
return {"code": 0, "data": data[0] if data else None}
return {"code": 0, "data": data}
@router.post("/progress")
async def save_progress(request: ProgressRequest, db: AsyncSession = Depends(get_db)):
"""保存用户进度"""
user_id = request.userId
story_id = request.storyId
# 查找是否存在
result = await db.execute(
select(UserProgress).where(
UserProgress.user_id == user_id,
UserProgress.story_id == story_id
)
)
progress = result.scalar_one_or_none()
if progress:
# 更新
await db.execute(
update(UserProgress).where(UserProgress.id == progress.id).values(
current_node_key=request.currentNodeKey,
is_completed=request.isCompleted,
ending_reached=request.endingReached,
play_count=UserProgress.play_count + 1
)
)
else:
# 新建
progress = UserProgress(
user_id=user_id,
story_id=story_id,
current_node_key=request.currentNodeKey,
is_completed=request.isCompleted,
ending_reached=request.endingReached
)
db.add(progress)
# 如果完成,记录结局
if request.isCompleted and request.endingReached:
# 检查是否已存在
ending_result = await db.execute(
select(UserEnding).where(
UserEnding.user_id == user_id,
UserEnding.story_id == story_id,
UserEnding.ending_name == request.endingReached
)
)
if not ending_result.scalar_one_or_none():
ending = UserEnding(
user_id=user_id,
story_id=story_id,
ending_name=request.endingReached
)
db.add(ending)
# 更新用户统计
await db.execute(
update(User).where(User.id == user_id).values(
total_play_count=User.total_play_count + 1
)
)
# 更新结局总数
count_result = await db.execute(
select(func.count()).select_from(UserEnding).where(UserEnding.user_id == user_id)
)
count = count_result.scalar()
await db.execute(
update(User).where(User.id == user_id).values(total_endings=count)
)
await db.commit()
return {"code": 0, "message": "保存成功"}
@router.post("/like")
async def toggle_like(request: LikeRequest, db: AsyncSession = Depends(get_db)):
"""点赞/取消点赞"""
user_id = request.userId
story_id = request.storyId
result = await db.execute(
select(UserProgress).where(
UserProgress.user_id == user_id,
UserProgress.story_id == story_id
)
)
progress = result.scalar_one_or_none()
if progress:
await db.execute(
update(UserProgress).where(UserProgress.id == progress.id).values(is_liked=request.isLiked)
)
else:
progress = UserProgress(
user_id=user_id,
story_id=story_id,
is_liked=request.isLiked
)
db.add(progress)
await db.commit()
return {"code": 0, "message": "点赞成功" if request.isLiked else "取消点赞成功"}
@router.post("/collect")
async def toggle_collect(request: CollectRequest, db: AsyncSession = Depends(get_db)):
"""收藏/取消收藏"""
user_id = request.userId
story_id = request.storyId
result = await db.execute(
select(UserProgress).where(
UserProgress.user_id == user_id,
UserProgress.story_id == story_id
)
)
progress = result.scalar_one_or_none()
if progress:
await db.execute(
update(UserProgress).where(UserProgress.id == progress.id).values(is_collected=request.isCollected)
)
else:
progress = UserProgress(
user_id=user_id,
story_id=story_id,
is_collected=request.isCollected
)
db.add(progress)
await db.commit()
return {"code": 0, "message": "收藏成功" if request.isCollected else "取消收藏成功"}
@router.get("/collections")
async def get_collections(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
"""获取收藏列表"""
result = await db.execute(
select(Story)
.join(UserProgress, Story.id == UserProgress.story_id)
.where(UserProgress.user_id == user_id, UserProgress.is_collected == True)
.order_by(UserProgress.updated_at.desc())
)
stories = result.scalars().all()
data = [{
"id": s.id,
"title": s.title,
"cover_url": s.cover_url,
"description": s.description,
"category": s.category,
"play_count": s.play_count,
"like_count": s.like_count
} for s in stories]
return {"code": 0, "data": data}
@router.get("/endings")
async def get_unlocked_endings(
user_id: int = Query(..., alias="userId"),
story_id: Optional[str] = Query(None, alias="storyId"),
db: AsyncSession = Depends(get_db)
):
"""获取已解锁结局"""
# 处理 storyId 为 "null" 字符串的情况
story_id_int = None
if story_id and story_id != "null":
try:
story_id_int = int(story_id)
except ValueError:
pass
query = select(UserEnding).where(UserEnding.user_id == user_id)
if story_id_int:
query = query.where(UserEnding.story_id == story_id_int)
result = await db.execute(query)
endings = result.scalars().all()
data = [{
"id": e.id,
"story_id": e.story_id,
"ending_name": e.ending_name,
"ending_score": e.ending_score,
"unlocked_at": str(e.unlocked_at) if e.unlocked_at else None
} for e in endings]
return {"code": 0, "data": data}
@router.get("/my-works")
async def get_my_works(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
"""获取我的作品"""
try:
result = await db.execute(
select(Story).where(Story.author_id == user_id).order_by(Story.created_at.desc())
)
works = result.scalars().all()
data = [{
"id": w.id,
"title": w.title,
"description": w.description,
"category": w.category,
"cover_url": w.cover_url,
"play_count": w.play_count,
"like_count": w.like_count,
"status": w.status,
"created_at": str(w.created_at) if w.created_at else None,
"updated_at": str(w.updated_at) if w.updated_at else None
} for w in works]
return {"code": 0, "data": data}
except Exception:
return {"code": 0, "data": []}
@router.get("/drafts")
async def get_drafts(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
"""获取草稿箱(预留)"""
# story_drafts表可能不存在返回空
return {"code": 0, "data": []}
@router.get("/recent-played")
async def get_recent_played(
user_id: int = Query(..., alias="userId"),
limit: int = Query(10, ge=1, le=50),
db: AsyncSession = Depends(get_db)
):
"""获取最近游玩"""
result = await db.execute(
select(Story, UserProgress)
.join(UserProgress, Story.id == UserProgress.story_id)
.where(UserProgress.user_id == user_id)
.order_by(UserProgress.updated_at.desc())
.limit(limit)
)
rows = result.all()
data = [{
"id": row.Story.id,
"title": row.Story.title,
"category": row.Story.category,
"description": row.Story.description,
"cover_url": row.Story.cover_url,
"current_node_key": row.UserProgress.current_node_key,
"is_completed": row.UserProgress.is_completed,
"progress": "已完成" if row.UserProgress.is_completed else "进行中"
} for row in rows]
return {"code": 0, "data": data}
@router.get("/ai-history")
async def get_ai_history(user_id: int = Query(..., alias="userId"), limit: int = Query(20), db: AsyncSession = Depends(get_db)):
"""获取AI创作历史预留"""
return {"code": 0, "data": []}
@router.get("/ai-quota")
async def get_ai_quota(user_id: int = Query(..., alias="userId"), db: AsyncSession = Depends(get_db)):
"""获取AI配额"""
# 返回默认值
return {
"code": 0,
"data": {
"daily": 3,
"used": 0,
"purchased": 0,
"gift": 0
}
}
# ========== 游玩记录 API ==========
@router.post("/play-record")
async def save_play_record(request: PlayRecordRequest, db: AsyncSession = Depends(get_db)):
"""保存游玩记录(相同路径只保留最新)"""
import json
# 查找该用户该故事的所有记录
result = await db.execute(
select(PlayRecord)
.where(PlayRecord.user_id == request.userId, PlayRecord.story_id == request.storyId)
)
existing_records = result.scalars().all()
# 检查是否有相同路径的记录
new_path_str = json.dumps(request.pathHistory, sort_keys=True, ensure_ascii=False)
for old_record in existing_records:
old_path_str = json.dumps(old_record.path_history, sort_keys=True, ensure_ascii=False)
if old_path_str == new_path_str:
# 相同路径,删除旧记录
await db.delete(old_record)
# 创建新记录
record = PlayRecord(
user_id=request.userId,
story_id=request.storyId,
ending_name=request.endingName,
ending_type=request.endingType,
path_history=request.pathHistory
)
db.add(record)
await db.commit()
await db.refresh(record)
return {
"code": 0,
"data": {
"recordId": record.id,
"message": "记录保存成功"
}
}
@router.get("/play-records")
async def get_play_records(
user_id: int = Query(..., alias="userId"),
story_id: Optional[int] = Query(None, alias="storyId"),
db: AsyncSession = Depends(get_db)
):
"""获取游玩记录列表"""
if story_id:
# 获取指定故事的记录
result = await db.execute(
select(PlayRecord)
.where(PlayRecord.user_id == user_id, PlayRecord.story_id == story_id)
.order_by(PlayRecord.created_at.desc())
)
records = result.scalars().all()
data = [{
"id": r.id,
"endingName": r.ending_name,
"endingType": r.ending_type,
"createdAt": r.created_at.strftime("%Y-%m-%d %H:%M") if r.created_at else ""
} for r in records]
else:
# 获取所有玩过的故事(按故事分组,取最新一条)
result = await db.execute(
select(PlayRecord, Story.title, Story.cover_url)
.join(Story, PlayRecord.story_id == Story.id)
.where(PlayRecord.user_id == user_id)
.order_by(PlayRecord.created_at.desc())
)
rows = result.all()
# 按 story_id 分组,取每个故事的最新记录和记录数
story_map = {}
for row in rows:
sid = row.PlayRecord.story_id
if sid not in story_map:
story_map[sid] = {
"storyId": sid,
"storyTitle": row.title,
"coverUrl": row.cover_url,
"latestEnding": row.PlayRecord.ending_name,
"latestTime": row.PlayRecord.created_at.strftime("%Y-%m-%d %H:%M") if row.PlayRecord.created_at else "",
"recordCount": 0
}
story_map[sid]["recordCount"] += 1
data = list(story_map.values())
return {"code": 0, "data": data}
@router.get("/play-records/{record_id}")
async def get_play_record_detail(
record_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取单条记录详情"""
result = await db.execute(
select(PlayRecord, Story.title)
.join(Story, PlayRecord.story_id == Story.id)
.where(PlayRecord.id == record_id)
)
row = result.first()
if not row:
return {"code": 404, "message": "记录不存在"}
record = row.PlayRecord
return {
"code": 0,
"data": {
"id": record.id,
"storyId": record.story_id,
"storyTitle": row.title,
"endingName": record.ending_name,
"endingType": record.ending_type,
"pathHistory": record.path_history,
"createdAt": record.created_at.strftime("%Y-%m-%d %H:%M") if record.created_at else ""
}
}
@router.delete("/play-records/{record_id}")
async def delete_play_record(
record_id: int,
db: AsyncSession = Depends(get_db)
):
"""删除游玩记录"""
result = await db.execute(select(PlayRecord).where(PlayRecord.id == record_id))
record = result.scalar_one_or_none()
if not record:
return {"code": 404, "message": "记录不存在"}
await db.delete(record)
await db.commit()
return {"code": 0, "message": "删除成功"}