init
This commit is contained in:
497
serve/internal/services/test_service.go
Normal file
497
serve/internal/services/test_service.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Nanqipro/YunQue-Tech-Projects/ai_english_learning/serve/internal/models"
|
||||
)
|
||||
|
||||
// TestService 测试服务
|
||||
type TestService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTestService 创建测试服务实例
|
||||
func NewTestService(db *gorm.DB) *TestService {
|
||||
return &TestService{db: db}
|
||||
}
|
||||
|
||||
// GetTestTemplates 获取测试模板列表
|
||||
func (s *TestService) GetTestTemplates(testType *models.TestType, difficulty *models.TestDifficulty, page, pageSize int) ([]models.TestTemplate, int64, error) {
|
||||
var templates []models.TestTemplate
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&models.TestTemplate{}).Where("is_active = ?", true)
|
||||
|
||||
if testType != nil {
|
||||
query = query.Where("type = ?", *testType)
|
||||
}
|
||||
if difficulty != nil {
|
||||
query = query.Where("difficulty = ?", *difficulty)
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&templates).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return templates, total, nil
|
||||
}
|
||||
|
||||
// GetTestTemplateByID 根据ID获取测试模板
|
||||
func (s *TestService) GetTestTemplateByID(id string) (*models.TestTemplate, error) {
|
||||
var template models.TestTemplate
|
||||
if err := s.db.Where("id = ? AND is_active = ?", id, true).First(&template).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &template, nil
|
||||
}
|
||||
|
||||
// CreateTestSession 创建测试会话
|
||||
func (s *TestService) CreateTestSession(templateID, userID string) (*models.TestSession, error) {
|
||||
// 获取模板
|
||||
template, err := s.GetTestTemplateByID(templateID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("模板不存在: %w", err)
|
||||
}
|
||||
|
||||
// 获取模板对应的题目
|
||||
var questions []models.TestQuestion
|
||||
if err := s.db.Where("template_id = ?", templateID).Order("order_index ASC").Find(&questions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(questions) == 0 {
|
||||
return nil, errors.New("模板没有配置题目")
|
||||
}
|
||||
|
||||
// 创建会话
|
||||
session := &models.TestSession{
|
||||
ID: uuid.New().String(),
|
||||
TemplateID: templateID,
|
||||
UserID: userID,
|
||||
Status: models.TestStatusPending,
|
||||
TimeRemaining: template.Duration,
|
||||
CurrentQuestionIndex: 0,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 保存会话
|
||||
if err := tx.Create(session).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 关联题目
|
||||
for i, question := range questions {
|
||||
sessionQuestion := models.TestSessionQuestion{
|
||||
SessionID: session.ID,
|
||||
QuestionID: question.ID,
|
||||
OrderIndex: i,
|
||||
}
|
||||
if err := tx.Create(&sessionQuestion).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 加载关联数据
|
||||
session.Template = template
|
||||
session.Questions = questions
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// GetTestSession 获取测试会话
|
||||
func (s *TestService) GetTestSession(sessionID string) (*models.TestSession, error) {
|
||||
var session models.TestSession
|
||||
if err := s.db.Preload("Template").Preload("Answers.Question").First(&session, "id = ?", sessionID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 加载题目
|
||||
var questions []models.TestQuestion
|
||||
if err := s.db.Raw(`
|
||||
SELECT q.* FROM test_questions q
|
||||
INNER JOIN test_session_questions sq ON q.id = sq.question_id
|
||||
WHERE sq.session_id = ?
|
||||
ORDER BY sq.order_index ASC
|
||||
`, sessionID).Scan(&questions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session.Questions = questions
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// StartTest 开始测试
|
||||
func (s *TestService) StartTest(sessionID string) (*models.TestSession, error) {
|
||||
session, err := s.GetTestSession(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.Status != models.TestStatusPending && session.Status != models.TestStatusPaused {
|
||||
return nil, errors.New("测试状态不允许开始")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
session.Status = models.TestStatusInProgress
|
||||
session.StartTime = &now
|
||||
session.UpdatedAt = now
|
||||
|
||||
if err := s.db.Save(session).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SubmitAnswer 提交答案
|
||||
func (s *TestService) SubmitAnswer(sessionID, questionID, answer string) (*models.TestSession, error) {
|
||||
session, err := s.GetTestSession(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.Status != models.TestStatusInProgress {
|
||||
return nil, errors.New("测试未在进行中")
|
||||
}
|
||||
|
||||
// 查找题目
|
||||
var question models.TestQuestion
|
||||
if err := s.db.First(&question, "id = ?", questionID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否已经回答过
|
||||
var existingAnswer models.TestAnswer
|
||||
err = s.db.Where("session_id = ? AND question_id = ?", sessionID, questionID).First(&existingAnswer).Error
|
||||
if err == nil {
|
||||
// 更新已有答案
|
||||
existingAnswer.Answer = answer
|
||||
existingAnswer.UpdatedAt = time.Now()
|
||||
|
||||
// 判断答案是否正确
|
||||
isCorrect := s.checkAnswer(question, answer)
|
||||
existingAnswer.IsCorrect = &isCorrect
|
||||
if isCorrect {
|
||||
existingAnswer.Score = question.Points
|
||||
} else {
|
||||
existingAnswer.Score = 0
|
||||
}
|
||||
|
||||
if err := s.db.Save(&existingAnswer).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// 创建新答案
|
||||
isCorrect := s.checkAnswer(question, answer)
|
||||
score := 0
|
||||
if isCorrect {
|
||||
score = question.Points
|
||||
}
|
||||
|
||||
testAnswer := &models.TestAnswer{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
QuestionID: questionID,
|
||||
Answer: answer,
|
||||
IsCorrect: &isCorrect,
|
||||
Score: score,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.db.Create(testAnswer).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 重新加载会话
|
||||
return s.GetTestSession(sessionID)
|
||||
}
|
||||
|
||||
// checkAnswer 检查答案是否正确
|
||||
func (s *TestService) checkAnswer(question models.TestQuestion, answer string) bool {
|
||||
switch question.QuestionType {
|
||||
case models.QuestionTypeSingleChoice, models.QuestionTypeTrueFalse:
|
||||
return answer == question.CorrectAnswer
|
||||
case models.QuestionTypeMultipleChoice:
|
||||
// 多选题需要比较JSON数组
|
||||
var userAnswers, correctAnswers []string
|
||||
json.Unmarshal([]byte(answer), &userAnswers)
|
||||
json.Unmarshal([]byte(question.CorrectAnswer), &correctAnswers)
|
||||
|
||||
if len(userAnswers) != len(correctAnswers) {
|
||||
return false
|
||||
}
|
||||
|
||||
answerMap := make(map[string]bool)
|
||||
for _, a := range correctAnswers {
|
||||
answerMap[a] = true
|
||||
}
|
||||
|
||||
for _, a := range userAnswers {
|
||||
if !answerMap[a] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case models.QuestionTypeFillBlank, models.QuestionTypeShortAnswer:
|
||||
// 简单的字符串比较,实际应用中可能需要更复杂的匹配逻辑
|
||||
return answer == question.CorrectAnswer
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// PauseTest 暂停测试
|
||||
func (s *TestService) PauseTest(sessionID string) (*models.TestSession, error) {
|
||||
session, err := s.GetTestSession(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.Status != models.TestStatusInProgress {
|
||||
return nil, errors.New("测试未在进行中")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
session.Status = models.TestStatusPaused
|
||||
session.PausedAt = &now
|
||||
session.UpdatedAt = now
|
||||
|
||||
if err := s.db.Save(session).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// ResumeTest 恢复测试
|
||||
func (s *TestService) ResumeTest(sessionID string) (*models.TestSession, error) {
|
||||
session, err := s.GetTestSession(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.Status != models.TestStatusPaused {
|
||||
return nil, errors.New("测试未暂停")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
session.Status = models.TestStatusInProgress
|
||||
session.PausedAt = nil
|
||||
session.UpdatedAt = now
|
||||
|
||||
if err := s.db.Save(session).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// CompleteTest 完成测试
|
||||
func (s *TestService) CompleteTest(sessionID string) (*models.TestResult, error) {
|
||||
session, err := s.GetTestSession(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.Status != models.TestStatusInProgress {
|
||||
return nil, errors.New("测试未在进行中")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
session.Status = models.TestStatusCompleted
|
||||
session.EndTime = &now
|
||||
session.UpdatedAt = now
|
||||
|
||||
// 计算结果
|
||||
var answers []models.TestAnswer
|
||||
if err := s.db.Preload("Question").Where("session_id = ?", sessionID).Find(&answers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
totalScore := 0
|
||||
maxScore := 0
|
||||
correctCount := 0
|
||||
wrongCount := 0
|
||||
skippedCount := len(session.Questions) - len(answers)
|
||||
timeSpent := 0
|
||||
|
||||
skillScores := make(map[models.SkillType]int)
|
||||
skillMaxScores := make(map[models.SkillType]int)
|
||||
|
||||
for _, answer := range answers {
|
||||
if answer.Question != nil {
|
||||
maxScore += answer.Question.Points
|
||||
skillMaxScores[answer.Question.SkillType] += answer.Question.Points
|
||||
|
||||
if answer.IsCorrect != nil && *answer.IsCorrect {
|
||||
totalScore += answer.Score
|
||||
correctCount++
|
||||
skillScores[answer.Question.SkillType] += answer.Score
|
||||
} else {
|
||||
wrongCount++
|
||||
}
|
||||
|
||||
timeSpent += answer.TimeSpent
|
||||
}
|
||||
}
|
||||
|
||||
percentage := 0.0
|
||||
if maxScore > 0 {
|
||||
percentage = float64(totalScore) / float64(maxScore) * 100
|
||||
}
|
||||
|
||||
// 构建技能得分JSON
|
||||
skillScoresJSON, _ := json.Marshal(skillScores)
|
||||
|
||||
// 创建测试结果
|
||||
result := &models.TestResult{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
UserID: session.UserID,
|
||||
TemplateID: session.TemplateID,
|
||||
TotalScore: totalScore,
|
||||
MaxScore: maxScore,
|
||||
Percentage: percentage,
|
||||
CorrectCount: correctCount,
|
||||
WrongCount: wrongCount,
|
||||
SkippedCount: skippedCount,
|
||||
TimeSpent: timeSpent,
|
||||
SkillScores: string(skillScoresJSON),
|
||||
Passed: totalScore >= session.Template.PassingScore,
|
||||
CompletedAt: now,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 保存会话状态
|
||||
if err := tx.Save(session).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保存结果
|
||||
if err := tx.Create(result).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 加载关联数据
|
||||
result.Session = session
|
||||
result.Template = session.Template
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetUserTestHistory 获取用户测试历史
|
||||
func (s *TestService) GetUserTestHistory(userID string, page, pageSize int) ([]models.TestResult, int64, error) {
|
||||
var results []models.TestResult
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&models.TestResult{}).Where("user_id = ?", userID)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Preload("Template").Order("completed_at DESC").Offset(offset).Limit(pageSize).Find(&results).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return results, total, nil
|
||||
}
|
||||
|
||||
// GetTestResultByID 获取测试结果详情
|
||||
func (s *TestService) GetTestResultByID(resultID string) (*models.TestResult, error) {
|
||||
var result models.TestResult
|
||||
if err := s.db.Preload("Template").Preload("Session.Answers.Question").First(&result, "id = ?", resultID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// GetUserTestStats 获取用户测试统计
|
||||
func (s *TestService) GetUserTestStats(userID string) (map[string]interface{}, error) {
|
||||
var stats struct {
|
||||
TotalTests int64
|
||||
CompletedTests int64
|
||||
AverageScore float64
|
||||
PassRate float64
|
||||
}
|
||||
|
||||
// 总测试数
|
||||
s.db.Model(&models.TestSession{}).Where("user_id = ?", userID).Count(&stats.TotalTests)
|
||||
|
||||
// 完成的测试数
|
||||
s.db.Model(&models.TestSession{}).Where("user_id = ? AND status = ?", userID, models.TestStatusCompleted).Count(&stats.CompletedTests)
|
||||
|
||||
// 平均分和通过率
|
||||
var results []models.TestResult
|
||||
s.db.Where("user_id = ?", userID).Find(&results)
|
||||
|
||||
if len(results) > 0 {
|
||||
totalPercentage := 0.0
|
||||
passedCount := 0
|
||||
for _, result := range results {
|
||||
totalPercentage += result.Percentage
|
||||
if result.Passed {
|
||||
passedCount++
|
||||
}
|
||||
}
|
||||
stats.AverageScore = totalPercentage / float64(len(results))
|
||||
stats.PassRate = float64(passedCount) / float64(len(results)) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_tests": stats.TotalTests,
|
||||
"completed_tests": stats.CompletedTests,
|
||||
"average_score": stats.AverageScore,
|
||||
"pass_rate": stats.PassRate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteTestResult 删除测试结果
|
||||
func (s *TestService) DeleteTestResult(resultID string) error {
|
||||
return s.db.Delete(&models.TestResult{}, "id = ?", resultID).Error
|
||||
}
|
||||
Reference in New Issue
Block a user