feat: 添加测试用户到种子数据, AI改写功能优化, 前端联调修复
This commit is contained in:
@@ -32,6 +32,24 @@ class CreateDraftRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
class CreateEndingDraftRequest(BaseModel):
|
||||
"""结局改写请求"""
|
||||
userId: int
|
||||
storyId: int
|
||||
endingName: str
|
||||
endingContent: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class ContinueEndingDraftRequest(BaseModel):
|
||||
"""结局续写请求"""
|
||||
userId: int
|
||||
storyId: int
|
||||
endingName: str
|
||||
endingContent: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class DraftResponse(BaseModel):
|
||||
id: int
|
||||
storyId: int
|
||||
@@ -120,6 +138,174 @@ async def process_ai_rewrite(draft_id: int):
|
||||
pass
|
||||
|
||||
|
||||
async def process_ai_rewrite_ending(draft_id: int):
|
||||
"""后台异步处理AI改写结局"""
|
||||
from app.database import async_session_factory
|
||||
from app.services.ai import ai_service
|
||||
import json
|
||||
import re
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
# 获取草稿
|
||||
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
|
||||
draft = result.scalar_one_or_none()
|
||||
|
||||
if not draft:
|
||||
return
|
||||
|
||||
# 更新状态为处理中
|
||||
draft.status = DraftStatus.processing
|
||||
await db.commit()
|
||||
|
||||
# 获取故事信息
|
||||
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
|
||||
story = story_result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = "故事不存在"
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
return
|
||||
|
||||
# 从 path_history 获取结局信息
|
||||
ending_info = draft.path_history or {}
|
||||
ending_name = ending_info.get("endingName", "未知结局")
|
||||
ending_content = ending_info.get("endingContent", "")
|
||||
|
||||
# 调用AI服务改写结局
|
||||
ai_result = await ai_service.rewrite_ending(
|
||||
story_title=story.title,
|
||||
story_category=story.category or "未知",
|
||||
ending_name=ending_name,
|
||||
ending_content=ending_content,
|
||||
user_prompt=draft.user_prompt
|
||||
)
|
||||
|
||||
if ai_result and ai_result.get("content"):
|
||||
content = ai_result["content"]
|
||||
new_ending_name = f"{ending_name}(AI改写)"
|
||||
|
||||
# 尝试解析 JSON 格式的返回
|
||||
try:
|
||||
json_match = re.search(r'\{[^{}]*"ending_name"[^{}]*"content"[^{}]*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
parsed = json.loads(json_match.group())
|
||||
new_ending_name = parsed.get("ending_name", new_ending_name)
|
||||
content = parsed.get("content", content)
|
||||
else:
|
||||
parsed = json.loads(content)
|
||||
new_ending_name = parsed.get("ending_name", new_ending_name)
|
||||
content = parsed.get("content", content)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
|
||||
# 成功 - 存储为单节点结局格式
|
||||
draft.status = DraftStatus.completed
|
||||
draft.ai_nodes = [{
|
||||
"nodeKey": "ending_rewrite",
|
||||
"content": content,
|
||||
"speaker": "旁白",
|
||||
"isEnding": True,
|
||||
"endingName": new_ending_name,
|
||||
"endingType": "rewrite"
|
||||
}]
|
||||
draft.entry_node_key = "ending_rewrite"
|
||||
draft.tokens_used = ai_result.get("tokens_used", 0)
|
||||
draft.title = f"{story.title}-{new_ending_name}"
|
||||
else:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = "AI服务暂时不可用"
|
||||
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[process_ai_rewrite_ending] 异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
try:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = str(e)[:500]
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
async def process_ai_continue_ending(draft_id: int):
|
||||
"""后台异步处理AI续写结局"""
|
||||
from app.database import async_session_factory
|
||||
from app.services.ai import ai_service
|
||||
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
# 获取草稿
|
||||
result = await db.execute(select(StoryDraft).where(StoryDraft.id == draft_id))
|
||||
draft = result.scalar_one_or_none()
|
||||
|
||||
if not draft:
|
||||
return
|
||||
|
||||
# 更新状态为处理中
|
||||
draft.status = DraftStatus.processing
|
||||
await db.commit()
|
||||
|
||||
# 获取故事信息
|
||||
story_result = await db.execute(select(Story).where(Story.id == draft.story_id))
|
||||
story = story_result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = "故事不存在"
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
return
|
||||
|
||||
# 从 path_history 获取结局信息
|
||||
ending_info = draft.path_history or {}
|
||||
ending_name = ending_info.get("endingName", "未知结局")
|
||||
ending_content = ending_info.get("endingContent", "")
|
||||
|
||||
# 调用AI服务续写结局
|
||||
ai_result = await ai_service.continue_ending(
|
||||
story_title=story.title,
|
||||
story_category=story.category or "未知",
|
||||
ending_name=ending_name,
|
||||
ending_content=ending_content,
|
||||
user_prompt=draft.user_prompt
|
||||
)
|
||||
|
||||
if ai_result and ai_result.get("nodes"):
|
||||
# 成功 - 存储多节点分支格式
|
||||
draft.status = DraftStatus.completed
|
||||
draft.ai_nodes = ai_result["nodes"]
|
||||
draft.entry_node_key = ai_result.get("entryNodeKey", "continue_1")
|
||||
draft.tokens_used = ai_result.get("tokens_used", 0)
|
||||
draft.title = f"{story.title}-{ending_name}续写"
|
||||
else:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = "AI服务暂时不可用"
|
||||
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[process_ai_continue_ending] 异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
try:
|
||||
draft.status = DraftStatus.failed
|
||||
draft.error_message = str(e)[:500]
|
||||
draft.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# ============ API 路由 ============
|
||||
|
||||
@router.post("")
|
||||
@@ -173,6 +359,96 @@ async def create_draft(
|
||||
}
|
||||
|
||||
|
||||
@router.post("/ending")
|
||||
async def create_ending_draft(
|
||||
request: CreateEndingDraftRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""提交AI改写结局任务(异步处理)"""
|
||||
if not request.prompt:
|
||||
raise HTTPException(status_code=400, detail="请输入改写指令")
|
||||
|
||||
# 获取故事信息
|
||||
result = await db.execute(select(Story).where(Story.id == request.storyId))
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
# 创建草稿记录,将结局信息存在 path_history
|
||||
draft = StoryDraft(
|
||||
user_id=request.userId,
|
||||
story_id=request.storyId,
|
||||
title=f"{story.title}-结局改写",
|
||||
path_history={"endingName": request.endingName, "endingContent": request.endingContent},
|
||||
current_node_key="ending",
|
||||
current_content=request.endingContent,
|
||||
user_prompt=request.prompt,
|
||||
status=DraftStatus.pending
|
||||
)
|
||||
|
||||
db.add(draft)
|
||||
await db.commit()
|
||||
await db.refresh(draft)
|
||||
|
||||
# 添加后台任务
|
||||
background_tasks.add_task(process_ai_rewrite_ending, draft.id)
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"draftId": draft.id,
|
||||
"message": "已提交,AI正在生成新结局..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/continue-ending")
|
||||
async def create_continue_ending_draft(
|
||||
request: ContinueEndingDraftRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""提交AI续写结局任务(异步处理)"""
|
||||
if not request.prompt:
|
||||
raise HTTPException(status_code=400, detail="请输入续写指令")
|
||||
|
||||
# 获取故事信息
|
||||
result = await db.execute(select(Story).where(Story.id == request.storyId))
|
||||
story = result.scalar_one_or_none()
|
||||
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="故事不存在")
|
||||
|
||||
# 创建草稿记录,将结局信息存在 path_history
|
||||
draft = StoryDraft(
|
||||
user_id=request.userId,
|
||||
story_id=request.storyId,
|
||||
title=f"{story.title}-结局续写",
|
||||
path_history={"endingName": request.endingName, "endingContent": request.endingContent},
|
||||
current_node_key="ending",
|
||||
current_content=request.endingContent,
|
||||
user_prompt=request.prompt,
|
||||
status=DraftStatus.pending
|
||||
)
|
||||
|
||||
db.add(draft)
|
||||
await db.commit()
|
||||
await db.refresh(draft)
|
||||
|
||||
# 添加后台任务
|
||||
background_tasks.add_task(process_ai_continue_ending, draft.id)
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"draftId": draft.id,
|
||||
"message": "已提交,AI正在续写故事..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_drafts(
|
||||
userId: int,
|
||||
|
||||
Reference in New Issue
Block a user