Initial commit: 百家号文章采集系统
This commit is contained in:
487
task_worker.py
Normal file
487
task_worker.py
Normal file
@@ -0,0 +1,487 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
任务处理器 - 后台并发处理队列中的任务
|
||||
支持动态调整并发数,通过 SocketIO 实时推送进度和日志
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
import traceback
|
||||
import psutil
|
||||
from task_queue import get_task_queue, TaskStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局变量,用于存储 socketio 实例
|
||||
_socketio_instance = None
|
||||
|
||||
def set_socketio(socketio):
|
||||
"""设置 SocketIO 实例"""
|
||||
global _socketio_instance
|
||||
_socketio_instance = socketio
|
||||
|
||||
def emit_log(task_id, message, level='info'):
|
||||
"""保存日志到数据库"""
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 保存到数据库
|
||||
try:
|
||||
from database import get_database
|
||||
db = get_database()
|
||||
db.add_task_log(task_id, message, level, timestamp)
|
||||
except Exception as e:
|
||||
logger.error(f"保存日志到数据库失败: {e}")
|
||||
|
||||
logger.info(f"[{task_id}] {message}")
|
||||
|
||||
def emit_progress(task_id, progress, current_step='', **kwargs):
|
||||
"""更新任务进度"""
|
||||
logger.info(f"[{task_id}] 进度: {progress}% - {current_step}")
|
||||
|
||||
|
||||
class TaskWorker:
|
||||
"""任务处理工作线程池(支持动态并发)"""
|
||||
|
||||
def __init__(self, min_workers=1, max_workers=3):
|
||||
self.queue = get_task_queue()
|
||||
self.running = False
|
||||
self.min_workers = min_workers # 最小并发数
|
||||
self.max_workers = max_workers # 最大并发数
|
||||
self.current_workers = min_workers # 当前并发数
|
||||
self.worker_threads = [] # 工作线程列表
|
||||
self.processing_tasks = set() # 正在处理的任务ID集合
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
"""启动工作线程池"""
|
||||
if self.running:
|
||||
logger.warning("工作线程池已经在运行")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
# 启动初始工作线程
|
||||
for i in range(self.min_workers):
|
||||
self._start_worker(i)
|
||||
|
||||
# 启动动态调整线程
|
||||
self.adjust_thread = threading.Thread(target=self._adjust_workers, daemon=True)
|
||||
self.adjust_thread.start()
|
||||
|
||||
logger.info(f"任务处理器已启动(初始并发数: {self.min_workers},最大并发数: {self.max_workers})")
|
||||
|
||||
def _start_worker(self, worker_id):
|
||||
"""启动一个工作线程"""
|
||||
thread = threading.Thread(target=self._work_loop, args=(worker_id,), daemon=True)
|
||||
thread.start()
|
||||
with self.lock:
|
||||
self.worker_threads.append(thread)
|
||||
logger.info(f"工作线程 #{worker_id} 已启动")
|
||||
|
||||
def stop(self):
|
||||
"""停止工作线程池"""
|
||||
self.running = False
|
||||
for thread in self.worker_threads:
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=5)
|
||||
logger.info("任务处理器已停止")
|
||||
|
||||
def _work_loop(self, worker_id):
|
||||
"""工作循环(单个线程)"""
|
||||
logger.info(f"工作线程 #{worker_id} 进入循环")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 获取待处理任务
|
||||
task = self.queue.get_pending_task()
|
||||
|
||||
if task:
|
||||
task_id = task["task_id"]
|
||||
|
||||
# 检查是否已经有其他线程在处理这个任务
|
||||
with self.lock:
|
||||
if task_id in self.processing_tasks:
|
||||
continue
|
||||
self.processing_tasks.add(task_id)
|
||||
|
||||
try:
|
||||
logger.info(f"工作线程 #{worker_id} 开始处理任务: {task_id}")
|
||||
self._process_task(task, worker_id)
|
||||
finally:
|
||||
# 处理完成后从集合中移除
|
||||
with self.lock:
|
||||
self.processing_tasks.discard(task_id)
|
||||
else:
|
||||
# 没有任务,休息一会
|
||||
time.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作线程 #{worker_id} 错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
time.sleep(5)
|
||||
|
||||
logger.info(f"工作线程 #{worker_id} 退出循环")
|
||||
|
||||
def _adjust_workers(self):
|
||||
"""动态调整工作线程数量"""
|
||||
logger.info("动态调整线程已启动")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
time.sleep(10) # 每10秒检查一次
|
||||
|
||||
# 获取系统资源信息
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory_percent = psutil.virtual_memory().percent
|
||||
|
||||
# 获取队列信息
|
||||
pending_count = len([t for t in self.queue.get_all_tasks() if t.get('status') == 'pending'])
|
||||
processing_count = len(self.processing_tasks)
|
||||
|
||||
# 决策逻辑
|
||||
target_workers = self._calculate_target_workers(
|
||||
pending_count,
|
||||
processing_count,
|
||||
cpu_percent,
|
||||
memory_percent
|
||||
)
|
||||
|
||||
# 调整线程数
|
||||
if target_workers > self.current_workers:
|
||||
# 增加线程
|
||||
for i in range(self.current_workers, target_workers):
|
||||
self._start_worker(i)
|
||||
logger.info(f"增加工作线程: {self.current_workers} -> {target_workers}")
|
||||
self.current_workers = target_workers
|
||||
elif target_workers < self.current_workers:
|
||||
# 减少线程(自然退出,不强制终止)
|
||||
logger.info(f"准备减少工作线程: {self.current_workers} -> {target_workers}")
|
||||
self.current_workers = target_workers
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调整线程数错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
time.sleep(30)
|
||||
|
||||
def _calculate_target_workers(self, pending_count, processing_count, cpu_percent, memory_percent):
|
||||
"""计算目标线程数"""
|
||||
# 基本逻辑:
|
||||
# 1. 如果没有待处理任务,保持最小线程数
|
||||
# 2. 如果有很多待处理任务,且系统资源充足,增加线程
|
||||
# 3. 如果系统资源紧张,减少线程
|
||||
|
||||
if pending_count == 0:
|
||||
return self.min_workers
|
||||
|
||||
# 系统资源紧张(CPU>80% 或 内存>85%)
|
||||
if cpu_percent > 80 or memory_percent > 85:
|
||||
logger.warning(f"系统资源紧张 (CPU: {cpu_percent}%, 内存: {memory_percent}%)")
|
||||
return max(self.min_workers, self.current_workers - 1)
|
||||
|
||||
# 系统资源充足
|
||||
if cpu_percent < 50 and memory_percent < 70:
|
||||
# 根据待处理任务数决定线程数
|
||||
if pending_count >= 3:
|
||||
return min(self.max_workers, self.current_workers + 1)
|
||||
elif pending_count >= 1:
|
||||
return min(self.max_workers, max(2, self.current_workers))
|
||||
|
||||
# 默认保持当前线程数
|
||||
return self.current_workers
|
||||
|
||||
def _process_task(self, task, worker_id):
|
||||
"""处理单个任务"""
|
||||
task_id = task["task_id"]
|
||||
logger.info(f"工作线程 #{worker_id} 开始处理任务: {task_id}")
|
||||
emit_log(task_id, f"任务开始处理 (Worker #{worker_id})")
|
||||
|
||||
# 获取当前重试次数
|
||||
retry_count = task.get("retry_count", 0)
|
||||
|
||||
try:
|
||||
# 更新状态为处理中
|
||||
self.queue.update_task_status(
|
||||
task_id,
|
||||
TaskStatus.PROCESSING,
|
||||
current_step="准备处理"
|
||||
)
|
||||
emit_progress(task_id, 5, "准备处理")
|
||||
|
||||
# 导入必要的模块
|
||||
from app import BaijiahaoScraper
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
# 步骤1: 解析URL获取UK
|
||||
self.queue.update_task_progress(task_id, 10, "解析URL获取UK")
|
||||
emit_progress(task_id, 10, "解析URL获取UK")
|
||||
emit_log(task_id, f"URL: {task['url']}")
|
||||
|
||||
url = task["url"]
|
||||
use_proxy = task.get("use_proxy", False)
|
||||
proxy_api_url = task.get("proxy_api_url")
|
||||
articles_only = task.get("articles_only", True) # 获取是否仅爬取文章
|
||||
|
||||
if use_proxy:
|
||||
emit_log(task_id, "已启用代理IP池", "info")
|
||||
|
||||
if articles_only:
|
||||
emit_log(task_id, "已启用文章过滤(跳过视频内容)", "info")
|
||||
|
||||
# 提取app_id
|
||||
import re
|
||||
app_id_match = re.search(r'app_id=([^&\s]+)', url)
|
||||
if not app_id_match:
|
||||
raise Exception("无法从 URL 中提取 app_id")
|
||||
|
||||
app_id = app_id_match.group(1)
|
||||
emit_log(task_id, f"解析到 app_id: {app_id}")
|
||||
|
||||
# 获取UK
|
||||
emit_log(task_id, "正在获取用户 UK...")
|
||||
try:
|
||||
uk, cookies = BaijiahaoScraper.get_uk_from_app_id(
|
||||
app_id,
|
||||
use_proxy=use_proxy,
|
||||
proxy_api_url=proxy_api_url
|
||||
)
|
||||
emit_log(task_id, f"成功获取 UK: {uk[:20]}...")
|
||||
except Exception as uk_error:
|
||||
emit_log(task_id, f"获取UK失败: {str(uk_error)}", "error")
|
||||
raise
|
||||
|
||||
# 步骤2: 初始化爬虫
|
||||
self.queue.update_task_progress(task_id, 20, "初始化爬虫")
|
||||
emit_progress(task_id, 20, "初始化爬虫")
|
||||
emit_log(task_id, "初始化爬虫实例...")
|
||||
|
||||
scraper = BaijiahaoScraper(
|
||||
uk=uk,
|
||||
cookies=cookies,
|
||||
use_proxy=use_proxy,
|
||||
proxy_api_url=proxy_api_url
|
||||
)
|
||||
|
||||
# 步骤3: 获取文章列表
|
||||
months = task.get("months", 6)
|
||||
|
||||
# 检查是否有断点续传数据
|
||||
last_page = task.get("last_page", 0)
|
||||
last_ctime = task.get("last_ctime")
|
||||
start_page = 1
|
||||
start_ctime = None
|
||||
|
||||
if last_page > 0 and last_ctime:
|
||||
# 断点续传
|
||||
start_page = last_page
|
||||
start_ctime = last_ctime
|
||||
emit_log(task_id, f"🔄 检测到断点数据,从第{start_page}页继续爬取", "info")
|
||||
|
||||
# 检查缓存中是否有数据
|
||||
from database import get_database
|
||||
db = get_database()
|
||||
cached_count = db.get_cached_article_count(task_id)
|
||||
if cached_count > 0:
|
||||
emit_log(task_id, f"💾 已缓存 {cached_count} 篇文章,将继续爬取...", "info")
|
||||
else:
|
||||
# 新任务,清除之前的缓存(如果有)
|
||||
from database import get_database
|
||||
db = get_database()
|
||||
db.clear_article_cache(task_id)
|
||||
|
||||
self.queue.update_task_progress(task_id, 30, f"获取文章列表(近{months}个月)")
|
||||
emit_progress(task_id, 30, f"获取文章列表(近{months}个月)")
|
||||
emit_log(task_id, f"开始获取近 {months} 个月的文章...")
|
||||
emit_log(task_id, "提示:抓取过程较慢(8-12秒/页),请耐心等待...", "info")
|
||||
emit_log(task_id, "系统正在使用代理IP池抓取数据,过程中会自动切换IP应对反爬...", "info")
|
||||
|
||||
# 定义保存回调函数:每页数据立即保存到数据库
|
||||
from database import get_database
|
||||
db = get_database()
|
||||
|
||||
def save_page_to_db(page, articles, ctime):
|
||||
"""保存每页数据到数据库缓存"""
|
||||
if articles:
|
||||
db.save_articles_batch(task_id, articles, page)
|
||||
emit_log(task_id, f"💾 第{page}页数据已保存,{len(articles)}篇文章", "success")
|
||||
|
||||
# 更新任务的断点信息
|
||||
self.queue.update_task_status(
|
||||
task_id,
|
||||
TaskStatus.PROCESSING,
|
||||
last_page=page,
|
||||
last_ctime=ctime
|
||||
)
|
||||
|
||||
# 更新进度(粗略估计,30-80%区间)
|
||||
total_cached = db.get_cached_article_count(task_id)
|
||||
progress = min(30 + int(page * 2), 80) # 每页增加2%,最多80%
|
||||
self.queue.update_task_progress(
|
||||
task_id,
|
||||
progress,
|
||||
f"正在抓取第{page}页...",
|
||||
processed_articles=total_cached
|
||||
)
|
||||
emit_progress(task_id, progress, f"正在抓取第{page}页...", processed_articles=total_cached)
|
||||
|
||||
# 调用 get_articles,传入回调函数和断点参数
|
||||
result = scraper.get_articles(
|
||||
months=months,
|
||||
app_id=app_id,
|
||||
articles_only=articles_only,
|
||||
task_id=task_id,
|
||||
on_page_fetched=save_page_to_db,
|
||||
start_page=start_page,
|
||||
start_ctime=start_ctime
|
||||
)
|
||||
|
||||
if not result or not result.get('completed'):
|
||||
# 未完成,保留断点信息以便续传
|
||||
raise Exception(f"抓取未完成,已保存到第{result.get('last_page', start_page)}页")
|
||||
|
||||
# 从数据库读取全部缓存文章
|
||||
articles = db.get_cached_articles(task_id)
|
||||
|
||||
if not articles:
|
||||
raise Exception("未获取到文章数据")
|
||||
|
||||
# 更新总文章数
|
||||
total = len(articles)
|
||||
self.queue.update_task_status(
|
||||
task_id,
|
||||
TaskStatus.PROCESSING,
|
||||
total_articles=total
|
||||
)
|
||||
emit_log(task_id, f"成功获取 {total} 篇文章", "success")
|
||||
|
||||
# 步骤4: 生成Excel(直接使用数据库中的数据)
|
||||
self.queue.update_task_progress(task_id, 90, "生成Excel文件")
|
||||
emit_progress(task_id, 90, "生成Excel文件")
|
||||
emit_log(task_id, "正在生成 Excel 文件...")
|
||||
|
||||
df = pd.DataFrame(articles)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"百家号文章_{app_id}_{timestamp}.xlsx"
|
||||
result_file = os.path.join(self.queue.results_dir, filename)
|
||||
|
||||
# 保存Excel
|
||||
with pd.ExcelWriter(result_file, engine='openpyxl') as writer:
|
||||
df.to_excel(writer, index=False, sheet_name='文章列表')
|
||||
|
||||
# 调整列宽
|
||||
worksheet = writer.sheets['文章列表']
|
||||
worksheet.column_dimensions['A'].width = 80 # 标题列
|
||||
worksheet.column_dimensions['B'].width = 20 # 时间列
|
||||
|
||||
emit_log(task_id, f"Excel 文件已生成: {filename}")
|
||||
|
||||
# 清除缓存数据(任务已完成)
|
||||
db.clear_article_cache(task_id)
|
||||
emit_log(task_id, "🗑️ 已清除缓存数据", "info")
|
||||
|
||||
# 步骤5: 完成
|
||||
self.queue.update_task_status(
|
||||
task_id,
|
||||
TaskStatus.COMPLETED,
|
||||
progress=100,
|
||||
current_step="处理完成",
|
||||
result_file=filename,
|
||||
processed_articles=total
|
||||
)
|
||||
emit_progress(task_id, 100, "处理完成", result_file=filename)
|
||||
emit_log(task_id, f"✅ 任务完成!导出 {total} 篇文章", "success")
|
||||
|
||||
logger.info(f"工作线程 #{worker_id} 任务完成: {task_id}, 导出 {total} 篇文章")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"工作线程 #{worker_id} 任务失败: {task_id}, 错误: {error_msg}")
|
||||
|
||||
# 记录详细错误堆栈
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(error_traceback)
|
||||
|
||||
# 将错误堆栈也推送到前端(分行推送)
|
||||
emit_log(task_id, f"❌ 任务失败: {error_msg}", "error")
|
||||
|
||||
# 推送错误详情(每行作为独立日志)
|
||||
for line in error_traceback.split('\n'):
|
||||
if line.strip():
|
||||
emit_log(task_id, line, "error")
|
||||
|
||||
# 判断是否需要重试或暂停
|
||||
retry_count += 1
|
||||
|
||||
# 检查是否有缓存数据(如果有,说明部分成功)
|
||||
from database import get_database
|
||||
db = get_database()
|
||||
cached_count = db.get_cached_article_count(task_id)
|
||||
|
||||
# 如果已经有缓存数据,说明部分成功,增加重试次数
|
||||
max_retries = 10 if cached_count > 0 else 3 # 有缓存时允许10次重试
|
||||
|
||||
# 如果连续失败超过上限,暂停任务10分钟
|
||||
if retry_count >= max_retries:
|
||||
from datetime import datetime
|
||||
paused_at = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
if cached_count > 0:
|
||||
emit_log(task_id, f"⚠️ 连续失败{retry_count}次,已缓存{cached_count}篇文章,10分钟后继续尝试", "warning")
|
||||
else:
|
||||
emit_log(task_id, f"⚠️ 连续失败{retry_count}次,任务将暂停10分钟后自动重试", "warning")
|
||||
|
||||
self.queue.update_task_status(
|
||||
task_id,
|
||||
TaskStatus.PAUSED,
|
||||
error=error_msg,
|
||||
last_error=error_msg,
|
||||
retry_count=retry_count,
|
||||
paused_at=paused_at,
|
||||
current_step=f"暂停中(10分钟后重试) - 错误: {error_msg}"
|
||||
)
|
||||
emit_progress(task_id, 0, f"暂停中(10分钟后重试)")
|
||||
|
||||
logger.warning(f"任务 {task_id} 已暂停,将在 {paused_at} 后10分钟恢复,已缓存{cached_count}篇")
|
||||
else:
|
||||
# 重试次数未达到上限,标记为待处理状态,等待下次重试
|
||||
if cached_count > 0:
|
||||
emit_log(task_id, f"⚠️ 任务失败,将进行第{retry_count + 1}次重试(已缓存{cached_count}篇)", "warning")
|
||||
else:
|
||||
emit_log(task_id, f"⚠️ 任务失败,将进行第{retry_count + 1}次重试", "warning")
|
||||
|
||||
self.queue.update_task_status(
|
||||
task_id,
|
||||
TaskStatus.PENDING,
|
||||
error=error_msg,
|
||||
last_error=error_msg,
|
||||
retry_count=retry_count,
|
||||
current_step=f"等待重试 (已失败{retry_count}次) - {error_msg}"
|
||||
)
|
||||
emit_progress(task_id, 0, f"等待重试")
|
||||
|
||||
|
||||
# 全局工作线程
|
||||
_worker = None
|
||||
|
||||
|
||||
def get_task_worker():
|
||||
"""获取全局任务处理器实例"""
|
||||
global _worker
|
||||
if _worker is None:
|
||||
_worker = TaskWorker()
|
||||
return _worker
|
||||
|
||||
|
||||
def start_task_worker():
|
||||
"""启动任务处理器"""
|
||||
worker = get_task_worker()
|
||||
worker.start()
|
||||
|
||||
|
||||
def stop_task_worker():
|
||||
"""停止任务处理器"""
|
||||
worker = get_task_worker()
|
||||
worker.stop()
|
||||
Reference in New Issue
Block a user