""" 故事相关API路由 """ import random from fastapi import APIRouter, Depends, Query, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, func, distinct from typing import Optional, List from pydantic import BaseModel from app.database import get_db from app.models.story import Story, StoryNode, StoryChoice, StoryCharacter router = APIRouter() # ========== 请求/响应模型 ========== class LikeRequest(BaseModel): like: bool class RewriteRequest(BaseModel): ending_name: str ending_content: str prompt: str class PathHistoryItem(BaseModel): nodeKey: str content: str = "" choice: str = "" class RewriteBranchRequest(BaseModel): userId: int currentNodeKey: str pathHistory: List[PathHistoryItem] currentContent: str prompt: str class NodeImageUpdate(BaseModel): nodeKey: str backgroundImage: str = "" characterImage: str = "" class CharacterImageUpdate(BaseModel): characterId: int avatarUrl: str = "" class ImageConfigRequest(BaseModel): coverUrl: str = "" nodes: List[NodeImageUpdate] = [] characters: List[CharacterImageUpdate] = [] class GenerateImageRequest(BaseModel): prompt: str style: str = "anime" # anime/realistic/illustration category: str = "character" # character/background/cover storyId: Optional[int] = None targetField: Optional[str] = None # coverUrl/backgroundImage/characterImage/avatarUrl targetKey: Optional[str] = None # nodeKey 或 characterId # ========== API接口 ========== @router.get("") async def get_stories( category: Optional[str] = Query(None), featured: bool = Query(False), limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0), db: AsyncSession = Depends(get_db) ): """获取故事列表""" query = select(Story).where(Story.status == 1) if category: query = query.where(Story.category == category) if featured: query = query.where(Story.is_featured == True) query = query.order_by(Story.is_featured.desc(), Story.play_count.desc()) query = query.limit(limit).offset(offset) result = await db.execute(query) stories = result.scalars().all() data = [{ "id": s.id, "title": s.title, "cover_url": s.cover_url, "description": s.description, "category": s.category, "play_count": s.play_count, "like_count": s.like_count, "is_featured": s.is_featured } for s in stories] return {"code": 0, "data": data} @router.get("/hot") async def get_hot_stories( limit: int = Query(10, ge=1, le=50), db: AsyncSession = Depends(get_db) ): """获取热门故事""" query = select(Story).where(Story.status == 1).order_by(Story.play_count.desc()).limit(limit) result = await db.execute(query) stories = result.scalars().all() data = [{ "id": s.id, "title": s.title, "cover_url": s.cover_url, "description": s.description, "category": s.category, "play_count": s.play_count, "like_count": s.like_count } for s in stories] return {"code": 0, "data": data} @router.get("/categories") async def get_categories(db: AsyncSession = Depends(get_db)): """获取分类列表""" query = select(distinct(Story.category)).where(Story.status == 1) result = await db.execute(query) categories = [row[0] for row in result.all()] return {"code": 0, "data": categories} @router.get("/test-image-gen") async def test_image_generation(): """测试图片生成服务是否正常""" from app.config import get_settings from app.services.image_gen import get_image_gen_service import httpx settings = get_settings() results = { "api_key_configured": bool(settings.gemini_api_key), "api_key_preview": settings.gemini_api_key[:8] + "..." if settings.gemini_api_key else "未配置", "base_url": "https://work.poloapi.com/v1beta", "network_test": None, "generate_test": None } # 测试网络连接 try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get("https://work.poloapi.com") results["network_test"] = { "success": response.status_code < 500, "status_code": response.status_code, "message": "网络连接正常" if response.status_code < 500 else "服务端错误" } except Exception as e: results["network_test"] = { "success": False, "error": str(e), "message": "网络连接失败" } # 测试实际生图(简单测试) if settings.gemini_api_key: try: service = get_image_gen_service() gen_result = await service.generate_image( prompt="a simple red circle on white background", image_type="avatar", style="illustration" ) results["generate_test"] = { "success": gen_result.get("success", False), "error": gen_result.get("error") if not gen_result.get("success") else None, "has_image_data": bool(gen_result.get("image_data")) } except Exception as e: results["generate_test"] = { "success": False, "error": str(e) } else: results["generate_test"] = { "success": False, "error": "API Key 未配置" } return { "code": 0, "message": "测试完成", "data": results } @router.get("/test-deepseek") async def test_deepseek(): """测试 DeepSeek AI 服务是否正常""" from app.config import get_settings import httpx settings = get_settings() results = { "ai_service_enabled": settings.ai_service_enabled, "provider": settings.ai_provider, "api_key_configured": bool(settings.deepseek_api_key), "api_key_preview": settings.deepseek_api_key[:8] + "..." + settings.deepseek_api_key[-4:] if settings.deepseek_api_key and len(settings.deepseek_api_key) > 12 else "未配置或太短", "base_url": settings.deepseek_base_url, "model": settings.deepseek_model, "network_test": None, "api_test": None } # 测试网络连接 try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get("https://api.deepseek.com") results["network_test"] = { "success": response.status_code < 500, "status_code": response.status_code, "message": "网络连接正常" } except Exception as e: results["network_test"] = { "success": False, "error": str(e), "message": "网络连接失败" } # 测试 API 调用 if settings.deepseek_api_key: try: async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( f"{settings.deepseek_base_url}/chat/completions", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {settings.deepseek_api_key}" }, json={ "model": settings.deepseek_model, "messages": [{"role": "user", "content": "说'测试成功'两个字"}], "max_tokens": 10 } ) if response.status_code == 200: data = response.json() content = data.get("choices", [{}])[0].get("message", {}).get("content", "") results["api_test"] = { "success": True, "status_code": 200, "response": content[:50] } else: results["api_test"] = { "success": False, "status_code": response.status_code, "error": response.text[:200] } except Exception as e: results["api_test"] = { "success": False, "error": str(e) } else: results["api_test"] = { "success": False, "error": "API Key 未配置" } return { "code": 0, "message": "测试完成", "data": results } @router.get("/test-cloud-upload") async def test_cloud_upload(): """测试云存储上传是否正常""" import os import httpx tcb_env = os.environ.get("TCB_ENV") or os.environ.get("CBR_ENV_ID") results = { "env_id": tcb_env, "is_cloud": bool(tcb_env) } if not tcb_env: return { "code": 1, "message": "非云托管环境,无法测试云存储上传", "data": results } try: # 测试获取上传链接 async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( "http://api.weixin.qq.com/tcb/uploadfile", json={ "env": tcb_env, "path": "test/cloud_upload_test.txt" }, headers={"Content-Type": "application/json"} ) results["status_code"] = resp.status_code results["response"] = resp.text[:500] if resp.text else "" if resp.status_code == 200: data = resp.json() if data.get("errcode", 0) == 0: results["upload_url"] = data.get("url", "")[:100] results["file_id"] = data.get("file_id", "") return { "code": 0, "message": "云存储上传链接获取成功", "data": results } else: results["errcode"] = data.get("errcode") results["errmsg"] = data.get("errmsg") return { "code": 1, "message": f"获取上传链接失败: {data.get('errmsg')}", "data": results } else: return { "code": 1, "message": f"请求失败: HTTP {resp.status_code}", "data": results } except Exception as e: results["error"] = str(e) return { "code": 1, "message": f"测试异常: {str(e)}", "data": results } @router.get("/{story_id}") async def get_story_detail(story_id: int, db: AsyncSession = Depends(get_db)): """获取故事详情(含节点和选项)""" # 获取故事 result = await db.execute(select(Story).where(Story.id == story_id, Story.status == 1)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 获取节点 nodes_result = await db.execute( select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order) ) nodes = nodes_result.scalars().all() # 获取选项 choices_result = await db.execute( select(StoryChoice).where(StoryChoice.story_id == story_id).order_by(StoryChoice.sort_order) ) choices = choices_result.scalars().all() # 组装节点和选项 nodes_map = {} for node in nodes: nodes_map[node.node_key] = { "id": node.id, "node_key": node.node_key, "content": node.content, "speaker": node.speaker, "background_image": node.background_image, "character_image": node.character_image, "bgm": node.bgm, "is_ending": node.is_ending, "ending_name": node.ending_name, "ending_score": node.ending_score, "ending_type": node.ending_type, "choices": [] } for choice in choices: # 找到对应的节点 for node in nodes: if node.id == choice.node_id and node.node_key in nodes_map: nodes_map[node.node_key]["choices"].append({ "text": choice.text, "nextNodeKey": choice.next_node_key, "isLocked": choice.is_locked }) break data = { "id": story.id, "title": story.title, "cover_url": story.cover_url, "description": story.description, "category": story.category, "author_id": story.author_id, "play_count": story.play_count, "like_count": story.like_count, "is_featured": story.is_featured, "nodes": nodes_map } return {"code": 0, "data": data} @router.post("/{story_id}/play") async def record_play(story_id: int, db: AsyncSession = Depends(get_db)): """记录游玩""" await db.execute( update(Story).where(Story.id == story_id).values(play_count=Story.play_count + 1) ) await db.commit() return {"code": 0, "message": "记录成功"} @router.post("/{story_id}/like") async def toggle_like(story_id: int, request: LikeRequest, db: AsyncSession = Depends(get_db)): """点赞/取消点赞""" delta = 1 if request.like else -1 await db.execute( update(Story).where(Story.id == story_id).values(like_count=Story.like_count + delta) ) await db.commit() return {"code": 0, "message": "点赞成功" if request.like else "取消点赞成功"} @router.post("/{story_id}/rewrite") async def ai_rewrite_ending(story_id: int, request: RewriteRequest, db: AsyncSession = Depends(get_db)): """AI改写结局""" import json import re if not request.prompt: raise HTTPException(status_code=400, detail="请输入改写指令") # 获取故事信息 result = await db.execute(select(Story).where(Story.id == story_id)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 获取故事角色 char_result = await db.execute( select(StoryCharacter).where(StoryCharacter.story_id == story_id) ) characters = [ { "name": c.name, "role_type": c.role_type, "gender": c.gender, "age_range": c.age_range, "appearance": c.appearance, "personality": c.personality } for c in char_result.scalars().all() ] # 调用 AI 服务 from app.services.ai import ai_service ai_result = await ai_service.rewrite_ending( story_title=story.title, story_category=story.category or "未知", ending_name=request.ending_name or "未知结局", ending_content=request.ending_content or "", user_prompt=request.prompt, characters=characters ) if ai_result and ai_result.get("content"): content = ai_result["content"] ending_name = f"{request.ending_name}(AI改写)" # 尝试解析 JSON 格式的返回 try: # 提取 JSON 部分 json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL) if json_match: parsed = json.loads(json_match.group()) ending_name = parsed.get("ending_name", ending_name) content = parsed.get("content", content) else: # 尝试直接解析整个内容 parsed = json.loads(content) ending_name = parsed.get("ending_name", ending_name) content = parsed.get("content", content) except (json.JSONDecodeError, AttributeError): # 解析失败,使用原始内容 pass return { "code": 0, "data": { "content": content, "speaker": "旁白", "is_ending": True, "ending_name": ending_name, "ending_type": "rewrite", "tokens_used": ai_result.get("tokens_used", 0) } } # AI 服务不可用时的降级处理 templates = [ f"根据你的愿望「{request.prompt}」,故事有了新的发展...\n\n", f"命运的齿轮开始转动,{request.prompt}...\n\n", f"在另一个平行世界里,{request.prompt}成为了现实...\n\n" ] template = random.choice(templates) new_content = ( template + "原本的结局被改写,新的故事在这里展开。\n\n" + f"【提示】AI服务暂时不可用,这是模板内容。" ) return { "code": 0, "data": { "content": new_content, "speaker": "旁白", "is_ending": True, "ending_name": f"{request.ending_name}(改写版)", "ending_type": "rewrite" } } @router.post("/{story_id}/rewrite-branch") async def ai_rewrite_branch( story_id: int, request: RewriteBranchRequest, db: AsyncSession = Depends(get_db) ): """AI改写中间章节,生成新的剧情分支""" if not request.prompt: raise HTTPException(status_code=400, detail="请输入改写指令") # 获取故事信息 result = await db.execute(select(Story).where(Story.id == story_id)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 获取故事角色 char_result = await db.execute( select(StoryCharacter).where(StoryCharacter.story_id == story_id) ) characters = [ { "name": c.name, "role_type": c.role_type, "gender": c.gender, "age_range": c.age_range, "appearance": c.appearance, "personality": c.personality } for c in char_result.scalars().all() ] # 将 Pydantic 模型转换为字典列表 path_history = [ {"nodeKey": item.nodeKey, "content": item.content, "choice": item.choice} for item in request.pathHistory ] # 调用 AI 服务 from app.services.ai import ai_service ai_result = await ai_service.rewrite_branch( story_title=story.title, story_category=story.category or "未知", path_history=path_history, current_content=request.currentContent, user_prompt=request.prompt, characters=characters ) if ai_result and ai_result.get("nodes"): return { "code": 0, "data": { "nodes": ai_result["nodes"], "entryNodeKey": ai_result.get("entryNodeKey", "branch_1"), "tokensUsed": ai_result.get("tokens_used", 0) } } # AI 服务不可用时,返回空结果(不使用兜底模板) return { "code": 0, "data": { "nodes": None, "entryNodeKey": None, "tokensUsed": 0, "error": "AI服务暂时不可用" } } @router.get("/{story_id}/images") async def get_story_images(story_id: int, db: AsyncSession = Depends(get_db)): """获取故事的所有图片配置""" # 获取故事封面 result = await db.execute(select(Story).where(Story.id == story_id)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") # 获取所有节点的图片 nodes_result = await db.execute( select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order) ) nodes = nodes_result.scalars().all() # 获取所有角色的头像 chars_result = await db.execute( select(StoryCharacter).where(StoryCharacter.story_id == story_id) ) characters = chars_result.scalars().all() return { "code": 0, "data": { "storyId": story_id, "title": story.title, "coverUrl": story.cover_url or "", "nodes": [ { "nodeKey": n.node_key, "content": n.content[:50] + "..." if len(n.content) > 50 else n.content, "backgroundImage": n.background_image or "", "characterImage": n.character_image or "", "isEnding": n.is_ending, "endingName": n.ending_name or "" } for n in nodes ], "characters": [ { "characterId": c.id, "name": c.name, "roleType": c.role_type, "avatarUrl": c.avatar_url or "", "avatarPrompt": c.avatar_prompt or "" } for c in characters ] } } @router.put("/{story_id}/images") async def update_story_images( story_id: int, request: ImageConfigRequest, db: AsyncSession = Depends(get_db) ): """批量更新故事的图片配置""" # 验证故事存在 result = await db.execute(select(Story).where(Story.id == story_id)) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="故事不存在") updated = {"cover": False, "nodes": 0, "characters": 0} # 更新封面 if request.coverUrl: await db.execute( update(Story).where(Story.id == story_id).values(cover_url=request.coverUrl) ) updated["cover"] = True # 更新节点图片 for node_img in request.nodes: values = {} if node_img.backgroundImage: values["background_image"] = node_img.backgroundImage if node_img.characterImage: values["character_image"] = node_img.characterImage if values: await db.execute( update(StoryNode) .where(StoryNode.story_id == story_id, StoryNode.node_key == node_img.nodeKey) .values(**values) ) updated["nodes"] += 1 # 更新角色头像 for char_img in request.characters: if char_img.avatarUrl: await db.execute( update(StoryCharacter) .where(StoryCharacter.id == char_img.characterId, StoryCharacter.story_id == story_id) .values(avatar_url=char_img.avatarUrl) ) updated["characters"] += 1 await db.commit() return { "code": 0, "message": "更新成功", "data": updated } @router.post("/generate-image") async def generate_story_image( request: GenerateImageRequest, db: AsyncSession = Depends(get_db) ): """使用AI生成图片并可选保存到故事""" from app.services.image_gen import get_image_gen_service # 生成图片 result = await get_image_gen_service().generate_and_save( prompt=request.prompt, category=request.category, style=request.style ) if not result.get("success"): return { "code": 1, "message": result.get("error", "生成失败"), "data": None } image_url = result["url"] # 如果指定了故事和目标字段,自动更新 if request.storyId and request.targetField: if request.targetField == "coverUrl": await db.execute( update(Story).where(Story.id == request.storyId).values(cover_url=image_url) ) elif request.targetField == "backgroundImage" and request.targetKey: await db.execute( update(StoryNode) .where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey) .values(background_image=image_url) ) elif request.targetField == "characterImage" and request.targetKey: await db.execute( update(StoryNode) .where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey) .values(character_image=image_url) ) elif request.targetField == "avatarUrl" and request.targetKey: await db.execute( update(StoryCharacter) .where(StoryCharacter.story_id == request.storyId, StoryCharacter.id == int(request.targetKey)) .values(avatar_url=image_url) ) await db.commit() return { "code": 0, "message": "生成成功", "data": { "url": image_url, "filename": result.get("filename"), "saved": bool(request.storyId and request.targetField) } } @router.post("/{story_id}/generate-all-images") async def generate_all_story_images( story_id: int, style: str = "anime", db: AsyncSession = Depends(get_db) ): """为故事批量生成所有角色头像""" from app.services.image_gen import get_image_gen_service image_service = get_image_gen_service() # 获取所有角色 result = await db.execute( select(StoryCharacter).where(StoryCharacter.story_id == story_id) ) characters = result.scalars().all() if not characters: return {"code": 1, "message": "故事没有角色数据", "data": None} generated = [] failed = [] for char in characters: # 使用avatar_prompt或自动构建 prompt = char.avatar_prompt or f"{char.name}, {char.gender}, {char.appearance or ''}" gen_result = await image_service.generate_and_save( prompt=prompt, category="character", style=style ) if gen_result.get("success"): # 更新数据库 await db.execute( update(StoryCharacter) .where(StoryCharacter.id == char.id) .values(avatar_url=gen_result["url"]) ) generated.append({"id": char.id, "name": char.name, "url": gen_result["url"]}) else: failed.append({"id": char.id, "name": char.name, "error": gen_result.get("error")}) await db.commit() return { "code": 0, "message": f"生成完成: {len(generated)}成功, {len(failed)}失败", "data": { "generated": generated, "failed": failed } }