Files
ai_baijiahao/task_queue.py

358 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
任务队列管理模块
支持离线处理、进度跟踪、结果导出
使用 SQLite 数据库存储(替代原 JSON 文件)
"""
import os
import threading
import time
from datetime import datetime
from enum import Enum
import logging
from database import get_database, migrate_from_json
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending" # 就绪(准备好了,等待工作线程)
PROCESSING = "processing" # 进行中
COMPLETED = "completed" # 完成
FAILED = "failed" # 失败
PAUSED = "paused" # 暂停(将在指定时间后自动恢复)
class TaskQueue:
"""任务队列管理器(使用 SQLite 数据库)"""
def __init__(self, queue_file="data/task_queue.json", results_dir="data/results"):
self.results_dir = results_dir
self.lock = threading.Lock()
self.db = get_database()
self._ensure_dirs()
# 从旧 JSON 文件迁移数据(只执行一次)
if os.path.exists(queue_file):
migrate_from_json(queue_file)
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.results_dir, exist_ok=True)
def add_task(self, url, months=6, use_proxy=False, proxy_api_url=None, username=None, articles_only=True):
"""添加新任务到队列
Args:
url: 百家号URL
months: 获取月数
use_proxy: 是否使用代理
proxy_api_url: 代理API地址
username: 用户名
articles_only: 是否仅爬取文章(跳过视频)
Returns:
task_id: 任务ID
"""
with self.lock:
task_id = f"task_{int(time.time() * 1000)}"
created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO tasks (
task_id, url, months, use_proxy, proxy_api_url,
username, status, created_at, progress, current_step,
total_articles, processed_articles, articles_only
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
task_id, url, months, 1 if use_proxy else 0, proxy_api_url,
username, TaskStatus.PENDING.value, created_at, 0, "等待处理",
0, 0, 1 if articles_only else 0
))
conn.commit()
logger.info(f"添加任务: {task_id}")
return task_id
def get_task(self, task_id):
"""获取任务信息"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM tasks WHERE task_id = ?",
(task_id,)
)
row = cursor.fetchone()
if row:
task = dict(row)
# 将 use_proxy 从整数转换为布尔值
task['use_proxy'] = bool(task['use_proxy'])
# 将 articles_only 从整数转换为布尔值
task['articles_only'] = bool(task.get('articles_only', 1))
return task
return None
def get_all_tasks(self, username=None):
"""获取所有任务(可按用户过滤)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
if username:
cursor.execute(
"SELECT * FROM tasks WHERE username = ? ORDER BY created_at DESC",
(username,)
)
else:
cursor.execute("SELECT * FROM tasks ORDER BY created_at DESC")
rows = cursor.fetchall()
tasks = []
for row in rows:
task = dict(row)
# 将 use_proxy 从整数转换为布尔值
task['use_proxy'] = bool(task['use_proxy'])
# 将 articles_only 从整数转换为布尔值
task['articles_only'] = bool(task.get('articles_only', 1))
tasks.append(task)
return tasks
def get_pending_task(self):
"""获取下一个待处理的任务(包括检查暂停任务是否可恢复)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 首先检查是否有暂停任务需要恢复
from datetime import datetime, timedelta
current_time = datetime.now()
cursor.execute(
"SELECT * FROM tasks WHERE status = ? ORDER BY paused_at ASC",
(TaskStatus.PAUSED.value,)
)
paused_tasks = cursor.fetchall()
for row in paused_tasks:
paused_at_str = row['paused_at']
if paused_at_str:
paused_at = datetime.strptime(paused_at_str, '%Y-%m-%d %H:%M:%S')
# 检查是否已经暂停超过10分钟
if current_time - paused_at >= timedelta(minutes=10):
task_id = row['task_id']
# 恢复任务为待处理状态(保留 last_page 和 last_ctime
cursor.execute("""
UPDATE tasks SET
status = ?,
current_step = ?,
retry_count = 0,
paused_at = NULL
WHERE task_id = ?
""", (TaskStatus.PENDING.value, "等待处理(从断点继续)", task_id))
conn.commit()
logger.info(f"任务 {task_id} 已从暂停状态恢复,将从第{row.get('last_page', 1)}页继续")
# 获取待处理任务
cursor.execute(
"SELECT * FROM tasks WHERE status = ? ORDER BY created_at ASC LIMIT 1",
(TaskStatus.PENDING.value,)
)
row = cursor.fetchone()
if row:
task = dict(row)
# 将 use_proxy 从整数转换为布尔值
task['use_proxy'] = bool(task['use_proxy'])
# 将 articles_only 从整数转换为布尔值
task['articles_only'] = bool(task.get('articles_only', 1))
return task
return None
def update_task_status(self, task_id, status, **kwargs):
"""更新任务状态
Args:
task_id: 任务ID
status: 新状态
**kwargs: 其他要更新的字段
"""
with self.lock:
status_value = status.value if isinstance(status, TaskStatus) else status
# 准备更新字段
update_fields = {"status": status_value}
# 更新时间戳
if status == TaskStatus.PROCESSING:
update_fields["started_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
update_fields["completed_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 合并其他字段
update_fields.update(kwargs)
# 构建 SQL 更新语句
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
values = list(update_fields.values()) + [task_id]
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
values
)
conn.commit()
return cursor.rowcount > 0
def update_task_progress(self, task_id, progress, current_step=None, processed_articles=None):
"""更新任务进度
Args:
task_id: 任务ID
progress: 进度百分比 (0-100)
current_step: 当前步骤描述
processed_articles: 已处理文章数
"""
with self.lock:
update_fields = {
"progress": min(100, max(0, progress))
}
if current_step is not None:
update_fields["current_step"] = current_step
if processed_articles is not None:
update_fields["processed_articles"] = processed_articles
set_clause = ", ".join([f"{key} = ?" for key in update_fields.keys()])
values = list(update_fields.values()) + [task_id]
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"UPDATE tasks SET {set_clause} WHERE task_id = ?",
values
)
conn.commit()
return cursor.rowcount > 0
def get_queue_stats(self, username=None):
"""获取队列统计信息"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 基础查询
if username:
base_query = "SELECT status, COUNT(*) as count FROM tasks WHERE username = ? GROUP BY status"
cursor.execute(base_query, (username,))
else:
base_query = "SELECT status, COUNT(*) as count FROM tasks GROUP BY status"
cursor.execute(base_query)
# 统计各状态数量
status_counts = {row["status"]: row["count"] for row in cursor.fetchall()}
# 获取总数
if username:
cursor.execute("SELECT COUNT(*) as total FROM tasks WHERE username = ?", (username,))
else:
cursor.execute("SELECT COUNT(*) as total FROM tasks")
total = cursor.fetchone()["total"]
stats = {
"total": total,
"pending": status_counts.get(TaskStatus.PENDING.value, 0),
"processing": status_counts.get(TaskStatus.PROCESSING.value, 0),
"completed": status_counts.get(TaskStatus.COMPLETED.value, 0),
"failed": status_counts.get(TaskStatus.FAILED.value, 0),
"paused": status_counts.get(TaskStatus.PAUSED.value, 0)
}
return stats
def delete_task(self, task_id):
"""删除任务(先自动终止再删除)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 检查任务是否存在
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
row = cursor.fetchone()
if not row:
return False
# 如果任务还在运行,先终止
if row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
cursor.execute('''
UPDATE tasks SET
status = ?,
error = ?,
current_step = ?,
completed_at = ?
WHERE task_id = ?
''', (
TaskStatus.FAILED.value,
"任务已被用户删除",
"任务已终止",
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
task_id
))
conn.commit()
logger.info(f"终止任务: {task_id}")
# 然后从数据库中删除
cursor.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,))
conn.commit()
if cursor.rowcount > 0:
logger.info(f"删除任务: {task_id}")
return True
return False
def cancel_task(self, task_id):
"""终止任务(将等待中或处理中任务标记为失败)"""
with self.lock:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# 检查任务状态
cursor.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
row = cursor.fetchone()
if row and row["status"] in [TaskStatus.PENDING.value, TaskStatus.PROCESSING.value]:
cursor.execute('''
UPDATE tasks SET
status = ?,
error = ?,
current_step = ?,
completed_at = ?
WHERE task_id = ?
''', (
TaskStatus.FAILED.value,
"任务已被用户终止",
"任务已终止",
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
task_id
))
conn.commit()
return cursor.rowcount > 0
return False
# 全局队列实例
_task_queue = None
def get_task_queue():
"""获取全局任务队列实例"""
global _task_queue
if _task_queue is None:
_task_queue = TaskQueue()
return _task_queue