414 lines
14 KiB
Python
414 lines
14 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
SQLite 数据库管理模块
|
|||
|
|
用于替换原有的 JSON 文件存储方式
|
|||
|
|
"""
|
|||
|
|
import sqlite3
|
|||
|
|
import os
|
|||
|
|
import logging
|
|||
|
|
from datetime import datetime
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
import threading
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Database:
|
|||
|
|
"""SQLite 数据库管理器"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_path="data/baijiahao.db"):
|
|||
|
|
self.db_path = db_path
|
|||
|
|
self._local = threading.local()
|
|||
|
|
self._ensure_database()
|
|||
|
|
|
|||
|
|
def _ensure_database(self):
|
|||
|
|
"""确保数据库文件和表结构存在"""
|
|||
|
|
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
|||
|
|
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 创建任务表
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
|||
|
|
task_id TEXT PRIMARY KEY,
|
|||
|
|
url TEXT NOT NULL,
|
|||
|
|
months REAL NOT NULL,
|
|||
|
|
use_proxy INTEGER NOT NULL,
|
|||
|
|
proxy_api_url TEXT,
|
|||
|
|
username TEXT,
|
|||
|
|
status TEXT NOT NULL,
|
|||
|
|
created_at TEXT NOT NULL,
|
|||
|
|
started_at TEXT,
|
|||
|
|
completed_at TEXT,
|
|||
|
|
paused_at TEXT,
|
|||
|
|
progress INTEGER DEFAULT 0,
|
|||
|
|
current_step TEXT,
|
|||
|
|
total_articles INTEGER DEFAULT 0,
|
|||
|
|
processed_articles INTEGER DEFAULT 0,
|
|||
|
|
error TEXT,
|
|||
|
|
result_file TEXT,
|
|||
|
|
retry_count INTEGER DEFAULT 0,
|
|||
|
|
last_error TEXT,
|
|||
|
|
articles_only INTEGER DEFAULT 1,
|
|||
|
|
last_page INTEGER DEFAULT 0,
|
|||
|
|
last_ctime TEXT
|
|||
|
|
)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
# 创建任务日志表
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE TABLE IF NOT EXISTS task_logs (
|
|||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|||
|
|
task_id TEXT NOT NULL,
|
|||
|
|
timestamp TEXT NOT NULL,
|
|||
|
|
message TEXT NOT NULL,
|
|||
|
|
level TEXT DEFAULT 'info',
|
|||
|
|
FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE
|
|||
|
|
)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
# 创建文章缓存表(用于断点续传)
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE TABLE IF NOT EXISTS article_cache (
|
|||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|||
|
|
task_id TEXT NOT NULL,
|
|||
|
|
title TEXT NOT NULL,
|
|||
|
|
url TEXT,
|
|||
|
|
publish_time TEXT,
|
|||
|
|
page_num INTEGER,
|
|||
|
|
created_at TEXT NOT NULL,
|
|||
|
|
FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE
|
|||
|
|
)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
# 创建索引提升查询性能
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE INDEX IF NOT EXISTS idx_tasks_status
|
|||
|
|
ON tasks(status)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE INDEX IF NOT EXISTS idx_tasks_username
|
|||
|
|
ON tasks(username)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE INDEX IF NOT EXISTS idx_tasks_created_at
|
|||
|
|
ON tasks(created_at DESC)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
# 为日志表创建索引
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id
|
|||
|
|
ON task_logs(task_id)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE INDEX IF NOT EXISTS idx_task_logs_timestamp
|
|||
|
|
ON task_logs(timestamp)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
# 为文章缓存表创建索引
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE INDEX IF NOT EXISTS idx_article_cache_task_id
|
|||
|
|
ON article_cache(task_id)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE INDEX IF NOT EXISTS idx_article_cache_page
|
|||
|
|
ON article_cache(task_id, page_num)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
conn.commit()
|
|||
|
|
logger.info(f"数据库初始化完成: {self.db_path}")
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def get_connection(self):
|
|||
|
|
"""获取线程安全的数据库连接(上下文管理器)"""
|
|||
|
|
if not hasattr(self._local, 'conn') or self._local.conn is None:
|
|||
|
|
self._local.conn = sqlite3.connect(
|
|||
|
|
self.db_path,
|
|||
|
|
check_same_thread=False,
|
|||
|
|
timeout=30.0
|
|||
|
|
)
|
|||
|
|
# 设置返回字典而不是元组
|
|||
|
|
self._local.conn.row_factory = sqlite3.Row
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
yield self._local.conn
|
|||
|
|
except Exception as e:
|
|||
|
|
self._local.conn.rollback()
|
|||
|
|
logger.error(f"数据库操作失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭当前线程的数据库连接"""
|
|||
|
|
if hasattr(self._local, 'conn') and self._local.conn is not None:
|
|||
|
|
self._local.conn.close()
|
|||
|
|
self._local.conn = None
|
|||
|
|
|
|||
|
|
def add_task_log(self, task_id, message, level='info', timestamp=None):
|
|||
|
|
"""添加任务日志
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
message: 日志消息
|
|||
|
|
level: 日志级别 (info/success/warning/error)
|
|||
|
|
timestamp: 时间戳,默认为当前时间
|
|||
|
|
"""
|
|||
|
|
if timestamp is None:
|
|||
|
|
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|||
|
|
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute('''
|
|||
|
|
INSERT INTO task_logs (task_id, timestamp, message, level)
|
|||
|
|
VALUES (?, ?, ?, ?)
|
|||
|
|
''', (task_id, timestamp, message, level))
|
|||
|
|
conn.commit()
|
|||
|
|
|
|||
|
|
def get_task_logs(self, task_id, limit=None):
|
|||
|
|
"""获取任务的所有日志
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
limit: 限制返回数量,默认返回所有
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
list: 日志列表,按时间顺序
|
|||
|
|
"""
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
if limit:
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT task_id, timestamp, message, level
|
|||
|
|
FROM task_logs
|
|||
|
|
WHERE task_id = ?
|
|||
|
|
ORDER BY id ASC
|
|||
|
|
LIMIT ?
|
|||
|
|
''', (task_id, limit))
|
|||
|
|
else:
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT task_id, timestamp, message, level
|
|||
|
|
FROM task_logs
|
|||
|
|
WHERE task_id = ?
|
|||
|
|
ORDER BY id ASC
|
|||
|
|
''', (task_id,))
|
|||
|
|
|
|||
|
|
rows = cursor.fetchall()
|
|||
|
|
return [dict(row) for row in rows]
|
|||
|
|
|
|||
|
|
def clear_task_logs(self, task_id):
|
|||
|
|
"""清除任务的所有日志
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
"""
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute('DELETE FROM task_logs WHERE task_id = ?', (task_id,))
|
|||
|
|
conn.commit()
|
|||
|
|
|
|||
|
|
def save_articles_batch(self, task_id, articles, page_num):
|
|||
|
|
"""批量保存文章到缓存表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
articles: 文章列表 [{'title': ..., 'url': ..., 'publish_time': ...}, ...]
|
|||
|
|
page_num: 页码
|
|||
|
|
"""
|
|||
|
|
if not articles:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|||
|
|
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 批量插入
|
|||
|
|
data = [
|
|||
|
|
(task_id,
|
|||
|
|
article.get('标题', article.get('title', '')),
|
|||
|
|
article.get('链接', article.get('url', '')),
|
|||
|
|
article.get('发布时间', article.get('publish_time', '')),
|
|||
|
|
page_num,
|
|||
|
|
timestamp)
|
|||
|
|
for article in articles
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
cursor.executemany('''
|
|||
|
|
INSERT INTO article_cache (task_id, title, url, publish_time, page_num, created_at)
|
|||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|||
|
|
''', data)
|
|||
|
|
|
|||
|
|
conn.commit()
|
|||
|
|
logger.debug(f"保存 {len(articles)} 篇文章到缓存(任务{task_id},第{page_num}页)")
|
|||
|
|
|
|||
|
|
def get_cached_articles(self, task_id):
|
|||
|
|
"""获取任务的所有缓存文章
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
list: 文章列表
|
|||
|
|
"""
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT title, url, publish_time, page_num
|
|||
|
|
FROM article_cache
|
|||
|
|
WHERE task_id = ?
|
|||
|
|
ORDER BY id ASC
|
|||
|
|
''', (task_id,))
|
|||
|
|
|
|||
|
|
rows = cursor.fetchall()
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
'标题': row['title'],
|
|||
|
|
'链接': row['url'],
|
|||
|
|
'发布时间': row['publish_time'],
|
|||
|
|
'page_num': row['page_num']
|
|||
|
|
}
|
|||
|
|
for row in rows
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def get_cached_article_count(self, task_id):
|
|||
|
|
"""获取任务已缓存的文章数量
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
int: 文章数量
|
|||
|
|
"""
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT COUNT(*) as count FROM article_cache WHERE task_id = ?
|
|||
|
|
''', (task_id,))
|
|||
|
|
|
|||
|
|
result = cursor.fetchone()
|
|||
|
|
return result['count'] if result else 0
|
|||
|
|
|
|||
|
|
def clear_article_cache(self, task_id):
|
|||
|
|
"""清除任务的所有缓存文章
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
"""
|
|||
|
|
with self.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute('DELETE FROM article_cache WHERE task_id = ?', (task_id,))
|
|||
|
|
conn.commit()
|
|||
|
|
logger.debug(f"清除任务 {task_id} 的所有缓存文章")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局数据库实例
|
|||
|
|
_db_instance = None
|
|||
|
|
_db_lock = threading.Lock()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_database():
|
|||
|
|
"""获取全局数据库实例(单例模式)"""
|
|||
|
|
global _db_instance
|
|||
|
|
if _db_instance is None:
|
|||
|
|
with _db_lock:
|
|||
|
|
if _db_instance is None:
|
|||
|
|
_db_instance = Database()
|
|||
|
|
return _db_instance
|
|||
|
|
|
|||
|
|
|
|||
|
|
def migrate_from_json(json_file="data/task_queue.json"):
|
|||
|
|
"""从 JSON 文件迁移数据到 SQLite 数据库
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
json_file: 原 JSON 文件路径
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
migrated_count: 成功迁移的任务数量
|
|||
|
|
"""
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
if not os.path.exists(json_file):
|
|||
|
|
logger.info("未找到旧的 JSON 文件,跳过数据迁移")
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 读取 JSON 数据
|
|||
|
|
with open(json_file, 'r', encoding='utf-8') as f:
|
|||
|
|
tasks = json.load(f)
|
|||
|
|
|
|||
|
|
if not tasks:
|
|||
|
|
logger.info("JSON 文件中没有任务数据")
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
db = get_database()
|
|||
|
|
migrated_count = 0
|
|||
|
|
|
|||
|
|
with db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
for task in tasks:
|
|||
|
|
try:
|
|||
|
|
# 检查任务是否已存在
|
|||
|
|
cursor.execute(
|
|||
|
|
"SELECT task_id FROM tasks WHERE task_id = ?",
|
|||
|
|
(task["task_id"],)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if cursor.fetchone():
|
|||
|
|
logger.debug(f"任务 {task['task_id']} 已存在,跳过")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 插入任务数据
|
|||
|
|
cursor.execute('''
|
|||
|
|
INSERT INTO tasks (
|
|||
|
|
task_id, url, months, use_proxy, proxy_api_url,
|
|||
|
|
username, status, created_at, started_at, completed_at,
|
|||
|
|
progress, current_step, total_articles, processed_articles,
|
|||
|
|
error, result_file
|
|||
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|||
|
|
''', (
|
|||
|
|
task["task_id"],
|
|||
|
|
task["url"],
|
|||
|
|
task["months"],
|
|||
|
|
1 if task["use_proxy"] else 0,
|
|||
|
|
task.get("proxy_api_url"),
|
|||
|
|
task.get("username"),
|
|||
|
|
task["status"],
|
|||
|
|
task["created_at"],
|
|||
|
|
task.get("started_at"),
|
|||
|
|
task.get("completed_at"),
|
|||
|
|
task.get("progress", 0),
|
|||
|
|
task.get("current_step"),
|
|||
|
|
task.get("total_articles", 0),
|
|||
|
|
task.get("processed_articles", 0),
|
|||
|
|
task.get("error"),
|
|||
|
|
task.get("result_file")
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
migrated_count += 1
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"迁移任务 {task.get('task_id')} 失败: {e}")
|
|||
|
|
|
|||
|
|
conn.commit()
|
|||
|
|
|
|||
|
|
logger.info(f"成功迁移 {migrated_count} 个任务到数据库")
|
|||
|
|
|
|||
|
|
# 备份原 JSON 文件
|
|||
|
|
backup_file = json_file + ".backup"
|
|||
|
|
if migrated_count > 0:
|
|||
|
|
import shutil
|
|||
|
|
shutil.copy2(json_file, backup_file)
|
|||
|
|
logger.info(f"原 JSON 文件已备份到: {backup_file}")
|
|||
|
|
|
|||
|
|
return migrated_count
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"数据迁移失败: {e}")
|
|||
|
|
return 0
|