358 lines
14 KiB
Python
358 lines
14 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
任务队列管理模块
|
|||
|
|
支持离线处理、进度跟踪、结果导出
|
|||
|
|
使用 SQLite 数据库存储(替代原 JSON 文件)
|
|||
|
|
"""
|
|||
|
|
import os
|
|||
|
|
import threading
|
|||
|
|
import time
|
|||
|
|
from datetime import datetime
|
|||
|
|
from enum import Enum
|
|||
|
|
import logging
|
|||
|
|
from database import get_database, migrate_from_json
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TaskStatus(Enum):
|
|||
|
|
"""任务状态"""
|
|||
|
|
PENDING = "pending" # 就绪(准备好了,等待工作线程)
|
|||
|
|
PROCESSING = "processing" # 进行中
|
|||
|
|
COMPLETED = "completed" # 完成
|
|||
|
|
FAILED = "failed" # 失败
|
|||
|
|
PAUSED = "paused" # 暂停(将在指定时间后自动恢复)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TaskQueue:
|
|||
|
|
"""任务队列管理器(使用 SQLite 数据库)"""
|
|||
|
|
|
|||
|
|
def __init__(self, queue_file="data/task_queue.json", results_dir="data/results"):
|
|||
|
|
self.results_dir = results_dir
|
|||
|
|
self.lock = threading.Lock()
|
|||
|
|
self.db = get_database()
|
|||
|
|
self._ensure_dirs()
|
|||
|
|
|
|||
|
|
# 从旧 JSON 文件迁移数据(只执行一次)
|
|||
|
|
if os.path.exists(queue_file):
|
|||
|
|
migrate_from_json(queue_file)
|
|||
|
|
|
|||
|
|
def _ensure_dirs(self):
|
|||
|
|
"""确保必要的目录存在"""
|
|||
|
|
os.makedirs(self.results_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
def add_task(self, url, months=6, use_proxy=False, proxy_api_url=None, username=None, articles_only=True):
|
|||
|
|
"""添加新任务到队列
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
url: 百家号URL
|
|||
|
|
months: 获取月数
|
|||
|
|
use_proxy: 是否使用代理
|
|||
|
|
proxy_api_url: 代理API地址
|
|||
|
|
username: 用户名
|
|||
|
|
articles_only: 是否仅爬取文章(跳过视频)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
"""
|
|||
|
|
with self.lock:
|
|||
|
|
task_id = f"task_{int(time.time() * 1000)}"
|
|||
|
|
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
|
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute('''
|
|||
|
|
INSERT INTO tasks (
|
|||
|
|
task_id, url, months, use_proxy, proxy_api_url,
|
|||
|
|
username, status, created_at, progress, current_step,
|
|||
|
|
total_articles, processed_articles, articles_only
|
|||
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|||
|
|
''', (
|
|||
|
|
task_id, url, months, 1 if use_proxy else 0, proxy_api_url,
|
|||
|
|
username, TaskStatus.PENDING.value, created_at, 0, "等待处理",
|
|||
|
|
0, 0, 1 if articles_only else 0
|
|||
|
|
))
|
|||
|
|
conn.commit()
|
|||
|
|
|
|||
|
|
logger.info(f"添加任务: {task_id}")
|
|||
|
|
return task_id
|
|||
|
|
|
|||
|
|
def get_task(self, task_id):
|
|||
|
|
"""获取任务信息"""
|
|||
|
|
with self.lock:
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute(
|
|||
|
|
"SELECT * FROM tasks WHERE task_id = ?",
|
|||
|
|
(task_id,)
|
|||
|
|
)
|
|||
|
|
row = cursor.fetchone()
|
|||
|
|
|
|||
|
|
if row:
|
|||
|
|
task = dict(row)
|
|||
|
|
# 将 use_proxy 从整数转换为布尔值
|
|||
|
|
task['use_proxy'] = bool(task['use_proxy'])
|
|||
|
|
# 将 articles_only 从整数转换为布尔值
|
|||
|
|
task['articles_only'] = bool(task.get('articles_only', 1))
|
|||
|
|
return task
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def get_all_tasks(self, username=None):
|
|||
|
|
"""获取所有任务(可按用户过滤)"""
|
|||
|
|
with self.lock:
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
if username:
|
|||
|
|
cursor.execute(
|
|||
|
|
"SELECT * FROM tasks WHERE username = ? ORDER BY created_at DESC",
|
|||
|
|
(username,)
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
cursor.execute("SELECT * FROM tasks ORDER BY created_at DESC")
|
|||
|
|
|
|||
|
|
rows = cursor.fetchall()
|
|||
|
|
tasks = []
|
|||
|
|
for row in rows:
|
|||
|
|
task = dict(row)
|
|||
|
|
# 将 use_proxy 从整数转换为布尔值
|
|||
|
|
task['use_proxy'] = bool(task['use_proxy'])
|
|||
|
|
# 将 articles_only 从整数转换为布尔值
|
|||
|
|
task['articles_only'] = bool(task.get('articles_only', 1))
|
|||
|
|
tasks.append(task)
|
|||
|
|
return tasks
|
|||
|
|
|
|||
|
|
def get_pending_task(self):
|
|||
|
|
"""获取下一个待处理的任务(包括检查暂停任务是否可恢复)"""
|
|||
|
|
with self.lock:
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 首先检查是否有暂停任务需要恢复
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
current_time = datetime.now()
|
|||
|
|
|
|||
|
|
cursor.execute(
|
|||
|
|
"SELECT * FROM tasks WHERE status = ? ORDER BY paused_at ASC",
|
|||
|
|
(TaskStatus.PAUSED.value,)
|
|||
|
|
)
|
|||
|
|
paused_tasks = cursor.fetchall()
|
|||
|
|
|
|||
|
|
for row in paused_tasks:
|
|||
|
|
paused_at_str = row['paused_at']
|
|||
|
|
if paused_at_str:
|
|||
|
|
paused_at = datetime.strptime(paused_at_str, '%Y-%m-%d %H:%M:%S')
|
|||
|
|
# 检查是否已经暂停超过10分钟
|
|||
|
|
if current_time - paused_at >= timedelta(minutes=10):
|
|||
|
|
task_id = row['task_id']
|
|||
|
|
# 恢复任务为待处理状态(保留 last_page 和 last_ctime)
|
|||
|
|
cursor.execute("""
|
|||
|
|
UPDATE tasks SET
|
|||
|
|
status = ?,
|
|||
|
|
current_step = ?,
|
|||
|
|
retry_count = 0,
|
|||
|
|
paused_at = NULL
|
|||
|
|
WHERE task_id = ?
|
|||
|
|
""", (TaskStatus.PENDING.value, "等待处理(从断点继续)", task_id))
|
|||
|
|
conn.commit()
|
|||
|
|
logger.info(f"任务 {task_id} 已从暂停状态恢复,将从第{row.get('last_page', 1)}页继续")
|
|||
|
|
|
|||
|
|
# 获取待处理任务
|
|||
|
|
cursor.execute(
|
|||
|
|
"SELECT * FROM tasks WHERE status = ? ORDER BY created_at ASC LIMIT 1",
|
|||
|
|
(TaskStatus.PENDING.value,)
|
|||
|
|
)
|
|||
|
|
row = cursor.fetchone()
|
|||
|
|
|
|||
|
|
if row:
|
|||
|
|
task = dict(row)
|
|||
|
|
# 将 use_proxy 从整数转换为布尔值
|
|||
|
|
task['use_proxy'] = bool(task['use_proxy'])
|
|||
|
|
# 将 articles_only 从整数转换为布尔值
|
|||
|
|
task['articles_only'] = bool(task.get('articles_only', 1))
|
|||
|
|
return task
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def update_task_status(self, task_id, status, **kwargs):
|
|||
|
|
"""更新任务状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
status: 新状态
|
|||
|
|
**kwargs: 其他要更新的字段
|
|||
|
|
"""
|
|||
|
|
with self.lock:
|
|||
|
|
status_value = status.value if isinstance(status, TaskStatus) else status
|
|||
|
|
|
|||
|
|
# 准备更新字段
|
|||
|
|
update_fields = {"status": status_value}
|
|||
|
|
|
|||
|
|
# 更新时间戳
|
|||
|
|
if status == TaskStatus.PROCESSING:
|
|||
|
|
update_fields["started_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
|
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
|||
|
|
update_fields["completed_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
|
|
|||
|
|
# 合并其他字段
|
|||
|
|
update_fields.update(kwargs)
|
|||
|
|
|
|||
|
|
# 构建 SQL 更新语句
|
|||
|
|
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
|
|||
|
|
values = list(update_fields.values()) + [task_id]
|
|||
|
|
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute(
|
|||
|
|
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
|
|||
|
|
values
|
|||
|
|
)
|
|||
|
|
conn.commit()
|
|||
|
|
return cursor.rowcount > 0
|
|||
|
|
|
|||
|
|
def update_task_progress(self, task_id, progress, current_step=None, processed_articles=None):
|
|||
|
|
"""更新任务进度
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
progress: 进度百分比 (0-100)
|
|||
|
|
current_step: 当前步骤描述
|
|||
|
|
processed_articles: 已处理文章数
|
|||
|
|
"""
|
|||
|
|
with self.lock:
|
|||
|
|
update_fields = {
|
|||
|
|
"progress": min(100, max(0, progress))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if current_step is not None:
|
|||
|
|
update_fields["current_step"] = current_step
|
|||
|
|
if processed_articles is not None:
|
|||
|
|
update_fields["processed_articles"] = processed_articles
|
|||
|
|
|
|||
|
|
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
|
|||
|
|
values = list(update_fields.values()) + [task_id]
|
|||
|
|
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.execute(
|
|||
|
|
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
|
|||
|
|
values
|
|||
|
|
)
|
|||
|
|
conn.commit()
|
|||
|
|
return cursor.rowcount > 0
|
|||
|
|
|
|||
|
|
def get_queue_stats(self, username=None):
|
|||
|
|
"""获取队列统计信息"""
|
|||
|
|
with self.lock:
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 基础查询
|
|||
|
|
if username:
|
|||
|
|
base_query = "SELECT status, COUNT(*) as count FROM tasks WHERE username = ? GROUP BY status"
|
|||
|
|
cursor.execute(base_query, (username,))
|
|||
|
|
else:
|
|||
|
|
base_query = "SELECT status, COUNT(*) as count FROM tasks GROUP BY status"
|
|||
|
|
cursor.execute(base_query)
|
|||
|
|
|
|||
|
|
# 统计各状态数量
|
|||
|
|
status_counts = {row["status"]: row["count"] for row in cursor.fetchall()}
|
|||
|
|
|
|||
|
|
# 获取总数
|
|||
|
|
if username:
|
|||
|
|
cursor.execute("SELECT COUNT(*) as total FROM tasks WHERE username = ?", (username,))
|
|||
|
|
else:
|
|||
|
|
cursor.execute("SELECT COUNT(*) as total FROM tasks")
|
|||
|
|
total = cursor.fetchone()["total"]
|
|||
|
|
|
|||
|
|
stats = {
|
|||
|
|
"total": total,
|
|||
|
|
"pending": status_counts.get(TaskStatus.PENDING.value, 0),
|
|||
|
|
"processing": status_counts.get(TaskStatus.PROCESSING.value, 0),
|
|||
|
|
"completed": status_counts.get(TaskStatus.COMPLETED.value, 0),
|
|||
|
|
"failed": status_counts.get(TaskStatus.FAILED.value, 0),
|
|||
|
|
"paused": status_counts.get(TaskStatus.PAUSED.value, 0)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return stats
|
|||
|
|
|
|||
|
|
def delete_task(self, task_id):
|
|||
|
|
"""删除任务(先自动终止再删除)"""
|
|||
|
|
with self.lock:
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 检查任务是否存在
|
|||
|
|
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
|
|||
|
|
row = cursor.fetchone()
|
|||
|
|
|
|||
|
|
if not row:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 如果任务还在运行,先终止
|
|||
|
|
if row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
|
|||
|
|
cursor.execute('''
|
|||
|
|
UPDATE tasks SET
|
|||
|
|
status = ?,
|
|||
|
|
error = ?,
|
|||
|
|
current_step = ?,
|
|||
|
|
completed_at = ?
|
|||
|
|
WHERE task_id = ?
|
|||
|
|
''', (
|
|||
|
|
TaskStatus.FAILED.value,
|
|||
|
|
"任务已被用户删除",
|
|||
|
|
"任务已终止",
|
|||
|
|
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|||
|
|
task_id
|
|||
|
|
))
|
|||
|
|
conn.commit()
|
|||
|
|
logger.info(f"终止任务: {task_id}")
|
|||
|
|
|
|||
|
|
# 然后从数据库中删除
|
|||
|
|
cursor.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,))
|
|||
|
|
conn.commit()
|
|||
|
|
|
|||
|
|
if cursor.rowcount > 0:
|
|||
|
|
logger.info(f"删除任务: {task_id}")
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def cancel_task(self, task_id):
|
|||
|
|
"""终止任务(将等待中或处理中任务标记为失败)"""
|
|||
|
|
with self.lock:
|
|||
|
|
with self.db.get_connection() as conn:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 检查任务状态
|
|||
|
|
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
|
|||
|
|
row = cursor.fetchone()
|
|||
|
|
|
|||
|
|
if row and row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
|
|||
|
|
cursor.execute('''
|
|||
|
|
UPDATE tasks SET
|
|||
|
|
status = ?,
|
|||
|
|
error = ?,
|
|||
|
|
current_step = ?,
|
|||
|
|
completed_at = ?
|
|||
|
|
WHERE task_id = ?
|
|||
|
|
''', (
|
|||
|
|
TaskStatus.FAILED.value,
|
|||
|
|
"任务已被用户终止",
|
|||
|
|
"任务已终止",
|
|||
|
|
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|||
|
|
task_id
|
|||
|
|
))
|
|||
|
|
conn.commit()
|
|||
|
|
return cursor.rowcount > 0
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 全局队列实例
|
|||
|
|
_task_queue = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_task_queue():
|
|||
|
|
"""获取全局任务队列实例"""
|
|||
|
|
global _task_queue
|
|||
|
|
if _task_queue is None:
|
|||
|
|
_task_queue = TaskQueue()
|
|||
|
|
return _task_queue
|