Files
ai_english/serve/internal/services/test_service.go

498 lines
12 KiB
Go
Raw Permalink Normal View History

2025-11-17 14:09:17 +08:00
package services
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/models"
)
// TestService 测试服务
type TestService struct {
db *gorm.DB
}
// NewTestService 创建测试服务实例
func NewTestService(db *gorm.DB) *TestService {
return &TestService{db: db}
}
// GetTestTemplates 获取测试模板列表
func (s *TestService) GetTestTemplates(testType *models.TestType, difficulty *models.TestDifficulty, page, pageSize int) ([]models.TestTemplate, int64, error) {
var templates []models.TestTemplate
var total int64
query := s.db.Model(&models.TestTemplate{}).Where("is_active = ?", true)
if testType != nil {
query = query.Where("type = ?", *testType)
}
if difficulty != nil {
query = query.Where("difficulty = ?", *difficulty)
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
offset := (page - 1) * pageSize
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&templates).Error; err != nil {
return nil, 0, err
}
return templates, total, nil
}
// GetTestTemplateByID 根据ID获取测试模板
func (s *TestService) GetTestTemplateByID(id string) (*models.TestTemplate, error) {
var template models.TestTemplate
if err := s.db.Where("id = ? AND is_active = ?", id, true).First(&template).Error; err != nil {
return nil, err
}
return &template, nil
}
// CreateTestSession 创建测试会话
func (s *TestService) CreateTestSession(templateID, userID string) (*models.TestSession, error) {
// 获取模板
template, err := s.GetTestTemplateByID(templateID)
if err != nil {
return nil, fmt.Errorf("模板不存在: %w", err)
}
// 获取模板对应的题目
var questions []models.TestQuestion
if err := s.db.Where("template_id = ?", templateID).Order("order_index ASC").Find(&questions).Error; err != nil {
return nil, err
}
if len(questions) == 0 {
return nil, errors.New("模板没有配置题目")
}
// 创建会话
session := &models.TestSession{
ID: uuid.New().String(),
TemplateID: templateID,
UserID: userID,
Status: models.TestStatusPending,
TimeRemaining: template.Duration,
CurrentQuestionIndex: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 开启事务
tx := s.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 保存会话
if err := tx.Create(session).Error; err != nil {
tx.Rollback()
return nil, err
}
// 关联题目
for i, question := range questions {
sessionQuestion := models.TestSessionQuestion{
SessionID: session.ID,
QuestionID: question.ID,
OrderIndex: i,
}
if err := tx.Create(&sessionQuestion).Error; err != nil {
tx.Rollback()
return nil, err
}
}
if err := tx.Commit().Error; err != nil {
return nil, err
}
// 加载关联数据
session.Template = template
session.Questions = questions
return session, nil
}
// GetTestSession 获取测试会话
func (s *TestService) GetTestSession(sessionID string) (*models.TestSession, error) {
var session models.TestSession
if err := s.db.Preload("Template").Preload("Answers.Question").First(&session, "id = ?", sessionID).Error; err != nil {
return nil, err
}
// 加载题目
var questions []models.TestQuestion
if err := s.db.Raw(`
SELECT q.* FROM test_questions q
INNER JOIN test_session_questions sq ON q.id = sq.question_id
WHERE sq.session_id = ?
ORDER BY sq.order_index ASC
`, sessionID).Scan(&questions).Error; err != nil {
return nil, err
}
session.Questions = questions
return &session, nil
}
// StartTest 开始测试
func (s *TestService) StartTest(sessionID string) (*models.TestSession, error) {
session, err := s.GetTestSession(sessionID)
if err != nil {
return nil, err
}
if session.Status != models.TestStatusPending && session.Status != models.TestStatusPaused {
return nil, errors.New("测试状态不允许开始")
}
now := time.Now()
session.Status = models.TestStatusInProgress
session.StartTime = &now
session.UpdatedAt = now
if err := s.db.Save(session).Error; err != nil {
return nil, err
}
return session, nil
}
// SubmitAnswer 提交答案
func (s *TestService) SubmitAnswer(sessionID, questionID, answer string) (*models.TestSession, error) {
session, err := s.GetTestSession(sessionID)
if err != nil {
return nil, err
}
if session.Status != models.TestStatusInProgress {
return nil, errors.New("测试未在进行中")
}
// 查找题目
var question models.TestQuestion
if err := s.db.First(&question, "id = ?", questionID).Error; err != nil {
return nil, err
}
// 检查是否已经回答过
var existingAnswer models.TestAnswer
err = s.db.Where("session_id = ? AND question_id = ?", sessionID, questionID).First(&existingAnswer).Error
if err == nil {
// 更新已有答案
existingAnswer.Answer = answer
existingAnswer.UpdatedAt = time.Now()
// 判断答案是否正确
isCorrect := s.checkAnswer(question, answer)
existingAnswer.IsCorrect = &isCorrect
if isCorrect {
existingAnswer.Score = question.Points
} else {
existingAnswer.Score = 0
}
if err := s.db.Save(&existingAnswer).Error; err != nil {
return nil, err
}
} else {
// 创建新答案
isCorrect := s.checkAnswer(question, answer)
score := 0
if isCorrect {
score = question.Points
}
testAnswer := &models.TestAnswer{
ID: uuid.New().String(),
SessionID: sessionID,
QuestionID: questionID,
Answer: answer,
IsCorrect: &isCorrect,
Score: score,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.Create(testAnswer).Error; err != nil {
return nil, err
}
}
// 重新加载会话
return s.GetTestSession(sessionID)
}
// checkAnswer 检查答案是否正确
func (s *TestService) checkAnswer(question models.TestQuestion, answer string) bool {
switch question.QuestionType {
case models.QuestionTypeSingleChoice, models.QuestionTypeTrueFalse:
return answer == question.CorrectAnswer
case models.QuestionTypeMultipleChoice:
// 多选题需要比较JSON数组
var userAnswers, correctAnswers []string
json.Unmarshal([]byte(answer), &userAnswers)
json.Unmarshal([]byte(question.CorrectAnswer), &correctAnswers)
if len(userAnswers) != len(correctAnswers) {
return false
}
answerMap := make(map[string]bool)
for _, a := range correctAnswers {
answerMap[a] = true
}
for _, a := range userAnswers {
if !answerMap[a] {
return false
}
}
return true
case models.QuestionTypeFillBlank, models.QuestionTypeShortAnswer:
// 简单的字符串比较,实际应用中可能需要更复杂的匹配逻辑
return answer == question.CorrectAnswer
default:
return false
}
}
// PauseTest 暂停测试
func (s *TestService) PauseTest(sessionID string) (*models.TestSession, error) {
session, err := s.GetTestSession(sessionID)
if err != nil {
return nil, err
}
if session.Status != models.TestStatusInProgress {
return nil, errors.New("测试未在进行中")
}
now := time.Now()
session.Status = models.TestStatusPaused
session.PausedAt = &now
session.UpdatedAt = now
if err := s.db.Save(session).Error; err != nil {
return nil, err
}
return session, nil
}
// ResumeTest 恢复测试
func (s *TestService) ResumeTest(sessionID string) (*models.TestSession, error) {
session, err := s.GetTestSession(sessionID)
if err != nil {
return nil, err
}
if session.Status != models.TestStatusPaused {
return nil, errors.New("测试未暂停")
}
now := time.Now()
session.Status = models.TestStatusInProgress
session.PausedAt = nil
session.UpdatedAt = now
if err := s.db.Save(session).Error; err != nil {
return nil, err
}
return session, nil
}
// CompleteTest 完成测试
func (s *TestService) CompleteTest(sessionID string) (*models.TestResult, error) {
session, err := s.GetTestSession(sessionID)
if err != nil {
return nil, err
}
if session.Status != models.TestStatusInProgress {
return nil, errors.New("测试未在进行中")
}
now := time.Now()
session.Status = models.TestStatusCompleted
session.EndTime = &now
session.UpdatedAt = now
// 计算结果
var answers []models.TestAnswer
if err := s.db.Preload("Question").Where("session_id = ?", sessionID).Find(&answers).Error; err != nil {
return nil, err
}
totalScore := 0
maxScore := 0
correctCount := 0
wrongCount := 0
skippedCount := len(session.Questions) - len(answers)
timeSpent := 0
skillScores := make(map[models.SkillType]int)
skillMaxScores := make(map[models.SkillType]int)
for _, answer := range answers {
if answer.Question != nil {
maxScore += answer.Question.Points
skillMaxScores[answer.Question.SkillType] += answer.Question.Points
if answer.IsCorrect != nil && *answer.IsCorrect {
totalScore += answer.Score
correctCount++
skillScores[answer.Question.SkillType] += answer.Score
} else {
wrongCount++
}
timeSpent += answer.TimeSpent
}
}
percentage := 0.0
if maxScore > 0 {
percentage = float64(totalScore) / float64(maxScore) * 100
}
// 构建技能得分JSON
skillScoresJSON, _ := json.Marshal(skillScores)
// 创建测试结果
result := &models.TestResult{
ID: uuid.New().String(),
SessionID: sessionID,
UserID: session.UserID,
TemplateID: session.TemplateID,
TotalScore: totalScore,
MaxScore: maxScore,
Percentage: percentage,
CorrectCount: correctCount,
WrongCount: wrongCount,
SkippedCount: skippedCount,
TimeSpent: timeSpent,
SkillScores: string(skillScoresJSON),
Passed: totalScore >= session.Template.PassingScore,
CompletedAt: now,
CreatedAt: now,
UpdatedAt: now,
}
// 开启事务
tx := s.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 保存会话状态
if err := tx.Save(session).Error; err != nil {
tx.Rollback()
return nil, err
}
// 保存结果
if err := tx.Create(result).Error; err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Commit().Error; err != nil {
return nil, err
}
// 加载关联数据
result.Session = session
result.Template = session.Template
return result, nil
}
// GetUserTestHistory 获取用户测试历史
func (s *TestService) GetUserTestHistory(userID string, page, pageSize int) ([]models.TestResult, int64, error) {
var results []models.TestResult
var total int64
query := s.db.Model(&models.TestResult{}).Where("user_id = ?", userID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
offset := (page - 1) * pageSize
if err := query.Preload("Template").Order("completed_at DESC").Offset(offset).Limit(pageSize).Find(&results).Error; err != nil {
return nil, 0, err
}
return results, total, nil
}
// GetTestResultByID 获取测试结果详情
func (s *TestService) GetTestResultByID(resultID string) (*models.TestResult, error) {
var result models.TestResult
if err := s.db.Preload("Template").Preload("Session.Answers.Question").First(&result, "id = ?", resultID).Error; err != nil {
return nil, err
}
return &result, nil
}
// GetUserTestStats 获取用户测试统计
func (s *TestService) GetUserTestStats(userID string) (map[string]interface{}, error) {
var stats struct {
TotalTests int64
CompletedTests int64
AverageScore float64
PassRate float64
}
// 总测试数
s.db.Model(&models.TestSession{}).Where("user_id = ?", userID).Count(&stats.TotalTests)
// 完成的测试数
s.db.Model(&models.TestSession{}).Where("user_id = ? AND status = ?", userID, models.TestStatusCompleted).Count(&stats.CompletedTests)
// 平均分和通过率
var results []models.TestResult
s.db.Where("user_id = ?", userID).Find(&results)
if len(results) > 0 {
totalPercentage := 0.0
passedCount := 0
for _, result := range results {
totalPercentage += result.Percentage
if result.Passed {
passedCount++
}
}
stats.AverageScore = totalPercentage / float64(len(results))
stats.PassRate = float64(passedCount) / float64(len(results)) * 100
}
return map[string]interface{}{
"total_tests": stats.TotalTests,
"completed_tests": stats.CompletedTests,
"average_score": stats.AverageScore,
"pass_rate": stats.PassRate,
}, nil
}
// DeleteTestResult 删除测试结果
func (s *TestService) DeleteTestResult(resultID string) error {
return s.db.Delete(&models.TestResult{}, "id = ?", resultID).Error
}