# -*- coding: utf-8 -*- """ SQLite 数据库管理模块 用于替换原有的 JSON 文件存储方式 """ import sqlite3 import os import logging from datetime import datetime from contextlib import contextmanager import threading logger = logging.getLogger(__name__) class Database: """SQLite 数据库管理器""" def __init__(self, db_path="data/baijiahao.db"): self.db_path = db_path self._local = threading.local() self._ensure_database() def _ensure_database(self): """确保数据库文件和表结构存在""" os.makedirs(os.path.dirname(self.db_path), exist_ok=True) with self.get_connection() as conn: cursor = conn.cursor() # 创建任务表 cursor.execute(''' CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, url TEXT NOT NULL, months REAL NOT NULL, use_proxy INTEGER NOT NULL, proxy_api_url TEXT, username TEXT, status TEXT NOT NULL, created_at TEXT NOT NULL, started_at TEXT, completed_at TEXT, paused_at TEXT, progress INTEGER DEFAULT 0, current_step TEXT, total_articles INTEGER DEFAULT 0, processed_articles INTEGER DEFAULT 0, error TEXT, result_file TEXT, retry_count INTEGER DEFAULT 0, last_error TEXT, articles_only INTEGER DEFAULT 1, last_page INTEGER DEFAULT 0, last_ctime TEXT ) ''') # 创建任务日志表 cursor.execute(''' CREATE TABLE IF NOT EXISTS task_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id TEXT NOT NULL, timestamp TEXT NOT NULL, message TEXT NOT NULL, level TEXT DEFAULT 'info', FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE ) ''') # 创建文章缓存表(用于断点续传) cursor.execute(''' CREATE TABLE IF NOT EXISTS article_cache ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id TEXT NOT NULL, title TEXT NOT NULL, url TEXT, publish_time TEXT, page_num INTEGER, created_at TEXT NOT NULL, FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE ) ''') # 创建索引提升查询性能 cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_tasks_username ON tasks(username) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at DESC) ''') # 为日志表创建索引 cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_task_logs_task_id ON task_logs(task_id) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_task_logs_timestamp ON task_logs(timestamp) ''') # 为文章缓存表创建索引 cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_article_cache_task_id ON article_cache(task_id) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_article_cache_page ON article_cache(task_id, page_num) ''') conn.commit() logger.info(f"数据库初始化完成: {self.db_path}") @contextmanager def get_connection(self): """获取线程安全的数据库连接(上下文管理器)""" if not hasattr(self._local, 'conn') or self._local.conn is None: self._local.conn = sqlite3.connect( self.db_path, check_same_thread=False, timeout=30.0 ) # 设置返回字典而不是元组 self._local.conn.row_factory = sqlite3.Row try: yield self._local.conn except Exception as e: self._local.conn.rollback() logger.error(f"数据库操作失败: {e}") raise def close(self): """关闭当前线程的数据库连接""" if hasattr(self._local, 'conn') and self._local.conn is not None: self._local.conn.close() self._local.conn = None def add_task_log(self, task_id, message, level='info', timestamp=None): """添加任务日志 Args: task_id: 任务ID message: 日志消息 level: 日志级别 (info/success/warning/error) timestamp: 时间戳,默认为当前时间 """ if timestamp is None: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO task_logs (task_id, timestamp, message, level) VALUES (?, ?, ?, ?) ''', (task_id, timestamp, message, level)) conn.commit() def get_task_logs(self, task_id, limit=None): """获取任务的所有日志 Args: task_id: 任务ID limit: 限制返回数量,默认返回所有 Returns: list: 日志列表,按时间顺序 """ with self.get_connection() as conn: cursor = conn.cursor() if limit: cursor.execute(''' SELECT task_id, timestamp, message, level FROM task_logs WHERE task_id = ? ORDER BY id ASC LIMIT ? ''', (task_id, limit)) else: cursor.execute(''' SELECT task_id, timestamp, message, level FROM task_logs WHERE task_id = ? ORDER BY id ASC ''', (task_id,)) rows = cursor.fetchall() return [dict(row) for row in rows] def clear_task_logs(self, task_id): """清除任务的所有日志 Args: task_id: 任务ID """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute('DELETE FROM task_logs WHERE task_id = ?', (task_id,)) conn.commit() def save_articles_batch(self, task_id, articles, page_num): """批量保存文章到缓存表 Args: task_id: 任务ID articles: 文章列表 [{'title': ..., 'url': ..., 'publish_time': ...}, ...] page_num: 页码 """ if not articles: return timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') with self.get_connection() as conn: cursor = conn.cursor() # 批量插入 data = [ (task_id, article.get('标题', article.get('title', '')), article.get('链接', article.get('url', '')), article.get('发布时间', article.get('publish_time', '')), page_num, timestamp) for article in articles ] cursor.executemany(''' INSERT INTO article_cache (task_id, title, url, publish_time, page_num, created_at) VALUES (?, ?, ?, ?, ?, ?) ''', data) conn.commit() logger.debug(f"保存 {len(articles)} 篇文章到缓存(任务{task_id},第{page_num}页)") def get_cached_articles(self, task_id): """获取任务的所有缓存文章 Args: task_id: 任务ID Returns: list: 文章列表 """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT title, url, publish_time, page_num FROM article_cache WHERE task_id = ? ORDER BY id ASC ''', (task_id,)) rows = cursor.fetchall() return [ { '标题': row['title'], '链接': row['url'], '发布时间': row['publish_time'], 'page_num': row['page_num'] } for row in rows ] def get_cached_article_count(self, task_id): """获取任务已缓存的文章数量 Args: task_id: 任务ID Returns: int: 文章数量 """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' SELECT COUNT(*) as count FROM article_cache WHERE task_id = ? ''', (task_id,)) result = cursor.fetchone() return result['count'] if result else 0 def clear_article_cache(self, task_id): """清除任务的所有缓存文章 Args: task_id: 任务ID """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute('DELETE FROM article_cache WHERE task_id = ?', (task_id,)) conn.commit() logger.debug(f"清除任务 {task_id} 的所有缓存文章") # 全局数据库实例 _db_instance = None _db_lock = threading.Lock() def get_database(): """获取全局数据库实例(单例模式)""" global _db_instance if _db_instance is None: with _db_lock: if _db_instance is None: _db_instance = Database() return _db_instance def migrate_from_json(json_file="data/task_queue.json"): """从 JSON 文件迁移数据到 SQLite 数据库 Args: json_file: 原 JSON 文件路径 Returns: migrated_count: 成功迁移的任务数量 """ import json if not os.path.exists(json_file): logger.info("未找到旧的 JSON 文件,跳过数据迁移") return 0 try: # 读取 JSON 数据 with open(json_file, 'r', encoding='utf-8') as f: tasks = json.load(f) if not tasks: logger.info("JSON 文件中没有任务数据") return 0 db = get_database() migrated_count = 0 with db.get_connection() as conn: cursor = conn.cursor() for task in tasks: try: # 检查任务是否已存在 cursor.execute( "SELECT task_id FROM tasks WHERE task_id = ?", (task["task_id"],) ) if cursor.fetchone(): logger.debug(f"任务 {task['task_id']} 已存在,跳过") continue # 插入任务数据 cursor.execute(''' INSERT INTO tasks ( task_id, url, months, use_proxy, proxy_api_url, username, status, created_at, started_at, completed_at, progress, current_step, total_articles, processed_articles, error, result_file ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( task["task_id"], task["url"], task["months"], 1 if task["use_proxy"] else 0, task.get("proxy_api_url"), task.get("username"), task["status"], task["created_at"], task.get("started_at"), task.get("completed_at"), task.get("progress", 0), task.get("current_step"), task.get("total_articles", 0), task.get("processed_articles", 0), task.get("error"), task.get("result_file") )) migrated_count += 1 except Exception as e: logger.error(f"迁移任务 {task.get('task_id')} 失败: {e}") conn.commit() logger.info(f"成功迁移 {migrated_count} 个任务到数据库") # 备份原 JSON 文件 backup_file = json_file + ".backup" if migrated_count > 0: import shutil shutil.copy2(json_file, backup_file) logger.info(f"原 JSON 文件已备份到: {backup_file}") return migrated_count except Exception as e: logger.error(f"数据迁移失败: {e}") return 0