Files
baijiahao_data_crawl/import_csv_to_database.py

605 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CSV数据导入数据库脚本
功能将三个CSV文件的数据以author_name为唯一键导入到数据库
- ai_statistics.csv -> ai_statistics表
- ai_statistics_day.csv -> ai_statistics_day表
- ai_statistics_days.csv -> ai_statistics_days表
"""
import os
import sys
import csv
from decimal import Decimal
from typing import Dict, List, Optional
from datetime import datetime
from database_config import DatabaseManager, DB_CONFIG
from log_config import setup_logger
class CSVImporter:
"""CSV数据导入器"""
def __init__(self, db_config: Optional[Dict] = None):
"""初始化导入器
Args:
db_config: 数据库配置默认使用database_config.DB_CONFIG
"""
self.script_dir = os.path.dirname(os.path.abspath(__file__))
self.db_manager = DatabaseManager(db_config)
# 初始化日志
self.logger = setup_logger('import_csv', os.path.join(self.script_dir, 'logs', 'import_csv.log'))
# CSV文件路径
self.csv_files = {
'ai_statistics': os.path.join(self.script_dir, 'ai_statistics.csv'),
'ai_statistics_day': os.path.join(self.script_dir, 'ai_statistics_day.csv'),
'ai_statistics_days': os.path.join(self.script_dir, 'ai_statistics_days.csv'),
}
def read_csv(self, file_path: str) -> List[Dict]:
"""读取CSV文件
Args:
file_path: CSV文件路径
Returns:
数据行列表
"""
if not os.path.exists(file_path):
print(f"[X] CSV文件不存在: {file_path}")
self.logger.error(f"CSV文件不存在: {file_path}")
return []
rows = []
try:
with open(file_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
print(f"[OK] 读取CSV文件成功: {file_path}, 共 {len(rows)}")
self.logger.info(f"读取CSV文件成功: {file_path}, 共 {len(rows)}")
return rows
except Exception as e:
print(f"[X] 读取CSV文件失败: {file_path}, 错误: {e}")
self.logger.error(f"读取CSV文件失败: {file_path}, 错误: {e}")
return []
def convert_value(self, value: str, field_type: str):
"""转换数据类型
Args:
value: 字符串值
field_type: 字段类型 (int, float, decimal, str)
Returns:
转换后的值
"""
if value == '' or value is None:
return None
try:
if field_type == 'int':
return int(float(value))
elif field_type == 'float':
return float(value)
elif field_type == 'decimal':
return Decimal(str(value))
else:
return str(value)
except Exception as e:
print(f"[!] 数据类型转换失败: {value} -> {field_type}, 错误: {e}")
self.logger.warning(f"数据类型转换失败: {value} -> {field_type}, 错误: {e}")
return None
def get_author_id(self, author_name: str, channel: int = 1) -> int:
"""从 ai_authors 表查询 author_id
Args:
author_name: 作者名称
channel: 渠道默认1=百度
Returns:
author_id如果找不到返回0
"""
try:
sql = "SELECT id FROM ai_authors WHERE author_name = %s AND channel = %s"
result = self.db_manager.execute_query(sql, (author_name, channel), fetch_one=True)
if result:
author_id = result.get('id', 0) if isinstance(result, dict) else result[0]
self.logger.debug(f"查询author_id: {author_name} -> {author_id}")
return author_id
else:
self.logger.warning(f"未找到author_id: {author_name}, channel={channel}")
return 0
except Exception as e:
self.logger.error(f"查询author_id失败: {author_name}, 错误: {e}")
return 0
def import_ai_statistics(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = [] # 批量参数列表
# 构建SQL所有记录使用相同SQL
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 处理slide_ratio值CSV中已是小数格式
slide_ratio_value = float(self.convert_value(row.get('slide_ratio', '0'), 'float') or 0.0)
slide_ratio_value = min(slide_ratio_value, 9.9999)
# 获取channel
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
# 从数据库查询正确的 author_id不使用CSV中的0
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
# 准备数据
record = {
'author_id': author_id, # 使用查询到的author_id
'author_name': author_name,
'channel': channel,
'date': row.get('date', '').strip(),
'submission_count': self.convert_value(row.get('submission_count', '0'), 'int') or 0,
'read_count': self.convert_value(row.get('read_count', '0'), 'int') or 0,
'comment_count': self.convert_value(row.get('comment_count', '0'), 'int') or 0,
'comment_rate': min(float(self.convert_value(row.get('comment_rate', '0'), 'float') or 0.0), 9.9999),
'like_count': self.convert_value(row.get('like_count', '0'), 'int') or 0,
'like_rate': min(float(self.convert_value(row.get('like_rate', '0'), 'float') or 0.0), 9.9999),
'favorite_count': self.convert_value(row.get('favorite_count', '0'), 'int') or 0,
'favorite_rate': min(float(self.convert_value(row.get('favorite_rate', '0'), 'float') or 0.0), 9.9999),
'share_count': self.convert_value(row.get('share_count', '0'), 'int') or 0,
'share_rate': min(float(self.convert_value(row.get('share_rate', '0'), 'float') or 0.0), 9.9999),
'slide_ratio': slide_ratio_value,
'baidu_search_volume': self.convert_value(row.get('baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
# 第一条记录时构建SQL
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'date']]
sql_template = f"""
INSERT INTO ai_statistics ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
# 添加到批量参数确保first_record_keys已初始化
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
# 当达到批量大小或最后一条记录时,执行批量插入
if len(batch_params) >= batch_size or idx == len(rows):
try:
# 使用 execute_many 批量提交
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = [] # 清空批次
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_day(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_day 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_day 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics_day']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_day表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics_day表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = []
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 处理avg_slide_ratio值CSV中已是小数格式
avg_slide_ratio_value = float(self.convert_value(row.get('avg_slide_ratio', '0'), 'float') or 0.0)
avg_slide_ratio_value = min(avg_slide_ratio_value, 9.9999)
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': row.get('stat_date', '').strip(),
'total_submission_count': self.convert_value(row.get('total_submission_count', '0'), 'int') or 0,
'total_read_count': self.convert_value(row.get('total_read_count', '0'), 'int') or 0,
'total_comment_count': self.convert_value(row.get('total_comment_count', '0'), 'int') or 0,
'total_like_count': self.convert_value(row.get('total_like_count', '0'), 'int') or 0,
'total_favorite_count': self.convert_value(row.get('total_favorite_count', '0'), 'int') or 0,
'total_share_count': self.convert_value(row.get('total_share_count', '0'), 'int') or 0,
'avg_comment_rate': min(float(self.convert_value(row.get('avg_comment_rate', '0'), 'float') or 0.0), 9.9999),
'avg_like_rate': min(float(self.convert_value(row.get('avg_like_rate', '0'), 'float') or 0.0), 9.9999),
'avg_favorite_rate': min(float(self.convert_value(row.get('avg_favorite_rate', '0'), 'float') or 0.0), 9.9999),
'avg_share_rate': min(float(self.convert_value(row.get('avg_share_rate', '0'), 'float') or 0.0), 9.9999),
'avg_slide_ratio': avg_slide_ratio_value,
'total_baidu_search_volume': self.convert_value(row.get('total_baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'stat_date']]
sql_template = f"""
INSERT INTO ai_statistics_day ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
if len(batch_params) >= batch_size or idx == len(rows):
try:
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics_day表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = []
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics_day表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics_day表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics_day 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics_day表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_days(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_days 表数据仅当日数据day_revenue
同时自动拆分数据到 ai_statistics_weekly 和 ai_statistics_monthly 表
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_days 表数据拆分到3个表")
print("="*70)
csv_file = self.csv_files['ai_statistics_days']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_days表没有数据可导入")
return False
self.logger.info(f"开始导入数据,共 {len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录将拆分到3个表\n")
# 三个表的统计
days_success = 0
weekly_success = 0
monthly_success = 0
failed_count = 0
# 批量参数
days_batch = []
weekly_batch = []
monthly_batch = []
# SQL模板
days_sql = None
weekly_sql = None
monthly_sql = None
days_keys = None
weekly_keys = None
monthly_keys = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
stat_date = row.get('stat_date', '').strip()
# 1. ai_statistics_days 表数据(仅当日数据)
day_revenue = self.convert_value(row.get('day_revenue', '0'), 'decimal') or Decimal('0')
daily_published_count = self.convert_value(row.get('daily_published_count', '0'), 'int') or 0
cumulative_published_count = self.convert_value(row.get('cumulative_published_count', '0'), 'int') or 0
days_record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': stat_date,
'daily_published_count': daily_published_count,
'day_revenue': day_revenue,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
# 2. ai_statistics_weekly 表数据
weekly_revenue = self.convert_value(row.get('weekly_revenue', '0'), 'decimal') or Decimal('0')
revenue_wow_growth_rate = self.convert_value(row.get('revenue_wow_growth_rate', '0'), 'decimal') or Decimal('0')
# 计算该日期所在周次格式WW如51
from datetime import datetime as dt, timedelta
date_obj = dt.strptime(stat_date, '%Y-%m-%d')
# 使用isocalendar()获取ISO周数周一为一周开始
year, week_num, _ = date_obj.isocalendar()
stat_weekly = week_num # 直接使用数字
weekly_record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_weekly': stat_weekly,
'weekly_revenue': weekly_revenue,
'revenue_wow_growth_rate': revenue_wow_growth_rate,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
# 3. ai_statistics_monthly 表数据
monthly_revenue = self.convert_value(row.get('monthly_revenue', '0'), 'decimal') or Decimal('0')
revenue_mom_growth_rate = self.convert_value(row.get('revenue_mom_growth_rate', '0'), 'decimal') or Decimal('0')
# 计算该日期所在月份格式YYYY-MM如2025-12
stat_monthly = date_obj.strftime('%Y-%m')
monthly_record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_monthly': stat_monthly,
'monthly_revenue': monthly_revenue,
'revenue_mom_growth_rate': revenue_mom_growth_rate,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
# 构建SQL模板首次
if days_sql is None:
days_keys = list(days_record.keys())
columns = ', '.join(days_keys)
placeholders = ', '.join(['%s'] * len(days_keys))
update_parts = [f"{key} = VALUES({key})" for key in days_keys if key not in ['author_name', 'channel', 'stat_date']]
days_sql = f"""
INSERT INTO ai_statistics_days ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if weekly_sql is None:
weekly_keys = list(weekly_record.keys())
columns = ', '.join(weekly_keys)
placeholders = ', '.join(['%s'] * len(weekly_keys))
update_parts = [f"{key} = VALUES({key})" for key in weekly_keys if key not in ['author_name', 'channel', 'stat_weekly']]
weekly_sql = f"""
INSERT INTO ai_statistics_weekly ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if monthly_sql is None:
monthly_keys = list(monthly_record.keys())
columns = ', '.join(monthly_keys)
placeholders = ', '.join(['%s'] * len(monthly_keys))
update_parts = [f"{key} = VALUES({key})" for key in monthly_keys if key not in ['author_name', 'channel', 'stat_monthly']]
monthly_sql = f"""
INSERT INTO ai_statistics_monthly ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
# 添加到批量参数
days_batch.append(tuple(days_record[key] for key in days_keys))
weekly_batch.append(tuple(weekly_record[key] for key in weekly_keys))
monthly_batch.append(tuple(monthly_record[key] for key in monthly_keys))
# 批量提交
if len(days_batch) >= batch_size or idx == len(rows):
try:
# 提交 ai_statistics_days
result = self.db_manager.execute_many(days_sql, days_batch, autocommit=True)
days_success += result
print(f"[days] 已导入 {days_success}")
days_batch = []
except Exception as e:
print(f" [X] days表提交失败: {e}")
self.logger.error(f"ai_statistics_days批量提交失败: {e}")
failed_count += len(days_batch)
days_batch = []
try:
# 提交 ai_statistics_weekly
result = self.db_manager.execute_many(weekly_sql, weekly_batch, autocommit=True)
weekly_success += result
print(f"[weekly] 已导入 {weekly_success}")
weekly_batch = []
except Exception as e:
print(f" [X] weekly表提交失败: {e}")
self.logger.error(f"ai_statistics_weekly批量提交失败: {e}")
weekly_batch = []
try:
# 提交 ai_statistics_monthly
result = self.db_manager.execute_many(monthly_sql, monthly_batch, autocommit=True)
monthly_success += result
print(f"[monthly] 已导入 {monthly_success}")
monthly_batch = []
except Exception as e:
print(f" [X] monthly表提交失败: {e}")
self.logger.error(f"ai_statistics_monthly批量提交失败: {e}")
monthly_batch = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"数据处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] 数据导入完成拆分到3个表")
print(f" ai_statistics_days: {days_success}")
print(f" ai_statistics_weekly: {weekly_success}")
print(f" ai_statistics_monthly: {monthly_success}")
if failed_count > 0:
print(f" 失败: {failed_count}")
print("="*70)
self.logger.info(f"数据导入完成: days={days_success}, weekly={weekly_success}, monthly={monthly_success}, failed={failed_count}")
return days_success > 0
def import_all(self) -> bool:
"""导入所有CSV文件"""
print("\n" + "="*70)
print("CSV数据导入数据库")
print("="*70)
print("\n将导入以下三个CSV文件到数据库")
print(" 1. ai_statistics.csv -> ai_statistics表")
print(" 2. ai_statistics_day.csv -> ai_statistics_day表")
print(" 3. ai_statistics_days.csv -> ai_statistics_days表")
print("\n唯一键:(author_name, channel, date/stat_date)")
print("="*70)
self.logger.info("开始执行CSV数据导入任务")
# 导入三个表
result1 = self.import_ai_statistics()
result2 = self.import_ai_statistics_day()
result3 = self.import_ai_statistics_days()
if result1 and result2 and result3:
print("\n" + "="*70)
print("✓ 所有CSV文件导入完成")
print("="*70)
self.logger.info("所有CSV文件导入成功")
return True
else:
print("\n" + "="*70)
print("✗ 部分CSV文件导入失败")
print("="*70)
self.logger.warning("部分CSV文件导入失败")
return False
def main():
"""主函数"""
logger = None
try:
# 创建导入器
importer = CSVImporter()
logger = importer.logger
# 导入所有CSV文件
success = importer.import_all()
return 0 if success else 1
except Exception as e:
print(f"\n[X] 程序执行出错: {e}")
if logger:
logger.error(f"程序执行出错: {e}", exc_info=True)
return 1
if __name__ == '__main__':
sys.exit(main())