Files
ai_Image_review/image_similarity_recalc.py

294 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
图片去重审核脚本 - 重新计算版
专门处理 status='draft' AND similarity='recalc' 的数据
"""
import configparser
import logging
import time
import dashscope
from dashscope import MultiModalEmbedding
from typing import Optional, Tuple, List, Dict
import pymysql
from dashvector import Client, Doc
class ImageSimilarityRecalc:
"""图片相似度重新计算器"""
def __init__(self, config_path: str = 'config.ini'):
self.config = configparser.ConfigParser()
self.config.read(config_path, encoding='utf-8')
self._setup_logging()
# 连接
self.db_conn = None
self.dashvector_client = None
self.collection = None
# DashScope API
self.dashscope_api_key = self.config.get('dashscope', 'api_key')
dashscope.api_key = self.dashscope_api_key
# 配置参数
self.image_cdn_base = self.config.get('image', 'cdn_base')
self.vector_threshold = self.config.getfloat('similarity', 'vector_threshold')
self.batch_size = self.config.getint('process', 'batch_size')
def _setup_logging(self):
log_level = self.config.get('process', 'log_level', fallback='INFO')
log_file = self.config.get('process', 'log_file', fallback='image_similarity.log')
self.logger = logging.getLogger('recalc')
if not self.logger.handlers:
self.logger.setLevel(getattr(logging, log_level))
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_file, encoding='utf-8')
fh.setFormatter(formatter)
self.logger.addHandler(fh)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
self.logger.addHandler(sh)
def connect_db(self):
"""连接数据库"""
self.db_conn = pymysql.connect(
host=self.config.get('database', 'host'),
port=self.config.getint('database', 'port'),
user=self.config.get('database', 'user'),
password=self.config.get('database', 'password'),
database=self.config.get('database', 'database'),
charset=self.config.get('database', 'charset'),
cursorclass=pymysql.cursors.DictCursor
)
self.logger.info("数据库连接成功")
def connect_dashvector(self):
"""连接 DashVector"""
api_key = self.config.get('dashvector', 'api_key')
endpoint = self.config.get('dashvector', 'endpoint')
collection_name = self.config.get('dashvector', 'collection_name')
self.dashvector_client = Client(api_key=api_key, endpoint=endpoint)
self.collection = self.dashvector_client.get(collection_name)
self.logger.info("DashVector 连接成功")
def get_image_embedding(self, image_url: str, max_retries: int = 5) -> Optional[List[float]]:
"""调用 DashScope 多模态 Embedding SDK 获取图片向量"""
for attempt in range(max_retries):
try:
input_data = [{'image': image_url}]
resp = MultiModalEmbedding.call(
model='multimodal-embedding-v1',
input=input_data
)
if resp.status_code == 200:
return resp.output['embeddings'][0]['embedding']
elif resp.status_code in (429, 403):
wait_time = 3 + attempt * 3
self.logger.warning(f"API 限流,等待 {wait_time} 秒后重试 ({attempt + 1}/{max_retries})...")
time.sleep(wait_time)
else:
self.logger.warning(f"Embedding API 错误: {resp.status_code} - {resp.message}")
return None
except Exception as e:
self.logger.warning(f"Embedding API 异常: {e}")
time.sleep(2)
return None
def get_recalc_images(self) -> List[dict]:
"""获取需要重新计算的图片 (status='draft' AND similarity='recalc')"""
with self.db_conn.cursor() as cursor:
sql = """
SELECT id, image_id, image_url, image_thumb_url, image_name
FROM ai_image_tags
WHERE status = 'draft' AND similarity = 'recalc'
AND image_url != '' AND image_url IS NOT NULL
ORDER BY id ASC
LIMIT %s
"""
cursor.execute(sql, (self.batch_size,))
return cursor.fetchall()
def search_similar(self, features: List[float], exclude_id: int) -> Tuple[bool, Optional[int], Optional[float]]:
"""在 DashVector 中搜索相似图片"""
try:
results = self.collection.query(features, topk=3)
if results and results.output:
for doc in results.output:
similar_id = int(doc.id)
if similar_id == exclude_id:
continue
similarity = 1.0 - doc.score
self.logger.info(f"搜索到: {similar_id}, 距离={doc.score:.4f}, 相似度={similarity:.4f}")
if similarity >= self.vector_threshold:
return True, similar_id, similarity
return False, None, None
except Exception as e:
self.logger.warning(f"搜索失败: {e}")
return False, None, None
def upsert_to_dashvector(self, image_id: int, features: List[float]):
"""存入 DashVector"""
try:
doc = Doc(id=str(image_id), vector=features)
result = self.collection.upsert([doc])
if result.code == 0:
self.logger.info(f"向量入库成功: {image_id}")
else:
self.logger.warning(f"向量入库失败 ID={image_id}: code={result.code}, msg={result.message}")
except Exception as e:
self.logger.warning(f"存入 DashVector 异常 ID={image_id}: {e}")
def update_as_duplicate(self, image_id: int, similar_id: int, score: float):
"""更新为重复图片"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET status = 'similarity',
similarity = 'yes',
similarity_image_tags_id = %s,
similarity_score = %s,
updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (similar_id, score, image_id))
self.db_conn.commit()
self.logger.info(f"重复: {image_id} -> {similar_id} (分数={score:.4f})")
def update_as_unique(self, image_id: int):
"""更新为不重复图片"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET status = 'tag_extension',
similarity = 'calc',
updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (image_id,))
self.db_conn.commit()
self.logger.info(f"不重复: {image_id} -> tag_extension")
def update_as_failed(self, image_id: int, reason: str):
"""标记为处理失败(保持 recalc 状态)"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (image_id,))
self.db_conn.commit()
self.logger.warning(f"处理失败 {image_id}: {reason}")
def process_batch(self, image_records: List[dict]) -> Tuple[int, int, int]:
"""处理一批图片,返回 (重复数, 不重复数, 失败数)"""
if not image_records:
return 0, 0, 0
duplicates = 0
unique = 0
failed = 0
for rec in image_records:
image_id = rec['id']
if not rec['image_url'] or rec['image_url'].strip() == '':
self.logger.warning(f"图像URL为空跳过处理: {image_id}")
self.update_as_failed(image_id, "图像URL为空")
failed += 1
continue
full_url = f"{self.image_cdn_base}{rec['image_url']}"
try:
time.sleep(0.5)
self.logger.info(f"重新计算 Embedding: {image_id} -> {full_url}")
features = self.get_image_embedding(image_url=full_url)
if features is None:
self.logger.warning(f"Embedding 获取失败: {image_id}")
self.update_as_failed(image_id, "Embedding API 失败")
failed += 1
continue
is_dup, similar_id, score = self.search_similar(features, image_id)
if is_dup:
self.update_as_duplicate(image_id, similar_id, score)
duplicates += 1
else:
self.upsert_to_dashvector(image_id, features)
self.update_as_unique(image_id)
unique += 1
except Exception as e:
self.logger.error(f"处理失败 {image_id}: {e}")
self.update_as_failed(image_id, str(e)[:200])
failed += 1
continue
return duplicates, unique, failed
def run(self):
"""运行主流程"""
self.logger.info("=" * 60)
self.logger.info("图片去重审核 - 重新计算版 (recalc)")
self.logger.info("=" * 60)
self.connect_db()
self.connect_dashvector()
total_duplicates = 0
total_unique = 0
total_failed = 0
batch_num = 0
try:
while True:
images = self.get_recalc_images()
if not images:
self.logger.info("没有需要重新计算的图片")
break
batch_num += 1
self.logger.info(f"\n--- 批次 {batch_num}: {len(images)} 张 (recalc) ---")
dup, uniq, fail = self.process_batch(images)
total_duplicates += dup
total_unique += uniq
total_failed += fail
self.logger.info(f"批次结果: 重复={dup}, 不重复={uniq}, 失败={fail}")
# 批次间休息,避免数据库连接问题
time.sleep(1)
finally:
if self.db_conn:
self.db_conn.close()
self.logger.info("=" * 60)
self.logger.info(f"完成! 总重复: {total_duplicates}, 总不重复: {total_unique}, 总失败: {total_failed}")
self.logger.info("=" * 60)
if __name__ == '__main__':
recalc = ImageSimilarityRecalc('config.ini')
recalc.run()