Files
baijiahao_data_crawl/import_csv_to_database.py

605 lines
28 KiB
Python
Raw Permalink Normal View History

2025-12-25 11:16:59 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CSV数据导入数据库脚本
功能将三个CSV文件的数据以author_name为唯一键导入到数据库
- ai_statistics.csv -> ai_statistics表
- ai_statistics_day.csv -> ai_statistics_day表
- ai_statistics_days.csv -> ai_statistics_days表
"""
import os
import sys
import csv
from decimal import Decimal
from typing import Dict, List, Optional
from datetime import datetime
from database_config import DatabaseManager, DB_CONFIG
from log_config import setup_logger
class CSVImporter:
"""CSV数据导入器"""
def __init__(self, db_config: Optional[Dict] = None):
"""初始化导入器
Args:
db_config: 数据库配置默认使用database_config.DB_CONFIG
"""
self.script_dir = os.path.dirname(os.path.abspath(__file__))
self.db_manager = DatabaseManager(db_config)
# 初始化日志
self.logger = setup_logger('import_csv', os.path.join(self.script_dir, 'logs', 'import_csv.log'))
# CSV文件路径
self.csv_files = {
'ai_statistics': os.path.join(self.script_dir, 'ai_statistics.csv'),
'ai_statistics_day': os.path.join(self.script_dir, 'ai_statistics_day.csv'),
'ai_statistics_days': os.path.join(self.script_dir, 'ai_statistics_days.csv'),
}
def read_csv(self, file_path: str) -> List[Dict]:
"""读取CSV文件
Args:
file_path: CSV文件路径
Returns:
数据行列表
"""
if not os.path.exists(file_path):
print(f"[X] CSV文件不存在: {file_path}")
self.logger.error(f"CSV文件不存在: {file_path}")
return []
rows = []
try:
with open(file_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
print(f"[OK] 读取CSV文件成功: {file_path}, 共 {len(rows)}")
self.logger.info(f"读取CSV文件成功: {file_path}, 共 {len(rows)}")
return rows
except Exception as e:
print(f"[X] 读取CSV文件失败: {file_path}, 错误: {e}")
self.logger.error(f"读取CSV文件失败: {file_path}, 错误: {e}")
return []
def convert_value(self, value: str, field_type: str):
"""转换数据类型
Args:
value: 字符串值
field_type: 字段类型 (int, float, decimal, str)
Returns:
转换后的值
"""
if value == '' or value is None:
return None
try:
if field_type == 'int':
return int(float(value))
elif field_type == 'float':
return float(value)
elif field_type == 'decimal':
return Decimal(str(value))
else:
return str(value)
except Exception as e:
print(f"[!] 数据类型转换失败: {value} -> {field_type}, 错误: {e}")
self.logger.warning(f"数据类型转换失败: {value} -> {field_type}, 错误: {e}")
return None
def get_author_id(self, author_name: str, channel: int = 1) -> int:
"""从 ai_authors 表查询 author_id
Args:
author_name: 作者名称
channel: 渠道默认1=百度
Returns:
author_id如果找不到返回0
"""
try:
sql = "SELECT id FROM ai_authors WHERE author_name = %s AND channel = %s"
result = self.db_manager.execute_query(sql, (author_name, channel), fetch_one=True)
if result:
author_id = result.get('id', 0) if isinstance(result, dict) else result[0]
self.logger.debug(f"查询author_id: {author_name} -> {author_id}")
return author_id
else:
self.logger.warning(f"未找到author_id: {author_name}, channel={channel}")
return 0
except Exception as e:
self.logger.error(f"查询author_id失败: {author_name}, 错误: {e}")
return 0
def import_ai_statistics(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = [] # 批量参数列表
# 构建SQL所有记录使用相同SQL
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 处理slide_ratio值CSV中已是小数格式
2025-12-25 11:16:59 +08:00
slide_ratio_value = float(self.convert_value(row.get('slide_ratio', '0'), 'float') or 0.0)
slide_ratio_value = min(slide_ratio_value, 9.9999)
# 获取channel
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
# 从数据库查询正确的 author_id不使用CSV中的0
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
# 准备数据
record = {
'author_id': author_id, # 使用查询到的author_id
'author_name': author_name,
'channel': channel,
'date': row.get('date', '').strip(),
'submission_count': self.convert_value(row.get('submission_count', '0'), 'int') or 0,
'read_count': self.convert_value(row.get('read_count', '0'), 'int') or 0,
'comment_count': self.convert_value(row.get('comment_count', '0'), 'int') or 0,
'comment_rate': min(float(self.convert_value(row.get('comment_rate', '0'), 'float') or 0.0), 9.9999),
'like_count': self.convert_value(row.get('like_count', '0'), 'int') or 0,
'like_rate': min(float(self.convert_value(row.get('like_rate', '0'), 'float') or 0.0), 9.9999),
'favorite_count': self.convert_value(row.get('favorite_count', '0'), 'int') or 0,
'favorite_rate': min(float(self.convert_value(row.get('favorite_rate', '0'), 'float') or 0.0), 9.9999),
'share_count': self.convert_value(row.get('share_count', '0'), 'int') or 0,
'share_rate': min(float(self.convert_value(row.get('share_rate', '0'), 'float') or 0.0), 9.9999),
'slide_ratio': slide_ratio_value,
'baidu_search_volume': self.convert_value(row.get('baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
# 第一条记录时构建SQL
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'date']]
sql_template = f"""
INSERT INTO ai_statistics ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
# 添加到批量参数确保first_record_keys已初始化
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
# 当达到批量大小或最后一条记录时,执行批量插入
if len(batch_params) >= batch_size or idx == len(rows):
try:
# 使用 execute_many 批量提交
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = [] # 清空批次
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_day(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_day 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_day 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics_day']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_day表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics_day表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = []
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 处理avg_slide_ratio值CSV中已是小数格式
2025-12-25 11:16:59 +08:00
avg_slide_ratio_value = float(self.convert_value(row.get('avg_slide_ratio', '0'), 'float') or 0.0)
avg_slide_ratio_value = min(avg_slide_ratio_value, 9.9999)
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': row.get('stat_date', '').strip(),
'total_submission_count': self.convert_value(row.get('total_submission_count', '0'), 'int') or 0,
'total_read_count': self.convert_value(row.get('total_read_count', '0'), 'int') or 0,
'total_comment_count': self.convert_value(row.get('total_comment_count', '0'), 'int') or 0,
'total_like_count': self.convert_value(row.get('total_like_count', '0'), 'int') or 0,
'total_favorite_count': self.convert_value(row.get('total_favorite_count', '0'), 'int') or 0,
'total_share_count': self.convert_value(row.get('total_share_count', '0'), 'int') or 0,
'avg_comment_rate': min(float(self.convert_value(row.get('avg_comment_rate', '0'), 'float') or 0.0), 9.9999),
'avg_like_rate': min(float(self.convert_value(row.get('avg_like_rate', '0'), 'float') or 0.0), 9.9999),
'avg_favorite_rate': min(float(self.convert_value(row.get('avg_favorite_rate', '0'), 'float') or 0.0), 9.9999),
'avg_share_rate': min(float(self.convert_value(row.get('avg_share_rate', '0'), 'float') or 0.0), 9.9999),
'avg_slide_ratio': avg_slide_ratio_value,
'total_baidu_search_volume': self.convert_value(row.get('total_baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'stat_date']]
sql_template = f"""
INSERT INTO ai_statistics_day ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
if len(batch_params) >= batch_size or idx == len(rows):
try:
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics_day表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = []
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics_day表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics_day表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics_day 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics_day表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_days(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_days 表数据仅当日数据day_revenue
同时自动拆分数据到 ai_statistics_weekly ai_statistics_monthly
2025-12-25 11:16:59 +08:00
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_days 表数据拆分到3个表")
2025-12-25 11:16:59 +08:00
print("="*70)
csv_file = self.csv_files['ai_statistics_days']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_days表没有数据可导入")
return False
self.logger.info(f"开始导入数据,共 {len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录将拆分到3个表\n")
2025-12-25 11:16:59 +08:00
# 三个表的统计
days_success = 0
weekly_success = 0
monthly_success = 0
2025-12-25 11:16:59 +08:00
failed_count = 0
# 批量参数
days_batch = []
weekly_batch = []
monthly_batch = []
# SQL模板
days_sql = None
weekly_sql = None
monthly_sql = None
days_keys = None
weekly_keys = None
monthly_keys = None
2025-12-25 11:16:59 +08:00
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
stat_date = row.get('stat_date', '').strip()
2025-12-25 11:16:59 +08:00
# 1. ai_statistics_days 表数据(仅当日数据)
day_revenue = self.convert_value(row.get('day_revenue', '0'), 'decimal') or Decimal('0')
daily_published_count = self.convert_value(row.get('daily_published_count', '0'), 'int') or 0
cumulative_published_count = self.convert_value(row.get('cumulative_published_count', '0'), 'int') or 0
days_record = {
2025-12-25 11:16:59 +08:00
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': stat_date,
'daily_published_count': daily_published_count,
'day_revenue': day_revenue,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
2025-12-25 11:16:59 +08:00
}
# 2. ai_statistics_weekly 表数据
weekly_revenue = self.convert_value(row.get('weekly_revenue', '0'), 'decimal') or Decimal('0')
revenue_wow_growth_rate = self.convert_value(row.get('revenue_wow_growth_rate', '0'), 'decimal') or Decimal('0')
# 计算该日期所在周次格式WW如51
from datetime import datetime as dt, timedelta
date_obj = dt.strptime(stat_date, '%Y-%m-%d')
# 使用isocalendar()获取ISO周数周一为一周开始
year, week_num, _ = date_obj.isocalendar()
stat_weekly = week_num # 直接使用数字
weekly_record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_weekly': stat_weekly,
'weekly_revenue': weekly_revenue,
'revenue_wow_growth_rate': revenue_wow_growth_rate,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
# 3. ai_statistics_monthly 表数据
monthly_revenue = self.convert_value(row.get('monthly_revenue', '0'), 'decimal') or Decimal('0')
revenue_mom_growth_rate = self.convert_value(row.get('revenue_mom_growth_rate', '0'), 'decimal') or Decimal('0')
# 计算该日期所在月份格式YYYY-MM如2025-12
stat_monthly = date_obj.strftime('%Y-%m')
monthly_record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_monthly': stat_monthly,
'monthly_revenue': monthly_revenue,
'revenue_mom_growth_rate': revenue_mom_growth_rate,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
# 构建SQL模板首次
if days_sql is None:
days_keys = list(days_record.keys())
columns = ', '.join(days_keys)
placeholders = ', '.join(['%s'] * len(days_keys))
update_parts = [f"{key} = VALUES({key})" for key in days_keys if key not in ['author_name', 'channel', 'stat_date']]
days_sql = f"""
2025-12-25 11:16:59 +08:00
INSERT INTO ai_statistics_days ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if weekly_sql is None:
weekly_keys = list(weekly_record.keys())
columns = ', '.join(weekly_keys)
placeholders = ', '.join(['%s'] * len(weekly_keys))
update_parts = [f"{key} = VALUES({key})" for key in weekly_keys if key not in ['author_name', 'channel', 'stat_weekly']]
weekly_sql = f"""
INSERT INTO ai_statistics_weekly ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
2025-12-25 11:16:59 +08:00
if monthly_sql is None:
monthly_keys = list(monthly_record.keys())
columns = ', '.join(monthly_keys)
placeholders = ', '.join(['%s'] * len(monthly_keys))
update_parts = [f"{key} = VALUES({key})" for key in monthly_keys if key not in ['author_name', 'channel', 'stat_monthly']]
monthly_sql = f"""
INSERT INTO ai_statistics_monthly ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
# 添加到批量参数
days_batch.append(tuple(days_record[key] for key in days_keys))
weekly_batch.append(tuple(weekly_record[key] for key in weekly_keys))
monthly_batch.append(tuple(monthly_record[key] for key in monthly_keys))
# 批量提交
if len(days_batch) >= batch_size or idx == len(rows):
2025-12-25 11:16:59 +08:00
try:
# 提交 ai_statistics_days
result = self.db_manager.execute_many(days_sql, days_batch, autocommit=True)
days_success += result
print(f"[days] 已导入 {days_success}")
days_batch = []
except Exception as e:
print(f" [X] days表提交失败: {e}")
self.logger.error(f"ai_statistics_days批量提交失败: {e}")
failed_count += len(days_batch)
days_batch = []
try:
# 提交 ai_statistics_weekly
result = self.db_manager.execute_many(weekly_sql, weekly_batch, autocommit=True)
weekly_success += result
print(f"[weekly] 已导入 {weekly_success}")
weekly_batch = []
except Exception as e:
print(f" [X] weekly表提交失败: {e}")
self.logger.error(f"ai_statistics_weekly批量提交失败: {e}")
weekly_batch = []
try:
# 提交 ai_statistics_monthly
result = self.db_manager.execute_many(monthly_sql, monthly_batch, autocommit=True)
monthly_success += result
print(f"[monthly] 已导入 {monthly_success}")
monthly_batch = []
except Exception as e:
print(f" [X] monthly表提交失败: {e}")
self.logger.error(f"ai_statistics_monthly批量提交失败: {e}")
monthly_batch = []
2025-12-25 11:16:59 +08:00
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"数据处理失败: {author_name}, 错误: {e}")
2025-12-25 11:16:59 +08:00
continue
print("\n" + "="*70)
print(f"[OK] 数据导入完成拆分到3个表")
print(f" ai_statistics_days: {days_success}")
print(f" ai_statistics_weekly: {weekly_success}")
print(f" ai_statistics_monthly: {monthly_success}")
2025-12-25 11:16:59 +08:00
if failed_count > 0:
print(f" 失败: {failed_count}")
2025-12-25 11:16:59 +08:00
print("="*70)
self.logger.info(f"数据导入完成: days={days_success}, weekly={weekly_success}, monthly={monthly_success}, failed={failed_count}")
return days_success > 0
2025-12-25 11:16:59 +08:00
def import_all(self) -> bool:
"""导入所有CSV文件"""
print("\n" + "="*70)
print("CSV数据导入数据库")
print("="*70)
print("\n将导入以下三个CSV文件到数据库")
print(" 1. ai_statistics.csv -> ai_statistics表")
print(" 2. ai_statistics_day.csv -> ai_statistics_day表")
print(" 3. ai_statistics_days.csv -> ai_statistics_days表")
print("\n唯一键:(author_name, channel, date/stat_date)")
print("="*70)
self.logger.info("开始执行CSV数据导入任务")
# 导入三个表
result1 = self.import_ai_statistics()
result2 = self.import_ai_statistics_day()
result3 = self.import_ai_statistics_days()
if result1 and result2 and result3:
print("\n" + "="*70)
print("✓ 所有CSV文件导入完成")
print("="*70)
self.logger.info("所有CSV文件导入成功")
return True
else:
print("\n" + "="*70)
print("✗ 部分CSV文件导入失败")
print("="*70)
self.logger.warning("部分CSV文件导入失败")
return False
def main():
"""主函数"""
logger = None
try:
# 创建导入器
importer = CSVImporter()
logger = importer.logger
# 导入所有CSV文件
success = importer.import_all()
return 0 if success else 1
except Exception as e:
print(f"\n[X] 程序执行出错: {e}")
if logger:
logger.error(f"程序执行出错: {e}", exc_info=True)
return 1
if __name__ == '__main__':
sys.exit(main())