295 lines
11 KiB
Python
295 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
图片去重审核脚本 - 重新计算版
|
||
专门处理 status='draft' AND similarity='recalc' 的数据
|
||
"""
|
||
|
||
import configparser
|
||
import logging
|
||
import time
|
||
import dashscope
|
||
from dashscope import MultiModalEmbedding
|
||
from typing import Optional, Tuple, List, Dict
|
||
|
||
import pymysql
|
||
from dashvector import Client, Doc
|
||
|
||
|
||
class ImageSimilarityRecalc:
|
||
"""图片相似度重新计算器"""
|
||
|
||
def __init__(self, config_path: str = 'config.ini'):
|
||
self.config = configparser.ConfigParser()
|
||
self.config.read(config_path, encoding='utf-8')
|
||
|
||
self._setup_logging()
|
||
|
||
# 连接
|
||
self.db_conn = None
|
||
self.dashvector_client = None
|
||
self.collection = None
|
||
|
||
# DashScope API
|
||
self.dashscope_api_key = self.config.get('dashscope', 'api_key')
|
||
dashscope.api_key = self.dashscope_api_key
|
||
|
||
# 配置参数
|
||
self.image_cdn_base = self.config.get('image', 'cdn_base')
|
||
self.vector_threshold = self.config.getfloat('similarity', 'vector_threshold')
|
||
self.batch_size = self.config.getint('process', 'batch_size')
|
||
|
||
def _setup_logging(self):
|
||
log_level = self.config.get('process', 'log_level', fallback='INFO')
|
||
log_file = self.config.get('process', 'log_file', fallback='image_similarity.log')
|
||
|
||
self.logger = logging.getLogger('recalc')
|
||
|
||
if not self.logger.handlers:
|
||
self.logger.setLevel(getattr(logging, log_level))
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
|
||
fh = logging.FileHandler(log_file, encoding='utf-8')
|
||
fh.setFormatter(formatter)
|
||
self.logger.addHandler(fh)
|
||
|
||
sh = logging.StreamHandler()
|
||
sh.setFormatter(formatter)
|
||
self.logger.addHandler(sh)
|
||
|
||
def connect_db(self):
|
||
"""连接数据库"""
|
||
self.db_conn = pymysql.connect(
|
||
host=self.config.get('database', 'host'),
|
||
port=self.config.getint('database', 'port'),
|
||
user=self.config.get('database', 'user'),
|
||
password=self.config.get('database', 'password'),
|
||
database=self.config.get('database', 'database'),
|
||
charset=self.config.get('database', 'charset'),
|
||
cursorclass=pymysql.cursors.DictCursor
|
||
)
|
||
self.logger.info("数据库连接成功")
|
||
|
||
def connect_dashvector(self):
|
||
"""连接 DashVector"""
|
||
api_key = self.config.get('dashvector', 'api_key')
|
||
endpoint = self.config.get('dashvector', 'endpoint')
|
||
collection_name = self.config.get('dashvector', 'collection_name')
|
||
|
||
self.dashvector_client = Client(api_key=api_key, endpoint=endpoint)
|
||
self.collection = self.dashvector_client.get(collection_name)
|
||
self.logger.info("DashVector 连接成功")
|
||
|
||
def get_image_embedding(self, image_url: str, max_retries: int = 5) -> Optional[List[float]]:
|
||
"""调用 DashScope 多模态 Embedding SDK 获取图片向量"""
|
||
for attempt in range(max_retries):
|
||
try:
|
||
input_data = [{'image': image_url}]
|
||
resp = MultiModalEmbedding.call(
|
||
model='multimodal-embedding-v1',
|
||
input=input_data
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
return resp.output['embeddings'][0]['embedding']
|
||
elif resp.status_code in (429, 403):
|
||
wait_time = 3 + attempt * 3
|
||
self.logger.warning(f"API 限流,等待 {wait_time} 秒后重试 ({attempt + 1}/{max_retries})...")
|
||
time.sleep(wait_time)
|
||
else:
|
||
self.logger.warning(f"Embedding API 错误: {resp.status_code} - {resp.message}")
|
||
return None
|
||
except Exception as e:
|
||
self.logger.warning(f"Embedding API 异常: {e}")
|
||
time.sleep(2)
|
||
|
||
return None
|
||
|
||
def get_recalc_images(self) -> List[dict]:
|
||
"""获取需要重新计算的图片 (status='draft' AND similarity='recalc')"""
|
||
with self.db_conn.cursor() as cursor:
|
||
sql = """
|
||
SELECT id, image_id, image_url, image_thumb_url, image_name
|
||
FROM ai_image_tags
|
||
WHERE status = 'draft' AND similarity = 'recalc'
|
||
AND image_url != '' AND image_url IS NOT NULL
|
||
ORDER BY id ASC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (self.batch_size,))
|
||
return cursor.fetchall()
|
||
|
||
def search_similar(self, features: List[float], exclude_id: int) -> Tuple[bool, Optional[int], Optional[float]]:
|
||
"""在 DashVector 中搜索相似图片"""
|
||
try:
|
||
results = self.collection.query(features, topk=3)
|
||
|
||
if results and results.output:
|
||
for doc in results.output:
|
||
similar_id = int(doc.id)
|
||
if similar_id == exclude_id:
|
||
continue
|
||
|
||
similarity = 1.0 - doc.score
|
||
self.logger.info(f"搜索到: {similar_id}, 距离={doc.score:.4f}, 相似度={similarity:.4f}")
|
||
|
||
if similarity >= self.vector_threshold:
|
||
return True, similar_id, similarity
|
||
|
||
return False, None, None
|
||
except Exception as e:
|
||
self.logger.warning(f"搜索失败: {e}")
|
||
return False, None, None
|
||
|
||
def upsert_to_dashvector(self, image_id: int, features: List[float]):
|
||
"""存入 DashVector"""
|
||
try:
|
||
doc = Doc(id=str(image_id), vector=features)
|
||
result = self.collection.upsert([doc])
|
||
if result.code == 0:
|
||
self.logger.info(f"向量入库成功: {image_id}")
|
||
else:
|
||
self.logger.warning(f"向量入库失败 ID={image_id}: code={result.code}, msg={result.message}")
|
||
except Exception as e:
|
||
self.logger.warning(f"存入 DashVector 异常 ID={image_id}: {e}")
|
||
|
||
def update_as_duplicate(self, image_id: int, similar_id: int, score: float):
|
||
"""更新为重复图片"""
|
||
with self.db_conn.cursor() as cursor:
|
||
sql = """
|
||
UPDATE ai_image_tags
|
||
SET status = 'similarity',
|
||
similarity = 'yes',
|
||
similarity_image_tags_id = %s,
|
||
similarity_score = %s,
|
||
updated_at = NOW()
|
||
WHERE id = %s
|
||
"""
|
||
cursor.execute(sql, (similar_id, score, image_id))
|
||
self.db_conn.commit()
|
||
self.logger.info(f"重复: {image_id} -> {similar_id} (分数={score:.4f})")
|
||
|
||
def update_as_unique(self, image_id: int):
|
||
"""更新为不重复图片"""
|
||
with self.db_conn.cursor() as cursor:
|
||
sql = """
|
||
UPDATE ai_image_tags
|
||
SET status = 'tag_extension',
|
||
similarity = 'calc',
|
||
updated_at = NOW()
|
||
WHERE id = %s
|
||
"""
|
||
cursor.execute(sql, (image_id,))
|
||
self.db_conn.commit()
|
||
self.logger.info(f"不重复: {image_id} -> tag_extension")
|
||
|
||
def update_as_failed(self, image_id: int, reason: str):
|
||
"""标记为处理失败(保持 recalc 状态)"""
|
||
with self.db_conn.cursor() as cursor:
|
||
sql = """
|
||
UPDATE ai_image_tags
|
||
SET updated_at = NOW()
|
||
WHERE id = %s
|
||
"""
|
||
cursor.execute(sql, (image_id,))
|
||
self.db_conn.commit()
|
||
self.logger.warning(f"处理失败 {image_id}: {reason}")
|
||
|
||
def process_batch(self, image_records: List[dict]) -> Tuple[int, int, int]:
|
||
"""处理一批图片,返回 (重复数, 不重复数, 失败数)"""
|
||
if not image_records:
|
||
return 0, 0, 0
|
||
|
||
duplicates = 0
|
||
unique = 0
|
||
failed = 0
|
||
|
||
for rec in image_records:
|
||
image_id = rec['id']
|
||
|
||
if not rec['image_url'] or rec['image_url'].strip() == '':
|
||
self.logger.warning(f"图像URL为空,跳过处理: {image_id}")
|
||
self.update_as_failed(image_id, "图像URL为空")
|
||
failed += 1
|
||
continue
|
||
|
||
full_url = f"{self.image_cdn_base}{rec['image_url']}"
|
||
|
||
try:
|
||
time.sleep(0.5)
|
||
self.logger.info(f"重新计算 Embedding: {image_id} -> {full_url}")
|
||
|
||
features = self.get_image_embedding(image_url=full_url)
|
||
|
||
if features is None:
|
||
self.logger.warning(f"Embedding 获取失败: {image_id}")
|
||
self.update_as_failed(image_id, "Embedding API 失败")
|
||
failed += 1
|
||
continue
|
||
|
||
is_dup, similar_id, score = self.search_similar(features, image_id)
|
||
|
||
if is_dup:
|
||
self.update_as_duplicate(image_id, similar_id, score)
|
||
duplicates += 1
|
||
else:
|
||
self.upsert_to_dashvector(image_id, features)
|
||
self.update_as_unique(image_id)
|
||
unique += 1
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"处理失败 {image_id}: {e}")
|
||
self.update_as_failed(image_id, str(e)[:200])
|
||
failed += 1
|
||
continue
|
||
|
||
return duplicates, unique, failed
|
||
|
||
def run(self):
|
||
"""运行主流程"""
|
||
self.logger.info("=" * 60)
|
||
self.logger.info("图片去重审核 - 重新计算版 (recalc)")
|
||
self.logger.info("=" * 60)
|
||
|
||
self.connect_db()
|
||
self.connect_dashvector()
|
||
|
||
total_duplicates = 0
|
||
total_unique = 0
|
||
total_failed = 0
|
||
batch_num = 0
|
||
|
||
try:
|
||
while True:
|
||
images = self.get_recalc_images()
|
||
|
||
if not images:
|
||
self.logger.info("没有需要重新计算的图片,等待 10 秒后继续检查...")
|
||
time.sleep(10)
|
||
continue
|
||
|
||
batch_num += 1
|
||
self.logger.info(f"\n--- 批次 {batch_num}: {len(images)} 张 (recalc) ---")
|
||
|
||
dup, uniq, fail = self.process_batch(images)
|
||
total_duplicates += dup
|
||
total_unique += uniq
|
||
total_failed += fail
|
||
|
||
self.logger.info(f"批次结果: 重复={dup}, 不重复={uniq}, 失败={fail}")
|
||
|
||
# 批次间休息,避免数据库连接问题
|
||
time.sleep(1)
|
||
|
||
finally:
|
||
if self.db_conn:
|
||
self.db_conn.close()
|
||
|
||
self.logger.info("=" * 60)
|
||
self.logger.info(f"完成! 总重复: {total_duplicates}, 总不重复: {total_unique}, 总失败: {total_failed}")
|
||
self.logger.info("=" * 60)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
recalc = ImageSimilarityRecalc('config.ini')
|
||
recalc.run()
|