Files
ai_baijiahao/database.py

414 lines
14 KiB
Python
Raw Permalink Normal View History

# -*- coding: utf-8 -*-
"""
SQLite 数据库管理模块
用于替换原有的 JSON 文件存储方式
"""
import sqlite3
import os
import logging
from datetime import datetime
from contextlib import contextmanager
import threading
logger = logging.getLogger(__name__)
class Database:
"""SQLite 数据库管理器"""
def __init__(self, db_path="data/baijiahao.db"):
self.db_path = db_path
self._local = threading.local()
self._ensure_database()
def _ensure_database(self):
"""确保数据库文件和表结构存在"""
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
with self.get_connection() as conn:
cursor = conn.cursor()
# 创建任务表
cursor.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
url TEXT NOT NULL,
months REAL NOT NULL,
use_proxy INTEGER NOT NULL,
proxy_api_url TEXT,
username TEXT,
status TEXT NOT NULL,
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
paused_at TEXT,
progress INTEGER DEFAULT 0,
current_step TEXT,
total_articles INTEGER DEFAULT 0,
processed_articles INTEGER DEFAULT 0,
error TEXT,
result_file TEXT,
retry_count INTEGER DEFAULT 0,
last_error TEXT,
articles_only INTEGER DEFAULT 1,
last_page INTEGER DEFAULT 0,
last_ctime TEXT
)
''')
# 创建任务日志表
cursor.execute('''
CREATE TABLE IF NOT EXISTS task_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
timestamp TEXT NOT NULL,
message TEXT NOT NULL,
level TEXT DEFAULT 'info',
FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE
)
''')
# 创建文章缓存表(用于断点续传)
cursor.execute('''
CREATE TABLE IF NOT EXISTS article_cache (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
title TEXT NOT NULL,
url TEXT,
publish_time TEXT,
page_num INTEGER,
created_at TEXT NOT NULL,
FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE
)
''')
# 创建索引提升查询性能
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_tasks_status
ON tasks(status)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_tasks_username
ON tasks(username)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_tasks_created_at
ON tasks(created_at DESC)
''')
# 为日志表创建索引
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id
ON task_logs(task_id)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_task_logs_timestamp
ON task_logs(timestamp)
''')
# 为文章缓存表创建索引
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_article_cache_task_id
ON article_cache(task_id)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_article_cache_page
ON article_cache(task_id, page_num)
''')
conn.commit()
logger.info(f"数据库初始化完成: {self.db_path}")
@contextmanager
def get_connection(self):
"""获取线程安全的数据库连接(上下文管理器)"""
if not hasattr(self._local, 'conn') or self._local.conn is None:
self._local.conn = sqlite3.connect(
self.db_path,
check_same_thread=False,
timeout=30.0
)
# 设置返回字典而不是元组
self._local.conn.row_factory = sqlite3.Row
try:
yield self._local.conn
except Exception as e:
self._local.conn.rollback()
logger.error(f"数据库操作失败: {e}")
raise
def close(self):
"""关闭当前线程的数据库连接"""
if hasattr(self._local, 'conn') and self._local.conn is not None:
self._local.conn.close()
self._local.conn = None
def add_task_log(self, task_id, message, level='info', timestamp=None):
"""添加任务日志
Args:
task_id: 任务ID
message: 日志消息
level: 日志级别 (info/success/warning/error)
timestamp: 时间戳默认为当前时间
"""
if timestamp is None:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO task_logs (task_id, timestamp, message, level)
VALUES (?, ?, ?, ?)
''', (task_id, timestamp, message, level))
conn.commit()
def get_task_logs(self, task_id, limit=None):
"""获取任务的所有日志
Args:
task_id: 任务ID
limit: 限制返回数量默认返回所有
Returns:
list: 日志列表按时间顺序
"""
with self.get_connection() as conn:
cursor = conn.cursor()
if limit:
cursor.execute('''
SELECT task_id, timestamp, message, level
FROM task_logs
WHERE task_id = ?
ORDER BY id ASC
LIMIT ?
''', (task_id, limit))
else:
cursor.execute('''
SELECT task_id, timestamp, message, level
FROM task_logs
WHERE task_id = ?
ORDER BY id ASC
''', (task_id,))
rows = cursor.fetchall()
return [dict(row) for row in rows]
def clear_task_logs(self, task_id):
"""清除任务的所有日志
Args:
task_id: 任务ID
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM task_logs WHERE task_id = ?', (task_id,))
conn.commit()
def save_articles_batch(self, task_id, articles, page_num):
"""批量保存文章到缓存表
Args:
task_id: 任务ID
articles: 文章列表 [{'title': ..., 'url': ..., 'publish_time': ...}, ...]
page_num: 页码
"""
if not articles:
return
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with self.get_connection() as conn:
cursor = conn.cursor()
# 批量插入
data = [
(task_id,
article.get('标题', article.get('title', '')),
article.get('链接', article.get('url', '')),
article.get('发布时间', article.get('publish_time', '')),
page_num,
timestamp)
for article in articles
]
cursor.executemany('''
INSERT INTO article_cache (task_id, title, url, publish_time, page_num, created_at)
VALUES (?, ?, ?, ?, ?, ?)
''', data)
conn.commit()
logger.debug(f"保存 {len(articles)} 篇文章到缓存(任务{task_id},第{page_num}页)")
def get_cached_articles(self, task_id):
"""获取任务的所有缓存文章
Args:
task_id: 任务ID
Returns:
list: 文章列表
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT title, url, publish_time, page_num
FROM article_cache
WHERE task_id = ?
ORDER BY id ASC
''', (task_id,))
rows = cursor.fetchall()
return [
{
'标题': row['title'],
'链接': row['url'],
'发布时间': row['publish_time'],
'page_num': row['page_num']
}
for row in rows
]
def get_cached_article_count(self, task_id):
"""获取任务已缓存的文章数量
Args:
task_id: 任务ID
Returns:
int: 文章数量
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT COUNT(*) as count FROM article_cache WHERE task_id = ?
''', (task_id,))
result = cursor.fetchone()
return result['count'] if result else 0
def clear_article_cache(self, task_id):
"""清除任务的所有缓存文章
Args:
task_id: 任务ID
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM article_cache WHERE task_id = ?', (task_id,))
conn.commit()
logger.debug(f"清除任务 {task_id} 的所有缓存文章")
# 全局数据库实例
_db_instance = None
_db_lock = threading.Lock()
def get_database():
"""获取全局数据库实例(单例模式)"""
global _db_instance
if _db_instance is None:
with _db_lock:
if _db_instance is None:
_db_instance = Database()
return _db_instance
def migrate_from_json(json_file="data/task_queue.json"):
"""从 JSON 文件迁移数据到 SQLite 数据库
Args:
json_file: JSON 文件路径
Returns:
migrated_count: 成功迁移的任务数量
"""
import json
if not os.path.exists(json_file):
logger.info("未找到旧的 JSON 文件,跳过数据迁移")
return 0
try:
# 读取 JSON 数据
with open(json_file, 'r', encoding='utf-8') as f:
tasks = json.load(f)
if not tasks:
logger.info("JSON 文件中没有任务数据")
return 0
db = get_database()
migrated_count = 0
with db.get_connection() as conn:
cursor = conn.cursor()
for task in tasks:
try:
# 检查任务是否已存在
cursor.execute(
"SELECT task_id FROM tasks WHERE task_id = ?",
(task["task_id"],)
)
if cursor.fetchone():
logger.debug(f"任务 {task['task_id']} 已存在,跳过")
continue
# 插入任务数据
cursor.execute('''
INSERT INTO tasks (
task_id, url, months, use_proxy, proxy_api_url,
username, status, created_at, started_at, completed_at,
progress, current_step, total_articles, processed_articles,
error, result_file
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
task["task_id"],
task["url"],
task["months"],
1 if task["use_proxy"] else 0,
task.get("proxy_api_url"),
task.get("username"),
task["status"],
task["created_at"],
task.get("started_at"),
task.get("completed_at"),
task.get("progress", 0),
task.get("current_step"),
task.get("total_articles", 0),
task.get("processed_articles", 0),
task.get("error"),
task.get("result_file")
))
migrated_count += 1
except Exception as e:
logger.error(f"迁移任务 {task.get('task_id')} 失败: {e}")
conn.commit()
logger.info(f"成功迁移 {migrated_count} 个任务到数据库")
# 备份原 JSON 文件
backup_file = json_file + ".backup"
if migrated_count > 0:
import shutil
shutil.copy2(json_file, backup_file)
logger.info(f"原 JSON 文件已备份到: {backup_file}")
return migrated_count
except Exception as e:
logger.error(f"数据迁移失败: {e}")
return 0