205 lines
6.6 KiB
Python
205 lines
6.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
DashVector向量检索模块
|
||
用于查找相似topic
|
||
"""
|
||
|
||
import dashscope
|
||
from dashvector import Client
|
||
import logging
|
||
from typing import List, Dict, Optional
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from log_config import setup_logger
|
||
|
||
# 配置日志记录器
|
||
logger = setup_logger(
|
||
name='dashvector_similar',
|
||
log_file='logs/dashvector_similar.log',
|
||
error_log_file='logs/dashvector_similar_error.log',
|
||
level=logging.INFO,
|
||
console_output=True
|
||
)
|
||
|
||
# === 配置 ===
|
||
DASHVECTOR_API_KEY = 'sk-55x6oBXypSlPHQ8NvPHfyBABcMIMUE0407A0FCC2A11F0B9C802831A608ABB'
|
||
DASHVECTOR_ENDPOINT = 'vrs-cn-2ml4jm42o0001r.dashvector.cn-hangzhou.aliyuncs.com'
|
||
dashscope.api_key = 'sk-6d22dd845a624d9c92a821d24a50e2e8'
|
||
|
||
collection_name = "ai_articles_collection" # 使用包含完整元数据的集合
|
||
|
||
def get_embedding(text: str) -> List[float]:
|
||
"""
|
||
生成文本的向量embedding
|
||
|
||
Args:
|
||
text: 输入文本
|
||
|
||
Returns:
|
||
List[float]: 向量embedding
|
||
"""
|
||
try:
|
||
logger.info(f"开始生成embedding,文本: {text}")
|
||
resp = dashscope.TextEmbedding.call(
|
||
model='text-embedding-v2',
|
||
input=text
|
||
)
|
||
if resp.status_code == 200:
|
||
embedding = resp.output['embeddings'][0]['embedding']
|
||
logger.info(f"生成embedding成功,向量维度: {len(embedding)}")
|
||
return embedding
|
||
else:
|
||
error_msg = f"DashScope API 错误: {resp.code} - {resp.message}"
|
||
logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
except Exception as e:
|
||
error_msg = f"生成 embedding 失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
# === 初始化客户端 ===
|
||
client = None
|
||
collection = None
|
||
|
||
def init_dashvector_client():
|
||
"""
|
||
初始化DashVector客户端和集合
|
||
"""
|
||
global client, collection
|
||
|
||
try:
|
||
logger.info("开始初始化DashVector客户端")
|
||
client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_ENDPOINT)
|
||
logger.info("DashVector客户端创建成功")
|
||
|
||
# 获取已存在的集合
|
||
collection = client.get(collection_name)
|
||
if collection is None:
|
||
error_msg = f"集合 '{collection_name}' 不存在!请先确保它已创建并插入数据。"
|
||
logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
logger.info(f"成功连接到集合: {collection_name}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"初始化DashVector客户端失败: {e}")
|
||
return False
|
||
|
||
# === 查询函数 ===
|
||
def search_chinese(query_text: str, topk: int = 3, similarity_threshold: float = 0.5) -> List[Dict]:
|
||
"""
|
||
检索相似topic
|
||
|
||
Args:
|
||
query_text: 查询文本
|
||
topk: 返回top k个结果
|
||
similarity_threshold: 相似度阈值,默认0.5
|
||
|
||
Returns:
|
||
List[Dict]: 相似topic列表,格式为 [{"id": str, "title": str, "created_at": str, "similar": float}]
|
||
"""
|
||
global collection
|
||
|
||
# 确保客户端已初始化
|
||
if collection is None:
|
||
logger.warning("DashVector客户端未初始化,尝试初始化")
|
||
if not init_dashvector_client():
|
||
logger.error("初始化失败,无法执行检索")
|
||
return []
|
||
|
||
logger.info(f"开始检索相似topic,查询文本: '{query_text}', topk={topk}, 相似度阈值={similarity_threshold}")
|
||
|
||
try:
|
||
# 生成查询向量
|
||
query_vec = get_embedding(query_text)
|
||
except Exception as e:
|
||
logger.error(f"生成查询向量失败: {e}")
|
||
return []
|
||
|
||
# 执行向量检索,增加返回字段:id, title, created_at
|
||
try:
|
||
logger.info(f"开始向量检索,topk={topk}")
|
||
rets = collection.query(vector=query_vec, topk=topk, output_fields=["id", "title", "created_at"])
|
||
except Exception as e:
|
||
logger.error(f"向量检索异常: {e}")
|
||
return []
|
||
|
||
if rets.code != 0:
|
||
logger.error(f"向量检索失败: {rets.message}")
|
||
return []
|
||
|
||
results = rets.output
|
||
if not results:
|
||
logger.warning("未找到相关结果")
|
||
return []
|
||
|
||
logger.info(f"检索成功,原始返回 {len(results)} 条结果")
|
||
|
||
# 过滤相似度低于阈值的结果
|
||
filtered_results = []
|
||
for i, doc in enumerate(results, 1):
|
||
score = doc.score
|
||
article_id = doc.fields.get('id', '')
|
||
title = doc.fields.get('title', '')
|
||
created_at = doc.fields.get('created_at', '')
|
||
|
||
logger.info(f" 结果[{i}] 相似度: {score:.4f} | ID: {article_id} | 标题: {title} | 创建时间: {created_at}")
|
||
|
||
# 相似度阈值过滤(注意:score越小越相似,0表示最相似)
|
||
if score <= similarity_threshold:
|
||
filtered_results.append({
|
||
"id": article_id,
|
||
"title": title,
|
||
"created_at": created_at,
|
||
"similar": score
|
||
})
|
||
logger.info(f" 结果[{i}] 通过相似度过滤({score:.4f} <= {similarity_threshold})")
|
||
else:
|
||
logger.info(f" 结果[{i}] 未通过相似度过滤({score:.4f} > {similarity_threshold})")
|
||
|
||
# 最多返回2条
|
||
if len(filtered_results) > 2:
|
||
filtered_results = filtered_results[:2]
|
||
logger.info(f"限制返回结果数量为2条")
|
||
|
||
logger.info(f"最终返回 {len(filtered_results)} 条相似topic")
|
||
return filtered_results
|
||
|
||
# === 测试查询 ===
|
||
if __name__ == "__main__":
|
||
# 初始化客户端
|
||
if not init_dashvector_client():
|
||
logger.error("初始化失败,退出")
|
||
sys.exit(1)
|
||
|
||
test_queries = [
|
||
"早泄怎么治疗?",
|
||
"肾虚吃什么药?",
|
||
"前列腺肥大的原因",
|
||
"阳痿能治好吗?",
|
||
"早泄可以吃哪些中药?",
|
||
"前列腺增生肥大能治好吗"
|
||
]
|
||
|
||
for q in test_queries:
|
||
print(f"\n{'='*60}")
|
||
print(f"🔍 测试查询: {q}")
|
||
print(f"{'='*60}")
|
||
|
||
results = search_chinese(q, topk=3, similarity_threshold=0.5)
|
||
|
||
if results:
|
||
print(f"\n✅ 找到 {len(results)} 条相关结果:")
|
||
for i, item in enumerate(results, 1):
|
||
print(f" [{i}] 相似度: {item['similar']:.4f} | ID: {item.get('id', 'N/A')} | 标题: {item.get('title', 'N/A')} | 创建时间: {item.get('created_at', 'N/A')}")
|
||
else:
|
||
print("⚠️ 未找到符合条件的相关结果")
|
||
|
||
print(f"\n{'='*60}")
|
||
print("测试完成")
|
||
print(f"{'='*60}") |