feat: 新增重算脚本和统计脚本,更新README

This commit is contained in:
2026-02-05 19:01:38 +08:00
parent d373a073e4
commit 5a6fbcbf28
7 changed files with 469 additions and 20 deletions

293
image_similarity_recalc.py Normal file
View File

@@ -0,0 +1,293 @@
# -*- coding: utf-8 -*-
"""
图片去重审核脚本 - 重新计算版
专门处理 status='draft' AND similarity='recalc' 的数据
"""
import configparser
import logging
import time
import dashscope
from dashscope import MultiModalEmbedding
from typing import Optional, Tuple, List, Dict
import pymysql
from dashvector import Client, Doc
class ImageSimilarityRecalc:
"""图片相似度重新计算器"""
def __init__(self, config_path: str = 'config.ini'):
self.config = configparser.ConfigParser()
self.config.read(config_path, encoding='utf-8')
self._setup_logging()
# 连接
self.db_conn = None
self.dashvector_client = None
self.collection = None
# DashScope API
self.dashscope_api_key = self.config.get('dashscope', 'api_key')
dashscope.api_key = self.dashscope_api_key
# 配置参数
self.image_cdn_base = self.config.get('image', 'cdn_base')
self.vector_threshold = self.config.getfloat('similarity', 'vector_threshold')
self.batch_size = self.config.getint('process', 'batch_size')
def _setup_logging(self):
log_level = self.config.get('process', 'log_level', fallback='INFO')
log_file = self.config.get('process', 'log_file', fallback='image_similarity.log')
self.logger = logging.getLogger('recalc')
if not self.logger.handlers:
self.logger.setLevel(getattr(logging, log_level))
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_file, encoding='utf-8')
fh.setFormatter(formatter)
self.logger.addHandler(fh)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
self.logger.addHandler(sh)
def connect_db(self):
"""连接数据库"""
self.db_conn = pymysql.connect(
host=self.config.get('database', 'host'),
port=self.config.getint('database', 'port'),
user=self.config.get('database', 'user'),
password=self.config.get('database', 'password'),
database=self.config.get('database', 'database'),
charset=self.config.get('database', 'charset'),
cursorclass=pymysql.cursors.DictCursor
)
self.logger.info("数据库连接成功")
def connect_dashvector(self):
"""连接 DashVector"""
api_key = self.config.get('dashvector', 'api_key')
endpoint = self.config.get('dashvector', 'endpoint')
collection_name = self.config.get('dashvector', 'collection_name')
self.dashvector_client = Client(api_key=api_key, endpoint=endpoint)
self.collection = self.dashvector_client.get(collection_name)
self.logger.info("DashVector 连接成功")
def get_image_embedding(self, image_url: str, max_retries: int = 5) -> Optional[List[float]]:
"""调用 DashScope 多模态 Embedding SDK 获取图片向量"""
for attempt in range(max_retries):
try:
input_data = [{'image': image_url}]
resp = MultiModalEmbedding.call(
model='multimodal-embedding-v1',
input=input_data
)
if resp.status_code == 200:
return resp.output['embeddings'][0]['embedding']
elif resp.status_code in (429, 403):
wait_time = 3 + attempt * 3
self.logger.warning(f"API 限流,等待 {wait_time} 秒后重试 ({attempt + 1}/{max_retries})...")
time.sleep(wait_time)
else:
self.logger.warning(f"Embedding API 错误: {resp.status_code} - {resp.message}")
return None
except Exception as e:
self.logger.warning(f"Embedding API 异常: {e}")
time.sleep(2)
return None
def get_recalc_images(self) -> List[dict]:
"""获取需要重新计算的图片 (status='draft' AND similarity='recalc')"""
with self.db_conn.cursor() as cursor:
sql = """
SELECT id, image_id, image_url, image_thumb_url, image_name
FROM ai_image_tags
WHERE status = 'draft' AND similarity = 'recalc'
AND image_url != '' AND image_url IS NOT NULL
ORDER BY id ASC
LIMIT %s
"""
cursor.execute(sql, (self.batch_size,))
return cursor.fetchall()
def search_similar(self, features: List[float], exclude_id: int) -> Tuple[bool, Optional[int], Optional[float]]:
"""在 DashVector 中搜索相似图片"""
try:
results = self.collection.query(features, topk=3)
if results and results.output:
for doc in results.output:
similar_id = int(doc.id)
if similar_id == exclude_id:
continue
similarity = 1.0 - doc.score
self.logger.info(f"搜索到: {similar_id}, 距离={doc.score:.4f}, 相似度={similarity:.4f}")
if similarity >= self.vector_threshold:
return True, similar_id, similarity
return False, None, None
except Exception as e:
self.logger.warning(f"搜索失败: {e}")
return False, None, None
def upsert_to_dashvector(self, image_id: int, features: List[float]):
"""存入 DashVector"""
try:
doc = Doc(id=str(image_id), vector=features)
result = self.collection.upsert([doc])
if result.code == 0:
self.logger.info(f"向量入库成功: {image_id}")
else:
self.logger.warning(f"向量入库失败 ID={image_id}: code={result.code}, msg={result.message}")
except Exception as e:
self.logger.warning(f"存入 DashVector 异常 ID={image_id}: {e}")
def update_as_duplicate(self, image_id: int, similar_id: int, score: float):
"""更新为重复图片"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET status = 'similarity',
similarity = 'yes',
similarity_image_tags_id = %s,
similarity_score = %s,
updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (similar_id, score, image_id))
self.db_conn.commit()
self.logger.info(f"重复: {image_id} -> {similar_id} (分数={score:.4f})")
def update_as_unique(self, image_id: int):
"""更新为不重复图片"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET status = 'tag_extension',
similarity = 'calc',
updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (image_id,))
self.db_conn.commit()
self.logger.info(f"不重复: {image_id} -> tag_extension")
def update_as_failed(self, image_id: int, reason: str):
"""标记为处理失败(保持 recalc 状态)"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (image_id,))
self.db_conn.commit()
self.logger.warning(f"处理失败 {image_id}: {reason}")
def process_batch(self, image_records: List[dict]) -> Tuple[int, int, int]:
"""处理一批图片,返回 (重复数, 不重复数, 失败数)"""
if not image_records:
return 0, 0, 0
duplicates = 0
unique = 0
failed = 0
for rec in image_records:
image_id = rec['id']
if not rec['image_url'] or rec['image_url'].strip() == '':
self.logger.warning(f"图像URL为空跳过处理: {image_id}")
self.update_as_failed(image_id, "图像URL为空")
failed += 1
continue
full_url = f"{self.image_cdn_base}{rec['image_url']}"
try:
time.sleep(0.5)
self.logger.info(f"重新计算 Embedding: {image_id} -> {full_url}")
features = self.get_image_embedding(image_url=full_url)
if features is None:
self.logger.warning(f"Embedding 获取失败: {image_id}")
self.update_as_failed(image_id, "Embedding API 失败")
failed += 1
continue
is_dup, similar_id, score = self.search_similar(features, image_id)
if is_dup:
self.update_as_duplicate(image_id, similar_id, score)
duplicates += 1
else:
self.upsert_to_dashvector(image_id, features)
self.update_as_unique(image_id)
unique += 1
except Exception as e:
self.logger.error(f"处理失败 {image_id}: {e}")
self.update_as_failed(image_id, str(e)[:200])
failed += 1
continue
return duplicates, unique, failed
def run(self):
"""运行主流程"""
self.logger.info("=" * 60)
self.logger.info("图片去重审核 - 重新计算版 (recalc)")
self.logger.info("=" * 60)
self.connect_db()
self.connect_dashvector()
total_duplicates = 0
total_unique = 0
total_failed = 0
batch_num = 0
try:
while True:
images = self.get_recalc_images()
if not images:
self.logger.info("没有需要重新计算的图片")
break
batch_num += 1
self.logger.info(f"\n--- 批次 {batch_num}: {len(images)} 张 (recalc) ---")
dup, uniq, fail = self.process_batch(images)
total_duplicates += dup
total_unique += uniq
total_failed += fail
self.logger.info(f"批次结果: 重复={dup}, 不重复={uniq}, 失败={fail}")
# 批次间休息,避免数据库连接问题
time.sleep(1)
finally:
if self.db_conn:
self.db_conn.close()
self.logger.info("=" * 60)
self.logger.info(f"完成! 总重复: {total_duplicates}, 总不重复: {total_unique}, 总失败: {total_failed}")
self.logger.info("=" * 60)
if __name__ == '__main__':
recalc = ImageSimilarityRecalc('config.ini')
recalc.run()