package services import ( "encoding/json" "errors" "fmt" "time" "github.com/google/uuid" "gorm.io/gorm" "github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/models" ) // TestService 测试服务 type TestService struct { db *gorm.DB } // NewTestService 创建测试服务实例 func NewTestService(db *gorm.DB) *TestService { return &TestService{db: db} } // GetTestTemplates 获取测试模板列表 func (s *TestService) GetTestTemplates(testType *models.TestType, difficulty *models.TestDifficulty, page, pageSize int) ([]models.TestTemplate, int64, error) { var templates []models.TestTemplate var total int64 query := s.db.Model(&models.TestTemplate{}).Where("is_active = ?", true) if testType != nil { query = query.Where("type = ?", *testType) } if difficulty != nil { query = query.Where("difficulty = ?", *difficulty) } if err := query.Count(&total).Error; err != nil { return nil, 0, err } offset := (page - 1) * pageSize if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&templates).Error; err != nil { return nil, 0, err } return templates, total, nil } // GetTestTemplateByID 根据ID获取测试模板 func (s *TestService) GetTestTemplateByID(id string) (*models.TestTemplate, error) { var template models.TestTemplate if err := s.db.Where("id = ? AND is_active = ?", id, true).First(&template).Error; err != nil { return nil, err } return &template, nil } // CreateTestSession 创建测试会话 func (s *TestService) CreateTestSession(templateID, userID string) (*models.TestSession, error) { // 获取模板 template, err := s.GetTestTemplateByID(templateID) if err != nil { return nil, fmt.Errorf("模板不存在: %w", err) } // 获取模板对应的题目 var questions []models.TestQuestion if err := s.db.Where("template_id = ?", templateID).Order("order_index ASC").Find(&questions).Error; err != nil { return nil, err } if len(questions) == 0 { return nil, errors.New("模板没有配置题目") } // 创建会话 session := &models.TestSession{ ID: uuid.New().String(), TemplateID: templateID, UserID: userID, Status: models.TestStatusPending, TimeRemaining: template.Duration, CurrentQuestionIndex: 0, CreatedAt: time.Now(), UpdatedAt: time.Now(), } // 开启事务 tx := s.db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() // 保存会话 if err := tx.Create(session).Error; err != nil { tx.Rollback() return nil, err } // 关联题目 for i, question := range questions { sessionQuestion := models.TestSessionQuestion{ SessionID: session.ID, QuestionID: question.ID, OrderIndex: i, } if err := tx.Create(&sessionQuestion).Error; err != nil { tx.Rollback() return nil, err } } if err := tx.Commit().Error; err != nil { return nil, err } // 加载关联数据 session.Template = template session.Questions = questions return session, nil } // GetTestSession 获取测试会话 func (s *TestService) GetTestSession(sessionID string) (*models.TestSession, error) { var session models.TestSession if err := s.db.Preload("Template").Preload("Answers.Question").First(&session, "id = ?", sessionID).Error; err != nil { return nil, err } // 加载题目 var questions []models.TestQuestion if err := s.db.Raw(` SELECT q.* FROM test_questions q INNER JOIN test_session_questions sq ON q.id = sq.question_id WHERE sq.session_id = ? ORDER BY sq.order_index ASC `, sessionID).Scan(&questions).Error; err != nil { return nil, err } session.Questions = questions return &session, nil } // StartTest 开始测试 func (s *TestService) StartTest(sessionID string) (*models.TestSession, error) { session, err := s.GetTestSession(sessionID) if err != nil { return nil, err } if session.Status != models.TestStatusPending && session.Status != models.TestStatusPaused { return nil, errors.New("测试状态不允许开始") } now := time.Now() session.Status = models.TestStatusInProgress session.StartTime = &now session.UpdatedAt = now if err := s.db.Save(session).Error; err != nil { return nil, err } return session, nil } // SubmitAnswer 提交答案 func (s *TestService) SubmitAnswer(sessionID, questionID, answer string) (*models.TestSession, error) { session, err := s.GetTestSession(sessionID) if err != nil { return nil, err } if session.Status != models.TestStatusInProgress { return nil, errors.New("测试未在进行中") } // 查找题目 var question models.TestQuestion if err := s.db.First(&question, "id = ?", questionID).Error; err != nil { return nil, err } // 检查是否已经回答过 var existingAnswer models.TestAnswer err = s.db.Where("session_id = ? AND question_id = ?", sessionID, questionID).First(&existingAnswer).Error if err == nil { // 更新已有答案 existingAnswer.Answer = answer existingAnswer.UpdatedAt = time.Now() // 判断答案是否正确 isCorrect := s.checkAnswer(question, answer) existingAnswer.IsCorrect = &isCorrect if isCorrect { existingAnswer.Score = question.Points } else { existingAnswer.Score = 0 } if err := s.db.Save(&existingAnswer).Error; err != nil { return nil, err } } else { // 创建新答案 isCorrect := s.checkAnswer(question, answer) score := 0 if isCorrect { score = question.Points } testAnswer := &models.TestAnswer{ ID: uuid.New().String(), SessionID: sessionID, QuestionID: questionID, Answer: answer, IsCorrect: &isCorrect, Score: score, CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := s.db.Create(testAnswer).Error; err != nil { return nil, err } } // 重新加载会话 return s.GetTestSession(sessionID) } // checkAnswer 检查答案是否正确 func (s *TestService) checkAnswer(question models.TestQuestion, answer string) bool { switch question.QuestionType { case models.QuestionTypeSingleChoice, models.QuestionTypeTrueFalse: return answer == question.CorrectAnswer case models.QuestionTypeMultipleChoice: // 多选题需要比较JSON数组 var userAnswers, correctAnswers []string json.Unmarshal([]byte(answer), &userAnswers) json.Unmarshal([]byte(question.CorrectAnswer), &correctAnswers) if len(userAnswers) != len(correctAnswers) { return false } answerMap := make(map[string]bool) for _, a := range correctAnswers { answerMap[a] = true } for _, a := range userAnswers { if !answerMap[a] { return false } } return true case models.QuestionTypeFillBlank, models.QuestionTypeShortAnswer: // 简单的字符串比较,实际应用中可能需要更复杂的匹配逻辑 return answer == question.CorrectAnswer default: return false } } // PauseTest 暂停测试 func (s *TestService) PauseTest(sessionID string) (*models.TestSession, error) { session, err := s.GetTestSession(sessionID) if err != nil { return nil, err } if session.Status != models.TestStatusInProgress { return nil, errors.New("测试未在进行中") } now := time.Now() session.Status = models.TestStatusPaused session.PausedAt = &now session.UpdatedAt = now if err := s.db.Save(session).Error; err != nil { return nil, err } return session, nil } // ResumeTest 恢复测试 func (s *TestService) ResumeTest(sessionID string) (*models.TestSession, error) { session, err := s.GetTestSession(sessionID) if err != nil { return nil, err } if session.Status != models.TestStatusPaused { return nil, errors.New("测试未暂停") } now := time.Now() session.Status = models.TestStatusInProgress session.PausedAt = nil session.UpdatedAt = now if err := s.db.Save(session).Error; err != nil { return nil, err } return session, nil } // CompleteTest 完成测试 func (s *TestService) CompleteTest(sessionID string) (*models.TestResult, error) { session, err := s.GetTestSession(sessionID) if err != nil { return nil, err } if session.Status != models.TestStatusInProgress { return nil, errors.New("测试未在进行中") } now := time.Now() session.Status = models.TestStatusCompleted session.EndTime = &now session.UpdatedAt = now // 计算结果 var answers []models.TestAnswer if err := s.db.Preload("Question").Where("session_id = ?", sessionID).Find(&answers).Error; err != nil { return nil, err } totalScore := 0 maxScore := 0 correctCount := 0 wrongCount := 0 skippedCount := len(session.Questions) - len(answers) timeSpent := 0 skillScores := make(map[models.SkillType]int) skillMaxScores := make(map[models.SkillType]int) for _, answer := range answers { if answer.Question != nil { maxScore += answer.Question.Points skillMaxScores[answer.Question.SkillType] += answer.Question.Points if answer.IsCorrect != nil && *answer.IsCorrect { totalScore += answer.Score correctCount++ skillScores[answer.Question.SkillType] += answer.Score } else { wrongCount++ } timeSpent += answer.TimeSpent } } percentage := 0.0 if maxScore > 0 { percentage = float64(totalScore) / float64(maxScore) * 100 } // 构建技能得分JSON skillScoresJSON, _ := json.Marshal(skillScores) // 创建测试结果 result := &models.TestResult{ ID: uuid.New().String(), SessionID: sessionID, UserID: session.UserID, TemplateID: session.TemplateID, TotalScore: totalScore, MaxScore: maxScore, Percentage: percentage, CorrectCount: correctCount, WrongCount: wrongCount, SkippedCount: skippedCount, TimeSpent: timeSpent, SkillScores: string(skillScoresJSON), Passed: totalScore >= session.Template.PassingScore, CompletedAt: now, CreatedAt: now, UpdatedAt: now, } // 开启事务 tx := s.db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() // 保存会话状态 if err := tx.Save(session).Error; err != nil { tx.Rollback() return nil, err } // 保存结果 if err := tx.Create(result).Error; err != nil { tx.Rollback() return nil, err } if err := tx.Commit().Error; err != nil { return nil, err } // 加载关联数据 result.Session = session result.Template = session.Template return result, nil } // GetUserTestHistory 获取用户测试历史 func (s *TestService) GetUserTestHistory(userID string, page, pageSize int) ([]models.TestResult, int64, error) { var results []models.TestResult var total int64 query := s.db.Model(&models.TestResult{}).Where("user_id = ?", userID) if err := query.Count(&total).Error; err != nil { return nil, 0, err } offset := (page - 1) * pageSize if err := query.Preload("Template").Order("completed_at DESC").Offset(offset).Limit(pageSize).Find(&results).Error; err != nil { return nil, 0, err } return results, total, nil } // GetTestResultByID 获取测试结果详情 func (s *TestService) GetTestResultByID(resultID string) (*models.TestResult, error) { var result models.TestResult if err := s.db.Preload("Template").Preload("Session.Answers.Question").First(&result, "id = ?", resultID).Error; err != nil { return nil, err } return &result, nil } // GetUserTestStats 获取用户测试统计 func (s *TestService) GetUserTestStats(userID string) (map[string]interface{}, error) { var stats struct { TotalTests int64 CompletedTests int64 AverageScore float64 PassRate float64 } // 总测试数 s.db.Model(&models.TestSession{}).Where("user_id = ?", userID).Count(&stats.TotalTests) // 完成的测试数 s.db.Model(&models.TestSession{}).Where("user_id = ? AND status = ?", userID, models.TestStatusCompleted).Count(&stats.CompletedTests) // 平均分和通过率 var results []models.TestResult s.db.Where("user_id = ?", userID).Find(&results) if len(results) > 0 { totalPercentage := 0.0 passedCount := 0 for _, result := range results { totalPercentage += result.Percentage if result.Passed { passedCount++ } } stats.AverageScore = totalPercentage / float64(len(results)) stats.PassRate = float64(passedCount) / float64(len(results)) * 100 } return map[string]interface{}{ "total_tests": stats.TotalTests, "completed_tests": stats.CompletedTests, "average_score": stats.AverageScore, "pass_rate": stats.PassRate, }, nil } // DeleteTestResult 删除测试结果 func (s *TestService) DeleteTestResult(resultID string) error { return s.db.Delete(&models.TestResult{}, "id = ?", resultID).Error }