# -*- coding: utf-8 -*- """ 任务队列管理模块 支持离线处理、进度跟踪、结果导出 使用 SQLite 数据库存储(替代原 JSON 文件) """ import os import threading import time from datetime import datetime from enum import Enum import logging from database import get_database, migrate_from_json logger = logging.getLogger(__name__) class TaskStatus(Enum): """任务状态""" PENDING = "pending" # 就绪(准备好了,等待工作线程) PROCESSING = "processing" # 进行中 COMPLETED = "completed" # 完成 FAILED = "failed" # 失败 PAUSED = "paused" # 暂停(将在指定时间后自动恢复) class TaskQueue: """任务队列管理器(使用 SQLite 数据库)""" def __init__(self, queue_file="data/task_queue.json", results_dir="data/results"): self.results_dir = results_dir self.lock = threading.Lock() self.db = get_database() self._ensure_dirs() # 从旧 JSON 文件迁移数据(只执行一次) if os.path.exists(queue_file): migrate_from_json(queue_file) def _ensure_dirs(self): """确保必要的目录存在""" os.makedirs(self.results_dir, exist_ok=True) def add_task(self, url, months=6, use_proxy=False, proxy_api_url=None, username=None, articles_only=True): """添加新任务到队列 Args: url: 百家号URL months: 获取月数 use_proxy: 是否使用代理 proxy_api_url: 代理API地址 username: 用户名 articles_only: 是否仅爬取文章(跳过视频) Returns: task_id: 任务ID """ with self.lock: task_id = f"task_{int(time.time() * 1000)}" created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with self.db.get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO tasks ( task_id, url, months, use_proxy, proxy_api_url, username, status, created_at, progress, current_step, total_articles, processed_articles, articles_only ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( task_id, url, months, 1 if use_proxy else 0, proxy_api_url, username, TaskStatus.PENDING.value, created_at, 0, "等待处理", 0, 0, 1 if articles_only else 0 )) conn.commit() logger.info(f"添加任务: {task_id}") return task_id def get_task(self, task_id): """获取任务信息""" with self.lock: with self.db.get_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT * FROM tasks WHERE task_id = ?", (task_id,) ) row = cursor.fetchone() if row: task = dict(row) # 将 use_proxy 从整数转换为布尔值 task['use_proxy'] = bool(task['use_proxy']) # 将 articles_only 从整数转换为布尔值 task['articles_only'] = bool(task.get('articles_only', 1)) return task return None def get_all_tasks(self, username=None): """获取所有任务(可按用户过滤)""" with self.lock: with self.db.get_connection() as conn: cursor = conn.cursor() if username: cursor.execute( "SELECT * FROM tasks WHERE username = ? ORDER BY created_at DESC", (username,) ) else: cursor.execute("SELECT * FROM tasks ORDER BY created_at DESC") rows = cursor.fetchall() tasks = [] for row in rows: task = dict(row) # 将 use_proxy 从整数转换为布尔值 task['use_proxy'] = bool(task['use_proxy']) # 将 articles_only 从整数转换为布尔值 task['articles_only'] = bool(task.get('articles_only', 1)) tasks.append(task) return tasks def get_pending_task(self): """获取下一个待处理的任务(包括检查暂停任务是否可恢复)""" with self.lock: with self.db.get_connection() as conn: cursor = conn.cursor() # 首先检查是否有暂停任务需要恢复 from datetime import datetime, timedelta current_time = datetime.now() cursor.execute( "SELECT * FROM tasks WHERE status = ? ORDER BY paused_at ASC", (TaskStatus.PAUSED.value,) ) paused_tasks = cursor.fetchall() for row in paused_tasks: paused_at_str = row['paused_at'] if paused_at_str: paused_at = datetime.strptime(paused_at_str, '%Y-%m-%d %H:%M:%S') # 检查是否已经暂停超过10分钟 if current_time - paused_at >= timedelta(minutes=10): task_id = row['task_id'] # 恢复任务为待处理状态(保留 last_page 和 last_ctime) cursor.execute(""" UPDATE tasks SET status = ?, current_step = ?, retry_count = 0, paused_at = NULL WHERE task_id = ? """, (TaskStatus.PENDING.value, "等待处理(从断点继续)", task_id)) conn.commit() logger.info(f"任务 {task_id} 已从暂停状态恢复,将从第{row.get('last_page', 1)}页继续") # 获取待处理任务 cursor.execute( "SELECT * FROM tasks WHERE status = ? ORDER BY created_at ASC LIMIT 1", (TaskStatus.PENDING.value,) ) row = cursor.fetchone() if row: task = dict(row) # 将 use_proxy 从整数转换为布尔值 task['use_proxy'] = bool(task['use_proxy']) # 将 articles_only 从整数转换为布尔值 task['articles_only'] = bool(task.get('articles_only', 1)) return task return None def update_task_status(self, task_id, status, **kwargs): """更新任务状态 Args: task_id: 任务ID status: 新状态 **kwargs: 其他要更新的字段 """ with self.lock: status_value = status.value if isinstance(status, TaskStatus) else status # 准备更新字段 update_fields = {"status": status_value} # 更新时间戳 if status == TaskStatus.PROCESSING: update_fields["started_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]: update_fields["completed_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 合并其他字段 update_fields.update(kwargs) # 构建 SQL 更新语句 set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()]) values = list(update_fields.values()) + [task_id] with self.db.get_connection() as conn: cursor = conn.cursor() cursor.execute( f"UPDATE tasks SET {set_clause} WHERE task_id = ?", values ) conn.commit() return cursor.rowcount > 0 def update_task_progress(self, task_id, progress, current_step=None, processed_articles=None): """更新任务进度 Args: task_id: 任务ID progress: 进度百分比 (0-100) current_step: 当前步骤描述 processed_articles: 已处理文章数 """ with self.lock: update_fields = { "progress": min(100, max(0, progress)) } if current_step is not None: update_fields["current_step"] = current_step if processed_articles is not None: update_fields["processed_articles"] = processed_articles set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()]) values = list(update_fields.values()) + [task_id] with self.db.get_connection() as conn: cursor = conn.cursor() cursor.execute( f"UPDATE tasks SET {set_clause} WHERE task_id = ?", values ) conn.commit() return cursor.rowcount > 0 def get_queue_stats(self, username=None): """获取队列统计信息""" with self.lock: with self.db.get_connection() as conn: cursor = conn.cursor() # 基础查询 if username: base_query = "SELECT status, COUNT(*) as count FROM tasks WHERE username = ? GROUP BY status" cursor.execute(base_query, (username,)) else: base_query = "SELECT status, COUNT(*) as count FROM tasks GROUP BY status" cursor.execute(base_query) # 统计各状态数量 status_counts = {row["status"]: row["count"] for row in cursor.fetchall()} # 获取总数 if username: cursor.execute("SELECT COUNT(*) as total FROM tasks WHERE username = ?", (username,)) else: cursor.execute("SELECT COUNT(*) as total FROM tasks") total = cursor.fetchone()["total"] stats = { "total": total, "pending": status_counts.get(TaskStatus.PENDING.value, 0), "processing": status_counts.get(TaskStatus.PROCESSING.value, 0), "completed": status_counts.get(TaskStatus.COMPLETED.value, 0), "failed": status_counts.get(TaskStatus.FAILED.value, 0), "paused": status_counts.get(TaskStatus.PAUSED.value, 0) } return stats def delete_task(self, task_id): """删除任务(先自动终止再删除)""" with self.lock: with self.db.get_connection() as conn: cursor = conn.cursor() # 检查任务是否存在 cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,)) row = cursor.fetchone() if not row: return False # 如果任务还在运行,先终止 if row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]: cursor.execute(''' UPDATE tasks SET status = ?, error = ?, current_step = ?, completed_at = ? WHERE task_id = ? ''', ( TaskStatus.FAILED.value, "任务已被用户删除", "任务已终止", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), task_id )) conn.commit() logger.info(f"终止任务: {task_id}") # 然后从数据库中删除 cursor.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,)) conn.commit() if cursor.rowcount > 0: logger.info(f"删除任务: {task_id}") return True return False def cancel_task(self, task_id): """终止任务(将等待中或处理中任务标记为失败)""" with self.lock: with self.db.get_connection() as conn: cursor = conn.cursor() # 检查任务状态 cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,)) row = cursor.fetchone() if row and row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]: cursor.execute(''' UPDATE tasks SET status = ?, error = ?, current_step = ?, completed_at = ? WHERE task_id = ? ''', ( TaskStatus.FAILED.value, "任务已被用户终止", "任务已终止", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), task_id )) conn.commit() return cursor.rowcount > 0 return False # 全局队列实例 _task_queue = None def get_task_queue(): """获取全局任务队列实例""" global _task_queue if _task_queue is None: _task_queue = TaskQueue() return _task_queue