Files
ai_wht_wechat/go_backend/service/scheduler_service.go
2025-12-19 22:36:48 +08:00

562 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"ai_xhs/config"
"ai_xhs/database"
"ai_xhs/models"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
"github.com/robfig/cron/v3"
)
// SchedulerService 定时任务服务
type SchedulerService struct {
cron *cron.Cron
maxConcurrent int
publishTimeout int
publishSem chan struct{} // 用于控制并发数的信号量
}
// NewSchedulerService 创建定时任务服务
func NewSchedulerService(maxConcurrent, publishTimeout int) *SchedulerService {
// 使用WithSeconds选项支持6位Cron表达式秒 分 时 日 月 周)
return &SchedulerService{
cron: cron.New(cron.WithSeconds()),
maxConcurrent: maxConcurrent,
publishTimeout: publishTimeout,
publishSem: make(chan struct{}, maxConcurrent),
}
}
// Start 启动定时任务
func (s *SchedulerService) Start(cronExpr string) error {
// 添加定时任务
_, err := s.cron.AddFunc(cronExpr, s.AutoPublishArticles)
if err != nil {
return fmt.Errorf("添加定时任务失败: %w", err)
}
// 启动cron
s.cron.Start()
log.Printf("定时发布任务已启动Cron表达式: %s", cronExpr)
return nil
}
// Stop 停止定时任务
func (s *SchedulerService) Stop() {
s.cron.Stop()
log.Println("定时发布任务已停止")
}
const (
defaultMaxArticlesPerUserPerRun = 5
defaultMaxFailuresPerUserPerRun = 3
)
// fetchProxyFromPool 从代理池接口获取一个代理地址http://ip:port
func fetchProxyFromPool() (string, error) {
proxyURL := config.AppConfig.Scheduler.ProxyFetchURL
if proxyURL == "" {
return "", nil
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(proxyURL)
if err != nil {
return "", fmt.Errorf("请求代理池接口失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("代理池接口返回非200状态码: %d", resp.StatusCode)
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取代理池响应失败: %w", err)
}
content := strings.TrimSpace(string(bodyBytes))
if content == "" {
return "", fmt.Errorf("代理池返回内容为空")
}
// 支持多行情况,取第一行 ip:port
line := strings.Split(content, "\n")[0]
line = strings.TrimSpace(line)
if line == "" {
return "", fmt.Errorf("代理池首行内容为空")
}
// 如果已经包含协议前缀,则直接返回
if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") {
return line, nil
}
// 默认补上 http:// 前缀
return "http://" + line, nil
}
func limitArticlesPerUserPerRun(articles []models.Article, perUserLimit int) []models.Article {
if perUserLimit <= 0 {
return articles
}
grouped := make(map[int][]models.Article)
for _, art := range articles {
userID := art.CreatedUserID
if art.PublishUserID != nil {
userID = *art.PublishUserID
}
grouped[userID] = append(grouped[userID], art)
}
limited := make([]models.Article, 0, len(articles))
for _, group := range grouped {
if len(group) > perUserLimit {
limited = append(limited, group[:perUserLimit]...)
} else {
limited = append(limited, group...)
}
}
return limited
}
// filterByDailyAndHourlyLimit 按每日和每小时上限过滤文章
func (s *SchedulerService) filterByDailyAndHourlyLimit(articles []models.Article, maxDaily, maxHourly int) []models.Article {
if maxDaily <= 0 && maxHourly <= 0 {
return articles
}
// 提取所有涉及的用户ID
userIDs := make(map[int]bool)
for _, art := range articles {
userID := art.CreatedUserID
if art.PublishUserID != nil {
userID = *art.PublishUserID
}
userIDs[userID] = true
}
// 批量查询每个用户的当日和当前小时已发布数量
userDailyPublished := make(map[int]int)
userHourlyPublished := make(map[int]int)
now := time.Now()
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
currentHourStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
for userID := range userIDs {
// 查询当日已发布数量
if maxDaily > 0 {
var dailyCount int64
if err := database.DB.Model(&models.Article{}).
Where("status = ? AND publish_time >= ? AND (publish_user_id = ? OR (publish_user_id IS NULL AND created_user_id = ?))",
"published", todayStart, userID, userID).
Count(&dailyCount).Error; err != nil {
log.Printf("[警告] 查询用户 %d 当日已发布数量失败: %v", userID, err)
} else {
userDailyPublished[userID] = int(dailyCount)
}
}
// 查询当前小时已发布数量
if maxHourly > 0 {
var hourlyCount int64
if err := database.DB.Model(&models.Article{}).
Where("status = ? AND publish_time >= ? AND (publish_user_id = ? OR (publish_user_id IS NULL AND created_user_id = ?))",
"published", currentHourStart, userID, userID).
Count(&hourlyCount).Error; err != nil {
log.Printf("[警告] 查询用户 %d 当前小时已发布数量失败: %v", userID, err)
} else {
userHourlyPublished[userID] = int(hourlyCount)
}
}
}
// 过滤超限文章
filtered := make([]models.Article, 0, len(articles))
skippedUsersDailyMap := make(map[int]bool)
skippedUsersHourlyMap := make(map[int]bool)
for _, art := range articles {
userID := art.CreatedUserID
if art.PublishUserID != nil {
userID = *art.PublishUserID
}
// 检查每日上限
if maxDaily > 0 && userDailyPublished[userID] >= maxDaily {
if !skippedUsersDailyMap[userID] {
log.Printf("[频控] 用户 %d 今日已发布 %d 篇,达到每日上限 %d跳过后续文案", userID, userDailyPublished[userID], maxDaily)
skippedUsersDailyMap[userID] = true
}
continue
}
// 检查每小时上限
if maxHourly > 0 && userHourlyPublished[userID] >= maxHourly {
if !skippedUsersHourlyMap[userID] {
log.Printf("[频控] 用户 %d 当前小时已发布 %d 篇,达到每小时上限 %d跳过后续文案", userID, userHourlyPublished[userID], maxHourly)
skippedUsersHourlyMap[userID] = true
}
continue
}
filtered = append(filtered, art)
}
return filtered
}
// AutoPublishArticles 自动发布文案
func (s *SchedulerService) AutoPublishArticles() {
log.Println("========== 开始执行定时发布任务 ==========")
startTime := time.Now()
// 查询所有待发布的文案状态为published_review
var articles []models.Article
if err := database.DB.Where("status = ?", "published_review").Find(&articles).Error; err != nil {
log.Printf("查询待发布文案失败: %v", err)
return
}
if len(articles) == 0 {
log.Println("没有待发布的文案")
return
}
originalTotal := len(articles)
perUserLimit := config.AppConfig.Scheduler.MaxArticlesPerUserPerRun
if perUserLimit <= 0 {
perUserLimit = defaultMaxArticlesPerUserPerRun
}
articles = limitArticlesPerUserPerRun(articles, perUserLimit)
log.Printf("找到 %d 篇待发布文案,按照每个用户每轮最多 %d 篇,本次计划发布 %d 篇", originalTotal, perUserLimit, len(articles))
// 查询每用户每日/每小时已发布数量,过滤超限用户
maxDaily := config.AppConfig.Scheduler.MaxDailyArticlesPerUser
maxHourly := config.AppConfig.Scheduler.MaxHourlyArticlesPerUser
if maxDaily > 0 || maxHourly > 0 {
beforeFilterCount := len(articles)
articles = s.filterByDailyAndHourlyLimit(articles, maxDaily, maxHourly)
log.Printf("应用每日/每小时上限过滤:过滤前 %d 篇,过滤后 %d 篇", beforeFilterCount, len(articles))
}
if len(articles) == 0 {
log.Println("所有文案均因频率限制被过滤,本轮无任务")
return
}
// 并发发布
var wg sync.WaitGroup
successCount := 0
failCount := 0
var mu sync.Mutex
userFailCount := make(map[int]int)
pausedUsers := make(map[int]bool)
failLimit := config.AppConfig.Scheduler.MaxFailuresPerUserPerRun
if failLimit <= 0 {
failLimit = defaultMaxFailuresPerUserPerRun
}
for _, article := range articles {
userID := article.CreatedUserID
if article.PublishUserID != nil {
userID = *article.PublishUserID
}
mu.Lock()
if pausedUsers[userID] {
mu.Unlock()
log.Printf("用户 %d 在本轮已暂停,跳过文案 ID: %d", userID, article.ID)
continue
}
mu.Unlock()
// 获取信号量
s.publishSem <- struct{}{}
wg.Add(1)
go func(art models.Article, uid int) {
defer wg.Done()
defer func() { <-s.publishSem }()
sleepSeconds := 3 + rand.Intn(8)
time.Sleep(time.Duration(sleepSeconds) * time.Second)
// 发布文案
err := s.publishArticle(art)
mu.Lock()
if err != nil {
failCount++
userFailCount[uid]++
if userFailCount[uid] >= failLimit && !pausedUsers[uid] {
pausedUsers[uid] = true
log.Printf("用户 %d 在本轮定时任务中失败次数达到 %d 次,暂停本轮后续发布", uid, userFailCount[uid])
}
log.Printf("发布失败 [文案ID: %d, 标题: %s]: %v", art.ID, art.Title, err)
} else {
successCount++
log.Printf("发布成功 [文案ID: %d, 标题: %s]", art.ID, art.Title)
}
mu.Unlock()
}(article, userID)
}
// 等待所有发布完成
wg.Wait()
duration := time.Since(startTime)
log.Printf("========== 定时发布任务完成 ==========")
log.Printf("总计: %d 篇, 成功: %d 篇, 失败: %d 篇, 耗时: %v",
len(articles), successCount, failCount, duration)
}
// publishArticle 发布单篇文案
func (s *SchedulerService) publishArticle(article models.Article) error {
// 1. 获取用户信息(发布用户)
var user models.User
if article.PublishUserID != nil {
if err := database.DB.First(&user, *article.PublishUserID).Error; err != nil {
return fmt.Errorf("获取发布用户信息失败: %w", err)
}
} else {
// 如果没有发布用户,使用创建用户
if err := database.DB.First(&user, article.CreatedUserID).Error; err != nil {
return fmt.Errorf("获取创建用户信息失败: %w", err)
}
}
// 2. 检查用户是否绑定了小红书
if user.IsBoundXHS != 1 || user.XHSCookie == "" {
return errors.New("用户未绑定小红书账号或Cookie已失效")
}
// 3. 获取文章图片
var articleImages []models.ArticleImage
if err := database.DB.Where("article_id = ?", article.ID).
Order("sort_order ASC").
Find(&articleImages).Error; err != nil {
return fmt.Errorf("获取文章图片失败: %w", err)
}
// 4. 提取图片URL列表
var imageURLs []string
for _, img := range articleImages {
if img.ImageURL != "" {
imageURLs = append(imageURLs, img.ImageURL)
}
}
// 5. 获取标签
var tags []string
var articleTag models.ArticleTag
if err := database.DB.Where("article_id = ?", article.ID).First(&articleTag).Error; err == nil {
if articleTag.CozeTag != "" {
// 解析标签(支持逗号、分号、空格分隔)
tags = parseTags(articleTag.CozeTag)
}
}
// 6. 解析Cookie数据库中存储的是JSON字符串
var cookies interface{}
if err := json.Unmarshal([]byte(user.XHSCookie), &cookies); err != nil {
return fmt.Errorf("解析Cookie失败: %wCookie内容: %s", err, user.XHSCookie)
}
// 7. 构造发布配置
publishConfig := map[string]interface{}{
"cookies": cookies, // 解析后的Cookie对象或数组
"title": article.Title,
"content": article.Content,
"images": imageURLs,
"tags": tags,
}
// 决定本次发布使用的代理
proxyToUse := config.AppConfig.Scheduler.Proxy
if proxyToUse == "" && config.AppConfig.Scheduler.ProxyFetchURL != "" {
if dynamicProxy, err := fetchProxyFromPool(); err != nil {
log.Printf("[代理池] 获取代理失败: %v", err)
} else if dynamicProxy != "" {
proxyToUse = dynamicProxy
log.Printf("[代理池] 使用动态代理: %s", proxyToUse)
}
}
// 注入代理和User-Agent如果有配置
if proxyToUse != "" {
publishConfig["proxy"] = proxyToUse
}
if ua := config.AppConfig.Scheduler.UserAgent; ua != "" {
publishConfig["user_agent"] = ua
}
// 8. 保存临时配置文件
tempDir := filepath.Join("..", "backend", "temp")
os.MkdirAll(tempDir, 0755)
configFile := filepath.Join(tempDir, fmt.Sprintf("publish_%d_%d.json", article.ID, time.Now().Unix()))
configData, err := json.MarshalIndent(publishConfig, "", " ")
if err != nil {
return fmt.Errorf("生成配置文件失败: %w", err)
}
if err := os.WriteFile(configFile, configData, 0644); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
}
defer os.Remove(configFile) // 发布完成后删除临时文件
// 9. 调用Python发布脚本
backendDir := filepath.Join("..", "backend")
pythonScript := filepath.Join(backendDir, "xhs_publish.py")
pythonCmd := getPythonPath(backendDir)
cmd := exec.Command(pythonCmd, pythonScript, "--config", configFile)
cmd.Dir = backendDir
// 设置超时
if s.publishTimeout > 0 {
timer := time.AfterFunc(time.Duration(s.publishTimeout)*time.Second, func() {
cmd.Process.Kill()
})
defer timer.Stop()
}
// 捕获输出
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// 执行命令
err = cmd.Run()
// 打印Python脚本日志
if stderr.Len() > 0 {
log.Printf("[Python日志-发布文案%d] %s", article.ID, stderr.String())
}
if err != nil {
// 更新文章状态为failed
s.updateArticleStatus(article.ID, "failed", fmt.Sprintf("发布失败: %v", err))
return fmt.Errorf("执行Python脚本失败: %w, stderr: %s", err, stderr.String())
}
// 10. 解析发布结果
// 注意Python脚本可能输出日志到stdout需要提取最后一行JSON
outputStr := stdout.String()
// 查找最后一个完整的JSON对象
var result map[string]interface{}
found := false
// 尝试从最后一行开始解析JSON
lines := strings.Split(strings.TrimSpace(outputStr), "\n")
// 从后往前找第一个有效的JSON
for i := len(lines) - 1; i >= 0; i-- {
line := strings.TrimSpace(lines[i])
if line == "" {
continue
}
// 尝试解析为JSON必须以{开头)
if strings.HasPrefix(line, "{") {
if err := json.Unmarshal([]byte(line), &result); err == nil {
found = true
log.Printf("成功解析JSON结果(第%d行): %s", i+1, line)
break
}
}
}
if !found {
errMsg := "Python脚本未返回有效JSON结果"
s.updateArticleStatus(article.ID, "failed", errMsg)
log.Printf("完整输出内容:\n%s", outputStr)
if stderr.Len() > 0 {
log.Printf("错误输出:\n%s", stderr.String())
}
return fmt.Errorf("%s, output: %s", errMsg, outputStr)
}
// 11. 检查发布是否成功
success, ok := result["success"].(bool)
if !ok || !success {
errMsg := "未知错误"
if errStr, ok := result["error"].(string); ok {
errMsg = errStr
}
s.updateArticleStatus(article.ID, "failed", errMsg)
return fmt.Errorf("发布失败: %s", errMsg)
}
// 12. 更新文章状态为published
s.updateArticleStatus(article.ID, "published", "发布成功")
return nil
}
// updateArticleStatus 更新文章状态
func (s *SchedulerService) updateArticleStatus(articleID int, status, message string) {
updates := map[string]interface{}{
"status": status,
}
if status == "published" {
now := time.Now()
updates["publish_time"] = now
}
if message != "" {
updates["review_comment"] = message
}
if err := database.DB.Model(&models.Article{}).Where("id = ?", articleID).Updates(updates).Error; err != nil {
log.Printf("更新文章%d状态失败: %v", articleID, err)
}
}
// parseTags 解析标签字符串(支持逗号、分号、空格分隔)
func parseTags(tagStr string) []string {
if tagStr == "" {
return nil
}
// 统一使用逗号分隔符
tagStr = strings.ReplaceAll(tagStr, ";", ",")
tagStr = strings.ReplaceAll(tagStr, " ", ",")
tagStr = strings.ReplaceAll(tagStr, "、", ",")
tagsRaw := strings.Split(tagStr, ",")
var tags []string
for _, tag := range tagsRaw {
tag = strings.TrimSpace(tag)
if tag != "" {
tags = append(tags, tag)
}
}
return tags
}