fix: 修复云存储上传API路径及添加测试接口
This commit is contained in:
@@ -8,9 +8,12 @@ from sqlalchemy.sql import func
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
import base64
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.story import Story, StoryDraft, DraftStatus, StoryCharacter
|
||||
from app.config import get_settings
|
||||
|
||||
router = APIRouter(prefix="/drafts", tags=["草稿箱"])
|
||||
|
||||
@@ -36,6 +39,148 @@ async def get_story_characters(db: AsyncSession, story_id: int) -> List[dict]:
|
||||
]
|
||||
|
||||
|
||||
async def upload_to_cloud_storage(image_bytes: bytes, cloud_path: str) -> str:
|
||||
"""
|
||||
上传图片到微信云存储(云托管容器内调用)
|
||||
cloud_path: 云存储路径,如 stories/1/drafts/10/branch_1/background.jpg
|
||||
返回: 文件访问路径
|
||||
"""
|
||||
import httpx
|
||||
|
||||
env_id = os.environ.get('TCB_ENV') or os.environ.get('CBR_ENV_ID')
|
||||
if not env_id:
|
||||
# 尝试从配置获取
|
||||
settings = get_settings()
|
||||
env_id = getattr(settings, 'wx_cloud_env', None)
|
||||
|
||||
if not env_id:
|
||||
raise Exception("未检测到云环境ID")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# 云托管内网调用云开发 API(不需要 access_token)
|
||||
# 参考: https://developers.weixin.qq.com/miniprogram/dev/wxcloudrun/src/development/storage/service/upload.html
|
||||
|
||||
# 1. 获取上传链接
|
||||
resp = await client.post(
|
||||
"http://api.weixin.qq.com/tcb/uploadfile",
|
||||
json={
|
||||
"env": env_id,
|
||||
"path": cloud_path
|
||||
},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"获取上传链接失败: {resp.status_code} - {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
if data.get("errcode", 0) != 0:
|
||||
raise Exception(f"获取上传链接失败: {data.get('errmsg')}")
|
||||
|
||||
upload_url = data.get("url")
|
||||
authorization = data.get("authorization")
|
||||
token = data.get("token")
|
||||
cos_file_id = data.get("cos_file_id")
|
||||
file_id = data.get("file_id")
|
||||
|
||||
# 2. 上传文件到 COS
|
||||
form_data = {
|
||||
"key": cloud_path,
|
||||
"Signature": authorization,
|
||||
"x-cos-security-token": token,
|
||||
"x-cos-meta-fileid": cos_file_id,
|
||||
}
|
||||
|
||||
files = {"file": ("background.jpg", image_bytes, "image/jpeg")}
|
||||
|
||||
upload_resp = await client.post(upload_url, data=form_data, files=files)
|
||||
|
||||
if upload_resp.status_code not in [200, 204]:
|
||||
raise Exception(f"上传文件失败: {upload_resp.status_code} - {upload_resp.text[:200]}")
|
||||
|
||||
print(f" [CloudStorage] 文件上传成功: {file_id}")
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"[upload_to_cloud_storage] 上传失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def generate_draft_images(story_id: int, draft_id: int, ai_nodes: dict, story_category: str):
|
||||
"""
|
||||
为草稿的 AI 生成节点生成背景图
|
||||
本地环境:保存到文件系统 /uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg
|
||||
云端环境:上传到云存储
|
||||
"""
|
||||
from app.services.image_gen import ImageGenService
|
||||
|
||||
if not ai_nodes:
|
||||
return
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# 检测是否是云端环境(TCB_ENV 或 CBR_ENV_ID 是云托管容器自动注入的)
|
||||
is_cloud = os.environ.get('TCB_ENV') or os.environ.get('CBR_ENV_ID')
|
||||
|
||||
# 本地环境使用文件系统
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
|
||||
draft_dir = os.path.join(base_dir, "stories", str(story_id), "drafts", str(draft_id))
|
||||
|
||||
service = ImageGenService()
|
||||
|
||||
for node_key, node_data in ai_nodes.items():
|
||||
if not isinstance(node_data, dict):
|
||||
continue
|
||||
|
||||
content = node_data.get('content', '')[:150]
|
||||
if not content:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 生成背景图 - 强调情绪表达
|
||||
bg_prompt = f"Background scene for {story_category} story. Scene: {content}. Wide shot, atmospheric, no characters, anime style. Strong emotional expression, dramatic mood, vivid colors reflecting the scene's emotion."
|
||||
result = await service.generate_image(bg_prompt, "background", "anime")
|
||||
|
||||
if result and result.get("success"):
|
||||
image_bytes = base64.b64decode(result["image_data"])
|
||||
# 路径格式和本地一致:uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg
|
||||
cloud_path = f"uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg"
|
||||
|
||||
# 云端环境:上传到云存储
|
||||
if is_cloud:
|
||||
try:
|
||||
file_id = await upload_to_cloud_storage(image_bytes, cloud_path)
|
||||
# 云存储返回的 file_id 格式: cloud://env-id.xxx/path
|
||||
# 前端通过 CDN 地址访问: https://7072-prod-xxx.tcb.qcloud.la/uploads/...
|
||||
node_data['background_url'] = f"/{cloud_path}"
|
||||
print(f" ✓ 云端草稿节点 {node_key} 背景图上传成功")
|
||||
except Exception as cloud_e:
|
||||
print(f" ✗ 云端上传失败: {cloud_e}")
|
||||
continue
|
||||
|
||||
# 本地环境:保存到文件系统
|
||||
node_dir = os.path.join(draft_dir, node_key)
|
||||
os.makedirs(node_dir, exist_ok=True)
|
||||
bg_path = os.path.join(node_dir, "background.jpg")
|
||||
|
||||
with open(bg_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 更新节点数据,添加图片路径
|
||||
node_data['background_url'] = f"/uploads/stories/{story_id}/drafts/{draft_id}/{node_key}/background.jpg"
|
||||
print(f" ✓ 草稿节点 {node_key} 背景图生成成功")
|
||||
else:
|
||||
print(f" ✗ 草稿节点 {node_key} 背景图生成失败: {result.get('error') if result else 'Unknown'}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ 草稿节点 {node_key} 图片生成异常: {e}")
|
||||
|
||||
# 避免请求过快
|
||||
import asyncio
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
class PathHistoryItem(BaseModel):
|
||||
@@ -121,11 +266,14 @@ async def process_ai_rewrite(draft_id: int):
|
||||
|
||||
# 获取故事角色
|
||||
characters = await get_story_characters(db, story.id)
|
||||
print(f"[process_ai_rewrite] 获取到角色数: {len(characters)}")
|
||||
|
||||
# 转换路径历史格式
|
||||
path_history = draft.path_history or []
|
||||
print(f"[process_ai_rewrite] 路径历史长度: {len(path_history)}")
|
||||
|
||||
# 调用AI服务
|
||||
print(f"[process_ai_rewrite] 开始调用 AI 服务...")
|
||||
ai_result = await ai_service.rewrite_branch(
|
||||
story_title=story.title,
|
||||
story_category=story.category or "未知",
|
||||
@@ -134,9 +282,21 @@ async def process_ai_rewrite(draft_id: int):
|
||||
user_prompt=draft.user_prompt,
|
||||
characters=characters
|
||||
)
|
||||
print(f"[process_ai_rewrite] AI 服务返回: {bool(ai_result)}")
|
||||
|
||||
if ai_result and ai_result.get("nodes"):
|
||||
# 成功
|
||||
# 成功 - 尝试生成配图(失败不影响改写结果)
|
||||
try:
|
||||
print(f"[process_ai_rewrite] AI生成成功,开始生成配图...")
|
||||
await generate_draft_images(
|
||||
story_id=draft.story_id,
|
||||
draft_id=draft.id,
|
||||
ai_nodes=ai_result["nodes"],
|
||||
story_category=story.category or "都市言情"
|
||||
)
|
||||
except Exception as img_e:
|
||||
print(f"[process_ai_rewrite] 配图生成失败(不影响改写结果): {img_e}")
|
||||
|
||||
draft.status = DraftStatus.completed
|
||||
draft.ai_nodes = ai_result["nodes"]
|
||||
draft.entry_node_key = ai_result.get("entryNodeKey", "branch_1")
|
||||
@@ -232,8 +392,7 @@ async def process_ai_rewrite_ending(draft_id: int):
|
||||
pass
|
||||
|
||||
# 成功 - 存储为对象格式(与故事节点格式一致)
|
||||
draft.status = DraftStatus.completed
|
||||
draft.ai_nodes = {
|
||||
ai_nodes = {
|
||||
"ending_rewrite": {
|
||||
"content": content,
|
||||
"speaker": "旁白",
|
||||
@@ -242,6 +401,21 @@ async def process_ai_rewrite_ending(draft_id: int):
|
||||
"ending_type": "rewrite"
|
||||
}
|
||||
}
|
||||
|
||||
# 生成配图(失败不影响改写结果)
|
||||
try:
|
||||
print(f"[process_ai_rewrite_ending] AI生成成功,开始生成配图...")
|
||||
await generate_draft_images(
|
||||
story_id=draft.story_id,
|
||||
draft_id=draft.id,
|
||||
ai_nodes=ai_nodes,
|
||||
story_category=story.category or "都市言情"
|
||||
)
|
||||
except Exception as img_e:
|
||||
print(f"[process_ai_rewrite_ending] 配图生成失败(不影响改写结果): {img_e}")
|
||||
|
||||
draft.status = DraftStatus.completed
|
||||
draft.ai_nodes = ai_nodes
|
||||
draft.entry_node_key = "ending_rewrite"
|
||||
draft.tokens_used = ai_result.get("tokens_used", 0)
|
||||
draft.title = f"{story.title}-{new_ending_name}"
|
||||
@@ -313,7 +487,18 @@ async def process_ai_continue_ending(draft_id: int):
|
||||
)
|
||||
|
||||
if ai_result and ai_result.get("nodes"):
|
||||
# 成功 - 存储多节点分支格式
|
||||
# 成功 - 尝试生成配图(失败不影响续写结果)
|
||||
try:
|
||||
print(f"[process_ai_continue_ending] AI生成成功,开始生成配图...")
|
||||
await generate_draft_images(
|
||||
story_id=draft.story_id,
|
||||
draft_id=draft.id,
|
||||
ai_nodes=ai_result["nodes"],
|
||||
story_category=story.category or "都市言情"
|
||||
)
|
||||
except Exception as img_e:
|
||||
print(f"[process_ai_continue_ending] 配图生成失败(不影响续写结果): {img_e}")
|
||||
|
||||
draft.status = DraftStatus.completed
|
||||
draft.ai_nodes = ai_result["nodes"]
|
||||
draft.entry_node_key = ai_result.get("entryNodeKey", "continue_1")
|
||||
@@ -648,7 +833,7 @@ async def delete_draft(
|
||||
userId: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除草稿"""
|
||||
"""删除草稿(同时清理图片文件)"""
|
||||
result = await db.execute(
|
||||
select(StoryDraft).where(
|
||||
StoryDraft.id == draft_id,
|
||||
@@ -660,6 +845,19 @@ async def delete_draft(
|
||||
if not draft:
|
||||
raise HTTPException(status_code=404, detail="草稿不存在")
|
||||
|
||||
# 删除草稿对应的图片文件夹
|
||||
try:
|
||||
import shutil
|
||||
settings = get_settings()
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..', settings.upload_path))
|
||||
draft_dir = os.path.join(base_dir, "stories", str(draft.story_id), "drafts", str(draft_id))
|
||||
if os.path.exists(draft_dir):
|
||||
shutil.rmtree(draft_dir)
|
||||
print(f"[delete_draft] 已清理图片目录: {draft_dir}")
|
||||
except Exception as e:
|
||||
print(f"[delete_draft] 清理图片失败: {e}")
|
||||
# 图片清理失败不影响草稿删除
|
||||
|
||||
await db.delete(draft)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -39,6 +39,32 @@ class RewriteBranchRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
class NodeImageUpdate(BaseModel):
|
||||
nodeKey: str
|
||||
backgroundImage: str = ""
|
||||
characterImage: str = ""
|
||||
|
||||
|
||||
class CharacterImageUpdate(BaseModel):
|
||||
characterId: int
|
||||
avatarUrl: str = ""
|
||||
|
||||
|
||||
class ImageConfigRequest(BaseModel):
|
||||
coverUrl: str = ""
|
||||
nodes: List[NodeImageUpdate] = []
|
||||
characters: List[CharacterImageUpdate] = []
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
prompt: str
|
||||
style: str = "anime" # anime/realistic/illustration
|
||||
category: str = "character" # character/background/cover
|
||||
storyId: Optional[int] = None
|
||||
targetField: Optional[str] = None # coverUrl/backgroundImage/characterImage/avatarUrl
|
||||
targetKey: Optional[str] = None # nodeKey 或 characterId
|
||||
|
||||
|
||||
# ========== API接口 ==========
|
||||
|
||||
@router.get("")
|
||||
@@ -110,6 +136,222 @@ async def get_categories(db: AsyncSession = Depends(get_db)):
|
||||
return {"code": 0, "data": categories}
|
||||
|
||||
|
||||
@router.get("/test-image-gen")
|
||||
async def test_image_generation():
|
||||
"""测试图片生成服务是否正常"""
|
||||
from app.config import get_settings
|
||||
from app.services.image_gen import get_image_gen_service
|
||||
import httpx
|
||||
|
||||
settings = get_settings()
|
||||
results = {
|
||||
"api_key_configured": bool(settings.gemini_api_key),
|
||||
"api_key_preview": settings.gemini_api_key[:8] + "..." if settings.gemini_api_key else "未配置",
|
||||
"base_url": "https://work.poloapi.com/v1beta",
|
||||
"network_test": None,
|
||||
"generate_test": None
|
||||
}
|
||||
|
||||
# 测试网络连接
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get("https://work.poloapi.com")
|
||||
results["network_test"] = {
|
||||
"success": response.status_code < 500,
|
||||
"status_code": response.status_code,
|
||||
"message": "网络连接正常" if response.status_code < 500 else "服务端错误"
|
||||
}
|
||||
except Exception as e:
|
||||
results["network_test"] = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "网络连接失败"
|
||||
}
|
||||
|
||||
# 测试实际生图(简单测试)
|
||||
if settings.gemini_api_key:
|
||||
try:
|
||||
service = get_image_gen_service()
|
||||
gen_result = await service.generate_image(
|
||||
prompt="a simple red circle on white background",
|
||||
image_type="avatar",
|
||||
style="illustration"
|
||||
)
|
||||
results["generate_test"] = {
|
||||
"success": gen_result.get("success", False),
|
||||
"error": gen_result.get("error") if not gen_result.get("success") else None,
|
||||
"has_image_data": bool(gen_result.get("image_data"))
|
||||
}
|
||||
except Exception as e:
|
||||
results["generate_test"] = {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
else:
|
||||
results["generate_test"] = {
|
||||
"success": False,
|
||||
"error": "API Key 未配置"
|
||||
}
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "测试完成",
|
||||
"data": results
|
||||
}
|
||||
|
||||
|
||||
@router.get("/test-deepseek")
|
||||
async def test_deepseek():
|
||||
"""测试 DeepSeek AI 服务是否正常"""
|
||||
from app.config import get_settings
|
||||
import httpx
|
||||
|
||||
settings = get_settings()
|
||||
results = {
|
||||
"ai_service_enabled": settings.ai_service_enabled,
|
||||
"provider": settings.ai_provider,
|
||||
"api_key_configured": bool(settings.deepseek_api_key),
|
||||
"api_key_preview": settings.deepseek_api_key[:8] + "..." + settings.deepseek_api_key[-4:] if settings.deepseek_api_key and len(settings.deepseek_api_key) > 12 else "未配置或太短",
|
||||
"base_url": settings.deepseek_base_url,
|
||||
"model": settings.deepseek_model,
|
||||
"network_test": None,
|
||||
"api_test": None
|
||||
}
|
||||
|
||||
# 测试网络连接
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get("https://api.deepseek.com")
|
||||
results["network_test"] = {
|
||||
"success": response.status_code < 500,
|
||||
"status_code": response.status_code,
|
||||
"message": "网络连接正常"
|
||||
}
|
||||
except Exception as e:
|
||||
results["network_test"] = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "网络连接失败"
|
||||
}
|
||||
|
||||
# 测试 API 调用
|
||||
if settings.deepseek_api_key:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.deepseek_base_url}/chat/completions",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {settings.deepseek_api_key}"
|
||||
},
|
||||
json={
|
||||
"model": settings.deepseek_model,
|
||||
"messages": [{"role": "user", "content": "说'测试成功'两个字"}],
|
||||
"max_tokens": 10
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
results["api_test"] = {
|
||||
"success": True,
|
||||
"status_code": 200,
|
||||
"response": content[:50]
|
||||
}
|
||||
else:
|
||||
results["api_test"] = {
|
||||
"success": False,
|
||||
"status_code": response.status_code,
|
||||
"error": response.text[:200]
|
||||
}
|
||||
except Exception as e:
|
||||
results["api_test"] = {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
else:
|
||||
results["api_test"] = {
|
||||
"success": False,
|
||||
"error": "API Key 未配置"
|
||||
}
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "测试完成",
|
||||
"data": results
|
||||
}
|
||||
|
||||
|
||||
@router.get("/test-cloud-upload")
|
||||
async def test_cloud_upload():
|
||||
"""测试云存储上传是否正常"""
|
||||
import os
|
||||
import httpx
|
||||
|
||||
tcb_env = os.environ.get("TCB_ENV") or os.environ.get("CBR_ENV_ID")
|
||||
|
||||
results = {
|
||||
"env_id": tcb_env,
|
||||
"is_cloud": bool(tcb_env)
|
||||
}
|
||||
|
||||
if not tcb_env:
|
||||
return {
|
||||
"code": 1,
|
||||
"message": "非云托管环境,无法测试云存储上传",
|
||||
"data": results
|
||||
}
|
||||
|
||||
try:
|
||||
# 测试获取上传链接
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
"http://api.weixin.qq.com/tcb/uploadfile",
|
||||
json={
|
||||
"env": tcb_env,
|
||||
"path": "test/cloud_upload_test.txt"
|
||||
},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
results["status_code"] = resp.status_code
|
||||
results["response"] = resp.text[:500] if resp.text else ""
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if data.get("errcode", 0) == 0:
|
||||
results["upload_url"] = data.get("url", "")[:100]
|
||||
results["file_id"] = data.get("file_id", "")
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "云存储上传链接获取成功",
|
||||
"data": results
|
||||
}
|
||||
else:
|
||||
results["errcode"] = data.get("errcode")
|
||||
results["errmsg"] = data.get("errmsg")
|
||||
return {
|
||||
"code": 1,
|
||||
"message": f"获取上传链接失败: {data.get('errmsg')}",
|
||||
"data": results
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"code": 1,
|
||||
"message": f"请求失败: HTTP {resp.status_code}",
|
||||
"data": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
results["error"] = str(e)
|
||||
return {
|
||||
"code": 1,
|
||||
"message": f"测试异常: {str(e)}",
|
||||
"data": results
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{story_id}")
|
||||
async def get_story_detail(story_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""获取故事详情(含节点和选项)"""
|
||||
@@ -372,3 +614,227 @@ async def ai_rewrite_branch(
|
||||
"error": "AI服务暂时不可用"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{story_id}/images")
|
||||
async def get_story_images(story_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""获取故事的所有图片配置"""
|
||||
# 获取故事封面
|
||||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||||
story = result.scalar_one_or_none()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
# 获取所有节点的图片
|
||||
nodes_result = await db.execute(
|
||||
select(StoryNode).where(StoryNode.story_id == story_id).order_by(StoryNode.sort_order)
|
||||
)
|
||||
nodes = nodes_result.scalars().all()
|
||||
|
||||
# 获取所有角色的头像
|
||||
chars_result = await db.execute(
|
||||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||||
)
|
||||
characters = chars_result.scalars().all()
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"storyId": story_id,
|
||||
"title": story.title,
|
||||
"coverUrl": story.cover_url or "",
|
||||
"nodes": [
|
||||
{
|
||||
"nodeKey": n.node_key,
|
||||
"content": n.content[:50] + "..." if len(n.content) > 50 else n.content,
|
||||
"backgroundImage": n.background_image or "",
|
||||
"characterImage": n.character_image or "",
|
||||
"isEnding": n.is_ending,
|
||||
"endingName": n.ending_name or ""
|
||||
}
|
||||
for n in nodes
|
||||
],
|
||||
"characters": [
|
||||
{
|
||||
"characterId": c.id,
|
||||
"name": c.name,
|
||||
"roleType": c.role_type,
|
||||
"avatarUrl": c.avatar_url or "",
|
||||
"avatarPrompt": c.avatar_prompt or ""
|
||||
}
|
||||
for c in characters
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{story_id}/images")
|
||||
async def update_story_images(
|
||||
story_id: int,
|
||||
request: ImageConfigRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""批量更新故事的图片配置"""
|
||||
# 验证故事存在
|
||||
result = await db.execute(select(Story).where(Story.id == story_id))
|
||||
story = result.scalar_one_or_none()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
updated = {"cover": False, "nodes": 0, "characters": 0}
|
||||
|
||||
# 更新封面
|
||||
if request.coverUrl:
|
||||
await db.execute(
|
||||
update(Story).where(Story.id == story_id).values(cover_url=request.coverUrl)
|
||||
)
|
||||
updated["cover"] = True
|
||||
|
||||
# 更新节点图片
|
||||
for node_img in request.nodes:
|
||||
values = {}
|
||||
if node_img.backgroundImage:
|
||||
values["background_image"] = node_img.backgroundImage
|
||||
if node_img.characterImage:
|
||||
values["character_image"] = node_img.characterImage
|
||||
if values:
|
||||
await db.execute(
|
||||
update(StoryNode)
|
||||
.where(StoryNode.story_id == story_id, StoryNode.node_key == node_img.nodeKey)
|
||||
.values(**values)
|
||||
)
|
||||
updated["nodes"] += 1
|
||||
|
||||
# 更新角色头像
|
||||
for char_img in request.characters:
|
||||
if char_img.avatarUrl:
|
||||
await db.execute(
|
||||
update(StoryCharacter)
|
||||
.where(StoryCharacter.id == char_img.characterId, StoryCharacter.story_id == story_id)
|
||||
.values(avatar_url=char_img.avatarUrl)
|
||||
)
|
||||
updated["characters"] += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "更新成功",
|
||||
"data": updated
|
||||
}
|
||||
|
||||
|
||||
@router.post("/generate-image")
|
||||
async def generate_story_image(
|
||||
request: GenerateImageRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""使用AI生成图片并可选保存到故事"""
|
||||
from app.services.image_gen import get_image_gen_service
|
||||
|
||||
# 生成图片
|
||||
result = await get_image_gen_service().generate_and_save(
|
||||
prompt=request.prompt,
|
||||
category=request.category,
|
||||
style=request.style
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
return {
|
||||
"code": 1,
|
||||
"message": result.get("error", "生成失败"),
|
||||
"data": None
|
||||
}
|
||||
|
||||
image_url = result["url"]
|
||||
|
||||
# 如果指定了故事和目标字段,自动更新
|
||||
if request.storyId and request.targetField:
|
||||
if request.targetField == "coverUrl":
|
||||
await db.execute(
|
||||
update(Story).where(Story.id == request.storyId).values(cover_url=image_url)
|
||||
)
|
||||
elif request.targetField == "backgroundImage" and request.targetKey:
|
||||
await db.execute(
|
||||
update(StoryNode)
|
||||
.where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey)
|
||||
.values(background_image=image_url)
|
||||
)
|
||||
elif request.targetField == "characterImage" and request.targetKey:
|
||||
await db.execute(
|
||||
update(StoryNode)
|
||||
.where(StoryNode.story_id == request.storyId, StoryNode.node_key == request.targetKey)
|
||||
.values(character_image=image_url)
|
||||
)
|
||||
elif request.targetField == "avatarUrl" and request.targetKey:
|
||||
await db.execute(
|
||||
update(StoryCharacter)
|
||||
.where(StoryCharacter.story_id == request.storyId, StoryCharacter.id == int(request.targetKey))
|
||||
.values(avatar_url=image_url)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "生成成功",
|
||||
"data": {
|
||||
"url": image_url,
|
||||
"filename": result.get("filename"),
|
||||
"saved": bool(request.storyId and request.targetField)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{story_id}/generate-all-images")
|
||||
async def generate_all_story_images(
|
||||
story_id: int,
|
||||
style: str = "anime",
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""为故事批量生成所有角色头像"""
|
||||
from app.services.image_gen import get_image_gen_service
|
||||
image_service = get_image_gen_service()
|
||||
|
||||
# 获取所有角色
|
||||
result = await db.execute(
|
||||
select(StoryCharacter).where(StoryCharacter.story_id == story_id)
|
||||
)
|
||||
characters = result.scalars().all()
|
||||
|
||||
if not characters:
|
||||
return {"code": 1, "message": "故事没有角色数据", "data": None}
|
||||
|
||||
generated = []
|
||||
failed = []
|
||||
|
||||
for char in characters:
|
||||
# 使用avatar_prompt或自动构建
|
||||
prompt = char.avatar_prompt or f"{char.name}, {char.gender}, {char.appearance or ''}"
|
||||
|
||||
gen_result = await image_service.generate_and_save(
|
||||
prompt=prompt,
|
||||
category="character",
|
||||
style=style
|
||||
)
|
||||
|
||||
if gen_result.get("success"):
|
||||
# 更新数据库
|
||||
await db.execute(
|
||||
update(StoryCharacter)
|
||||
.where(StoryCharacter.id == char.id)
|
||||
.values(avatar_url=gen_result["url"])
|
||||
)
|
||||
generated.append({"id": char.id, "name": char.name, "url": gen_result["url"]})
|
||||
else:
|
||||
failed.append({"id": char.id, "name": char.name, "error": gen_result.get("error")})
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"message": f"生成完成: {len(generated)}成功, {len(failed)}失败",
|
||||
"data": {
|
||||
"generated": generated,
|
||||
"failed": failed
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user