2026-02-04 14:36:13 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
图片去重审核脚本 - DashScope 多模态版
|
|
|
|
|
|
采用: pHash预筛 + DashScope多模态Embedding + 异步批量处理
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import configparser
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import aiohttp
|
|
|
|
|
|
import imagehash
|
|
|
|
|
|
import base64
|
|
|
|
|
|
import time
|
|
|
|
|
|
import dashscope
|
|
|
|
|
|
from dashscope import MultiModalEmbedding
|
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
from typing import Optional, Tuple, List, Dict
|
|
|
|
|
|
|
|
|
|
|
|
import pymysql
|
|
|
|
|
|
from dashvector import Client, Doc
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageSimilarityChecker:
|
|
|
|
|
|
"""图片相似度检查器 - DashScope 多模态版"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config_path: str = 'config.ini'):
|
|
|
|
|
|
self.config = configparser.ConfigParser()
|
|
|
|
|
|
self.config.read(config_path, encoding='utf-8')
|
|
|
|
|
|
|
|
|
|
|
|
self._setup_logging()
|
|
|
|
|
|
|
|
|
|
|
|
# 连接
|
|
|
|
|
|
self.db_conn = None
|
|
|
|
|
|
self.dashvector_client = None
|
|
|
|
|
|
self.collection = None
|
|
|
|
|
|
|
|
|
|
|
|
# DashScope API
|
|
|
|
|
|
self.dashscope_api_key = self.config.get('dashscope', 'api_key')
|
|
|
|
|
|
dashscope.api_key = self.dashscope_api_key
|
|
|
|
|
|
|
|
|
|
|
|
# pHash 缓存 {phash_str: image_tag_id}
|
|
|
|
|
|
self.phash_cache: Dict[str, int] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 配置参数
|
|
|
|
|
|
self.image_cdn_base = self.config.get('image', 'cdn_base')
|
|
|
|
|
|
self.phash_threshold = self.config.getint('similarity', 'phash_threshold')
|
|
|
|
|
|
self.vector_threshold = self.config.getfloat('similarity', 'vector_threshold')
|
|
|
|
|
|
self.batch_size = self.config.getint('process', 'batch_size')
|
|
|
|
|
|
self.concurrent_downloads = self.config.getint('process', 'concurrent_downloads')
|
|
|
|
|
|
|
|
|
|
|
|
def _setup_logging(self):
|
|
|
|
|
|
log_level = self.config.get('process', 'log_level', fallback='INFO')
|
|
|
|
|
|
log_file = self.config.get('process', 'log_file', fallback='image_similarity.log')
|
|
|
|
|
|
|
|
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# 避免重复添加 handler
|
|
|
|
|
|
if not self.logger.handlers:
|
|
|
|
|
|
self.logger.setLevel(getattr(logging, log_level))
|
|
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
|
|
|
fh = logging.FileHandler(log_file, encoding='utf-8')
|
|
|
|
|
|
fh.setFormatter(formatter)
|
|
|
|
|
|
self.logger.addHandler(fh)
|
|
|
|
|
|
|
|
|
|
|
|
sh = logging.StreamHandler()
|
|
|
|
|
|
sh.setFormatter(formatter)
|
|
|
|
|
|
self.logger.addHandler(sh)
|
|
|
|
|
|
|
|
|
|
|
|
def connect_db(self):
|
|
|
|
|
|
"""连接数据库"""
|
|
|
|
|
|
self.db_conn = pymysql.connect(
|
|
|
|
|
|
host=self.config.get('database', 'host'),
|
|
|
|
|
|
port=self.config.getint('database', 'port'),
|
|
|
|
|
|
user=self.config.get('database', 'user'),
|
|
|
|
|
|
password=self.config.get('database', 'password'),
|
|
|
|
|
|
database=self.config.get('database', 'database'),
|
|
|
|
|
|
charset=self.config.get('database', 'charset'),
|
|
|
|
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
|
|
|
|
)
|
|
|
|
|
|
self.logger.info("数据库连接成功")
|
|
|
|
|
|
|
|
|
|
|
|
def connect_dashvector(self):
|
|
|
|
|
|
"""连接 DashVector"""
|
|
|
|
|
|
api_key = self.config.get('dashvector', 'api_key')
|
|
|
|
|
|
endpoint = self.config.get('dashvector', 'endpoint')
|
|
|
|
|
|
collection_name = self.config.get('dashvector', 'collection_name')
|
|
|
|
|
|
dimension = self.config.getint('dashvector', 'vector_dimension')
|
|
|
|
|
|
|
|
|
|
|
|
self.dashvector_client = Client(api_key=api_key, endpoint=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查集合是否存在
|
|
|
|
|
|
existing = self.dashvector_client.get(collection_name)
|
|
|
|
|
|
if existing is None:
|
|
|
|
|
|
self.logger.info(f"创建集合 {collection_name},维度 {dimension}")
|
|
|
|
|
|
self.dashvector_client.create(collection_name, dimension=dimension)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.logger.info(f"集合 {collection_name} 已存在,直接复用")
|
|
|
|
|
|
|
|
|
|
|
|
self.collection = self.dashvector_client.get(collection_name)
|
|
|
|
|
|
self.logger.info("DashVector 连接成功")
|
|
|
|
|
|
|
|
|
|
|
|
def get_image_embedding(self, image_url: str = None, image_base64: str = None, max_retries: int = 5) -> Optional[List[float]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
调用 DashScope 多模态 Embedding SDK 获取图片向量
|
|
|
|
|
|
支持传入 URL 或 base64,带限流退避重试
|
|
|
|
|
|
"""
|
|
|
|
|
|
for attempt in range(max_retries):
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 构建输入
|
|
|
|
|
|
if image_url:
|
|
|
|
|
|
input_data = [{'image': image_url}]
|
|
|
|
|
|
elif image_base64:
|
|
|
|
|
|
input_data = [{'image': f'data:image/jpeg;base64,{image_base64}'}]
|
|
|
|
|
|
else:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
resp = MultiModalEmbedding.call(
|
|
|
|
|
|
model='multimodal-embedding-v1',
|
|
|
|
|
|
input=input_data
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if resp.status_code == 200:
|
|
|
|
|
|
return resp.output['embeddings'][0]['embedding']
|
|
|
|
|
|
elif resp.status_code in (429, 403):
|
|
|
|
|
|
wait_time = 3 + attempt * 3
|
|
|
|
|
|
self.logger.warning(f"API 限流,等待 {wait_time} 秒后重试 ({attempt + 1}/{max_retries})...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.logger.warning(f"Embedding API 错误: {resp.status_code} - {resp.message}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.warning(f"Embedding API 异常: {e}")
|
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def load_phash_cache(self):
|
|
|
|
|
|
"""初始化 pHash 缓存"""
|
|
|
|
|
|
self.logger.info("pHash 缓存初始化完成")
|
|
|
|
|
|
|
|
|
|
|
|
def compute_phash(self, image: Image.Image) -> str:
|
|
|
|
|
|
"""计算感知哈希"""
|
|
|
|
|
|
return str(imagehash.phash(image))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_phash_duplicate(self, phash: str) -> Tuple[bool, Optional[int], Optional[int]]:
|
|
|
|
|
|
"""通过 pHash 检查是否重复"""
|
|
|
|
|
|
phash_obj = imagehash.hex_to_hash(phash)
|
|
|
|
|
|
|
|
|
|
|
|
for cached_phash, image_id in self.phash_cache.items():
|
|
|
|
|
|
cached_obj = imagehash.hex_to_hash(cached_phash)
|
|
|
|
|
|
distance = phash_obj - cached_obj
|
|
|
|
|
|
|
|
|
|
|
|
if distance <= self.phash_threshold:
|
|
|
|
|
|
return True, image_id, distance
|
|
|
|
|
|
|
|
|
|
|
|
return False, None, None
|
|
|
|
|
|
|
|
|
|
|
|
async def download_image_async(self, session: aiohttp.ClientSession,
|
|
|
|
|
|
image_id: int, url: str) -> Tuple[int, Optional[Image.Image], Optional[bytes]]:
|
|
|
|
|
|
"""异步下载单张图片"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
|
|
|
|
|
if response.status == 200:
|
|
|
|
|
|
data = await response.read()
|
|
|
|
|
|
image = Image.open(BytesIO(data)).convert('RGB')
|
|
|
|
|
|
return image_id, image, data
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.warning(f"下载失败 ID={image_id}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
return image_id, None, None
|
|
|
|
|
|
|
|
|
|
|
|
async def download_images_batch(self, image_records: List[dict]) -> Dict[int, Tuple[Image.Image, bytes, str]]:
|
|
|
|
|
|
"""批量异步下载图片"""
|
|
|
|
|
|
images = {}
|
|
|
|
|
|
|
|
|
|
|
|
connector = aiohttp.TCPConnector(limit=self.concurrent_downloads)
|
|
|
|
|
|
async with aiohttp.ClientSession(connector=connector) as session:
|
|
|
|
|
|
tasks = [
|
|
|
|
|
|
self.download_image_async(session, rec['id'], rec['image_url'])
|
|
|
|
|
|
for rec in image_records
|
|
|
|
|
|
]
|
|
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
for i, (image_id, image, data) in enumerate(results):
|
|
|
|
|
|
if image is not None:
|
|
|
|
|
|
url = image_records[i]['image_url']
|
|
|
|
|
|
images[image_id] = (image, data, url)
|
|
|
|
|
|
|
|
|
|
|
|
return images
|
|
|
|
|
|
|
|
|
|
|
|
def search_similar(self, features: List[float], exclude_id: int) -> Tuple[bool, Optional[int], Optional[float]]:
|
|
|
|
|
|
"""在 DashVector 中搜索相似图片(score越小越相似)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = self.collection.query(features, topk=3)
|
|
|
|
|
|
|
|
|
|
|
|
if results and results.output:
|
|
|
|
|
|
for doc in results.output:
|
|
|
|
|
|
similar_id = int(doc.id)
|
|
|
|
|
|
if similar_id == exclude_id:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# score 是距离,越小越相似,转换为相似度
|
|
|
|
|
|
similarity = 1.0 - doc.score
|
|
|
|
|
|
self.logger.info(f"搜索到: {similar_id}, 距离={doc.score:.4f}, 相似度={similarity:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
if similarity >= self.vector_threshold:
|
|
|
|
|
|
return True, similar_id, similarity
|
|
|
|
|
|
|
|
|
|
|
|
return False, None, None
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.warning(f"搜索失败: {e}")
|
|
|
|
|
|
return False, None, None
|
|
|
|
|
|
|
|
|
|
|
|
def upsert_to_dashvector(self, image_id: int, features: List[float]):
|
|
|
|
|
|
"""存入 DashVector"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
doc = Doc(id=str(image_id), vector=features)
|
|
|
|
|
|
result = self.collection.upsert([doc])
|
|
|
|
|
|
if result.code == 0:
|
|
|
|
|
|
self.logger.info(f"向量入库成功: {image_id}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.logger.warning(f"向量入库失败 ID={image_id}: code={result.code}, msg={result.message}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.warning(f"存入 DashVector 异常 ID={image_id}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def get_draft_images(self) -> List[dict]:
|
|
|
|
|
|
"""获取待处理图片"""
|
|
|
|
|
|
with self.db_conn.cursor() as cursor:
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
SELECT id, image_id, image_url, image_thumb_url, image_name
|
|
|
|
|
|
FROM ai_image_tags
|
|
|
|
|
|
WHERE status = 'draft' AND similarity = 'draft'
|
|
|
|
|
|
AND image_url != '' AND image_url IS NOT NULL
|
|
|
|
|
|
ORDER BY id ASC
|
|
|
|
|
|
LIMIT %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
cursor.execute(sql, (self.batch_size,))
|
|
|
|
|
|
return cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
|
|
def update_as_duplicate(self, image_id: int, similar_id: int, score: float):
|
|
|
|
|
|
"""更新为重复图片"""
|
|
|
|
|
|
with self.db_conn.cursor() as cursor:
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
UPDATE ai_image_tags
|
|
|
|
|
|
SET status = 'similarity',
|
|
|
|
|
|
similarity = 'yes',
|
|
|
|
|
|
similarity_image_tags_id = %s,
|
2026-02-05 19:01:38 +08:00
|
|
|
|
similarity_score = %s,
|
2026-02-04 14:36:13 +08:00
|
|
|
|
updated_at = NOW()
|
|
|
|
|
|
WHERE id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
cursor.execute(sql, (similar_id, score, image_id))
|
|
|
|
|
|
self.db_conn.commit()
|
|
|
|
|
|
self.logger.info(f"重复: {image_id} -> {similar_id} (分数={score:.4f})")
|
|
|
|
|
|
|
|
|
|
|
|
def update_as_unique(self, image_id: int):
|
|
|
|
|
|
"""更新为不重复图片"""
|
|
|
|
|
|
with self.db_conn.cursor() as cursor:
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
UPDATE ai_image_tags
|
|
|
|
|
|
SET status = 'tag_extension',
|
|
|
|
|
|
similarity = 'calc',
|
|
|
|
|
|
updated_at = NOW()
|
|
|
|
|
|
WHERE id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
cursor.execute(sql, (image_id,))
|
|
|
|
|
|
self.db_conn.commit()
|
|
|
|
|
|
self.logger.info(f"不重复: {image_id} -> tag_extension")
|
|
|
|
|
|
|
|
|
|
|
|
def update_as_failed(self, image_id: int, reason: str):
|
|
|
|
|
|
"""标记为处理失败"""
|
|
|
|
|
|
with self.db_conn.cursor() as cursor:
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
UPDATE ai_image_tags
|
|
|
|
|
|
SET status = 'draft',
|
|
|
|
|
|
similarity = 'recalc',
|
|
|
|
|
|
updated_at = NOW()
|
|
|
|
|
|
WHERE id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
cursor.execute(sql, (image_id,))
|
|
|
|
|
|
self.db_conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
def process_batch(self, image_records: List[dict]) -> Tuple[int, int, int]:
|
|
|
|
|
|
"""处理一批图片,返回 (重复数, 不重复数, 失败数)"""
|
|
|
|
|
|
if not image_records:
|
|
|
|
|
|
return 0, 0, 0
|
|
|
|
|
|
|
|
|
|
|
|
duplicates = 0
|
|
|
|
|
|
unique = 0
|
|
|
|
|
|
failed = 0
|
|
|
|
|
|
|
|
|
|
|
|
for rec in image_records:
|
|
|
|
|
|
image_id = rec['id']
|
|
|
|
|
|
# 检查是否有有效的图像URL
|
|
|
|
|
|
if not rec['image_url'] or rec['image_url'].strip() == '':
|
|
|
|
|
|
self.logger.warning(f"图像URL为空,跳过处理: {image_id}")
|
|
|
|
|
|
self.update_as_failed(image_id, "图像URL为空")
|
|
|
|
|
|
failed += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 拼接 CDN URL,使用原图
|
|
|
|
|
|
full_url = f"{self.image_cdn_base}{rec['image_url']}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 限流控制:免费版 2 QPS
|
|
|
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
self.logger.info(f"获取 Embedding: {image_id} -> {full_url}")
|
|
|
|
|
|
|
|
|
|
|
|
# 直接传 URL 给 DashScope
|
|
|
|
|
|
features = self.get_image_embedding(image_url=full_url)
|
|
|
|
|
|
|
|
|
|
|
|
if features is None:
|
|
|
|
|
|
self.logger.warning(f"Embedding 获取失败: {image_id}")
|
|
|
|
|
|
self.update_as_failed(image_id, "Embedding API 失败")
|
|
|
|
|
|
failed += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# DashVector 搜索相似图片
|
|
|
|
|
|
is_dup, similar_id, score = self.search_similar(features, image_id)
|
|
|
|
|
|
|
|
|
|
|
|
if is_dup:
|
|
|
|
|
|
self.update_as_duplicate(image_id, similar_id, score)
|
|
|
|
|
|
duplicates += 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.upsert_to_dashvector(image_id, features)
|
|
|
|
|
|
self.update_as_unique(image_id)
|
|
|
|
|
|
unique += 1
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.error(f"处理失败 {image_id}: {e}")
|
|
|
|
|
|
self.update_as_failed(image_id, str(e)[:200])
|
|
|
|
|
|
failed += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
return duplicates, unique, failed
|
|
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
|
"""运行主流程"""
|
|
|
|
|
|
self.logger.info("=" * 60)
|
|
|
|
|
|
self.logger.info("图片去重审核 - DashScope 多模态版")
|
|
|
|
|
|
self.logger.info("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化
|
|
|
|
|
|
self.connect_db()
|
|
|
|
|
|
self.connect_dashvector()
|
|
|
|
|
|
self.load_phash_cache()
|
|
|
|
|
|
|
|
|
|
|
|
total_duplicates = 0
|
|
|
|
|
|
total_unique = 0
|
|
|
|
|
|
batch_num = 0
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
images = self.get_draft_images()
|
|
|
|
|
|
|
|
|
|
|
|
if not images:
|
2026-02-05 21:32:28 +08:00
|
|
|
|
self.logger.info("没有待处理的图片,等待 10 秒后继续检查...")
|
|
|
|
|
|
time.sleep(10)
|
|
|
|
|
|
continue
|
2026-02-04 14:36:13 +08:00
|
|
|
|
|
|
|
|
|
|
batch_num += 1
|
|
|
|
|
|
self.logger.info(f"\n--- 批次 {batch_num}: {len(images)} 张 ---")
|
|
|
|
|
|
|
|
|
|
|
|
dup, uniq, fail = self.process_batch(images)
|
|
|
|
|
|
total_duplicates += dup
|
|
|
|
|
|
total_unique += uniq
|
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f"批次结果: 重复={dup}, 不重复={uniq}, 失败={fail}")
|
2026-02-05 19:01:38 +08:00
|
|
|
|
|
|
|
|
|
|
# 批次间休息,避免数据库连接问题
|
|
|
|
|
|
time.sleep(1)
|
2026-02-04 14:36:13 +08:00
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if self.db_conn:
|
|
|
|
|
|
self.db_conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
self.logger.info("=" * 60)
|
|
|
|
|
|
self.logger.info(f"完成! 总重复: {total_duplicates}, 总不重复: {total_unique}")
|
|
|
|
|
|
self.logger.info("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
checker = ImageSimilarityChecker('config.ini')
|
|
|
|
|
|
checker.run()
|