Files
ai_Image_review/image_similarity_check.py

391 lines
15 KiB
Python
Raw Normal View History

2026-02-04 14:36:13 +08:00
# -*- coding: utf-8 -*-
"""
图片去重审核脚本 - DashScope 多模态版
采用: pHash预筛 + DashScope多模态Embedding + 异步批量处理
"""
import configparser
import logging
import asyncio
import aiohttp
import imagehash
import base64
import time
import dashscope
from dashscope import MultiModalEmbedding
from io import BytesIO
from typing import Optional, Tuple, List, Dict
import pymysql
from dashvector import Client, Doc
from PIL import Image
class ImageSimilarityChecker:
"""图片相似度检查器 - DashScope 多模态版"""
def __init__(self, config_path: str = 'config.ini'):
self.config = configparser.ConfigParser()
self.config.read(config_path, encoding='utf-8')
self._setup_logging()
# 连接
self.db_conn = None
self.dashvector_client = None
self.collection = None
# DashScope API
self.dashscope_api_key = self.config.get('dashscope', 'api_key')
dashscope.api_key = self.dashscope_api_key
# pHash 缓存 {phash_str: image_tag_id}
self.phash_cache: Dict[str, int] = {}
# 配置参数
self.image_cdn_base = self.config.get('image', 'cdn_base')
self.phash_threshold = self.config.getint('similarity', 'phash_threshold')
self.vector_threshold = self.config.getfloat('similarity', 'vector_threshold')
self.batch_size = self.config.getint('process', 'batch_size')
self.concurrent_downloads = self.config.getint('process', 'concurrent_downloads')
def _setup_logging(self):
log_level = self.config.get('process', 'log_level', fallback='INFO')
log_file = self.config.get('process', 'log_file', fallback='image_similarity.log')
self.logger = logging.getLogger(__name__)
# 避免重复添加 handler
if not self.logger.handlers:
self.logger.setLevel(getattr(logging, log_level))
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_file, encoding='utf-8')
fh.setFormatter(formatter)
self.logger.addHandler(fh)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
self.logger.addHandler(sh)
def connect_db(self):
"""连接数据库"""
self.db_conn = pymysql.connect(
host=self.config.get('database', 'host'),
port=self.config.getint('database', 'port'),
user=self.config.get('database', 'user'),
password=self.config.get('database', 'password'),
database=self.config.get('database', 'database'),
charset=self.config.get('database', 'charset'),
cursorclass=pymysql.cursors.DictCursor
)
self.logger.info("数据库连接成功")
def connect_dashvector(self):
"""连接 DashVector"""
api_key = self.config.get('dashvector', 'api_key')
endpoint = self.config.get('dashvector', 'endpoint')
collection_name = self.config.get('dashvector', 'collection_name')
dimension = self.config.getint('dashvector', 'vector_dimension')
self.dashvector_client = Client(api_key=api_key, endpoint=endpoint)
# 检查集合是否存在
existing = self.dashvector_client.get(collection_name)
if existing is None:
self.logger.info(f"创建集合 {collection_name},维度 {dimension}")
self.dashvector_client.create(collection_name, dimension=dimension)
else:
self.logger.info(f"集合 {collection_name} 已存在,直接复用")
self.collection = self.dashvector_client.get(collection_name)
self.logger.info("DashVector 连接成功")
def get_image_embedding(self, image_url: str = None, image_base64: str = None, max_retries: int = 5) -> Optional[List[float]]:
"""
调用 DashScope 多模态 Embedding SDK 获取图片向量
支持传入 URL base64带限流退避重试
"""
for attempt in range(max_retries):
try:
# 构建输入
if image_url:
input_data = [{'image': image_url}]
elif image_base64:
input_data = [{'image': f'data:image/jpeg;base64,{image_base64}'}]
else:
return None
resp = MultiModalEmbedding.call(
model='multimodal-embedding-v1',
input=input_data
)
if resp.status_code == 200:
return resp.output['embeddings'][0]['embedding']
elif resp.status_code in (429, 403):
wait_time = 3 + attempt * 3
self.logger.warning(f"API 限流,等待 {wait_time} 秒后重试 ({attempt + 1}/{max_retries})...")
time.sleep(wait_time)
else:
self.logger.warning(f"Embedding API 错误: {resp.status_code} - {resp.message}")
return None
except Exception as e:
self.logger.warning(f"Embedding API 异常: {e}")
time.sleep(2)
return None
def load_phash_cache(self):
"""初始化 pHash 缓存"""
self.logger.info("pHash 缓存初始化完成")
def compute_phash(self, image: Image.Image) -> str:
"""计算感知哈希"""
return str(imagehash.phash(image))
def check_phash_duplicate(self, phash: str) -> Tuple[bool, Optional[int], Optional[int]]:
"""通过 pHash 检查是否重复"""
phash_obj = imagehash.hex_to_hash(phash)
for cached_phash, image_id in self.phash_cache.items():
cached_obj = imagehash.hex_to_hash(cached_phash)
distance = phash_obj - cached_obj
if distance <= self.phash_threshold:
return True, image_id, distance
return False, None, None
async def download_image_async(self, session: aiohttp.ClientSession,
image_id: int, url: str) -> Tuple[int, Optional[Image.Image], Optional[bytes]]:
"""异步下载单张图片"""
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
if response.status == 200:
data = await response.read()
image = Image.open(BytesIO(data)).convert('RGB')
return image_id, image, data
except Exception as e:
self.logger.warning(f"下载失败 ID={image_id}: {e}")
return image_id, None, None
async def download_images_batch(self, image_records: List[dict]) -> Dict[int, Tuple[Image.Image, bytes, str]]:
"""批量异步下载图片"""
images = {}
connector = aiohttp.TCPConnector(limit=self.concurrent_downloads)
async with aiohttp.ClientSession(connector=connector) as session:
tasks = [
self.download_image_async(session, rec['id'], rec['image_url'])
for rec in image_records
]
results = await asyncio.gather(*tasks)
for i, (image_id, image, data) in enumerate(results):
if image is not None:
url = image_records[i]['image_url']
images[image_id] = (image, data, url)
return images
def search_similar(self, features: List[float], exclude_id: int) -> Tuple[bool, Optional[int], Optional[float]]:
"""在 DashVector 中搜索相似图片score越小越相似"""
try:
results = self.collection.query(features, topk=3)
if results and results.output:
for doc in results.output:
similar_id = int(doc.id)
if similar_id == exclude_id:
continue
# score 是距离,越小越相似,转换为相似度
similarity = 1.0 - doc.score
self.logger.info(f"搜索到: {similar_id}, 距离={doc.score:.4f}, 相似度={similarity:.4f}")
if similarity >= self.vector_threshold:
return True, similar_id, similarity
return False, None, None
except Exception as e:
self.logger.warning(f"搜索失败: {e}")
return False, None, None
def upsert_to_dashvector(self, image_id: int, features: List[float]):
"""存入 DashVector"""
try:
doc = Doc(id=str(image_id), vector=features)
result = self.collection.upsert([doc])
if result.code == 0:
self.logger.info(f"向量入库成功: {image_id}")
else:
self.logger.warning(f"向量入库失败 ID={image_id}: code={result.code}, msg={result.message}")
except Exception as e:
self.logger.warning(f"存入 DashVector 异常 ID={image_id}: {e}")
def get_draft_images(self) -> List[dict]:
"""获取待处理图片"""
with self.db_conn.cursor() as cursor:
sql = """
SELECT id, image_id, image_url, image_thumb_url, image_name
FROM ai_image_tags
WHERE status = 'draft' AND similarity = 'draft'
AND image_url != '' AND image_url IS NOT NULL
ORDER BY id ASC
LIMIT %s
"""
cursor.execute(sql, (self.batch_size,))
return cursor.fetchall()
def update_as_duplicate(self, image_id: int, similar_id: int, score: float):
"""更新为重复图片"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET status = 'similarity',
similarity = 'yes',
similarity_image_tags_id = %s,
similarity_score = %s,
2026-02-04 14:36:13 +08:00
updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (similar_id, score, image_id))
self.db_conn.commit()
self.logger.info(f"重复: {image_id} -> {similar_id} (分数={score:.4f})")
def update_as_unique(self, image_id: int):
"""更新为不重复图片"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET status = 'tag_extension',
similarity = 'calc',
updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (image_id,))
self.db_conn.commit()
self.logger.info(f"不重复: {image_id} -> tag_extension")
def update_as_failed(self, image_id: int, reason: str):
"""标记为处理失败"""
with self.db_conn.cursor() as cursor:
sql = """
UPDATE ai_image_tags
SET status = 'draft',
similarity = 'recalc',
updated_at = NOW()
WHERE id = %s
"""
cursor.execute(sql, (image_id,))
self.db_conn.commit()
def process_batch(self, image_records: List[dict]) -> Tuple[int, int, int]:
"""处理一批图片,返回 (重复数, 不重复数, 失败数)"""
if not image_records:
return 0, 0, 0
duplicates = 0
unique = 0
failed = 0
for rec in image_records:
image_id = rec['id']
# 检查是否有有效的图像URL
if not rec['image_url'] or rec['image_url'].strip() == '':
self.logger.warning(f"图像URL为空跳过处理: {image_id}")
self.update_as_failed(image_id, "图像URL为空")
failed += 1
continue
# 拼接 CDN URL使用原图
full_url = f"{self.image_cdn_base}{rec['image_url']}"
try:
# 限流控制:免费版 2 QPS
time.sleep(0.5)
self.logger.info(f"获取 Embedding: {image_id} -> {full_url}")
# 直接传 URL 给 DashScope
features = self.get_image_embedding(image_url=full_url)
if features is None:
self.logger.warning(f"Embedding 获取失败: {image_id}")
self.update_as_failed(image_id, "Embedding API 失败")
failed += 1
continue
# DashVector 搜索相似图片
is_dup, similar_id, score = self.search_similar(features, image_id)
if is_dup:
self.update_as_duplicate(image_id, similar_id, score)
duplicates += 1
else:
self.upsert_to_dashvector(image_id, features)
self.update_as_unique(image_id)
unique += 1
except Exception as e:
self.logger.error(f"处理失败 {image_id}: {e}")
self.update_as_failed(image_id, str(e)[:200])
failed += 1
continue
return duplicates, unique, failed
def run(self):
"""运行主流程"""
self.logger.info("=" * 60)
self.logger.info("图片去重审核 - DashScope 多模态版")
self.logger.info("=" * 60)
# 初始化
self.connect_db()
self.connect_dashvector()
self.load_phash_cache()
total_duplicates = 0
total_unique = 0
batch_num = 0
try:
while True:
images = self.get_draft_images()
if not images:
self.logger.info("没有待处理的图片,等待 10 秒后继续检查...")
time.sleep(10)
continue
2026-02-04 14:36:13 +08:00
batch_num += 1
self.logger.info(f"\n--- 批次 {batch_num}: {len(images)} 张 ---")
dup, uniq, fail = self.process_batch(images)
total_duplicates += dup
total_unique += uniq
self.logger.info(f"批次结果: 重复={dup}, 不重复={uniq}, 失败={fail}")
# 批次间休息,避免数据库连接问题
time.sleep(1)
2026-02-04 14:36:13 +08:00
finally:
if self.db_conn:
self.db_conn.close()
self.logger.info("=" * 60)
self.logger.info(f"完成! 总重复: {total_duplicates}, 总不重复: {total_unique}")
self.logger.info("=" * 60)
if __name__ == '__main__':
checker = ImageSimilarityChecker('config.ini')
checker.run()