Files
ai_wht_B/release/11/prompt_routes.py

325 lines
12 KiB
Python
Raw Permalink Normal View History

2026-01-06 14:18:39 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
提示词管理接口
"""
from flask import Blueprint, request, jsonify
import logging
from datetime import datetime
from auth_utils import require_auth, AuthUtils
from database_config import get_db_manager
from log_utils import log_create, log_update, log_delete, log_error, log_operation
logger = logging.getLogger(__name__)
# 创建蓝图
prompt_bp = Blueprint('prompt', __name__, url_prefix='/api/prompts')
@prompt_bp.route('/list', methods=['GET'])
@require_auth
def get_prompts_list():
"""获取提示词列表"""
try:
current_user = AuthUtils.get_current_user()
enterprise_id = current_user.get('enterprise_id')
if not enterprise_id:
return jsonify({
'code': 400,
'message': '无法获取企业ID',
'data': None
}), 400
# 获取查询参数
page = int(request.args.get('page', 1))
page_size = int(request.args.get('pageSize', 20))
# 计算偏移量
offset = (page - 1) * page_size
db_manager = get_db_manager()
# 查询总数
count_sql = "SELECT COUNT(*) as total FROM ai_prompt_workflow WHERE enterprise_id = %s"
count_result = db_manager.execute_query(count_sql, (enterprise_id,))
total = count_result[0]['total']
# 查询提示词列表
sql = """
SELECT id, prompt_workflow_name, workflow_id, content, usage_count, created_at, updated_at
FROM ai_prompt_workflow
WHERE enterprise_id = %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
prompts = db_manager.execute_query(sql, (enterprise_id, page_size, offset))
# 查询每个提示词的标签
for prompt in prompts:
tag_sql = """
SELECT t.id, t.tag_name
FROM ai_prompt_tags t
INNER JOIN ai_prompt_tags_relation r ON t.id = r.tag_id
WHERE r.prompt_workflow_id = %s
"""
tags = db_manager.execute_query(tag_sql, (prompt['id'],))
prompt['tags'] = [tag['tag_name'] for tag in tags]
logger.info(f"获取提示词列表成功,总数: {total}")
return jsonify({
'code': 200,
'message': 'success',
'data': {
'total': total,
'list': prompts
},
'timestamp': int(datetime.now().timestamp() * 1000)
})
except Exception as e:
logger.error(f"[获取提示词列表] 处理请求时发生错误: {str(e)}", exc_info=True)
return jsonify({
'code': 500,
'message': '服务器内部错误',
'data': None
}), 500
@prompt_bp.route('/create', methods=['POST'])
@require_auth
def create_prompt():
"""创建提示词"""
try:
current_user = AuthUtils.get_current_user()
enterprise_id = current_user.get('enterprise_id')
if not enterprise_id:
return jsonify({
'code': 400,
'message': '无法获取企业ID',
'data': None
}), 400
data = request.get_json()
if not data:
return jsonify({
'code': 400,
'message': '请求参数错误',
'data': None
}), 400
# 验证必需字段
required_fields = ['prompt_workflow_name', 'content']
for field in required_fields:
if not data.get(field):
return jsonify({
'code': 400,
'message': f'缺少必需字段: {field}',
'data': None
}), 400
db_manager = get_db_manager()
# 生成workflow_id
import uuid
workflow_id = f"WF-{str(uuid.uuid4())[:8].upper()}"
# 创建提示词
sql = """
INSERT INTO ai_prompt_workflow
(enterprise_id, prompt_workflow_name, workflow_id, content, usage_count)
VALUES (%s, %s, %s, %s, %s)
"""
prompt_id = db_manager.execute_insert(sql, (
enterprise_id,
data['prompt_workflow_name'],
workflow_id,
data['content'],
0
))
# 添加标签关联
if data.get('tags'):
for tag_name in data['tags']:
# 查找或创建标签ai_prompt_tags是全局标签表
tag_sql = "SELECT id FROM ai_prompt_tags WHERE tag_name = %s"
tag_result = db_manager.execute_query(tag_sql, (tag_name,))
if tag_result:
tag_id = tag_result[0]['id']
else:
tag_insert_sql = "INSERT INTO ai_prompt_tags (tag_name, created_user_id) VALUES (%s, %s)"
tag_id = db_manager.execute_insert(tag_insert_sql, (tag_name, current_user.get('user_id', 0)))
# 创建关联
rel_sql = "INSERT INTO ai_prompt_tags_relation (prompt_workflow_id, prompt_workflow_name, tag_id, tag_name, created_user_id) VALUES (%s, %s, %s, %s, %s)"
db_manager.execute_insert(rel_sql, (prompt_id, data['prompt_workflow_name'], tag_id, tag_name, current_user.get('user_id', 0)))
return jsonify({
'code': 200,
'message': '创建成功',
'data': {
'id': prompt_id,
'prompt_workflow_name': data['prompt_workflow_name']
},
'timestamp': int(datetime.now().timestamp() * 1000)
})
except Exception as e:
logger.error(f"[创建提示词] 处理请求时发生错误: {str(e)}", exc_info=True)
return jsonify({
'code': 500,
'message': '服务器内部错误',
'data': None
}), 500
@prompt_bp.route('/<int:prompt_id>', methods=['PUT'])
@require_auth
def update_prompt(prompt_id):
"""更新提示词"""
try:
current_user = AuthUtils.get_current_user()
enterprise_id = current_user.get('enterprise_id')
if not enterprise_id:
return jsonify({
'code': 400,
'message': '无法获取企业ID',
'data': None
}), 400
data = request.get_json()
if not data:
return jsonify({
'code': 400,
'message': '请求参数错误',
'data': None
}), 400
db_manager = get_db_manager()
# 检查提示词是否存在且属于当前企业
check_sql = "SELECT id FROM ai_prompt_workflow WHERE id = %s AND enterprise_id = %s"
existing = db_manager.execute_query(check_sql, (prompt_id, enterprise_id))
if not existing:
return jsonify({
'code': 404,
'message': '提示词不存在',
'data': None
}), 404
# 构建更新字段
update_fields = []
params = []
if 'prompt_workflow_name' in data:
update_fields.append("prompt_workflow_name = %s")
params.append(data['prompt_workflow_name'])
if 'content' in data:
update_fields.append("content = %s")
params.append(data['content'])
if update_fields:
params.append(prompt_id)
sql = f"UPDATE ai_prompt_workflow SET {', '.join(update_fields)}, updated_at = NOW() WHERE id = %s"
db_manager.execute_update(sql, params)
# 更新标签关联
if 'tags' in data:
# 删除旧标签关联
del_sql = "DELETE FROM ai_prompt_tags_relation WHERE prompt_workflow_id = %s"
db_manager.execute_update(del_sql, (prompt_id,))
# 添加新标签
for tag_name in data['tags']:
# 查找或创建标签ai_prompt_tags是全局标签表
tag_sql = "SELECT id FROM ai_prompt_tags WHERE tag_name = %s"
tag_result = db_manager.execute_query(tag_sql, (tag_name,))
if tag_result:
tag_id = tag_result[0]['id']
else:
tag_insert_sql = "INSERT INTO ai_prompt_tags (tag_name, created_user_id) VALUES (%s, %s)"
tag_id = db_manager.execute_insert(tag_insert_sql, (tag_name, current_user.get('user_id', 0)))
# 查询prompt_workflow_name
prompt_info = db_manager.execute_query("SELECT prompt_workflow_name FROM ai_prompt_workflow WHERE id = %s", (prompt_id,))
prompt_workflow_name = prompt_info[0]['prompt_workflow_name'] if prompt_info else ''
# 创建关联
rel_sql = "INSERT INTO ai_prompt_tags_relation (prompt_workflow_id, prompt_workflow_name, tag_id, tag_name, created_user_id) VALUES (%s, %s, %s, %s, %s)"
db_manager.execute_insert(rel_sql, (prompt_id, prompt_workflow_name, tag_id, tag_name, current_user.get('user_id', 0)))
logger.info(f"更新提示词成功: ID {prompt_id}")
return jsonify({
'code': 200,
'message': '更新成功',
'data': None,
'timestamp': int(datetime.now().timestamp() * 1000)
})
except Exception as e:
logger.error(f"[更新提示词] 处理请求时发生错误: {str(e)}", exc_info=True)
return jsonify({
'code': 500,
'message': '服务器内部错误',
'data': None
}), 500
@prompt_bp.route('/<int:prompt_id>', methods=['DELETE'])
@require_auth
def delete_prompt(prompt_id):
"""删除提示词"""
try:
current_user = AuthUtils.get_current_user()
enterprise_id = current_user.get('enterprise_id')
if not enterprise_id:
return jsonify({
'code': 400,
'message': '无法获取企业ID',
'data': None
}), 400
db_manager = get_db_manager()
# 检查提示词是否存在且属于当前企业
check_sql = "SELECT id, prompt_workflow_name FROM ai_prompt_workflow WHERE id = %s AND enterprise_id = %s"
existing = db_manager.execute_query(check_sql, (prompt_id, enterprise_id))
if not existing:
return jsonify({
'code': 404,
'message': '提示词不存在',
'data': None
}), 404
# 删除标签关联
del_rel_sql = "DELETE FROM ai_prompt_tags_relation WHERE prompt_workflow_id = %s"
db_manager.execute_update(del_rel_sql, (prompt_id,))
# 删除提示词
sql = "DELETE FROM ai_prompt_workflow WHERE id = %s"
db_manager.execute_update(sql, (prompt_id,))
logger.info(f"删除提示词成功: {existing[0]['prompt_workflow_name']}")
return jsonify({
'code': 200,
'message': '删除成功',
'data': None,
'timestamp': int(datetime.now().timestamp() * 1000)
})
except Exception as e:
logger.error(f"[删除提示词] 处理请求时发生错误: {str(e)}", exc_info=True)
return jsonify({
'code': 500,
'message': '服务器内部错误',
'data': None
}), 500