Files
ai_baijiahao/task_worker.py

488 lines
20 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
"""
任务处理器 - 后台并发处理队列中的任务
支持动态调整并发数通过 SocketIO 实时推送进度和日志
"""
import threading
import time
import logging
import traceback
import psutil
from task_queue import get_task_queue, TaskStatus
logger = logging.getLogger(__name__)
# 全局变量,用于存储 socketio 实例
_socketio_instance = None
def set_socketio(socketio):
"""设置 SocketIO 实例"""
global _socketio_instance
_socketio_instance = socketio
def emit_log(task_id, message, level='info'):
"""保存日志到数据库"""
from datetime import datetime
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 保存到数据库
try:
from database import get_database
db = get_database()
db.add_task_log(task_id, message, level, timestamp)
except Exception as e:
logger.error(f"保存日志到数据库失败: {e}")
logger.info(f"[{task_id}] {message}")
def emit_progress(task_id, progress, current_step='', **kwargs):
"""更新任务进度"""
logger.info(f"[{task_id}] 进度: {progress}% - {current_step}")
class TaskWorker:
"""任务处理工作线程池(支持动态并发)"""
def __init__(self, min_workers=1, max_workers=3):
self.queue = get_task_queue()
self.running = False
self.min_workers = min_workers # 最小并发数
self.max_workers = max_workers # 最大并发数
self.current_workers = min_workers # 当前并发数
self.worker_threads = [] # 工作线程列表
self.processing_tasks = set() # 正在处理的任务ID集合
self.lock = threading.Lock()
def start(self):
"""启动工作线程池"""
if self.running:
logger.warning("工作线程池已经在运行")
return
self.running = True
# 启动初始工作线程
for i in range(self.min_workers):
self._start_worker(i)
# 启动动态调整线程
self.adjust_thread = threading.Thread(target=self._adjust_workers, daemon=True)
self.adjust_thread.start()
logger.info(f"任务处理器已启动(初始并发数: {self.min_workers},最大并发数: {self.max_workers}")
def _start_worker(self, worker_id):
"""启动一个工作线程"""
thread = threading.Thread(target=self._work_loop, args=(worker_id,), daemon=True)
thread.start()
with self.lock:
self.worker_threads.append(thread)
logger.info(f"工作线程 #{worker_id} 已启动")
def stop(self):
"""停止工作线程池"""
self.running = False
for thread in self.worker_threads:
if thread and thread.is_alive():
thread.join(timeout=5)
logger.info("任务处理器已停止")
def _work_loop(self, worker_id):
"""工作循环(单个线程)"""
logger.info(f"工作线程 #{worker_id} 进入循环")
while self.running:
try:
# 获取待处理任务
task = self.queue.get_pending_task()
if task:
task_id = task["task_id"]
# 检查是否已经有其他线程在处理这个任务
with self.lock:
if task_id in self.processing_tasks:
continue
self.processing_tasks.add(task_id)
try:
logger.info(f"工作线程 #{worker_id} 开始处理任务: {task_id}")
self._process_task(task, worker_id)
finally:
# 处理完成后从集合中移除
with self.lock:
self.processing_tasks.discard(task_id)
else:
# 没有任务,休息一会
time.sleep(2)
except Exception as e:
logger.error(f"工作线程 #{worker_id} 错误: {e}")
logger.error(traceback.format_exc())
time.sleep(5)
logger.info(f"工作线程 #{worker_id} 退出循环")
def _adjust_workers(self):
"""动态调整工作线程数量"""
logger.info("动态调整线程已启动")
while self.running:
try:
time.sleep(10) # 每10秒检查一次
# 获取系统资源信息
cpu_percent = psutil.cpu_percent(interval=1)
memory_percent = psutil.virtual_memory().percent
# 获取队列信息
pending_count = len([t for t in self.queue.get_all_tasks() if t.get('status') == 'pending'])
processing_count = len(self.processing_tasks)
# 决策逻辑
target_workers = self._calculate_target_workers(
pending_count,
processing_count,
cpu_percent,
memory_percent
)
# 调整线程数
if target_workers > self.current_workers:
# 增加线程
for i in range(self.current_workers, target_workers):
self._start_worker(i)
logger.info(f"增加工作线程: {self.current_workers} -> {target_workers}")
self.current_workers = target_workers
elif target_workers < self.current_workers:
# 减少线程(自然退出,不强制终止)
logger.info(f"准备减少工作线程: {self.current_workers} -> {target_workers}")
self.current_workers = target_workers
except Exception as e:
logger.error(f"调整线程数错误: {e}")
logger.error(traceback.format_exc())
time.sleep(30)
def _calculate_target_workers(self, pending_count, processing_count, cpu_percent, memory_percent):
"""计算目标线程数"""
# 基本逻辑:
# 1. 如果没有待处理任务,保持最小线程数
# 2. 如果有很多待处理任务,且系统资源充足,增加线程
# 3. 如果系统资源紧张,减少线程
if pending_count == 0:
return self.min_workers
# 系统资源紧张CPU>80% 或 内存>85%
if cpu_percent > 80 or memory_percent > 85:
logger.warning(f"系统资源紧张 (CPU: {cpu_percent}%, 内存: {memory_percent}%)")
return max(self.min_workers, self.current_workers - 1)
# 系统资源充足
if cpu_percent < 50 and memory_percent < 70:
# 根据待处理任务数决定线程数
if pending_count >= 3:
return min(self.max_workers, self.current_workers + 1)
elif pending_count >= 1:
return min(self.max_workers, max(2, self.current_workers))
# 默认保持当前线程数
return self.current_workers
def _process_task(self, task, worker_id):
"""处理单个任务"""
task_id = task["task_id"]
logger.info(f"工作线程 #{worker_id} 开始处理任务: {task_id}")
emit_log(task_id, f"任务开始处理 (Worker #{worker_id})")
# 获取当前重试次数
retry_count = task.get("retry_count", 0)
try:
# 更新状态为处理中
self.queue.update_task_status(
task_id,
TaskStatus.PROCESSING,
current_step="准备处理"
)
emit_progress(task_id, 5, "准备处理")
# 导入必要的模块
from app import BaijiahaoScraper
import pandas as pd
import os
# 步骤1: 解析URL获取UK
self.queue.update_task_progress(task_id, 10, "解析URL获取UK")
emit_progress(task_id, 10, "解析URL获取UK")
emit_log(task_id, f"URL: {task['url']}")
url = task["url"]
use_proxy = task.get("use_proxy", False)
proxy_api_url = task.get("proxy_api_url")
articles_only = task.get("articles_only", True) # 获取是否仅爬取文章
if use_proxy:
emit_log(task_id, "已启用代理IP池", "info")
if articles_only:
emit_log(task_id, "已启用文章过滤(跳过视频内容)", "info")
# 提取app_id
import re
app_id_match = re.search(r'app_id=([^&\s]+)', url)
if not app_id_match:
raise Exception("无法从 URL 中提取 app_id")
app_id = app_id_match.group(1)
emit_log(task_id, f"解析到 app_id: {app_id}")
# 获取UK
emit_log(task_id, "正在获取用户 UK...")
try:
uk, cookies = BaijiahaoScraper.get_uk_from_app_id(
app_id,
use_proxy=use_proxy,
proxy_api_url=proxy_api_url
)
emit_log(task_id, f"成功获取 UK: {uk[:20]}...")
except Exception as uk_error:
emit_log(task_id, f"获取UK失败: {str(uk_error)}", "error")
raise
# 步骤2: 初始化爬虫
self.queue.update_task_progress(task_id, 20, "初始化爬虫")
emit_progress(task_id, 20, "初始化爬虫")
emit_log(task_id, "初始化爬虫实例...")
scraper = BaijiahaoScraper(
uk=uk,
cookies=cookies,
use_proxy=use_proxy,
proxy_api_url=proxy_api_url
)
# 步骤3: 获取文章列表
months = task.get("months", 6)
# 检查是否有断点续传数据
last_page = task.get("last_page", 0)
last_ctime = task.get("last_ctime")
start_page = 1
start_ctime = None
if last_page > 0 and last_ctime:
# 断点续传
start_page = last_page
start_ctime = last_ctime
emit_log(task_id, f"🔄 检测到断点数据,从第{start_page}页继续爬取", "info")
# 检查缓存中是否有数据
from database import get_database
db = get_database()
cached_count = db.get_cached_article_count(task_id)
if cached_count > 0:
emit_log(task_id, f"💾 已缓存 {cached_count} 篇文章,将继续爬取...", "info")
else:
# 新任务,清除之前的缓存(如果有)
from database import get_database
db = get_database()
db.clear_article_cache(task_id)
self.queue.update_task_progress(task_id, 30, f"获取文章列表(近{months}个月)")
emit_progress(task_id, 30, f"获取文章列表(近{months}个月)")
emit_log(task_id, f"开始获取近 {months} 个月的文章...")
emit_log(task_id, "提示抓取过程较慢8-12秒/页),请耐心等待...", "info")
emit_log(task_id, "系统正在使用代理IP池抓取数据过程中会自动切换IP应对反爬...", "info")
# 定义保存回调函数:每页数据立即保存到数据库
from database import get_database
db = get_database()
def save_page_to_db(page, articles, ctime):
"""保存每页数据到数据库缓存"""
if articles:
db.save_articles_batch(task_id, articles, page)
emit_log(task_id, f"💾 第{page}页数据已保存,{len(articles)}篇文章", "success")
# 更新任务的断点信息
self.queue.update_task_status(
task_id,
TaskStatus.PROCESSING,
last_page=page,
last_ctime=ctime
)
# 更新进度粗略估计30-80%区间)
total_cached = db.get_cached_article_count(task_id)
progress = min(30 + int(page * 2), 80) # 每页增加2%最多80%
self.queue.update_task_progress(
task_id,
progress,
f"正在抓取第{page}页...",
processed_articles=total_cached
)
emit_progress(task_id, progress, f"正在抓取第{page}页...", processed_articles=total_cached)
# 调用 get_articles传入回调函数和断点参数
result = scraper.get_articles(
months=months,
app_id=app_id,
articles_only=articles_only,
task_id=task_id,
on_page_fetched=save_page_to_db,
start_page=start_page,
start_ctime=start_ctime
)
if not result or not result.get('completed'):
# 未完成,保留断点信息以便续传
raise Exception(f"抓取未完成,已保存到第{result.get('last_page', start_page)}")
# 从数据库读取全部缓存文章
articles = db.get_cached_articles(task_id)
if not articles:
raise Exception("未获取到文章数据")
# 更新总文章数
total = len(articles)
self.queue.update_task_status(
task_id,
TaskStatus.PROCESSING,
total_articles=total
)
emit_log(task_id, f"成功获取 {total} 篇文章", "success")
# 步骤4: 生成Excel直接使用数据库中的数据
self.queue.update_task_progress(task_id, 90, "生成Excel文件")
emit_progress(task_id, 90, "生成Excel文件")
emit_log(task_id, "正在生成 Excel 文件...")
df = pd.DataFrame(articles)
# 生成文件名
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"百家号文章_{app_id}_{timestamp}.xlsx"
result_file = os.path.join(self.queue.results_dir, filename)
# 保存Excel
with pd.ExcelWriter(result_file, engine='openpyxl') as writer:
df.to_excel(writer, index=False, sheet_name='文章列表')
# 调整列宽
worksheet = writer.sheets['文章列表']
worksheet.column_dimensions['A'].width = 80 # 标题列
worksheet.column_dimensions['B'].width = 20 # 时间列
emit_log(task_id, f"Excel 文件已生成: {filename}")
# 清除缓存数据(任务已完成)
db.clear_article_cache(task_id)
emit_log(task_id, "🗑️ 已清除缓存数据", "info")
# 步骤5: 完成
self.queue.update_task_status(
task_id,
TaskStatus.COMPLETED,
progress=100,
current_step="处理完成",
result_file=filename,
processed_articles=total
)
emit_progress(task_id, 100, "处理完成", result_file=filename)
emit_log(task_id, f"✅ 任务完成!导出 {total} 篇文章", "success")
logger.info(f"工作线程 #{worker_id} 任务完成: {task_id}, 导出 {total} 篇文章")
except Exception as e:
error_msg = str(e)
logger.error(f"工作线程 #{worker_id} 任务失败: {task_id}, 错误: {error_msg}")
# 记录详细错误堆栈
error_traceback = traceback.format_exc()
logger.error(error_traceback)
# 将错误堆栈也推送到前端(分行推送)
emit_log(task_id, f"❌ 任务失败: {error_msg}", "error")
# 推送错误详情(每行作为独立日志)
for line in error_traceback.split('\n'):
if line.strip():
emit_log(task_id, line, "error")
# 判断是否需要重试或暂停
retry_count += 1
# 检查是否有缓存数据(如果有,说明部分成功)
from database import get_database
db = get_database()
cached_count = db.get_cached_article_count(task_id)
# 如果已经有缓存数据,说明部分成功,增加重试次数
max_retries = 10 if cached_count > 0 else 3 # 有缓存时允许10次重试
# 如果连续失败超过上限暂停任务10分钟
if retry_count >= max_retries:
from datetime import datetime
paused_at = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
if cached_count > 0:
emit_log(task_id, f"⚠️ 连续失败{retry_count}次,已缓存{cached_count}篇文章10分钟后继续尝试", "warning")
else:
emit_log(task_id, f"⚠️ 连续失败{retry_count}任务将暂停10分钟后自动重试", "warning")
self.queue.update_task_status(
task_id,
TaskStatus.PAUSED,
error=error_msg,
last_error=error_msg,
retry_count=retry_count,
paused_at=paused_at,
current_step=f"暂停中10分钟后重试 - 错误: {error_msg}"
)
emit_progress(task_id, 0, f"暂停中10分钟后重试")
logger.warning(f"任务 {task_id} 已暂停,将在 {paused_at} 后10分钟恢复已缓存{cached_count}")
else:
# 重试次数未达到上限,标记为待处理状态,等待下次重试
if cached_count > 0:
emit_log(task_id, f"⚠️ 任务失败,将进行第{retry_count + 1}次重试(已缓存{cached_count}篇)", "warning")
else:
emit_log(task_id, f"⚠️ 任务失败,将进行第{retry_count + 1}次重试", "warning")
self.queue.update_task_status(
task_id,
TaskStatus.PENDING,
error=error_msg,
last_error=error_msg,
retry_count=retry_count,
current_step=f"等待重试 (已失败{retry_count}次) - {error_msg}"
)
emit_progress(task_id, 0, f"等待重试")
# 全局工作线程
_worker = None
def get_task_worker():
"""获取全局任务处理器实例"""
global _worker
if _worker is None:
_worker = TaskWorker()
return _worker
def start_task_worker():
"""启动任务处理器"""
worker = get_task_worker()
worker.start()
def stop_task_worker():
"""停止任务处理器"""
worker = get_task_worker()
worker.stop()