Files
baijiahao_data_crawl/import_csv_to_database.py

509 lines
24 KiB
Python
Raw Permalink Normal View History

2025-12-25 11:16:59 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CSV数据导入数据库脚本
功能将三个CSV文件的数据以author_name为唯一键导入到数据库
- ai_statistics.csv -> ai_statistics表
- ai_statistics_day.csv -> ai_statistics_day表
- ai_statistics_days.csv -> ai_statistics_days表
"""
import os
import sys
import csv
from decimal import Decimal
from typing import Dict, List, Optional
from datetime import datetime
from database_config import DatabaseManager, DB_CONFIG
from log_config import setup_logger
class CSVImporter:
"""CSV数据导入器"""
def __init__(self, db_config: Optional[Dict] = None):
"""初始化导入器
Args:
db_config: 数据库配置默认使用database_config.DB_CONFIG
"""
self.script_dir = os.path.dirname(os.path.abspath(__file__))
self.db_manager = DatabaseManager(db_config)
# 初始化日志
self.logger = setup_logger('import_csv', os.path.join(self.script_dir, 'logs', 'import_csv.log'))
# CSV文件路径
self.csv_files = {
'ai_statistics': os.path.join(self.script_dir, 'ai_statistics.csv'),
'ai_statistics_day': os.path.join(self.script_dir, 'ai_statistics_day.csv'),
'ai_statistics_days': os.path.join(self.script_dir, 'ai_statistics_days.csv'),
}
def read_csv(self, file_path: str) -> List[Dict]:
"""读取CSV文件
Args:
file_path: CSV文件路径
Returns:
数据行列表
"""
if not os.path.exists(file_path):
print(f"[X] CSV文件不存在: {file_path}")
self.logger.error(f"CSV文件不存在: {file_path}")
return []
rows = []
try:
with open(file_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
print(f"[OK] 读取CSV文件成功: {file_path}, 共 {len(rows)}")
self.logger.info(f"读取CSV文件成功: {file_path}, 共 {len(rows)}")
return rows
except Exception as e:
print(f"[X] 读取CSV文件失败: {file_path}, 错误: {e}")
self.logger.error(f"读取CSV文件失败: {file_path}, 错误: {e}")
return []
def convert_value(self, value: str, field_type: str):
"""转换数据类型
Args:
value: 字符串值
field_type: 字段类型 (int, float, decimal, str)
Returns:
转换后的值
"""
if value == '' or value is None:
return None
try:
if field_type == 'int':
return int(float(value))
elif field_type == 'float':
return float(value)
elif field_type == 'decimal':
return Decimal(str(value))
else:
return str(value)
except Exception as e:
print(f"[!] 数据类型转换失败: {value} -> {field_type}, 错误: {e}")
self.logger.warning(f"数据类型转换失败: {value} -> {field_type}, 错误: {e}")
return None
def get_author_id(self, author_name: str, channel: int = 1) -> int:
"""从 ai_authors 表查询 author_id
Args:
author_name: 作者名称
channel: 渠道默认1=百度
Returns:
author_id如果找不到返回0
"""
try:
sql = "SELECT id FROM ai_authors WHERE author_name = %s AND channel = %s"
result = self.db_manager.execute_query(sql, (author_name, channel), fetch_one=True)
if result:
author_id = result.get('id', 0) if isinstance(result, dict) else result[0]
self.logger.debug(f"查询author_id: {author_name} -> {author_id}")
return author_id
else:
self.logger.warning(f"未找到author_id: {author_name}, channel={channel}")
return 0
except Exception as e:
self.logger.error(f"查询author_id失败: {author_name}, 错误: {e}")
return 0
def import_ai_statistics(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = [] # 批量参数列表
# 构建SQL所有记录使用相同SQL
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 处理slide_ratio值
slide_ratio_value = float(self.convert_value(row.get('slide_ratio', '0'), 'float') or 0.0)
if slide_ratio_value > 10:
slide_ratio_value = slide_ratio_value / 100
slide_ratio_value = min(slide_ratio_value, 9.9999)
# 获取channel
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
# 从数据库查询正确的 author_id不使用CSV中的0
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
# 准备数据
record = {
'author_id': author_id, # 使用查询到的author_id
'author_name': author_name,
'channel': channel,
'date': row.get('date', '').strip(),
'submission_count': self.convert_value(row.get('submission_count', '0'), 'int') or 0,
'read_count': self.convert_value(row.get('read_count', '0'), 'int') or 0,
'comment_count': self.convert_value(row.get('comment_count', '0'), 'int') or 0,
'comment_rate': min(float(self.convert_value(row.get('comment_rate', '0'), 'float') or 0.0), 9.9999),
'like_count': self.convert_value(row.get('like_count', '0'), 'int') or 0,
'like_rate': min(float(self.convert_value(row.get('like_rate', '0'), 'float') or 0.0), 9.9999),
'favorite_count': self.convert_value(row.get('favorite_count', '0'), 'int') or 0,
'favorite_rate': min(float(self.convert_value(row.get('favorite_rate', '0'), 'float') or 0.0), 9.9999),
'share_count': self.convert_value(row.get('share_count', '0'), 'int') or 0,
'share_rate': min(float(self.convert_value(row.get('share_rate', '0'), 'float') or 0.0), 9.9999),
'slide_ratio': slide_ratio_value,
'baidu_search_volume': self.convert_value(row.get('baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
# 第一条记录时构建SQL
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'date']]
sql_template = f"""
INSERT INTO ai_statistics ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
# 添加到批量参数确保first_record_keys已初始化
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
# 当达到批量大小或最后一条记录时,执行批量插入
if len(batch_params) >= batch_size or idx == len(rows):
try:
# 使用 execute_many 批量提交
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = [] # 清空批次
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_day(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_day 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_day 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics_day']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_day表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics_day表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = []
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
avg_slide_ratio_value = float(self.convert_value(row.get('avg_slide_ratio', '0'), 'float') or 0.0)
if avg_slide_ratio_value > 10:
avg_slide_ratio_value = avg_slide_ratio_value / 100
avg_slide_ratio_value = min(avg_slide_ratio_value, 9.9999)
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': row.get('stat_date', '').strip(),
'total_submission_count': self.convert_value(row.get('total_submission_count', '0'), 'int') or 0,
'total_read_count': self.convert_value(row.get('total_read_count', '0'), 'int') or 0,
'total_comment_count': self.convert_value(row.get('total_comment_count', '0'), 'int') or 0,
'total_like_count': self.convert_value(row.get('total_like_count', '0'), 'int') or 0,
'total_favorite_count': self.convert_value(row.get('total_favorite_count', '0'), 'int') or 0,
'total_share_count': self.convert_value(row.get('total_share_count', '0'), 'int') or 0,
'avg_comment_rate': min(float(self.convert_value(row.get('avg_comment_rate', '0'), 'float') or 0.0), 9.9999),
'avg_like_rate': min(float(self.convert_value(row.get('avg_like_rate', '0'), 'float') or 0.0), 9.9999),
'avg_favorite_rate': min(float(self.convert_value(row.get('avg_favorite_rate', '0'), 'float') or 0.0), 9.9999),
'avg_share_rate': min(float(self.convert_value(row.get('avg_share_rate', '0'), 'float') or 0.0), 9.9999),
'avg_slide_ratio': avg_slide_ratio_value,
'total_baidu_search_volume': self.convert_value(row.get('total_baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'stat_date']]
sql_template = f"""
INSERT INTO ai_statistics_day ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
if len(batch_params) >= batch_size or idx == len(rows):
try:
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics_day表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = []
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics_day表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics_day表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics_day 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics_day表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_days(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_days 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_days 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics_days']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_days表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics_days表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = []
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
# 处理day_revenue字段每日收益
day_revenue_value = self.convert_value(row.get('day_revenue', '0'), 'decimal')
if day_revenue_value is None:
day_revenue_value = Decimal('0')
record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': row.get('stat_date', '').strip(),
'daily_published_count': self.convert_value(row.get('daily_published_count', '0'), 'int') or 0,
'cumulative_published_count': self.convert_value(row.get('cumulative_published_count', '0'), 'int') or 0,
'day_revenue': day_revenue_value, # 每日收益
'monthly_revenue': self.convert_value(row.get('monthly_revenue', '0'), 'decimal') or Decimal('0'),
'weekly_revenue': self.convert_value(row.get('weekly_revenue', '0'), 'decimal') or Decimal('0'),
'revenue_mom_growth_rate': self.convert_value(row.get('revenue_mom_growth_rate', '0'), 'decimal') or Decimal('0'),
'revenue_wow_growth_rate': self.convert_value(row.get('revenue_wow_growth_rate', '0'), 'decimal') or Decimal('0'),
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'stat_date']]
sql_template = f"""
INSERT INTO ai_statistics_days ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
if len(batch_params) >= batch_size or idx == len(rows):
try:
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics_days表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = []
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics_days表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics_days表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics_days 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics_days表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_all(self) -> bool:
"""导入所有CSV文件"""
print("\n" + "="*70)
print("CSV数据导入数据库")
print("="*70)
print("\n将导入以下三个CSV文件到数据库")
print(" 1. ai_statistics.csv -> ai_statistics表")
print(" 2. ai_statistics_day.csv -> ai_statistics_day表")
print(" 3. ai_statistics_days.csv -> ai_statistics_days表")
print("\n唯一键:(author_name, channel, date/stat_date)")
print("="*70)
self.logger.info("开始执行CSV数据导入任务")
# 导入三个表
result1 = self.import_ai_statistics()
result2 = self.import_ai_statistics_day()
result3 = self.import_ai_statistics_days()
if result1 and result2 and result3:
print("\n" + "="*70)
print("✓ 所有CSV文件导入完成")
print("="*70)
self.logger.info("所有CSV文件导入成功")
return True
else:
print("\n" + "="*70)
print("✗ 部分CSV文件导入失败")
print("="*70)
self.logger.warning("部分CSV文件导入失败")
return False
def main():
"""主函数"""
logger = None
try:
# 创建导入器
importer = CSVImporter()
logger = importer.logger
# 导入所有CSV文件
success = importer.import_all()
return 0 if success else 1
except Exception as e:
print(f"\n[X] 程序执行出错: {e}")
if logger:
logger.error(f"程序执行出错: {e}", exc_info=True)
return 1
if __name__ == '__main__':
sys.exit(main())