101 lines
2.8 KiB
Python
101 lines
2.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
SQLite数据库初始化脚本
|
|
自动创建开发环境(ai_mip_dev.db)和生产环境(ai_mip_prod.db)数据库
|
|
"""
|
|
|
|
import sqlite3
|
|
import os
|
|
from pathlib import Path
|
|
|
|
# 数据库文件路径
|
|
DB_DIR = Path(__file__).parent
|
|
DEV_DB = DB_DIR / "ai_mip_dev.db"
|
|
PROD_DB = DB_DIR / "ai_mip_prod.db"
|
|
|
|
# SQL脚本路径
|
|
INIT_SQL = DB_DIR / "init_sqlite.sql"
|
|
SEED_DEV_SQL = DB_DIR / "seed_dev.sql"
|
|
|
|
|
|
def execute_sql_file(conn, sql_file):
|
|
"""执行SQL文件"""
|
|
with open(sql_file, 'r', encoding='utf-8') as f:
|
|
sql_script = f.read()
|
|
|
|
# SQLite需要逐条执行语句
|
|
conn.executescript(sql_script)
|
|
conn.commit()
|
|
print(f"✓ 已执行: {sql_file.name}")
|
|
|
|
|
|
def init_database(db_path, with_seed=False):
|
|
"""初始化数据库"""
|
|
# 如果数据库已存在,询问是否覆盖
|
|
if db_path.exists():
|
|
response = input(f"\n数据库 {db_path.name} 已存在,是否覆盖? (y/n): ").strip().lower()
|
|
if response != 'y':
|
|
print(f"跳过 {db_path.name}")
|
|
return
|
|
os.remove(db_path)
|
|
print(f"已删除旧数据库: {db_path.name}")
|
|
|
|
print(f"\n创建数据库: {db_path.name}")
|
|
|
|
# 连接数据库(自动创建)
|
|
conn = sqlite3.connect(db_path)
|
|
|
|
try:
|
|
# 执行初始化SQL
|
|
execute_sql_file(conn, INIT_SQL)
|
|
|
|
# 如果需要,执行种子数据
|
|
if with_seed:
|
|
execute_sql_file(conn, SEED_DEV_SQL)
|
|
|
|
print(f"✓ 数据库 {db_path.name} 创建成功")
|
|
|
|
# 验证表是否创建成功
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = cursor.fetchall()
|
|
print(f" 创建的表: {', '.join([t[0] for t in tables])}")
|
|
|
|
except Exception as e:
|
|
print(f"✗ 创建数据库失败: {str(e)}")
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("SQLite数据库初始化工具")
|
|
print("=" * 60)
|
|
|
|
# 检查SQL文件是否存在
|
|
if not INIT_SQL.exists():
|
|
print(f"错误: 找不到初始化脚本 {INIT_SQL}")
|
|
return
|
|
|
|
# 初始化开发数据库(带测试数据)
|
|
print("\n[1] 初始化开发环境数据库")
|
|
init_database(DEV_DB, with_seed=True)
|
|
|
|
# 初始化生产数据库(不带测试数据)
|
|
print("\n[2] 初始化生产环境数据库")
|
|
init_database(PROD_DB, with_seed=False)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("数据库初始化完成")
|
|
print("=" * 60)
|
|
print(f"开发数据库: {DEV_DB}")
|
|
print(f"生产数据库: {PROD_DB}")
|
|
print("\n使用方法:")
|
|
print(" 开发环境: 在 .env.development 中设置 DATABASE_PATH=db/ai_mip_dev.db")
|
|
print(" 生产环境: 在 .env.production 中设置 DATABASE_PATH=db/ai_mip_prod.db")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|