Initial commit: 百家号文章采集系统
This commit is contained in:
413
database.py
Normal file
413
database.py
Normal file
@@ -0,0 +1,413 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
SQLite 数据库管理模块
|
||||
用于替换原有的 JSON 文件存储方式
|
||||
"""
|
||||
import sqlite3
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
"""SQLite 数据库管理器"""
|
||||
|
||||
def __init__(self, db_path="data/baijiahao.db"):
|
||||
self.db_path = db_path
|
||||
self._local = threading.local()
|
||||
self._ensure_database()
|
||||
|
||||
def _ensure_database(self):
|
||||
"""确保数据库文件和表结构存在"""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建任务表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
url TEXT NOT NULL,
|
||||
months REAL NOT NULL,
|
||||
use_proxy INTEGER NOT NULL,
|
||||
proxy_api_url TEXT,
|
||||
username TEXT,
|
||||
status TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
paused_at TEXT,
|
||||
progress INTEGER DEFAULT 0,
|
||||
current_step TEXT,
|
||||
total_articles INTEGER DEFAULT 0,
|
||||
processed_articles INTEGER DEFAULT 0,
|
||||
error TEXT,
|
||||
result_file TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
last_error TEXT,
|
||||
articles_only INTEGER DEFAULT 1,
|
||||
last_page INTEGER DEFAULT 0,
|
||||
last_ctime TEXT
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建任务日志表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS task_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
level TEXT DEFAULT 'info',
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建文章缓存表(用于断点续传)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS article_cache (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
url TEXT,
|
||||
publish_time TEXT,
|
||||
page_num INTEGER,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (task_id) REFERENCES tasks(task_id) ON DELETE CASCADE
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建索引提升查询性能
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status
|
||||
ON tasks(status)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_username
|
||||
ON tasks(username)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_created_at
|
||||
ON tasks(created_at DESC)
|
||||
''')
|
||||
|
||||
# 为日志表创建索引
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_task_logs_task_id
|
||||
ON task_logs(task_id)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_task_logs_timestamp
|
||||
ON task_logs(timestamp)
|
||||
''')
|
||||
|
||||
# 为文章缓存表创建索引
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_article_cache_task_id
|
||||
ON article_cache(task_id)
|
||||
''')
|
||||
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_article_cache_page
|
||||
ON article_cache(task_id, page_num)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"数据库初始化完成: {self.db_path}")
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""获取线程安全的数据库连接(上下文管理器)"""
|
||||
if not hasattr(self._local, 'conn') or self._local.conn is None:
|
||||
self._local.conn = sqlite3.connect(
|
||||
self.db_path,
|
||||
check_same_thread=False,
|
||||
timeout=30.0
|
||||
)
|
||||
# 设置返回字典而不是元组
|
||||
self._local.conn.row_factory = sqlite3.Row
|
||||
|
||||
try:
|
||||
yield self._local.conn
|
||||
except Exception as e:
|
||||
self._local.conn.rollback()
|
||||
logger.error(f"数据库操作失败: {e}")
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""关闭当前线程的数据库连接"""
|
||||
if hasattr(self._local, 'conn') and self._local.conn is not None:
|
||||
self._local.conn.close()
|
||||
self._local.conn = None
|
||||
|
||||
def add_task_log(self, task_id, message, level='info', timestamp=None):
|
||||
"""添加任务日志
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
message: 日志消息
|
||||
level: 日志级别 (info/success/warning/error)
|
||||
timestamp: 时间戳,默认为当前时间
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO task_logs (task_id, timestamp, message, level)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (task_id, timestamp, message, level))
|
||||
conn.commit()
|
||||
|
||||
def get_task_logs(self, task_id, limit=None):
|
||||
"""获取任务的所有日志
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
limit: 限制返回数量,默认返回所有
|
||||
|
||||
Returns:
|
||||
list: 日志列表,按时间顺序
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if limit:
|
||||
cursor.execute('''
|
||||
SELECT task_id, timestamp, message, level
|
||||
FROM task_logs
|
||||
WHERE task_id = ?
|
||||
ORDER BY id ASC
|
||||
LIMIT ?
|
||||
''', (task_id, limit))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT task_id, timestamp, message, level
|
||||
FROM task_logs
|
||||
WHERE task_id = ?
|
||||
ORDER BY id ASC
|
||||
''', (task_id,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def clear_task_logs(self, task_id):
|
||||
"""清除任务的所有日志
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM task_logs WHERE task_id = ?', (task_id,))
|
||||
conn.commit()
|
||||
|
||||
def save_articles_batch(self, task_id, articles, page_num):
|
||||
"""批量保存文章到缓存表
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
articles: 文章列表 [{'title': ..., 'url': ..., 'publish_time': ...}, ...]
|
||||
page_num: 页码
|
||||
"""
|
||||
if not articles:
|
||||
return
|
||||
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 批量插入
|
||||
data = [
|
||||
(task_id,
|
||||
article.get('标题', article.get('title', '')),
|
||||
article.get('链接', article.get('url', '')),
|
||||
article.get('发布时间', article.get('publish_time', '')),
|
||||
page_num,
|
||||
timestamp)
|
||||
for article in articles
|
||||
]
|
||||
|
||||
cursor.executemany('''
|
||||
INSERT INTO article_cache (task_id, title, url, publish_time, page_num, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', data)
|
||||
|
||||
conn.commit()
|
||||
logger.debug(f"保存 {len(articles)} 篇文章到缓存(任务{task_id},第{page_num}页)")
|
||||
|
||||
def get_cached_articles(self, task_id):
|
||||
"""获取任务的所有缓存文章
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
list: 文章列表
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT title, url, publish_time, page_num
|
||||
FROM article_cache
|
||||
WHERE task_id = ?
|
||||
ORDER BY id ASC
|
||||
''', (task_id,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
return [
|
||||
{
|
||||
'标题': row['title'],
|
||||
'链接': row['url'],
|
||||
'发布时间': row['publish_time'],
|
||||
'page_num': row['page_num']
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_cached_article_count(self, task_id):
|
||||
"""获取任务已缓存的文章数量
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
int: 文章数量
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) as count FROM article_cache WHERE task_id = ?
|
||||
''', (task_id,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
return result['count'] if result else 0
|
||||
|
||||
def clear_article_cache(self, task_id):
|
||||
"""清除任务的所有缓存文章
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM article_cache WHERE task_id = ?', (task_id,))
|
||||
conn.commit()
|
||||
logger.debug(f"清除任务 {task_id} 的所有缓存文章")
|
||||
|
||||
|
||||
# 全局数据库实例
|
||||
_db_instance = None
|
||||
_db_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_database():
|
||||
"""获取全局数据库实例(单例模式)"""
|
||||
global _db_instance
|
||||
if _db_instance is None:
|
||||
with _db_lock:
|
||||
if _db_instance is None:
|
||||
_db_instance = Database()
|
||||
return _db_instance
|
||||
|
||||
|
||||
def migrate_from_json(json_file="data/task_queue.json"):
|
||||
"""从 JSON 文件迁移数据到 SQLite 数据库
|
||||
|
||||
Args:
|
||||
json_file: 原 JSON 文件路径
|
||||
|
||||
Returns:
|
||||
migrated_count: 成功迁移的任务数量
|
||||
"""
|
||||
import json
|
||||
|
||||
if not os.path.exists(json_file):
|
||||
logger.info("未找到旧的 JSON 文件,跳过数据迁移")
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 读取 JSON 数据
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
tasks = json.load(f)
|
||||
|
||||
if not tasks:
|
||||
logger.info("JSON 文件中没有任务数据")
|
||||
return 0
|
||||
|
||||
db = get_database()
|
||||
migrated_count = 0
|
||||
|
||||
with db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# 检查任务是否已存在
|
||||
cursor.execute(
|
||||
"SELECT task_id FROM tasks WHERE task_id = ?",
|
||||
(task["task_id"],)
|
||||
)
|
||||
|
||||
if cursor.fetchone():
|
||||
logger.debug(f"任务 {task['task_id']} 已存在,跳过")
|
||||
continue
|
||||
|
||||
# 插入任务数据
|
||||
cursor.execute('''
|
||||
INSERT INTO tasks (
|
||||
task_id, url, months, use_proxy, proxy_api_url,
|
||||
username, status, created_at, started_at, completed_at,
|
||||
progress, current_step, total_articles, processed_articles,
|
||||
error, result_file
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
task["task_id"],
|
||||
task["url"],
|
||||
task["months"],
|
||||
1 if task["use_proxy"] else 0,
|
||||
task.get("proxy_api_url"),
|
||||
task.get("username"),
|
||||
task["status"],
|
||||
task["created_at"],
|
||||
task.get("started_at"),
|
||||
task.get("completed_at"),
|
||||
task.get("progress", 0),
|
||||
task.get("current_step"),
|
||||
task.get("total_articles", 0),
|
||||
task.get("processed_articles", 0),
|
||||
task.get("error"),
|
||||
task.get("result_file")
|
||||
))
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移任务 {task.get('task_id')} 失败: {e}")
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"成功迁移 {migrated_count} 个任务到数据库")
|
||||
|
||||
# 备份原 JSON 文件
|
||||
backup_file = json_file + ".backup"
|
||||
if migrated_count > 0:
|
||||
import shutil
|
||||
shutil.copy2(json_file, backup_file)
|
||||
logger.info(f"原 JSON 文件已备份到: {backup_file}")
|
||||
|
||||
return migrated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据迁移失败: {e}")
|
||||
return 0
|
||||
Reference in New Issue
Block a user