Files
ai_wht_B/dashvector_get_similar_topic.py

205 lines
6.6 KiB
Python
Raw Normal View History

2026-01-06 14:18:39 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DashVector向量检索模块
用于查找相似topic
"""
import dashscope
from dashvector import Client
import logging
from typing import List, Dict, Optional
import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from log_config import setup_logger
# 配置日志记录器
logger = setup_logger(
name='dashvector_similar',
log_file='logs/dashvector_similar.log',
error_log_file='logs/dashvector_similar_error.log',
level=logging.INFO,
console_output=True
)
# === 配置 ===
DASHVECTOR_API_KEY = 'sk-55x6oBXypSlPHQ8NvPHfyBABcMIMUE0407A0FCC2A11F0B9C802831A608ABB'
DASHVECTOR_ENDPOINT = 'vrs-cn-2ml4jm42o0001r.dashvector.cn-hangzhou.aliyuncs.com'
dashscope.api_key = 'sk-6d22dd845a624d9c92a821d24a50e2e8'
collection_name = "ai_articles_collection" # 使用包含完整元数据的集合
def get_embedding(text: str) -> List[float]:
"""
生成文本的向量embedding
Args:
text: 输入文本
Returns:
List[float]: 向量embedding
"""
try:
logger.info(f"开始生成embedding文本: {text}")
resp = dashscope.TextEmbedding.call(
model='text-embedding-v2',
input=text
)
if resp.status_code == 200:
embedding = resp.output['embeddings'][0]['embedding']
logger.info(f"生成embedding成功向量维度: {len(embedding)}")
return embedding
else:
error_msg = f"DashScope API 错误: {resp.code} - {resp.message}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
error_msg = f"生成 embedding 失败: {str(e)}"
logger.error(error_msg)
raise Exception(error_msg)
# === 初始化客户端 ===
client = None
collection = None
def init_dashvector_client():
"""
初始化DashVector客户端和集合
"""
global client, collection
try:
logger.info("开始初始化DashVector客户端")
client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_ENDPOINT)
logger.info("DashVector客户端创建成功")
# 获取已存在的集合
collection = client.get(collection_name)
if collection is None:
error_msg = f"集合 '{collection_name}' 不存在!请先确保它已创建并插入数据。"
logger.error(error_msg)
raise Exception(error_msg)
logger.info(f"成功连接到集合: {collection_name}")
return True
except Exception as e:
logger.error(f"初始化DashVector客户端失败: {e}")
return False
# === 查询函数 ===
def search_chinese(query_text: str, topk: int = 3, similarity_threshold: float = 0.5) -> List[Dict]:
"""
检索相似topic
Args:
query_text: 查询文本
topk: 返回top k个结果
similarity_threshold: 相似度阈值默认0.5
Returns:
List[Dict]: 相似topic列表格式为 [{"id": str, "title": str, "created_at": str, "similar": float}]
"""
global collection
# 确保客户端已初始化
if collection is None:
logger.warning("DashVector客户端未初始化尝试初始化")
if not init_dashvector_client():
logger.error("初始化失败,无法执行检索")
return []
logger.info(f"开始检索相似topic查询文本: '{query_text}', topk={topk}, 相似度阈值={similarity_threshold}")
try:
# 生成查询向量
query_vec = get_embedding(query_text)
except Exception as e:
logger.error(f"生成查询向量失败: {e}")
return []
# 执行向量检索增加返回字段id, title, created_at
try:
logger.info(f"开始向量检索topk={topk}")
rets = collection.query(vector=query_vec, topk=topk, output_fields=["id", "title", "created_at"])
except Exception as e:
logger.error(f"向量检索异常: {e}")
return []
if rets.code != 0:
logger.error(f"向量检索失败: {rets.message}")
return []
results = rets.output
if not results:
logger.warning("未找到相关结果")
return []
logger.info(f"检索成功,原始返回 {len(results)} 条结果")
# 过滤相似度低于阈值的结果
filtered_results = []
for i, doc in enumerate(results, 1):
score = doc.score
article_id = doc.fields.get('id', '')
title = doc.fields.get('title', '')
created_at = doc.fields.get('created_at', '')
logger.info(f" 结果[{i}] 相似度: {score:.4f} | ID: {article_id} | 标题: {title} | 创建时间: {created_at}")
# 相似度阈值过滤注意score越小越相似0表示最相似
if score <= similarity_threshold:
filtered_results.append({
"id": article_id,
"title": title,
"created_at": created_at,
"similar": score
})
logger.info(f" 结果[{i}] 通过相似度过滤({score:.4f} <= {similarity_threshold}")
else:
logger.info(f" 结果[{i}] 未通过相似度过滤({score:.4f} > {similarity_threshold}")
# 最多返回2条
if len(filtered_results) > 2:
filtered_results = filtered_results[:2]
logger.info(f"限制返回结果数量为2条")
logger.info(f"最终返回 {len(filtered_results)} 条相似topic")
return filtered_results
# === 测试查询 ===
if __name__ == "__main__":
# 初始化客户端
if not init_dashvector_client():
logger.error("初始化失败,退出")
sys.exit(1)
test_queries = [
"早泄怎么治疗?",
"肾虚吃什么药?",
"前列腺肥大的原因",
"阳痿能治好吗?",
"早泄可以吃哪些中药?",
"前列腺增生肥大能治好吗"
]
for q in test_queries:
print(f"\n{'='*60}")
print(f"🔍 测试查询: {q}")
print(f"{'='*60}")
results = search_chinese(q, topk=3, similarity_threshold=0.5)
if results:
print(f"\n✅ 找到 {len(results)} 条相关结果:")
for i, item in enumerate(results, 1):
print(f" [{i}] 相似度: {item['similar']:.4f} | ID: {item.get('id', 'N/A')} | 标题: {item.get('title', 'N/A')} | 创建时间: {item.get('created_at', 'N/A')}")
else:
print("⚠️ 未找到符合条件的相关结果")
print(f"\n{'='*60}")
print("测试完成")
print(f"{'='*60}")