Files
baijiahao_data_crawl/data_validation.py

770 lines
28 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据比对验证脚本
功能
1. 顺序验证验证不同数据源中记录的顺序一致性
2. 交叉验证对比数据内容识别缺失新增或不匹配的记录
支持的数据源
- JSON文件 (bjh_integrated_data.json)
- CSV文件 (ai_statistics_*.csv)
- MySQL数据库 (ai_statistics_* )
使用方法
# 验证JSON和CSV的一致性
python data_validation.py --source json csv --date 2025-12-29
# 验证CSV和数据库的一致性
python data_validation.py --source csv database --date 2025-12-29
# 完整验证(三个数据源)
python data_validation.py --source json csv database --date 2025-12-29
# 验证特定表
python data_validation.py --source csv database --table ai_statistics_day --date 2025-12-29
"""
import sys
import os
import json
import csv
import argparse
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Any, Set
from collections import OrderedDict
import hashlib
# 设置UTF-8编码
if sys.platform == 'win32':
import io
if not isinstance(sys.stdout, io.TextIOWrapper) or sys.stdout.encoding != 'utf-8':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
if not isinstance(sys.stderr, io.TextIOWrapper) or sys.stderr.encoding != 'utf-8':
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
# 导入数据库配置
try:
from database_config import DatabaseManager
except ImportError:
print("[X] 无法导入 database_config.py数据库验证功能将不可用")
DatabaseManager = None
class DataValidator:
"""数据比对验证器"""
def __init__(self, date_str: Optional[str] = None):
"""初始化
Args:
date_str: 目标日期 (YYYY-MM-DD)默认为昨天
"""
self.script_dir = os.path.dirname(os.path.abspath(__file__))
# 目标日期(默认为昨天)
if date_str:
self.target_date = datetime.strptime(date_str, '%Y-%m-%d')
else:
# 默认使用昨天的日期
self.target_date = datetime.now() - timedelta(days=1)
self.date_str = self.target_date.strftime('%Y-%m-%d')
# 数据库管理器
self.db_manager = None
if DatabaseManager:
try:
self.db_manager = DatabaseManager()
print(f"[OK] 数据库连接成功")
except Exception as e:
print(f"[!] 数据库连接失败: {e}")
# 验证结果
self.validation_results = {
'顺序验证': [],
'交叉验证': [],
'差异统计': {}
}
def load_json_data(self, file_path: Optional[str] = None) -> Optional[Any]:
"""加载JSON数据
Args:
file_path: JSON文件路径默认为 bjh_integrated_data.json
Returns:
JSON数据字典
"""
if not file_path:
file_path = os.path.join(self.script_dir, 'bjh_integrated_data.json')
try:
if not os.path.exists(file_path):
print(f"[X] JSON文件不存在: {file_path}")
return None
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"[OK] 加载JSON文件: {file_path}")
print(f" 账号数量: {len(data) if isinstance(data, list) else 1}")
return data
except Exception as e:
print(f"[X] 加载JSON文件失败: {e}")
return None
def load_csv_data(self, csv_file: str) -> Optional[List[Dict]]:
"""加载CSV数据
Args:
csv_file: CSV文件名
Returns:
CSV数据列表
"""
csv_path = os.path.join(self.script_dir, csv_file)
try:
if not os.path.exists(csv_path):
print(f"[X] CSV文件不存在: {csv_path}")
return None
rows = []
with open(csv_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
rows = list(reader)
print(f"[OK] 加载CSV文件: {csv_file}")
print(f" 记录数量: {len(rows)}")
return rows
except Exception as e:
print(f"[X] 加载CSV文件失败: {e}")
return None
def load_database_data(self, table_name: str, date_filter: Optional[str] = None) -> Optional[List[Dict]]:
"""从数据库加载数据
Args:
table_name: 表名
date_filter: 日期过滤字段名 'date', 'stat_date'
Returns:
数据库记录列表
"""
if not self.db_manager:
print(f"[X] 数据库管理器未初始化")
return None
try:
# 构建SQL查询
if date_filter:
sql = f"SELECT * FROM {table_name} WHERE {date_filter} = %s ORDER BY author_name, channel"
params = (self.date_str,)
else:
sql = f"SELECT * FROM {table_name} ORDER BY author_name, channel"
params = None
rows = self.db_manager.execute_query(sql, params)
print(f"[OK] 加载数据库表: {table_name}")
if date_filter:
print(f" 过滤条件: {date_filter} = {self.date_str}")
print(f" 记录数量: {len(rows) if rows else 0}")
return rows if rows else []
except Exception as e:
print(f"[X] 加载数据库数据失败: {e}")
import traceback
traceback.print_exc()
return None
def generate_record_key(self, record: Dict, key_fields: List[str]) -> str:
"""生成记录唯一键
Args:
record: 数据记录
key_fields: 主键字段列表
Returns:
唯一键字符串
"""
key_values = []
for field in key_fields:
value = record.get(field, '')
# 统一转为字符串并去除空白
key_values.append(str(value).strip())
return '|'.join(key_values)
def calculate_record_hash(self, record: Dict, exclude_fields: Optional[Set[str]] = None) -> str:
"""计算记录的哈希值(用于内容比对)
Args:
record: 数据记录
exclude_fields: 排除的字段集合如时间戳字段
Returns:
MD5哈希值
"""
if exclude_fields is None:
exclude_fields = {'updated_at', 'created_at', 'fetch_time'}
# 排序字段并生成稳定的字符串
sorted_items = []
for key in sorted(record.keys()):
if key not in exclude_fields:
value = record.get(key, '')
# 浮点数保留4位小数
if isinstance(value, float):
value = f"{value:.4f}"
sorted_items.append(f"{key}={value}")
content = '|'.join(sorted_items)
return hashlib.md5(content.encode('utf-8')).hexdigest()
def validate_order(self, source1_data: List[Dict], source2_data: List[Dict],
source1_name: str, source2_name: str,
key_fields: List[str]) -> Dict:
"""顺序验证:验证两个数据源中记录的顺序是否一致
Args:
source1_data: 数据源1的数据
source2_data: 数据源2的数据
source1_name: 数据源1名称
source2_name: 数据源2名称
key_fields: 主键字段列表
Returns:
验证结果字典
"""
print(f"\n{'='*70}")
print(f"顺序验证: {source1_name} vs {source2_name}")
print(f"{'='*70}")
result = {
'source1': source1_name,
'source2': source2_name,
'source1_count': len(source1_data),
'source2_count': len(source2_data),
'order_match': True,
'mismatches': []
}
# 生成记录键列表
source1_keys = [self.generate_record_key(r, key_fields) for r in source1_data]
source2_keys = [self.generate_record_key(r, key_fields) for r in source2_data]
# 比对顺序
min_len = min(len(source1_keys), len(source2_keys))
for i in range(min_len):
if source1_keys[i] != source2_keys[i]:
result['order_match'] = False
result['mismatches'].append({
'position': i,
'source1_key': source1_keys[i],
'source2_key': source2_keys[i]
})
# 输出结果
if result['order_match'] and len(source1_keys) == len(source2_keys):
print(f"[✓] 顺序一致,记录数相同: {len(source1_keys)}")
else:
print(f"[X] 顺序不一致")
print(f" {source1_name} 记录数: {len(source1_keys)}")
print(f" {source2_name} 记录数: {len(source2_keys)}")
if result['mismatches']:
print(f" 不匹配位置数: {len(result['mismatches'])}")
# 显示前5个不匹配
for mismatch in result['mismatches'][:5]:
print(f" 位置{mismatch['position']}: {mismatch['source1_key']} != {mismatch['source2_key']}")
return result
def validate_cross(self, source1_data: List[Dict], source2_data: List[Dict],
source1_name: str, source2_name: str,
key_fields: List[str],
compare_fields: Optional[List[str]] = None) -> Dict:
"""交叉验证:对比数据内容,识别缺失、新增或不匹配的记录
Args:
source1_data: 数据源1的数据
source2_data: 数据源2的数据
source1_name: 数据源1名称
source2_name: 数据源2名称
key_fields: 主键字段列表
compare_fields: 需要对比的字段列表None表示全部字段
Returns:
验证结果字典
"""
print(f"\n{'='*70}")
print(f"交叉验证: {source1_name} vs {source2_name}")
print(f"{'='*70}")
# 构建字典key -> record
source1_dict = {}
for record in source1_data:
key = self.generate_record_key(record, key_fields)
source1_dict[key] = record
source2_dict = {}
for record in source2_data:
key = self.generate_record_key(record, key_fields)
source2_dict[key] = record
# 查找差异
only_in_source1 = set(source1_dict.keys()) - set(source2_dict.keys())
only_in_source2 = set(source2_dict.keys()) - set(source1_dict.keys())
common_keys = set(source1_dict.keys()) & set(source2_dict.keys())
# 对比共同记录的字段值
field_mismatches = []
for key in common_keys:
record1 = source1_dict[key]
record2 = source2_dict[key]
# 确定要比对的字段
if compare_fields:
fields_to_compare = compare_fields
else:
fields_to_compare = set(record1.keys()) & set(record2.keys())
# 比对每个字段
mismatches_in_record = {}
for field in fields_to_compare:
val1 = record1.get(field, '')
val2 = record2.get(field, '')
# 类型转换和标准化
val1_normalized = self._normalize_value(val1)
val2_normalized = self._normalize_value(val2)
if val1_normalized != val2_normalized:
mismatches_in_record[field] = {
source1_name: val1,
source2_name: val2
}
if mismatches_in_record:
field_mismatches.append({
'key': key,
'fields': mismatches_in_record
})
# 输出结果
result = {
'source1': source1_name,
'source2': source2_name,
'source1_count': len(source1_data),
'source2_count': len(source2_data),
'only_in_source1': list(only_in_source1),
'only_in_source2': list(only_in_source2),
'common_count': len(common_keys),
'field_mismatches': field_mismatches
}
print(f"记录数统计:")
print(f" {source1_name}: {len(source1_data)}")
print(f" {source2_name}: {len(source2_data)}")
print(f" 共同记录: {len(common_keys)}")
print(f" 仅在{source1_name}: {len(only_in_source1)}")
print(f" 仅在{source2_name}: {len(only_in_source2)}")
print(f" 字段不匹配: {len(field_mismatches)}")
# 显示详细差异
if only_in_source1:
print(f"\n仅在{source1_name}中的记录前5条:")
for key in list(only_in_source1)[:5]:
print(f" - {key}")
if only_in_source2:
print(f"\n仅在{source2_name}中的记录前5条:")
for key in list(only_in_source2)[:5]:
print(f" - {key}")
if field_mismatches:
print(f"\n字段值不匹配的记录前3条:")
for mismatch in field_mismatches[:3]:
print(f" 记录: {mismatch['key']}")
for field, values in list(mismatch['fields'].items())[:5]: # 每条记录最多显示5个字段
print(f" 字段 {field}:")
print(f" {source1_name}: {values[source1_name]}")
print(f" {source2_name}: {values[source2_name]}")
return result
def _normalize_value(self, value: Any) -> str:
"""标准化值用于比对
Args:
value: 原始值
Returns:
标准化后的字符串
"""
if value is None or value == '':
return ''
# 浮点数保留4位小数
if isinstance(value, float):
return f"{value:.4f}"
# 整数转字符串
if isinstance(value, int):
return str(value)
# 字符串去除首尾空白
return str(value).strip()
def validate_ai_statistics(self, sources: List[str]) -> bool:
"""验证 ai_statistics 表数据
Args:
sources: 数据源列表 ['json', 'csv', 'database']
Returns:
验证是否通过
"""
print(f"\n{'#'*70}")
print(f"# 验证 ai_statistics 表数据")
print(f"# 日期: {self.date_str}")
print(f"{'#'*70}")
# 主键字段
key_fields = ['author_name', 'channel']
# 重要字段
compare_fields = [
'submission_count', 'read_count', 'comment_count', 'comment_rate',
'like_count', 'like_rate', 'favorite_count', 'favorite_rate',
'share_count', 'share_rate', 'slide_ratio', 'baidu_search_volume'
]
# 加载数据
data_sources = {}
if 'json' in sources:
json_data = self.load_json_data()
if json_data:
# 确保json_data是列表类型
if not isinstance(json_data, list):
json_data = [json_data]
# 从JSON提取 ai_statistics 数据
json_records = self._extract_ai_statistics_from_json(json_data)
data_sources['json'] = json_records
if 'csv' in sources:
csv_data = self.load_csv_data('ai_statistics.csv')
if csv_data:
data_sources['csv'] = csv_data
if 'database' in sources:
db_data = self.load_database_data('ai_statistics', date_filter='date')
if db_data:
data_sources['database'] = db_data
# 执行验证
if len(data_sources) < 2:
print(f"[X] 数据源不足至少需要2个数据源进行比对")
return False
# 两两比对
source_names = list(data_sources.keys())
all_passed = True
for i in range(len(source_names)):
for j in range(i + 1, len(source_names)):
source1_name = source_names[i]
source2_name = source_names[j]
# 只对 json vs csv 进行顺序验证
if (source1_name == 'json' and source2_name == 'csv') or \
(source1_name == 'csv' and source2_name == 'json'):
# 顺序验证
order_result = self.validate_order(
data_sources[source1_name],
data_sources[source2_name],
source1_name,
source2_name,
key_fields
)
self.validation_results['顺序验证'].append(order_result)
if not order_result['order_match']:
all_passed = False
# 交叉验证(所有组合都执行)
cross_result = self.validate_cross(
data_sources[source1_name],
data_sources[source2_name],
source1_name,
source2_name,
key_fields,
compare_fields
)
self.validation_results['交叉验证'].append(cross_result)
# 判断是否通过
if cross_result['only_in_source1'] or \
cross_result['only_in_source2'] or \
cross_result['field_mismatches']:
all_passed = False
return all_passed
def validate_ai_statistics_day(self, sources: List[str]) -> bool:
"""验证 ai_statistics_day 表数据
Args:
sources: 数据源列表
Returns:
验证是否通过
"""
print(f"\n{'#'*70}")
print(f"# 验证 ai_statistics_day 表数据")
print(f"# 日期: {self.date_str}")
print(f"{'#'*70}")
key_fields = ['author_name', 'channel', 'stat_date']
compare_fields = [
'total_submission_count', 'total_read_count', 'total_comment_count',
'total_like_count', 'total_favorite_count', 'total_share_count',
'avg_comment_rate', 'avg_like_rate', 'avg_favorite_rate',
'avg_share_rate', 'avg_slide_ratio', 'total_baidu_search_volume'
]
# 加载数据
data_sources = {}
if 'csv' in sources:
csv_data = self.load_csv_data('ai_statistics_day.csv')
if csv_data:
data_sources['csv'] = csv_data
if 'database' in sources:
db_data = self.load_database_data('ai_statistics_day', date_filter='stat_date')
if db_data:
data_sources['database'] = db_data
if len(data_sources) < 2:
print(f"[X] 数据源不足")
return False
# 执行验证
source_names = list(data_sources.keys())
all_passed = True
for i in range(len(source_names)):
for j in range(i + 1, len(source_names)):
source1_name = source_names[i]
source2_name = source_names[j]
# ai_statistics_day 表不需要顺序验证,只执行交叉验证
cross_result = self.validate_cross(
data_sources[source1_name],
data_sources[source2_name],
source1_name,
source2_name,
key_fields,
compare_fields
)
self.validation_results['交叉验证'].append(cross_result)
if cross_result['only_in_source1'] or \
cross_result['only_in_source2'] or \
cross_result['field_mismatches']:
all_passed = False
return all_passed
def _extract_ai_statistics_from_json(self, json_data: List[Dict]) -> List[Dict]:
"""从JSON数据中提取ai_statistics格式的数据
Args:
json_data: JSON数据
Returns:
ai_statistics格式的数据列表
"""
records = []
for account_data in json_data:
account_id = account_data.get('account_id', '')
if not account_id:
continue
analytics = account_data.get('analytics', {})
apis = analytics.get('apis', [])
if apis and len(apis) > 0:
api_data = apis[0].get('data', {})
if api_data.get('errno') == 0:
total_info = api_data.get('data', {}).get('total_info', {})
record = {
'author_name': account_id,
'channel': 1,
'submission_count': int(total_info.get('publish_count', 0) or 0),
'read_count': int(total_info.get('view_count', 0) or 0),
'comment_count': int(total_info.get('comment_count', 0) or 0),
'comment_rate': float(total_info.get('comment_rate', 0) or 0) / 100,
'like_count': int(total_info.get('likes_count', 0) or 0),
'like_rate': float(total_info.get('likes_rate', 0) or 0) / 100,
'favorite_count': int(total_info.get('collect_count', 0) or 0),
'favorite_rate': float(total_info.get('collect_rate', 0) or 0) / 100,
'share_count': int(total_info.get('share_count', 0) or 0),
'share_rate': float(total_info.get('share_rate', 0) or 0) / 100,
'slide_ratio': float(total_info.get('pic_slide_rate', 0) or 0) / 100,
'baidu_search_volume': int(total_info.get('disp_pv', 0) or 0)
}
records.append(record)
return records
def generate_report(self, output_file: Optional[str] = None) -> None:
"""生成验证报告
Args:
output_file: 输出文件路径
"""
if not output_file:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_file = os.path.join(self.script_dir, f'validation_report_{timestamp}.txt')
try:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(f"数据验证报告\n")
f.write(f"{'='*70}\n")
f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"目标日期: {self.date_str}\n\n")
# 顺序验证结果
f.write(f"\n顺序验证结果\n")
f.write(f"{'-'*70}\n")
for result in self.validation_results['顺序验证']:
f.write(f"{result['source1']} vs {result['source2']}\n")
f.write(f" 顺序匹配: {'' if result['order_match'] else ''}\n")
f.write(f" {result['source1']} 记录数: {result['source1_count']}\n")
f.write(f" {result['source2']} 记录数: {result['source2_count']}\n")
if result['mismatches']:
f.write(f" 不匹配数: {len(result['mismatches'])}\n")
f.write(f"\n")
# 交叉验证结果
f.write(f"\n交叉验证结果\n")
f.write(f"{'-'*70}\n")
for result in self.validation_results['交叉验证']:
f.write(f"{result['source1']} vs {result['source2']}\n")
f.write(f" 共同记录: {result['common_count']}\n")
f.write(f" 仅在{result['source1']}: {len(result['only_in_source1'])}\n")
f.write(f" 仅在{result['source2']}: {len(result['only_in_source2'])}\n")
f.write(f" 字段不匹配: {len(result['field_mismatches'])}\n")
f.write(f"\n")
print(f"\n[OK] 验证报告已生成: {output_file}")
except Exception as e:
print(f"[X] 生成报告失败: {e}")
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description='数据比对验证脚本',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
# 验证JSON和CSV
python data_validation.py --source json csv --date 2025-12-29
# 验证CSV和数据库
python data_validation.py --source csv database --date 2025-12-29
# 完整验证(三个数据源)
python data_validation.py --source json csv database --date 2025-12-29
# 验证特定表
python data_validation.py --source csv database --table ai_statistics_day --date 2025-12-29
"""
)
parser.add_argument(
'--source',
nargs='+',
choices=['json', 'csv', 'database'],
default=['json', 'csv', 'database'],
help='数据源列表至少2个'
)
parser.add_argument(
'--date',
type=str,
default=(datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d'),
help='目标日期 (YYYY-MM-DD),默认为昨天'
)
parser.add_argument(
'--table',
type=str,
choices=['ai_statistics', 'ai_statistics_day', 'ai_statistics_days'],
default='ai_statistics',
help='要验证的表名'
)
parser.add_argument(
'--report',
type=str,
help='输出报告文件路径'
)
args = parser.parse_args()
# 检查数据源数量
if len(args.source) < 2:
print("[X] 至少需要指定2个数据源进行比对")
return 1
# 创建验证器
validator = DataValidator(date_str=args.date)
# 执行验证
try:
if args.table == 'ai_statistics':
passed = validator.validate_ai_statistics(args.source)
elif args.table == 'ai_statistics_day':
passed = validator.validate_ai_statistics_day(args.source)
else:
print(f"[!] 表 {args.table} 的验证功能暂未实现")
passed = False
# 生成报告
validator.generate_report(args.report)
# 输出总结
print(f"\n{'='*70}")
if passed:
print(f"[✓] 验证通过:所有数据源数据一致")
else:
print(f"[X] 验证失败:发现数据差异")
print(f"{'='*70}")
return 0 if passed else 1
except Exception as e:
print(f"\n[X] 验证过程出错: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == '__main__':
sys.exit(main())