Files
ai_baijiahao/task_queue.py

358 lines
14 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
"""
任务队列管理模块
支持离线处理进度跟踪结果导出
使用 SQLite 数据库存储替代原 JSON 文件
"""
import os
import threading
import time
from datetime import datetime
from enum import Enum
import logging
from database import get_database, migrate_from_json
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending" # 就绪(准备好了,等待工作线程)
PROCESSING = "processing" # 进行中
COMPLETED = "completed" # 完成
FAILED = "failed" # 失败
PAUSED = "paused" # 暂停(将在指定时间后自动恢复)
class TaskQueue:
"""任务队列管理器(使用 SQLite 数据库)"""
def __init__(self, queue_file="data/task_queue.json", results_dir="data/results"):
self.results_dir = results_dir
self.lock = threading.Lock()
self.db = get_database()
self._ensure_dirs()
# 从旧 JSON 文件迁移数据(只执行一次)
if os.path.exists(queue_file):
migrate_from_json(queue_file)
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.results_dir, exist_ok=True)
def add_task(self, url, months=6, use_proxy=False, proxy_api_url=None, username=None, articles_only=True):
"""添加新任务到队列
Args:
url: 百家号URL
months: 获取月数
use_proxy: 是否使用代理
proxy_api_url: 代理API地址
username: 用户名
articles_only: 是否仅爬取文章跳过视频
Returns:
task_id: 任务ID
"""
with self.lock:
task_id = f"task_{int(time.time() * 1000)}"
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO tasks (
task_id, url, months, use_proxy, proxy_api_url,
username, status, created_at, progress, current_step,
total_articles, processed_articles, articles_only
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
task_id, url, months, 1 if use_proxy else 0, proxy_api_url,
username, TaskStatus.PENDING.value, created_at, 0, "等待处理",
0, 0, 1 if articles_only else 0
))
conn.commit()
logger.info(f"添加任务: {task_id}")
return task_id
def get_task(self, task_id):
"""获取任务信息"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM tasks WHERE task_id = ?",
(task_id,)
)
row = cursor.fetchone()
if row:
task = dict(row)
# 将 use_proxy 从整数转换为布尔值
task['use_proxy'] = bool(task['use_proxy'])
# 将 articles_only 从整数转换为布尔值
task['articles_only'] = bool(task.get('articles_only', 1))
return task
return None
def get_all_tasks(self, username=None):
"""获取所有任务(可按用户过滤)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
if username:
cursor.execute(
"SELECT * FROM tasks WHERE username = ? ORDER BY created_at DESC",
(username,)
)
else:
cursor.execute("SELECT * FROM tasks ORDER BY created_at DESC")
rows = cursor.fetchall()
tasks = []
for row in rows:
task = dict(row)
# 将 use_proxy 从整数转换为布尔值
task['use_proxy'] = bool(task['use_proxy'])
# 将 articles_only 从整数转换为布尔值
task['articles_only'] = bool(task.get('articles_only', 1))
tasks.append(task)
return tasks
def get_pending_task(self):
"""获取下一个待处理的任务(包括检查暂停任务是否可恢复)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 首先检查是否有暂停任务需要恢复
from datetime import datetime, timedelta
current_time = datetime.now()
cursor.execute(
"SELECT * FROM tasks WHERE status = ? ORDER BY paused_at ASC",
(TaskStatus.PAUSED.value,)
)
paused_tasks = cursor.fetchall()
for row in paused_tasks:
paused_at_str = row['paused_at']
if paused_at_str:
paused_at = datetime.strptime(paused_at_str, '%Y-%m-%d %H:%M:%S')
# 检查是否已经暂停超过10分钟
if current_time - paused_at >= timedelta(minutes=10):
task_id = row['task_id']
# 恢复任务为待处理状态(保留 last_page 和 last_ctime
cursor.execute("""
UPDATE tasks SET
status = ?,
current_step = ?,
retry_count = 0,
paused_at = NULL
WHERE task_id = ?
""", (TaskStatus.PENDING.value, "等待处理(从断点继续)", task_id))
conn.commit()
logger.info(f"任务 {task_id} 已从暂停状态恢复,将从第{row.get('last_page', 1)}页继续")
# 获取待处理任务
cursor.execute(
"SELECT * FROM tasks WHERE status = ? ORDER BY created_at ASC LIMIT 1",
(TaskStatus.PENDING.value,)
)
row = cursor.fetchone()
if row:
task = dict(row)
# 将 use_proxy 从整数转换为布尔值
task['use_proxy'] = bool(task['use_proxy'])
# 将 articles_only 从整数转换为布尔值
task['articles_only'] = bool(task.get('articles_only', 1))
return task
return None
def update_task_status(self, task_id, status, **kwargs):
"""更新任务状态
Args:
task_id: 任务ID
status: 新状态
**kwargs: 其他要更新的字段
"""
with self.lock:
status_value = status.value if isinstance(status, TaskStatus) else status
# 准备更新字段
update_fields = {"status": status_value}
# 更新时间戳
if status == TaskStatus.PROCESSING:
update_fields["started_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
update_fields["completed_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 合并其他字段
update_fields.update(kwargs)
# 构建 SQL 更新语句
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
values = list(update_fields.values()) + [task_id]
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
values
)
conn.commit()
return cursor.rowcount > 0
def update_task_progress(self, task_id, progress, current_step=None, processed_articles=None):
"""更新任务进度
Args:
task_id: 任务ID
progress: 进度百分比 (0-100)
current_step: 当前步骤描述
processed_articles: 已处理文章数
"""
with self.lock:
update_fields = {
"progress": min(100, max(0, progress))
}
if current_step is not None:
update_fields["current_step"] = current_step
if processed_articles is not None:
update_fields["processed_articles"] = processed_articles
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
values = list(update_fields.values()) + [task_id]
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
values
)
conn.commit()
return cursor.rowcount > 0
def get_queue_stats(self, username=None):
"""获取队列统计信息"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 基础查询
if username:
base_query = "SELECT status, COUNT(*) as count FROM tasks WHERE username = ? GROUP BY status"
cursor.execute(base_query, (username,))
else:
base_query = "SELECT status, COUNT(*) as count FROM tasks GROUP BY status"
cursor.execute(base_query)
# 统计各状态数量
status_counts = {row["status"]: row["count"] for row in cursor.fetchall()}
# 获取总数
if username:
cursor.execute("SELECT COUNT(*) as total FROM tasks WHERE username = ?", (username,))
else:
cursor.execute("SELECT COUNT(*) as total FROM tasks")
total = cursor.fetchone()["total"]
stats = {
"total": total,
"pending": status_counts.get(TaskStatus.PENDING.value, 0),
"processing": status_counts.get(TaskStatus.PROCESSING.value, 0),
"completed": status_counts.get(TaskStatus.COMPLETED.value, 0),
"failed": status_counts.get(TaskStatus.FAILED.value, 0),
"paused": status_counts.get(TaskStatus.PAUSED.value, 0)
}
return stats
def delete_task(self, task_id):
"""删除任务(先自动终止再删除)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 检查任务是否存在
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
row = cursor.fetchone()
if not row:
return False
# 如果任务还在运行,先终止
if row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
cursor.execute('''
UPDATE tasks SET
status = ?,
error = ?,
current_step = ?,
completed_at = ?
WHERE task_id = ?
''', (
TaskStatus.FAILED.value,
"任务已被用户删除",
"任务已终止",
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
task_id
))
conn.commit()
logger.info(f"终止任务: {task_id}")
# 然后从数据库中删除
cursor.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,))
conn.commit()
if cursor.rowcount > 0:
logger.info(f"删除任务: {task_id}")
return True
return False
def cancel_task(self, task_id):
"""终止任务(将等待中或处理中任务标记为失败)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 检查任务状态
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
row = cursor.fetchone()
if row and row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
cursor.execute('''
UPDATE tasks SET
status = ?,
error = ?,
current_step = ?,
completed_at = ?
WHERE task_id = ?
''', (
TaskStatus.FAILED.value,
"任务已被用户终止",
"任务已终止",
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
task_id
))
conn.commit()
return cursor.rowcount > 0
return False
# 全局队列实例
_task_queue = None
def get_task_queue():
"""获取全局任务队列实例"""
global _task_queue
if _task_queue is None:
_task_queue = TaskQueue()
return _task_queue