Files
baijiahao_data_crawl/import_csv_to_database.py
“shengyudong” 322ac74336 2025-12-25 upload
2025-12-25 11:16:59 +08:00

509 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CSV数据导入数据库脚本
功能将三个CSV文件的数据以author_name为唯一键导入到数据库
- ai_statistics.csv -> ai_statistics表
- ai_statistics_day.csv -> ai_statistics_day表
- ai_statistics_days.csv -> ai_statistics_days表
"""
import os
import sys
import csv
from decimal import Decimal
from typing import Dict, List, Optional
from datetime import datetime
from database_config import DatabaseManager, DB_CONFIG
from log_config import setup_logger
class CSVImporter:
"""CSV数据导入器"""
def __init__(self, db_config: Optional[Dict] = None):
"""初始化导入器
Args:
db_config: 数据库配置默认使用database_config.DB_CONFIG
"""
self.script_dir = os.path.dirname(os.path.abspath(__file__))
self.db_manager = DatabaseManager(db_config)
# 初始化日志
self.logger = setup_logger('import_csv', os.path.join(self.script_dir, 'logs', 'import_csv.log'))
# CSV文件路径
self.csv_files = {
'ai_statistics': os.path.join(self.script_dir, 'ai_statistics.csv'),
'ai_statistics_day': os.path.join(self.script_dir, 'ai_statistics_day.csv'),
'ai_statistics_days': os.path.join(self.script_dir, 'ai_statistics_days.csv'),
}
def read_csv(self, file_path: str) -> List[Dict]:
"""读取CSV文件
Args:
file_path: CSV文件路径
Returns:
数据行列表
"""
if not os.path.exists(file_path):
print(f"[X] CSV文件不存在: {file_path}")
self.logger.error(f"CSV文件不存在: {file_path}")
return []
rows = []
try:
with open(file_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
print(f"[OK] 读取CSV文件成功: {file_path}, 共 {len(rows)}")
self.logger.info(f"读取CSV文件成功: {file_path}, 共 {len(rows)}")
return rows
except Exception as e:
print(f"[X] 读取CSV文件失败: {file_path}, 错误: {e}")
self.logger.error(f"读取CSV文件失败: {file_path}, 错误: {e}")
return []
def convert_value(self, value: str, field_type: str):
"""转换数据类型
Args:
value: 字符串值
field_type: 字段类型 (int, float, decimal, str)
Returns:
转换后的值
"""
if value == '' or value is None:
return None
try:
if field_type == 'int':
return int(float(value))
elif field_type == 'float':
return float(value)
elif field_type == 'decimal':
return Decimal(str(value))
else:
return str(value)
except Exception as e:
print(f"[!] 数据类型转换失败: {value} -> {field_type}, 错误: {e}")
self.logger.warning(f"数据类型转换失败: {value} -> {field_type}, 错误: {e}")
return None
def get_author_id(self, author_name: str, channel: int = 1) -> int:
"""从 ai_authors 表查询 author_id
Args:
author_name: 作者名称
channel: 渠道默认1=百度
Returns:
author_id如果找不到返回0
"""
try:
sql = "SELECT id FROM ai_authors WHERE author_name = %s AND channel = %s"
result = self.db_manager.execute_query(sql, (author_name, channel), fetch_one=True)
if result:
author_id = result.get('id', 0) if isinstance(result, dict) else result[0]
self.logger.debug(f"查询author_id: {author_name} -> {author_id}")
return author_id
else:
self.logger.warning(f"未找到author_id: {author_name}, channel={channel}")
return 0
except Exception as e:
self.logger.error(f"查询author_id失败: {author_name}, 错误: {e}")
return 0
def import_ai_statistics(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = [] # 批量参数列表
# 构建SQL所有记录使用相同SQL
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 处理slide_ratio值
slide_ratio_value = float(self.convert_value(row.get('slide_ratio', '0'), 'float') or 0.0)
if slide_ratio_value > 10:
slide_ratio_value = slide_ratio_value / 100
slide_ratio_value = min(slide_ratio_value, 9.9999)
# 获取channel
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
# 从数据库查询正确的 author_id不使用CSV中的0
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
# 准备数据
record = {
'author_id': author_id, # 使用查询到的author_id
'author_name': author_name,
'channel': channel,
'date': row.get('date', '').strip(),
'submission_count': self.convert_value(row.get('submission_count', '0'), 'int') or 0,
'read_count': self.convert_value(row.get('read_count', '0'), 'int') or 0,
'comment_count': self.convert_value(row.get('comment_count', '0'), 'int') or 0,
'comment_rate': min(float(self.convert_value(row.get('comment_rate', '0'), 'float') or 0.0), 9.9999),
'like_count': self.convert_value(row.get('like_count', '0'), 'int') or 0,
'like_rate': min(float(self.convert_value(row.get('like_rate', '0'), 'float') or 0.0), 9.9999),
'favorite_count': self.convert_value(row.get('favorite_count', '0'), 'int') or 0,
'favorite_rate': min(float(self.convert_value(row.get('favorite_rate', '0'), 'float') or 0.0), 9.9999),
'share_count': self.convert_value(row.get('share_count', '0'), 'int') or 0,
'share_rate': min(float(self.convert_value(row.get('share_rate', '0'), 'float') or 0.0), 9.9999),
'slide_ratio': slide_ratio_value,
'baidu_search_volume': self.convert_value(row.get('baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
# 第一条记录时构建SQL
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'date']]
sql_template = f"""
INSERT INTO ai_statistics ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
# 添加到批量参数确保first_record_keys已初始化
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
# 当达到批量大小或最后一条记录时,执行批量插入
if len(batch_params) >= batch_size or idx == len(rows):
try:
# 使用 execute_many 批量提交
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = [] # 清空批次
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_day(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_day 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_day 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics_day']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_day表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics_day表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = []
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
avg_slide_ratio_value = float(self.convert_value(row.get('avg_slide_ratio', '0'), 'float') or 0.0)
if avg_slide_ratio_value > 10:
avg_slide_ratio_value = avg_slide_ratio_value / 100
avg_slide_ratio_value = min(avg_slide_ratio_value, 9.9999)
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': row.get('stat_date', '').strip(),
'total_submission_count': self.convert_value(row.get('total_submission_count', '0'), 'int') or 0,
'total_read_count': self.convert_value(row.get('total_read_count', '0'), 'int') or 0,
'total_comment_count': self.convert_value(row.get('total_comment_count', '0'), 'int') or 0,
'total_like_count': self.convert_value(row.get('total_like_count', '0'), 'int') or 0,
'total_favorite_count': self.convert_value(row.get('total_favorite_count', '0'), 'int') or 0,
'total_share_count': self.convert_value(row.get('total_share_count', '0'), 'int') or 0,
'avg_comment_rate': min(float(self.convert_value(row.get('avg_comment_rate', '0'), 'float') or 0.0), 9.9999),
'avg_like_rate': min(float(self.convert_value(row.get('avg_like_rate', '0'), 'float') or 0.0), 9.9999),
'avg_favorite_rate': min(float(self.convert_value(row.get('avg_favorite_rate', '0'), 'float') or 0.0), 9.9999),
'avg_share_rate': min(float(self.convert_value(row.get('avg_share_rate', '0'), 'float') or 0.0), 9.9999),
'avg_slide_ratio': avg_slide_ratio_value,
'total_baidu_search_volume': self.convert_value(row.get('total_baidu_search_volume', '0'), 'int') or 0,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'stat_date']]
sql_template = f"""
INSERT INTO ai_statistics_day ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
if len(batch_params) >= batch_size or idx == len(rows):
try:
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics_day表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = []
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics_day表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics_day表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics_day 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics_day表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_ai_statistics_days(self, batch_size: int = 50) -> bool:
"""导入 ai_statistics_days 表数据(使用批量提交)
Args:
batch_size: 批量提交大小默认50条
"""
print("\n" + "="*70)
print("开始导入 ai_statistics_days 表数据")
print("="*70)
csv_file = self.csv_files['ai_statistics_days']
rows = self.read_csv(csv_file)
if not rows:
print("[X] 没有数据可导入")
self.logger.warning("ai_statistics_days表没有数据可导入")
return False
self.logger.info(f"开始导入ai_statistics_days表数据{len(rows)} 条记录,批量大小: {batch_size}")
print(f"\n总计 {len(rows)} 条记录,分批导入(每批 {batch_size} 条)\n")
success_count = 0
failed_count = 0
batch_params = []
first_record_keys = None
sql_template = None
for idx, row in enumerate(rows, 1):
author_name = row.get('author_name', '').strip()
if not author_name:
continue
try:
# 获取channel并查询author_id
channel = self.convert_value(row.get('channel', '1'), 'int') or 1
author_id = self.get_author_id(author_name, int(channel))
if author_id == 0:
self.logger.warning(f"跳过无效记录: {author_name}, author_id未找到")
failed_count += 1
continue
# 处理day_revenue字段每日收益
day_revenue_value = self.convert_value(row.get('day_revenue', '0'), 'decimal')
if day_revenue_value is None:
day_revenue_value = Decimal('0')
record = {
'author_id': author_id,
'author_name': author_name,
'channel': channel,
'stat_date': row.get('stat_date', '').strip(),
'daily_published_count': self.convert_value(row.get('daily_published_count', '0'), 'int') or 0,
'cumulative_published_count': self.convert_value(row.get('cumulative_published_count', '0'), 'int') or 0,
'day_revenue': day_revenue_value, # 每日收益
'monthly_revenue': self.convert_value(row.get('monthly_revenue', '0'), 'decimal') or Decimal('0'),
'weekly_revenue': self.convert_value(row.get('weekly_revenue', '0'), 'decimal') or Decimal('0'),
'revenue_mom_growth_rate': self.convert_value(row.get('revenue_mom_growth_rate', '0'), 'decimal') or Decimal('0'),
'revenue_wow_growth_rate': self.convert_value(row.get('revenue_wow_growth_rate', '0'), 'decimal') or Decimal('0'),
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), # 添加更新时间戳,强制更新
}
if sql_template is None:
first_record_keys = list(record.keys())
columns = ', '.join(first_record_keys)
placeholders = ', '.join(['%s'] * len(first_record_keys))
update_parts = [f"{key} = VALUES({key})" for key in first_record_keys if key not in ['author_name', 'channel', 'stat_date']]
sql_template = f"""
INSERT INTO ai_statistics_days ({columns})
VALUES ({placeholders})
ON DUPLICATE KEY UPDATE {', '.join(update_parts)}
"""
if first_record_keys is not None:
batch_params.append(tuple(record[key] for key in first_record_keys))
if len(batch_params) >= batch_size or idx == len(rows):
try:
result_count = self.db_manager.execute_many(sql_template, batch_params, autocommit=True)
success_count += result_count
print(f"[批次提交] 已导入 {success_count} 条记录(本批: {result_count}/{len(batch_params)}")
self.logger.info(f"ai_statistics_days表批量提交: {result_count}/{len(batch_params)} 条记录")
batch_params = []
except Exception as batch_error:
failed_count += len(batch_params)
print(f" [X] 批次提交失败: {batch_error}")
self.logger.error(f"ai_statistics_days表批量提交失败: {batch_error}")
batch_params = []
except Exception as e:
failed_count += 1
print(f" [X] 处理失败 ({author_name}): {e}")
self.logger.error(f"ai_statistics_days表处理失败: {author_name}, 错误: {e}")
continue
print("\n" + "="*70)
print(f"[OK] ai_statistics_days 表数据导入完成")
print(f" 成功: {success_count} 条记录")
if failed_count > 0:
print(f" 失败: {failed_count} 条记录")
print("="*70)
self.logger.info(f"ai_statistics_days表数据导入完成: 成功 {success_count} 条,失败 {failed_count}")
return success_count > 0
def import_all(self) -> bool:
"""导入所有CSV文件"""
print("\n" + "="*70)
print("CSV数据导入数据库")
print("="*70)
print("\n将导入以下三个CSV文件到数据库")
print(" 1. ai_statistics.csv -> ai_statistics表")
print(" 2. ai_statistics_day.csv -> ai_statistics_day表")
print(" 3. ai_statistics_days.csv -> ai_statistics_days表")
print("\n唯一键:(author_name, channel, date/stat_date)")
print("="*70)
self.logger.info("开始执行CSV数据导入任务")
# 导入三个表
result1 = self.import_ai_statistics()
result2 = self.import_ai_statistics_day()
result3 = self.import_ai_statistics_days()
if result1 and result2 and result3:
print("\n" + "="*70)
print("✓ 所有CSV文件导入完成")
print("="*70)
self.logger.info("所有CSV文件导入成功")
return True
else:
print("\n" + "="*70)
print("✗ 部分CSV文件导入失败")
print("="*70)
self.logger.warning("部分CSV文件导入失败")
return False
def main():
"""主函数"""
logger = None
try:
# 创建导入器
importer = CSVImporter()
logger = importer.logger
# 导入所有CSV文件
success = importer.import_all()
return 0 if success else 1
except Exception as e:
print(f"\n[X] 程序执行出错: {e}")
if logger:
logger.error(f"程序执行出错: {e}", exc_info=True)
return 1
if __name__ == '__main__':
sys.exit(main())