212 lines
7.7 KiB
Python
212 lines
7.7 KiB
Python
|
|
"""
|
|||
|
|
AI服务封装模块
|
|||
|
|
支持多种AI提供商:DeepSeek, OpenAI, Claude, 通义千问
|
|||
|
|
"""
|
|||
|
|
from typing import Optional, Dict, Any
|
|||
|
|
import httpx
|
|||
|
|
|
|||
|
|
class AIService:
|
|||
|
|
def __init__(self):
|
|||
|
|
from app.config import get_settings
|
|||
|
|
settings = get_settings()
|
|||
|
|
|
|||
|
|
self.enabled = settings.ai_service_enabled
|
|||
|
|
self.provider = settings.ai_provider
|
|||
|
|
|
|||
|
|
# 根据提供商初始化配置
|
|||
|
|
if self.provider == "deepseek":
|
|||
|
|
self.api_key = settings.deepseek_api_key
|
|||
|
|
self.base_url = settings.deepseek_base_url
|
|||
|
|
self.model = settings.deepseek_model
|
|||
|
|
elif self.provider == "openai":
|
|||
|
|
self.api_key = settings.openai_api_key
|
|||
|
|
self.base_url = settings.openai_base_url
|
|||
|
|
self.model = "gpt-3.5-turbo"
|
|||
|
|
elif self.provider == "claude":
|
|||
|
|
# Claude 需要从环境变量读取
|
|||
|
|
import os
|
|||
|
|
self.api_key = os.getenv("CLAUDE_API_KEY", "")
|
|||
|
|
self.base_url = "https://api.anthropic.com"
|
|||
|
|
self.model = "claude-3-haiku-20240307"
|
|||
|
|
elif self.provider == "qwen":
|
|||
|
|
import os
|
|||
|
|
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")
|
|||
|
|
self.base_url = "https://dashscope.aliyuncs.com/api/v1"
|
|||
|
|
self.model = "qwen-plus"
|
|||
|
|
else:
|
|||
|
|
self.api_key = None
|
|||
|
|
self.base_url = None
|
|||
|
|
self.model = None
|
|||
|
|
|
|||
|
|
async def rewrite_ending(
|
|||
|
|
self,
|
|||
|
|
story_title: str,
|
|||
|
|
story_category: str,
|
|||
|
|
ending_name: str,
|
|||
|
|
ending_content: str,
|
|||
|
|
user_prompt: str
|
|||
|
|
) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
AI改写结局
|
|||
|
|
:return: {"content": str, "tokens_used": int} 或 None
|
|||
|
|
"""
|
|||
|
|
if not self.enabled or not self.api_key:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 构建Prompt
|
|||
|
|
system_prompt = """你是一个专业的互动故事创作专家。根据用户的改写指令,重新创作故事结局。
|
|||
|
|
要求:
|
|||
|
|
1. 保持原故事的世界观和人物性格
|
|||
|
|
2. 结局要有张力和情感冲击
|
|||
|
|
3. 结局内容字数控制在200-400字
|
|||
|
|
4. 为新结局取一个4-8字的新名字,体现改写后的剧情走向
|
|||
|
|
5. 输出格式必须是JSON:{"ending_name": "新结局名称", "content": "结局内容"}"""
|
|||
|
|
|
|||
|
|
user_prompt_text = f"""故事标题:{story_title}
|
|||
|
|
故事分类:{story_category}
|
|||
|
|
原结局名称:{ending_name}
|
|||
|
|
原结局内容:{ending_content[:500]}
|
|||
|
|
---
|
|||
|
|
用户改写指令:{user_prompt}
|
|||
|
|
---
|
|||
|
|
请创作新的结局(输出JSON格式):"""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if self.provider == "openai":
|
|||
|
|
return await self._call_openai(system_prompt, user_prompt_text)
|
|||
|
|
elif self.provider == "claude":
|
|||
|
|
return await self._call_claude(user_prompt_text)
|
|||
|
|
elif self.provider == "qwen":
|
|||
|
|
return await self._call_qwen(system_prompt, user_prompt_text)
|
|||
|
|
elif self.provider == "deepseek":
|
|||
|
|
return await self._call_deepseek(system_prompt, user_prompt_text)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"AI调用失败:{e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
async def _call_openai(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
|||
|
|
"""调用OpenAI API"""
|
|||
|
|
url = f"{self.base_url}/chat/completions"
|
|||
|
|
headers = {
|
|||
|
|
"Authorization": f"Bearer {self.api_key}",
|
|||
|
|
"Content-Type": "application/json"
|
|||
|
|
}
|
|||
|
|
data = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt}
|
|||
|
|
],
|
|||
|
|
"temperature": 0.8,
|
|||
|
|
"max_tokens": 500
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|||
|
|
response = await client.post(url, headers=headers, json=data)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
result = response.json()
|
|||
|
|
|
|||
|
|
content = result["choices"][0]["message"]["content"]
|
|||
|
|
tokens = result["usage"]["total_tokens"]
|
|||
|
|
|
|||
|
|
return {"content": content.strip(), "tokens_used": tokens}
|
|||
|
|
|
|||
|
|
async def _call_claude(self, prompt: str) -> Optional[Dict]:
|
|||
|
|
"""调用Claude API"""
|
|||
|
|
url = "https://api.anthropic.com/v1/messages"
|
|||
|
|
headers = {
|
|||
|
|
"x-api-key": self.api_key,
|
|||
|
|
"anthropic-version": "2023-06-01",
|
|||
|
|
"content-type": "application/json"
|
|||
|
|
}
|
|||
|
|
data = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"max_tokens": 1024,
|
|||
|
|
"messages": [{"role": "user", "content": prompt}]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|||
|
|
response = await client.post(url, headers=headers, json=data)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
result = response.json()
|
|||
|
|
|
|||
|
|
content = result["content"][0]["text"]
|
|||
|
|
tokens = result.get("usage", {}).get("output_tokens", 0)
|
|||
|
|
|
|||
|
|
return {"content": content.strip(), "tokens_used": tokens}
|
|||
|
|
|
|||
|
|
async def _call_qwen(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
|||
|
|
"""调用通义千问API"""
|
|||
|
|
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|||
|
|
headers = {
|
|||
|
|
"Authorization": f"Bearer {self.api_key}",
|
|||
|
|
"Content-Type": "application/json"
|
|||
|
|
}
|
|||
|
|
data = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"input": {
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"parameters": {
|
|||
|
|
"result_format": "message",
|
|||
|
|
"temperature": 0.8
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|||
|
|
response = await client.post(url, headers=headers, json=data)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
result = response.json()
|
|||
|
|
|
|||
|
|
content = result["output"]["choices"][0]["message"]["content"]
|
|||
|
|
tokens = result.get("usage", {}).get("total_tokens", 0)
|
|||
|
|
|
|||
|
|
return {"content": content.strip(), "tokens_used": tokens}
|
|||
|
|
|
|||
|
|
async def _call_deepseek(self, system_prompt: str, user_prompt: str) -> Optional[Dict]:
|
|||
|
|
"""调用 DeepSeek API"""
|
|||
|
|
url = f"{self.base_url}/chat/completions"
|
|||
|
|
headers = {
|
|||
|
|
"Authorization": f"Bearer {self.api_key}",
|
|||
|
|
"Content-Type": "application/json"
|
|||
|
|
}
|
|||
|
|
data = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt}
|
|||
|
|
],
|
|||
|
|
"temperature": 0.8,
|
|||
|
|
"max_tokens": 500
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|||
|
|
try:
|
|||
|
|
response = await client.post(url, headers=headers, json=data)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
result = response.json()
|
|||
|
|
|
|||
|
|
if "choices" in result and len(result["choices"]) > 0:
|
|||
|
|
content = result["choices"][0]["message"]["content"]
|
|||
|
|
tokens = result.get("usage", {}).get("total_tokens", 0)
|
|||
|
|
|
|||
|
|
return {"content": content.strip(), "tokens_used": tokens}
|
|||
|
|
else:
|
|||
|
|
print(f"DeepSeek API 返回异常:{result}")
|
|||
|
|
return None
|
|||
|
|
except httpx.HTTPStatusError as e:
|
|||
|
|
print(f"DeepSeek HTTP 错误:{e.response.status_code} - {e.response.text}")
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"DeepSeek 调用失败:{e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 单例模式
|
|||
|
|
ai_service = AIService()
|