Files
yixiaogao/backend/pkg/database/repository.go

456 lines
13 KiB
Go
Raw Permalink Normal View History

2025-11-27 18:40:08 +08:00
package database
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
)
// OfficialAccountRepository 公众号数据仓库
type OfficialAccountRepository struct {
db *DB
}
// NewOfficialAccountRepository 创建公众号仓库
func NewOfficialAccountRepository(db *DB) *OfficialAccountRepository {
return &OfficialAccountRepository{db: db}
}
// Create 创建公众号
func (r *OfficialAccountRepository) Create(account *OfficialAccount) (int64, error) {
result, err := r.db.Exec(`
INSERT INTO official_accounts (biz, nickname, homepage, description)
VALUES (?, ?, ?, ?)
`, account.Biz, account.Nickname, account.Homepage, account.Description)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// GetByBiz 根据Biz获取公众号
func (r *OfficialAccountRepository) GetByBiz(biz string) (*OfficialAccount, error) {
account := &OfficialAccount{}
err := r.db.QueryRow(`
SELECT id, biz, nickname, homepage, description, created_at, updated_at
FROM official_accounts WHERE biz = ?
`, biz).Scan(&account.ID, &account.Biz, &account.Nickname, &account.Homepage,
&account.Description, &account.CreatedAt, &account.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return account, nil
}
// GetByID 根据ID获取公众号
func (r *OfficialAccountRepository) GetByID(id int64) (*OfficialAccount, error) {
account := &OfficialAccount{}
err := r.db.QueryRow(`
SELECT id, biz, nickname, homepage, description, created_at, updated_at
FROM official_accounts WHERE id = ?
`, id).Scan(&account.ID, &account.Biz, &account.Nickname, &account.Homepage,
&account.Description, &account.CreatedAt, &account.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return account, nil
}
// List 获取所有公众号列表
func (r *OfficialAccountRepository) List() ([]*OfficialAccount, error) {
rows, err := r.db.Query(`
SELECT id, biz, nickname, homepage, description, created_at, updated_at
FROM official_accounts ORDER BY created_at DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var accounts []*OfficialAccount
for rows.Next() {
account := &OfficialAccount{}
err := rows.Scan(&account.ID, &account.Biz, &account.Nickname, &account.Homepage,
&account.Description, &account.CreatedAt, &account.UpdatedAt)
if err != nil {
return nil, err
}
accounts = append(accounts, account)
}
return accounts, nil
}
// Update 更新公众号信息
func (r *OfficialAccountRepository) Update(account *OfficialAccount) error {
_, err := r.db.Exec(`
UPDATE official_accounts
SET nickname = ?, homepage = ?, description = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`, account.Nickname, account.Homepage, account.Description, account.ID)
return err
}
// ArticleRepository 文章数据仓库
type ArticleRepository struct {
db *DB
}
// NewArticleRepository 创建文章仓库
func NewArticleRepository(db *DB) *ArticleRepository {
return &ArticleRepository{db: db}
}
// Create 创建文章
func (r *ArticleRepository) Create(article *Article) (int64, error) {
result, err := r.db.Exec(`
INSERT INTO articles (
official_id, title, author, link, publish_time, create_time,
comment_id, read_num, like_num, share_num, content_preview, paragraph_count
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, article.OfficialID, article.Title, article.Author, article.Link,
article.PublishTime, article.CreateTime, article.CommentID,
article.ReadNum, article.LikeNum, article.ShareNum,
article.ContentPreview, article.ParagraphCount)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// GetByID 根据ID获取文章
func (r *ArticleRepository) GetByID(id int64) (*Article, error) {
article := &Article{}
err := r.db.QueryRow(`
SELECT id, official_id, title, author, link, publish_time, create_time,
comment_id, read_num, like_num, share_num, content_preview,
paragraph_count, created_at, updated_at
FROM articles WHERE id = ?
`, id).Scan(&article.ID, &article.OfficialID, &article.Title, &article.Author,
&article.Link, &article.PublishTime, &article.CreateTime, &article.CommentID,
&article.ReadNum, &article.LikeNum, &article.ShareNum, &article.ContentPreview,
&article.ParagraphCount, &article.CreatedAt, &article.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return article, nil
}
// GetByLink 根据链接获取文章
func (r *ArticleRepository) GetByLink(link string) (*Article, error) {
article := &Article{}
err := r.db.QueryRow(`
SELECT id, official_id, title, author, link, publish_time, create_time,
comment_id, read_num, like_num, share_num, content_preview,
paragraph_count, created_at, updated_at
FROM articles WHERE link = ?
`, link).Scan(&article.ID, &article.OfficialID, &article.Title, &article.Author,
&article.Link, &article.PublishTime, &article.CreateTime, &article.CommentID,
&article.ReadNum, &article.LikeNum, &article.ShareNum, &article.ContentPreview,
&article.ParagraphCount, &article.CreatedAt, &article.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return article, nil
}
// List 获取文章列表(分页)
func (r *ArticleRepository) List(officialID int64, page, pageSize int) ([]*ArticleListItem, int, error) {
// 构建查询条件
whereClause := ""
args := []interface{}{}
if officialID > 0 {
whereClause = "WHERE a.official_id = ?"
args = append(args, officialID)
}
// 获取总数
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM articles a %s", whereClause)
var total int
err := r.db.QueryRow(countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, err
}
// 获取列表
offset := (page - 1) * pageSize
listQuery := fmt.Sprintf(`
SELECT a.id, a.title, a.author, a.publish_time, a.read_num, a.like_num,
a.content_preview, o.nickname
FROM articles a
LEFT JOIN official_accounts o ON a.official_id = o.id
%s
ORDER BY a.publish_time DESC
LIMIT ? OFFSET ?
`, whereClause)
args = append(args, pageSize, offset)
rows, err := r.db.Query(listQuery, args...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
var items []*ArticleListItem
for rows.Next() {
item := &ArticleListItem{}
err := rows.Scan(&item.ID, &item.Title, &item.Author, &item.PublishTime,
&item.ReadNum, &item.LikeNum, &item.ContentPreview, &item.OfficialName)
if err != nil {
return nil, 0, err
}
items = append(items, item)
}
return items, total, nil
}
// Search 搜索文章
func (r *ArticleRepository) Search(keyword string, page, pageSize int) ([]*ArticleListItem, int, error) {
keyword = "%" + keyword + "%"
// 获取总数
var total int
err := r.db.QueryRow(`
SELECT COUNT(*) FROM articles WHERE title LIKE ? OR author LIKE ?
`, keyword, keyword).Scan(&total)
if err != nil {
return nil, 0, err
}
// 获取列表
offset := (page - 1) * pageSize
rows, err := r.db.Query(`
SELECT a.id, a.title, a.author, a.publish_time, a.read_num, a.like_num,
a.content_preview, o.nickname
FROM articles a
LEFT JOIN official_accounts o ON a.official_id = o.id
WHERE a.title LIKE ? OR a.author LIKE ?
ORDER BY a.publish_time DESC
LIMIT ? OFFSET ?
`, keyword, keyword, pageSize, offset)
if err != nil {
return nil, 0, err
}
defer rows.Close()
var items []*ArticleListItem
for rows.Next() {
item := &ArticleListItem{}
err := rows.Scan(&item.ID, &item.Title, &item.Author, &item.PublishTime,
&item.ReadNum, &item.LikeNum, &item.ContentPreview, &item.OfficialName)
if err != nil {
return nil, 0, err
}
items = append(items, item)
}
return items, total, nil
}
// Update 更新文章信息
func (r *ArticleRepository) Update(article *Article) error {
_, err := r.db.Exec(`
UPDATE articles
SET read_num = ?, like_num = ?, share_num = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`, article.ReadNum, article.LikeNum, article.ShareNum, article.ID)
return err
}
// ArticleContentRepository 文章内容数据仓库
type ArticleContentRepository struct {
db *DB
}
// NewArticleContentRepository 创建文章内容仓库
func NewArticleContentRepository(db *DB) *ArticleContentRepository {
return &ArticleContentRepository{db: db}
}
// Create 创建文章内容
func (r *ArticleContentRepository) Create(content *ArticleContent) (int64, error) {
result, err := r.db.Exec(`
INSERT INTO article_contents (article_id, html_content, text_content, paragraphs, images)
VALUES (?, ?, ?, ?, ?)
`, content.ArticleID, content.HtmlContent, content.TextContent,
content.Paragraphs, content.Images)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// GetByArticleID 根据文章ID获取内容
func (r *ArticleContentRepository) GetByArticleID(articleID int64) (*ArticleContent, error) {
content := &ArticleContent{}
err := r.db.QueryRow(`
SELECT id, article_id, html_content, text_content, paragraphs, images, created_at
FROM article_contents WHERE article_id = ?
`, articleID).Scan(&content.ID, &content.ArticleID, &content.HtmlContent,
&content.TextContent, &content.Paragraphs, &content.Images, &content.CreatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return content, nil
}
// GetArticleDetail 获取文章详情(包含内容)
func (r *ArticleContentRepository) GetArticleDetail(articleID int64) (*ArticleDetail, error) {
detail := &ArticleDetail{}
var paragraphsJSON, imagesJSON string
err := r.db.QueryRow(`
SELECT a.id, a.official_id, a.title, a.author, a.link, a.publish_time,
a.create_time, a.comment_id, a.read_num, a.like_num, a.share_num,
a.content_preview, a.paragraph_count, a.created_at, a.updated_at,
o.nickname, c.html_content, c.text_content, c.paragraphs, c.images
FROM articles a
LEFT JOIN official_accounts o ON a.official_id = o.id
LEFT JOIN article_contents c ON a.id = c.article_id
WHERE a.id = ?
`, articleID).Scan(
&detail.ID, &detail.OfficialID, &detail.Title, &detail.Author,
&detail.Link, &detail.PublishTime, &detail.CreateTime, &detail.CommentID,
&detail.ReadNum, &detail.LikeNum, &detail.ShareNum, &detail.ContentPreview,
&detail.ParagraphCount, &detail.CreatedAt, &detail.UpdatedAt,
&detail.OfficialName, &detail.HtmlContent, &detail.TextContent,
&paragraphsJSON, &imagesJSON,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
// 解析JSON数组
if paragraphsJSON != "" {
json.Unmarshal([]byte(paragraphsJSON), &detail.Paragraphs)
}
if imagesJSON != "" {
json.Unmarshal([]byte(imagesJSON), &detail.Images)
}
return detail, nil
}
// GetStatistics 获取统计信息
func (db *DB) GetStatistics() (*Statistics, error) {
stats := &Statistics{}
err := db.QueryRow(`
SELECT
(SELECT COUNT(*) FROM official_accounts) as total_officials,
(SELECT COUNT(*) FROM articles) as total_articles,
(SELECT COALESCE(SUM(read_num), 0) FROM articles) as total_read_num,
(SELECT COALESCE(SUM(like_num), 0) FROM articles) as total_like_num
`).Scan(&stats.TotalOfficials, &stats.TotalArticles, &stats.TotalReadNum, &stats.TotalLikeNum)
if err != nil {
return nil, err
}
return stats, nil
}
// BatchInsertArticles 批量插入文章
func (r *ArticleRepository) BatchInsertArticles(articles []*Article) error {
if len(articles) == 0 {
return nil
}
// 开始事务
tx, err := r.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`
INSERT OR IGNORE INTO articles (
official_id, title, author, link, publish_time, create_time,
comment_id, read_num, like_num, share_num, content_preview, paragraph_count
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
return err
}
defer stmt.Close()
for _, article := range articles {
_, err = stmt.Exec(
article.OfficialID, article.Title, article.Author, article.Link,
article.PublishTime, article.CreateTime, article.CommentID,
article.ReadNum, article.LikeNum, article.ShareNum,
article.ContentPreview, article.ParagraphCount,
)
if err != nil {
return err
}
}
return tx.Commit()
}
// Helper function: 将字符串数组转换为JSON字符串
func StringsToJSON(strs []string) string {
if len(strs) == 0 {
return "[]"
}
data, _ := json.Marshal(strs)
return string(data)
}
// Helper function: 生成内容预览
func GeneratePreview(content string, maxLen int) string {
if len(content) <= maxLen {
return content
}
// 移除换行符和多余空格
content = strings.ReplaceAll(content, "\n", " ")
content = strings.ReplaceAll(content, "\r", "")
content = strings.Join(strings.Fields(content), " ")
if len(content) <= maxLen {
return content
}
return content[:maxLen] + "..."
}