358 lines
14 KiB
Python
358 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
任务队列管理模块
|
||
支持离线处理、进度跟踪、结果导出
|
||
使用 SQLite 数据库存储(替代原 JSON 文件)
|
||
"""
|
||
import os
|
||
import threading
|
||
import time
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
import logging
|
||
from database import get_database, migrate_from_json
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TaskStatus(Enum):
|
||
"""任务状态"""
|
||
PENDING = "pending" # 就绪(准备好了,等待工作线程)
|
||
PROCESSING = "processing" # 进行中
|
||
COMPLETED = "completed" # 完成
|
||
FAILED = "failed" # 失败
|
||
PAUSED = "paused" # 暂停(将在指定时间后自动恢复)
|
||
|
||
|
||
class TaskQueue:
|
||
"""任务队列管理器(使用 SQLite 数据库)"""
|
||
|
||
def __init__(self, queue_file="data/task_queue.json", results_dir="data/results"):
|
||
self.results_dir = results_dir
|
||
self.lock = threading.Lock()
|
||
self.db = get_database()
|
||
self._ensure_dirs()
|
||
|
||
# 从旧 JSON 文件迁移数据(只执行一次)
|
||
if os.path.exists(queue_file):
|
||
migrate_from_json(queue_file)
|
||
|
||
def _ensure_dirs(self):
|
||
"""确保必要的目录存在"""
|
||
os.makedirs(self.results_dir, exist_ok=True)
|
||
|
||
def add_task(self, url, months=6, use_proxy=False, proxy_api_url=None, username=None, articles_only=True):
|
||
"""添加新任务到队列
|
||
|
||
Args:
|
||
url: 百家号URL
|
||
months: 获取月数
|
||
use_proxy: 是否使用代理
|
||
proxy_api_url: 代理API地址
|
||
username: 用户名
|
||
articles_only: 是否仅爬取文章(跳过视频)
|
||
|
||
Returns:
|
||
task_id: 任务ID
|
||
"""
|
||
with self.lock:
|
||
task_id = f"task_{int(time.time() * 1000)}"
|
||
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
INSERT INTO tasks (
|
||
task_id, url, months, use_proxy, proxy_api_url,
|
||
username, status, created_at, progress, current_step,
|
||
total_articles, processed_articles, articles_only
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
''', (
|
||
task_id, url, months, 1 if use_proxy else 0, proxy_api_url,
|
||
username, TaskStatus.PENDING.value, created_at, 0, "等待处理",
|
||
0, 0, 1 if articles_only else 0
|
||
))
|
||
conn.commit()
|
||
|
||
logger.info(f"添加任务: {task_id}")
|
||
return task_id
|
||
|
||
def get_task(self, task_id):
|
||
"""获取任务信息"""
|
||
with self.lock:
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"SELECT * FROM tasks WHERE task_id = ?",
|
||
(task_id,)
|
||
)
|
||
row = cursor.fetchone()
|
||
|
||
if row:
|
||
task = dict(row)
|
||
# 将 use_proxy 从整数转换为布尔值
|
||
task['use_proxy'] = bool(task['use_proxy'])
|
||
# 将 articles_only 从整数转换为布尔值
|
||
task['articles_only'] = bool(task.get('articles_only', 1))
|
||
return task
|
||
return None
|
||
|
||
def get_all_tasks(self, username=None):
|
||
"""获取所有任务(可按用户过滤)"""
|
||
with self.lock:
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
if username:
|
||
cursor.execute(
|
||
"SELECT * FROM tasks WHERE username = ? ORDER BY created_at DESC",
|
||
(username,)
|
||
)
|
||
else:
|
||
cursor.execute("SELECT * FROM tasks ORDER BY created_at DESC")
|
||
|
||
rows = cursor.fetchall()
|
||
tasks = []
|
||
for row in rows:
|
||
task = dict(row)
|
||
# 将 use_proxy 从整数转换为布尔值
|
||
task['use_proxy'] = bool(task['use_proxy'])
|
||
# 将 articles_only 从整数转换为布尔值
|
||
task['articles_only'] = bool(task.get('articles_only', 1))
|
||
tasks.append(task)
|
||
return tasks
|
||
|
||
def get_pending_task(self):
|
||
"""获取下一个待处理的任务(包括检查暂停任务是否可恢复)"""
|
||
with self.lock:
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 首先检查是否有暂停任务需要恢复
|
||
from datetime import datetime, timedelta
|
||
current_time = datetime.now()
|
||
|
||
cursor.execute(
|
||
"SELECT * FROM tasks WHERE status = ? ORDER BY paused_at ASC",
|
||
(TaskStatus.PAUSED.value,)
|
||
)
|
||
paused_tasks = cursor.fetchall()
|
||
|
||
for row in paused_tasks:
|
||
paused_at_str = row['paused_at']
|
||
if paused_at_str:
|
||
paused_at = datetime.strptime(paused_at_str, '%Y-%m-%d %H:%M:%S')
|
||
# 检查是否已经暂停超过10分钟
|
||
if current_time - paused_at >= timedelta(minutes=10):
|
||
task_id = row['task_id']
|
||
# 恢复任务为待处理状态(保留 last_page 和 last_ctime)
|
||
cursor.execute("""
|
||
UPDATE tasks SET
|
||
status = ?,
|
||
current_step = ?,
|
||
retry_count = 0,
|
||
paused_at = NULL
|
||
WHERE task_id = ?
|
||
""", (TaskStatus.PENDING.value, "等待处理(从断点继续)", task_id))
|
||
conn.commit()
|
||
logger.info(f"任务 {task_id} 已从暂停状态恢复,将从第{row.get('last_page', 1)}页继续")
|
||
|
||
# 获取待处理任务
|
||
cursor.execute(
|
||
"SELECT * FROM tasks WHERE status = ? ORDER BY created_at ASC LIMIT 1",
|
||
(TaskStatus.PENDING.value,)
|
||
)
|
||
row = cursor.fetchone()
|
||
|
||
if row:
|
||
task = dict(row)
|
||
# 将 use_proxy 从整数转换为布尔值
|
||
task['use_proxy'] = bool(task['use_proxy'])
|
||
# 将 articles_only 从整数转换为布尔值
|
||
task['articles_only'] = bool(task.get('articles_only', 1))
|
||
return task
|
||
return None
|
||
|
||
def update_task_status(self, task_id, status, **kwargs):
|
||
"""更新任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
status: 新状态
|
||
**kwargs: 其他要更新的字段
|
||
"""
|
||
with self.lock:
|
||
status_value = status.value if isinstance(status, TaskStatus) else status
|
||
|
||
# 准备更新字段
|
||
update_fields = {"status": status_value}
|
||
|
||
# 更新时间戳
|
||
if status == TaskStatus.PROCESSING:
|
||
update_fields["started_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
||
update_fields["completed_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
# 合并其他字段
|
||
update_fields.update(kwargs)
|
||
|
||
# 构建 SQL 更新语句
|
||
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
|
||
values = list(update_fields.values()) + [task_id]
|
||
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
|
||
values
|
||
)
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
def update_task_progress(self, task_id, progress, current_step=None, processed_articles=None):
|
||
"""更新任务进度
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
progress: 进度百分比 (0-100)
|
||
current_step: 当前步骤描述
|
||
processed_articles: 已处理文章数
|
||
"""
|
||
with self.lock:
|
||
update_fields = {
|
||
"progress": min(100, max(0, progress))
|
||
}
|
||
|
||
if current_step is not None:
|
||
update_fields["current_step"] = current_step
|
||
if processed_articles is not None:
|
||
update_fields["processed_articles"] = processed_articles
|
||
|
||
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
|
||
values = list(update_fields.values()) + [task_id]
|
||
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
|
||
values
|
||
)
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
def get_queue_stats(self, username=None):
|
||
"""获取队列统计信息"""
|
||
with self.lock:
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 基础查询
|
||
if username:
|
||
base_query = "SELECT status, COUNT(*) as count FROM tasks WHERE username = ? GROUP BY status"
|
||
cursor.execute(base_query, (username,))
|
||
else:
|
||
base_query = "SELECT status, COUNT(*) as count FROM tasks GROUP BY status"
|
||
cursor.execute(base_query)
|
||
|
||
# 统计各状态数量
|
||
status_counts = {row["status"]: row["count"] for row in cursor.fetchall()}
|
||
|
||
# 获取总数
|
||
if username:
|
||
cursor.execute("SELECT COUNT(*) as total FROM tasks WHERE username = ?", (username,))
|
||
else:
|
||
cursor.execute("SELECT COUNT(*) as total FROM tasks")
|
||
total = cursor.fetchone()["total"]
|
||
|
||
stats = {
|
||
"total": total,
|
||
"pending": status_counts.get(TaskStatus.PENDING.value, 0),
|
||
"processing": status_counts.get(TaskStatus.PROCESSING.value, 0),
|
||
"completed": status_counts.get(TaskStatus.COMPLETED.value, 0),
|
||
"failed": status_counts.get(TaskStatus.FAILED.value, 0),
|
||
"paused": status_counts.get(TaskStatus.PAUSED.value, 0)
|
||
}
|
||
|
||
return stats
|
||
|
||
def delete_task(self, task_id):
|
||
"""删除任务(先自动终止再删除)"""
|
||
with self.lock:
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 检查任务是否存在
|
||
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
|
||
row = cursor.fetchone()
|
||
|
||
if not row:
|
||
return False
|
||
|
||
# 如果任务还在运行,先终止
|
||
if row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
|
||
cursor.execute('''
|
||
UPDATE tasks SET
|
||
status = ?,
|
||
error = ?,
|
||
current_step = ?,
|
||
completed_at = ?
|
||
WHERE task_id = ?
|
||
''', (
|
||
TaskStatus.FAILED.value,
|
||
"任务已被用户删除",
|
||
"任务已终止",
|
||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||
task_id
|
||
))
|
||
conn.commit()
|
||
logger.info(f"终止任务: {task_id}")
|
||
|
||
# 然后从数据库中删除
|
||
cursor.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,))
|
||
conn.commit()
|
||
|
||
if cursor.rowcount > 0:
|
||
logger.info(f"删除任务: {task_id}")
|
||
return True
|
||
return False
|
||
|
||
def cancel_task(self, task_id):
|
||
"""终止任务(将等待中或处理中任务标记为失败)"""
|
||
with self.lock:
|
||
with self.db.get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 检查任务状态
|
||
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
|
||
row = cursor.fetchone()
|
||
|
||
if row and row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
|
||
cursor.execute('''
|
||
UPDATE tasks SET
|
||
status = ?,
|
||
error = ?,
|
||
current_step = ?,
|
||
completed_at = ?
|
||
WHERE task_id = ?
|
||
''', (
|
||
TaskStatus.FAILED.value,
|
||
"任务已被用户终止",
|
||
"任务已终止",
|
||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||
task_id
|
||
))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
return False
|
||
|
||
# 全局队列实例
|
||
_task_queue = None
|
||
|
||
|
||
def get_task_queue():
|
||
"""获取全局任务队列实例"""
|
||
global _task_queue
|
||
if _task_queue is None:
|
||
_task_queue = TaskQueue()
|
||
return _task_queue
|