Files
ai_wht_wechat/go_backend/service/scheduler_service.go
2026-01-06 19:36:42 +08:00

561 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"ai_xhs/config"
"ai_xhs/database"
"ai_xhs/models"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/robfig/cron/v3"
)
// SchedulerService 定时任务服务
type SchedulerService struct {
cron *cron.Cron
maxConcurrent int
publishTimeout int
publishSem chan struct{} // 用于控制并发数的信号量
}
// NewSchedulerService 创建定时任务服务
func NewSchedulerService(maxConcurrent, publishTimeout int) *SchedulerService {
// 使用WithSeconds选项支持6位Cron表达式秒 分 时 日 月 周)
return &SchedulerService{
cron: cron.New(cron.WithSeconds()),
maxConcurrent: maxConcurrent,
publishTimeout: publishTimeout,
publishSem: make(chan struct{}, maxConcurrent),
}
}
// Start 启动定时任务
func (s *SchedulerService) Start(cronExpr string) error {
// 添加定时任务
_, err := s.cron.AddFunc(cronExpr, s.AutoPublishArticles)
if err != nil {
return fmt.Errorf("添加定时任务失败: %w", err)
}
// 启动cron
s.cron.Start()
log.Printf("定时发布任务已启动Cron表达式: %s", cronExpr)
return nil
}
// Stop 停止定时任务
func (s *SchedulerService) Stop() {
s.cron.Stop()
log.Println("定时发布任务已停止")
}
const (
defaultMaxArticlesPerUserPerRun = 5
defaultMaxFailuresPerUserPerRun = 3
)
// fetchProxyFromPool 从代理池接口获取一个代理地址http://ip:port
func fetchProxyFromPool() (string, error) {
proxyURL := config.AppConfig.Scheduler.ProxyFetchURL
if proxyURL == "" {
return "", nil
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(proxyURL)
if err != nil {
return "", fmt.Errorf("请求代理池接口失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("代理池接口返回非200状态码: %d", resp.StatusCode)
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取代理池响应失败: %w", err)
}
content := strings.TrimSpace(string(bodyBytes))
if content == "" {
return "", fmt.Errorf("代理池返回内容为空")
}
// 支持多行情况,取第一行 ip:port
line := strings.Split(content, "\n")[0]
line = strings.TrimSpace(line)
if line == "" {
return "", fmt.Errorf("代理池首行内容为空")
}
// 如果已经包含协议前缀,则直接返回
if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") {
return line, nil
}
// 默认补上 http:// 前缀
return "http://" + line, nil
}
func limitArticlesPerUserPerRun(articles []models.Article, perUserLimit int) []models.Article {
if perUserLimit <= 0 {
return articles
}
grouped := make(map[int][]models.Article)
for _, art := range articles {
userID := art.CreatedUserID
if art.PublishUserID != nil {
userID = *art.PublishUserID
}
grouped[userID] = append(grouped[userID], art)
}
limited := make([]models.Article, 0, len(articles))
for _, group := range grouped {
if len(group) > perUserLimit {
limited = append(limited, group[:perUserLimit]...)
} else {
limited = append(limited, group...)
}
}
return limited
}
// filterByDailyAndHourlyLimit 按每日和每小时上限过滤文章
func (s *SchedulerService) filterByDailyAndHourlyLimit(articles []models.Article, maxDaily, maxHourly int) []models.Article {
if maxDaily <= 0 && maxHourly <= 0 {
return articles
}
// 提取所有涉及的用户ID
userIDs := make(map[int]bool)
for _, art := range articles {
userID := art.CreatedUserID
if art.PublishUserID != nil {
userID = *art.PublishUserID
}
userIDs[userID] = true
}
// 批量查询每个用户的当日和当前小时已发布数量
userDailyPublished := make(map[int]int)
userHourlyPublished := make(map[int]int)
now := time.Now()
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
currentHourStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
for userID := range userIDs {
// 查询当日已发布数量
if maxDaily > 0 {
var dailyCount int64
if err := database.DB.Model(&models.Article{}).
Where("status = ? AND publish_time >= ? AND (publish_user_id = ? OR (publish_user_id IS NULL AND created_user_id = ?))",
"published", todayStart, userID, userID).
Count(&dailyCount).Error; err != nil {
log.Printf("[警告] 查询用户 %d 当日已发布数量失败: %v", userID, err)
} else {
userDailyPublished[userID] = int(dailyCount)
}
}
// 查询当前小时已发布数量
if maxHourly > 0 {
var hourlyCount int64
if err := database.DB.Model(&models.Article{}).
Where("status = ? AND publish_time >= ? AND (publish_user_id = ? OR (publish_user_id IS NULL AND created_user_id = ?))",
"published", currentHourStart, userID, userID).
Count(&hourlyCount).Error; err != nil {
log.Printf("[警告] 查询用户 %d 当前小时已发布数量失败: %v", userID, err)
} else {
userHourlyPublished[userID] = int(hourlyCount)
}
}
}
// 过滤超限文章
filtered := make([]models.Article, 0, len(articles))
skippedUsersDailyMap := make(map[int]bool)
skippedUsersHourlyMap := make(map[int]bool)
for _, art := range articles {
userID := art.CreatedUserID
if art.PublishUserID != nil {
userID = *art.PublishUserID
}
// 检查每日上限
if maxDaily > 0 && userDailyPublished[userID] >= maxDaily {
if !skippedUsersDailyMap[userID] {
log.Printf("[频控] 用户 %d 今日已发布 %d 篇,达到每日上限 %d跳过后续文案", userID, userDailyPublished[userID], maxDaily)
skippedUsersDailyMap[userID] = true
}
continue
}
// 检查每小时上限
if maxHourly > 0 && userHourlyPublished[userID] >= maxHourly {
if !skippedUsersHourlyMap[userID] {
log.Printf("[频控] 用户 %d 当前小时已发布 %d 篇,达到每小时上限 %d跳过后续文案", userID, userHourlyPublished[userID], maxHourly)
skippedUsersHourlyMap[userID] = true
}
continue
}
filtered = append(filtered, art)
}
return filtered
}
// AutoPublishArticles 自动发布文案
func (s *SchedulerService) AutoPublishArticles() {
log.Println("========== 开始执行定时发布任务 ==========")
startTime := time.Now()
// 查询所有待发布的文案状态为published_review
var articles []models.Article
if err := database.DB.Where("status = ?", "published_review").Find(&articles).Error; err != nil {
log.Printf("查询待发布文案失败: %v", err)
return
}
if len(articles) == 0 {
log.Println("没有待发布的文案")
return
}
originalTotal := len(articles)
perUserLimit := config.AppConfig.Scheduler.MaxArticlesPerUserPerRun
if perUserLimit <= 0 {
perUserLimit = defaultMaxArticlesPerUserPerRun
}
articles = limitArticlesPerUserPerRun(articles, perUserLimit)
log.Printf("找到 %d 篇待发布文案,按照每个用户每轮最多 %d 篇,本次计划发布 %d 篇", originalTotal, perUserLimit, len(articles))
// 查询每用户每日/每小时已发布数量,过滤超限用户
maxDaily := config.AppConfig.Scheduler.MaxDailyArticlesPerUser
maxHourly := config.AppConfig.Scheduler.MaxHourlyArticlesPerUser
if maxDaily > 0 || maxHourly > 0 {
beforeFilterCount := len(articles)
articles = s.filterByDailyAndHourlyLimit(articles, maxDaily, maxHourly)
log.Printf("应用每日/每小时上限过滤:过滤前 %d 篇,过滤后 %d 篇", beforeFilterCount, len(articles))
}
if len(articles) == 0 {
log.Println("所有文案均因频率限制被过滤,本轮无任务")
return
}
// 并发发布
var wg sync.WaitGroup
successCount := 0
failCount := 0
var mu sync.Mutex
userFailCount := make(map[int]int)
pausedUsers := make(map[int]bool)
failLimit := config.AppConfig.Scheduler.MaxFailuresPerUserPerRun
if failLimit <= 0 {
failLimit = defaultMaxFailuresPerUserPerRun
}
for _, article := range articles {
userID := article.CreatedUserID
if article.PublishUserID != nil {
userID = *article.PublishUserID
}
mu.Lock()
if pausedUsers[userID] {
mu.Unlock()
log.Printf("用户 %d 在本轮已暂停,跳过文案 ID: %d", userID, article.ID)
continue
}
mu.Unlock()
// 获取信号量
s.publishSem <- struct{}{}
wg.Add(1)
go func(art models.Article, uid int) {
defer wg.Done()
defer func() { <-s.publishSem }()
sleepSeconds := 3 + rand.Intn(8)
time.Sleep(time.Duration(sleepSeconds) * time.Second)
// 发布文案
err := s.publishArticle(art)
mu.Lock()
if err != nil {
failCount++
userFailCount[uid]++
if userFailCount[uid] >= failLimit && !pausedUsers[uid] {
pausedUsers[uid] = true
log.Printf("用户 %d 在本轮定时任务中失败次数达到 %d 次,暂停本轮后续发布", uid, userFailCount[uid])
}
log.Printf("发布失败 [文案ID: %d, 标题: %s]: %v", art.ID, art.Title, err)
} else {
successCount++
log.Printf("发布成功 [文案ID: %d, 标题: %s]", art.ID, art.Title)
}
mu.Unlock()
}(article, userID)
}
// 等待所有发布完成
wg.Wait()
duration := time.Since(startTime)
log.Printf("========== 定时发布任务完成 ==========")
log.Printf("总计: %d 篇, 成功: %d 篇, 失败: %d 篇, 耗时: %v",
len(articles), successCount, failCount, duration)
}
// publishArticle 发布单篇文案使用FastAPI服务
func (s *SchedulerService) publishArticle(article models.Article) error {
// 1. 获取用户信息(发布用户)
var user models.User
if article.PublishUserID != nil {
if err := database.DB.First(&user, *article.PublishUserID).Error; err != nil {
return fmt.Errorf("获取发布用户信息失败: %w", err)
}
} else {
// 如果没有发布用户,使用创建用户
if err := database.DB.First(&user, article.CreatedUserID).Error; err != nil {
return fmt.Errorf("获取创建用户信息失败: %w", err)
}
}
// 2. 检查用户是否绑定了小红书并获取author记录
if user.IsBoundXHS != 1 {
return errors.New("用户未绑定小红书账号")
}
// 查询对应的 author 记录获取Cookie
var author models.Author
if err := database.DB.Where(
"phone = ? AND enterprise_id = ? AND channel = 1 AND status = 'active'",
user.Phone, user.EnterpriseID,
).First(&author).Error; err != nil {
return fmt.Errorf("未找到有效的小红书作者记录: %w", err)
}
if author.XHSCookie == "" {
return errors.New("小红书Cookie已失效")
}
// 3. 获取文章图片
var articleImages []models.ArticleImage
if err := database.DB.Where("article_id = ?", article.ID).
Order("sort_order ASC").
Find(&articleImages).Error; err != nil {
return fmt.Errorf("获取文章图片失败: %w", err)
}
// 4. 提取图片URL列表
var imageURLs []string
for _, img := range articleImages {
if img.ImageURL != "" {
imageURLs = append(imageURLs, img.ImageURL)
}
}
// 5. 获取标签
var tags []string
var articleTag models.ArticleTag
if err := database.DB.Where("article_id = ?", article.ID).First(&articleTag).Error; err == nil {
if articleTag.CozeTag != "" {
// 解析标签(支持逗号、分号、空格分隔)
tags = parseTags(articleTag.CozeTag)
}
}
// 6. 准备发布数据优先使用storage_state文件其次使用login_state
var cookiesData interface{}
var loginStateData map[string]interface{}
var useStorageStateMode bool
// 检查storage_state文件是否存在根据手机号查找
storageStateFile := fmt.Sprintf("../backend/storage_states/xhs_%s.json", author.XHSPhone)
if _, err := os.Stat(storageStateFile); err == nil {
log.Printf("[调度器] 检测到storage_state文件: %s", storageStateFile)
useStorageStateMode = true
} else {
log.Printf("[调度器] storage_state文件不存在使用login_state或cookies模式")
useStorageStateMode = false
// 尝试解析为JSON对象
if err := json.Unmarshal([]byte(author.XHSCookie), &loginStateData); err == nil {
// 检查是否是login_state格式包含cookies字段
if _, ok := loginStateData["cookies"]; ok {
log.Printf("[调度器] 检测到login_state格式将使用完整登录状态")
cookiesData = loginStateData // 使用完整的login_state
} else {
// 可能是cookies数组
log.Printf("[调度器] 检测到纯cookies格式")
cookiesData = loginStateData
}
} else {
return fmt.Errorf("解析Cookie失败: %wCookie内容: %s", err, author.XHSCookie[:100])
}
}
// 7. 调用FastAPI服务使用浏览器池+预热)
fastAPIURL := config.AppConfig.XHS.PythonServiceURL
if fastAPIURL == "" {
fastAPIURL = "http://localhost:8000" // 默认地址
}
publishEndpoint := fastAPIURL + "/api/xhs/publish-with-cookies"
// 构造请求体
// 优先级storage_state文件 > login_state > cookies
var fullRequest map[string]interface{}
if useStorageStateMode {
// 模式1使用storage_state文件通过手机号查找
fullRequest = map[string]interface{}{
"phone": author.XHSPhone, // 传递手机号Python后端会根据手机号查找文件
"title": article.Title,
"content": article.Content,
"images": imageURLs,
"topics": tags,
}
log.Printf("[调度器] 使用storage_state模式发布手机号: %s", author.XHSPhone)
} else if loginState, ok := cookiesData.(map[string]interface{}); ok {
if _, hasLoginStateStructure := loginState["cookies"]; hasLoginStateStructure {
// 模式2完整的login_state格式
fullRequest = map[string]interface{}{
"login_state": loginState,
"title": article.Title,
"content": article.Content,
"images": imageURLs,
"topics": tags,
}
log.Printf("[调度器] 使用login_state模式发布")
} else {
// 模式3纺cookies格式
fullRequest = map[string]interface{}{
"cookies": loginState,
"title": article.Title,
"content": article.Content,
"images": imageURLs,
"topics": tags,
}
log.Printf("[调度器] 使用cookies模式发布")
}
} else {
// 兜底:直接发送
fullRequest = map[string]interface{}{
"cookies": cookiesData,
"title": article.Title,
"content": article.Content,
"images": imageURLs,
"topics": tags,
}
}
requestBody, err := json.Marshal(fullRequest)
if err != nil {
return fmt.Errorf("构造请求数据失败: %w", err)
}
// 发送HTTP请求
timeout := time.Duration(s.publishTimeout) * time.Second
if s.publishTimeout <= 0 {
timeout = 120 * time.Second // 默认120秒超时
}
client := &http.Client{Timeout: timeout}
resp, err := client.Post(publishEndpoint, "application/json", bytes.NewBuffer(requestBody))
if err != nil {
s.updateArticleStatus(article.ID, "failed", fmt.Sprintf("调用FastAPI服务失败: %v", err))
return fmt.Errorf("调用FastAPI服务失败: %w", err)
}
defer resp.Body.Close()
// 9. 解析响应
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
s.updateArticleStatus(article.ID, "failed", fmt.Sprintf("解析FastAPI响应失败: %v", err))
return fmt.Errorf("解析FastAPI响应失败: %w", err)
}
// 10. 检查发布是否成功
code, ok := result["code"].(float64)
if !ok || code != 0 {
errMsg := "未知错误"
if msg, ok := result["message"].(string); ok {
errMsg = msg
}
s.updateArticleStatus(article.ID, "failed", errMsg)
return fmt.Errorf("发布失败: %s", errMsg)
}
// 11. 更新文章状态为published
s.updateArticleStatus(article.ID, "published", "发布成功")
log.Printf("[使用FastAPI] 文章 %d 发布成功,享受浏览器池+预热加速", article.ID)
return nil
}
// updateArticleStatus 更新文章状态
func (s *SchedulerService) updateArticleStatus(articleID int, status, message string) {
updates := map[string]interface{}{
"status": status,
}
if status == "published" {
now := time.Now()
updates["publish_time"] = now
}
if message != "" {
updates["review_comment"] = message
}
if err := database.DB.Model(&models.Article{}).Where("id = ?", articleID).Updates(updates).Error; err != nil {
log.Printf("更新文章%d状态失败: %v", articleID, err)
}
}
// parseTags 解析标签字符串(支持逗号、分号、空格分隔)
func parseTags(tagStr string) []string {
if tagStr == "" {
return nil
}
// 统一使用逗号分隔符
tagStr = strings.ReplaceAll(tagStr, ";", ",")
tagStr = strings.ReplaceAll(tagStr, " ", ",")
tagStr = strings.ReplaceAll(tagStr, "、", ",")
tagsRaw := strings.Split(tagStr, ",")
var tags []string
for _, tag := range tagsRaw {
tag = strings.TrimSpace(tag)
if tag != "" {
tags = append(tags, tag)
}
}
return tags
}