Files
ai_baijiahao/task_worker.py

488 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
任务处理器 - 后台并发处理队列中的任务
支持动态调整并发数,通过 SocketIO 实时推送进度和日志
"""
import threading
import time
import logging
import traceback
import psutil
from task_queue import get_task_queue, TaskStatus
logger = logging.getLogger(__name__)
# 全局变量,用于存储 socketio 实例
_socketio_instance = None
def set_socketio(socketio):
"""设置 SocketIO 实例"""
global _socketio_instance
_socketio_instance = socketio
def emit_log(task_id, message, level='info'):
"""保存日志到数据库"""
from datetime import datetime
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 保存到数据库
try:
from database import get_database
db = get_database()
db.add_task_log(task_id, message, level, timestamp)
except Exception as e:
logger.error(f"保存日志到数据库失败: {e}")
logger.info(f"[{task_id}] {message}")
def emit_progress(task_id, progress, current_step='', **kwargs):
"""更新任务进度"""
logger.info(f"[{task_id}] 进度: {progress}% - {current_step}")
class TaskWorker:
"""任务处理工作线程池(支持动态并发)"""
def __init__(self, min_workers=1, max_workers=3):
self.queue = get_task_queue()
self.running = False
self.min_workers = min_workers # 最小并发数
self.max_workers = max_workers # 最大并发数
self.current_workers = min_workers # 当前并发数
self.worker_threads = [] # 工作线程列表
self.processing_tasks = set() # 正在处理的任务ID集合
self.lock = threading.Lock()
def start(self):
"""启动工作线程池"""
if self.running:
logger.warning("工作线程池已经在运行")
return
self.running = True
# 启动初始工作线程
for i in range(self.min_workers):
self._start_worker(i)
# 启动动态调整线程
self.adjust_thread = threading.Thread(target=self._adjust_workers, daemon=True)
self.adjust_thread.start()
logger.info(f"任务处理器已启动(初始并发数: {self.min_workers},最大并发数: {self.max_workers}")
def _start_worker(self, worker_id):
"""启动一个工作线程"""
thread = threading.Thread(target=self._work_loop, args=(worker_id,), daemon=True)
thread.start()
with self.lock:
self.worker_threads.append(thread)
logger.info(f"工作线程 #{worker_id} 已启动")
def stop(self):
"""停止工作线程池"""
self.running = False
for thread in self.worker_threads:
if thread and thread.is_alive():
thread.join(timeout=5)
logger.info("任务处理器已停止")
def _work_loop(self, worker_id):
"""工作循环(单个线程)"""
logger.info(f"工作线程 #{worker_id} 进入循环")
while self.running:
try:
# 获取待处理任务
task = self.queue.get_pending_task()
if task:
task_id = task["task_id"]
# 检查是否已经有其他线程在处理这个任务
with self.lock:
if task_id in self.processing_tasks:
continue
self.processing_tasks.add(task_id)
try:
logger.info(f"工作线程 #{worker_id} 开始处理任务: {task_id}")
self._process_task(task, worker_id)
finally:
# 处理完成后从集合中移除
with self.lock:
self.processing_tasks.discard(task_id)
else:
# 没有任务,休息一会
time.sleep(2)
except Exception as e:
logger.error(f"工作线程 #{worker_id} 错误: {e}")
logger.error(traceback.format_exc())
time.sleep(5)
logger.info(f"工作线程 #{worker_id} 退出循环")
def _adjust_workers(self):
"""动态调整工作线程数量"""
logger.info("动态调整线程已启动")
while self.running:
try:
time.sleep(10) # 每10秒检查一次
# 获取系统资源信息
cpu_percent = psutil.cpu_percent(interval=1)
memory_percent = psutil.virtual_memory().percent
# 获取队列信息
pending_count = len([t for t in self.queue.get_all_tasks() if t.get('status') == 'pending'])
processing_count = len(self.processing_tasks)
# 决策逻辑
target_workers = self._calculate_target_workers(
pending_count,
processing_count,
cpu_percent,
memory_percent
)
# 调整线程数
if target_workers > self.current_workers:
# 增加线程
for i in range(self.current_workers, target_workers):
self._start_worker(i)
logger.info(f"增加工作线程: {self.current_workers} -> {target_workers}")
self.current_workers = target_workers
elif target_workers < self.current_workers:
# 减少线程(自然退出,不强制终止)
logger.info(f"准备减少工作线程: {self.current_workers} -> {target_workers}")
self.current_workers = target_workers
except Exception as e:
logger.error(f"调整线程数错误: {e}")
logger.error(traceback.format_exc())
time.sleep(30)
def _calculate_target_workers(self, pending_count, processing_count, cpu_percent, memory_percent):
"""计算目标线程数"""
# 基本逻辑:
# 1. 如果没有待处理任务,保持最小线程数
# 2. 如果有很多待处理任务,且系统资源充足,增加线程
# 3. 如果系统资源紧张,减少线程
if pending_count == 0:
return self.min_workers
# 系统资源紧张CPU>80% 或 内存>85%
if cpu_percent > 80 or memory_percent > 85:
logger.warning(f"系统资源紧张 (CPU: {cpu_percent}%, 内存: {memory_percent}%)")
return max(self.min_workers, self.current_workers - 1)
# 系统资源充足
if cpu_percent < 50 and memory_percent < 70:
# 根据待处理任务数决定线程数
if pending_count >= 3:
return min(self.max_workers, self.current_workers + 1)
elif pending_count >= 1:
return min(self.max_workers, max(2, self.current_workers))
# 默认保持当前线程数
return self.current_workers
def _process_task(self, task, worker_id):
"""处理单个任务"""
task_id = task["task_id"]
logger.info(f"工作线程 #{worker_id} 开始处理任务: {task_id}")
emit_log(task_id, f"任务开始处理 (Worker #{worker_id})")
# 获取当前重试次数
retry_count = task.get("retry_count", 0)
try:
# 更新状态为处理中
self.queue.update_task_status(
task_id,
TaskStatus.PROCESSING,
current_step="准备处理"
)
emit_progress(task_id, 5, "准备处理")
# 导入必要的模块
from app import BaijiahaoScraper
import pandas as pd
import os
# 步骤1: 解析URL获取UK
self.queue.update_task_progress(task_id, 10, "解析URL获取UK")
emit_progress(task_id, 10, "解析URL获取UK")
emit_log(task_id, f"URL: {task['url']}")
url = task["url"]
use_proxy = task.get("use_proxy", False)
proxy_api_url = task.get("proxy_api_url")
articles_only = task.get("articles_only", True) # 获取是否仅爬取文章
if use_proxy:
emit_log(task_id, "已启用代理IP池", "info")
if articles_only:
emit_log(task_id, "已启用文章过滤(跳过视频内容)", "info")
# 提取app_id
import re
app_id_match = re.search(r'app_id=([^&\s]+)', url)
if not app_id_match:
raise Exception("无法从 URL 中提取 app_id")
app_id = app_id_match.group(1)
emit_log(task_id, f"解析到 app_id: {app_id}")
# 获取UK
emit_log(task_id, "正在获取用户 UK...")
try:
uk, cookies = BaijiahaoScraper.get_uk_from_app_id(
app_id,
use_proxy=use_proxy,
proxy_api_url=proxy_api_url
)
emit_log(task_id, f"成功获取 UK: {uk[:20]}...")
except Exception as uk_error:
emit_log(task_id, f"获取UK失败: {str(uk_error)}", "error")
raise
# 步骤2: 初始化爬虫
self.queue.update_task_progress(task_id, 20, "初始化爬虫")
emit_progress(task_id, 20, "初始化爬虫")
emit_log(task_id, "初始化爬虫实例...")
scraper = BaijiahaoScraper(
uk=uk,
cookies=cookies,
use_proxy=use_proxy,
proxy_api_url=proxy_api_url
)
# 步骤3: 获取文章列表
months = task.get("months", 6)
# 检查是否有断点续传数据
last_page = task.get("last_page", 0)
last_ctime = task.get("last_ctime")
start_page = 1
start_ctime = None
if last_page > 0 and last_ctime:
# 断点续传
start_page = last_page
start_ctime = last_ctime
emit_log(task_id, f"🔄 检测到断点数据,从第{start_page}页继续爬取", "info")
# 检查缓存中是否有数据
from database import get_database
db = get_database()
cached_count = db.get_cached_article_count(task_id)
if cached_count > 0:
emit_log(task_id, f"💾 已缓存 {cached_count} 篇文章,将继续爬取...", "info")
else:
# 新任务,清除之前的缓存(如果有)
from database import get_database
db = get_database()
db.clear_article_cache(task_id)
self.queue.update_task_progress(task_id, 30, f"获取文章列表(近{months}个月)")
emit_progress(task_id, 30, f"获取文章列表(近{months}个月)")
emit_log(task_id, f"开始获取近 {months} 个月的文章...")
emit_log(task_id, "提示抓取过程较慢8-12秒/页),请耐心等待...", "info")
emit_log(task_id, "系统正在使用代理IP池抓取数据过程中会自动切换IP应对反爬...", "info")
# 定义保存回调函数:每页数据立即保存到数据库
from database import get_database
db = get_database()
def save_page_to_db(page, articles, ctime):
"""保存每页数据到数据库缓存"""
if articles:
db.save_articles_batch(task_id, articles, page)
emit_log(task_id, f"💾 第{page}页数据已保存,{len(articles)}篇文章", "success")
# 更新任务的断点信息
self.queue.update_task_status(
task_id,
TaskStatus.PROCESSING,
last_page=page,
last_ctime=ctime
)
# 更新进度粗略估计30-80%区间)
total_cached = db.get_cached_article_count(task_id)
progress = min(30 + int(page * 2), 80) # 每页增加2%最多80%
self.queue.update_task_progress(
task_id,
progress,
f"正在抓取第{page}页...",
processed_articles=total_cached
)
emit_progress(task_id, progress, f"正在抓取第{page}页...", processed_articles=total_cached)
# 调用 get_articles传入回调函数和断点参数
result = scraper.get_articles(
months=months,
app_id=app_id,
articles_only=articles_only,
task_id=task_id,
on_page_fetched=save_page_to_db,
start_page=start_page,
start_ctime=start_ctime
)
if not result or not result.get('completed'):
# 未完成,保留断点信息以便续传
raise Exception(f"抓取未完成,已保存到第{result.get('last_page', start_page)}")
# 从数据库读取全部缓存文章
articles = db.get_cached_articles(task_id)
if not articles:
raise Exception("未获取到文章数据")
# 更新总文章数
total = len(articles)
self.queue.update_task_status(
task_id,
TaskStatus.PROCESSING,
total_articles=total
)
emit_log(task_id, f"成功获取 {total} 篇文章", "success")
# 步骤4: 生成Excel直接使用数据库中的数据
self.queue.update_task_progress(task_id, 90, "生成Excel文件")
emit_progress(task_id, 90, "生成Excel文件")
emit_log(task_id, "正在生成 Excel 文件...")
df = pd.DataFrame(articles)
# 生成文件名
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"百家号文章_{app_id}_{timestamp}.xlsx"
result_file = os.path.join(self.queue.results_dir, filename)
# 保存Excel
with pd.ExcelWriter(result_file, engine='openpyxl') as writer:
df.to_excel(writer, index=False, sheet_name='文章列表')
# 调整列宽
worksheet = writer.sheets['文章列表']
worksheet.column_dimensions['A'].width = 80 # 标题列
worksheet.column_dimensions['B'].width = 20 # 时间列
emit_log(task_id, f"Excel 文件已生成: {filename}")
# 清除缓存数据(任务已完成)
db.clear_article_cache(task_id)
emit_log(task_id, "🗑️ 已清除缓存数据", "info")
# 步骤5: 完成
self.queue.update_task_status(
task_id,
TaskStatus.COMPLETED,
progress=100,
current_step="处理完成",
result_file=filename,
processed_articles=total
)
emit_progress(task_id, 100, "处理完成", result_file=filename)
emit_log(task_id, f"✅ 任务完成!导出 {total} 篇文章", "success")
logger.info(f"工作线程 #{worker_id} 任务完成: {task_id}, 导出 {total} 篇文章")
except Exception as e:
error_msg = str(e)
logger.error(f"工作线程 #{worker_id} 任务失败: {task_id}, 错误: {error_msg}")
# 记录详细错误堆栈
error_traceback = traceback.format_exc()
logger.error(error_traceback)
# 将错误堆栈也推送到前端(分行推送)
emit_log(task_id, f"❌ 任务失败: {error_msg}", "error")
# 推送错误详情(每行作为独立日志)
for line in error_traceback.split('\n'):
if line.strip():
emit_log(task_id, line, "error")
# 判断是否需要重试或暂停
retry_count += 1
# 检查是否有缓存数据(如果有,说明部分成功)
from database import get_database
db = get_database()
cached_count = db.get_cached_article_count(task_id)
# 如果已经有缓存数据,说明部分成功,增加重试次数
max_retries = 10 if cached_count > 0 else 3 # 有缓存时允许10次重试
# 如果连续失败超过上限暂停任务10分钟
if retry_count >= max_retries:
from datetime import datetime
paused_at = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
if cached_count > 0:
emit_log(task_id, f"⚠️ 连续失败{retry_count}次,已缓存{cached_count}篇文章10分钟后继续尝试", "warning")
else:
emit_log(task_id, f"⚠️ 连续失败{retry_count}任务将暂停10分钟后自动重试", "warning")
self.queue.update_task_status(
task_id,
TaskStatus.PAUSED,
error=error_msg,
last_error=error_msg,
retry_count=retry_count,
paused_at=paused_at,
current_step=f"暂停中10分钟后重试 - 错误: {error_msg}"
)
emit_progress(task_id, 0, f"暂停中10分钟后重试")
logger.warning(f"任务 {task_id} 已暂停,将在 {paused_at} 后10分钟恢复已缓存{cached_count}")
else:
# 重试次数未达到上限,标记为待处理状态,等待下次重试
if cached_count > 0:
emit_log(task_id, f"⚠️ 任务失败,将进行第{retry_count + 1}次重试(已缓存{cached_count}篇)", "warning")
else:
emit_log(task_id, f"⚠️ 任务失败,将进行第{retry_count + 1}次重试", "warning")
self.queue.update_task_status(
task_id,
TaskStatus.PENDING,
error=error_msg,
last_error=error_msg,
retry_count=retry_count,
current_step=f"等待重试 (已失败{retry_count}次) - {error_msg}"
)
emit_progress(task_id, 0, f"等待重试")
# 全局工作线程
_worker = None
def get_task_worker():
"""获取全局任务处理器实例"""
global _worker
if _worker is None:
_worker = TaskWorker()
return _worker
def start_task_worker():
"""启动任务处理器"""
worker = get_task_worker()
worker.start()
def stop_task_worker():
"""停止任务处理器"""
worker = get_task_worker()
worker.stop()