Files
ai_wht_B/dashvector_get_similar_topic.py
“shengyudong” 5a384b694e 2026-1-6
2026-01-06 14:18:39 +08:00

205 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DashVector向量检索模块
用于查找相似topic
"""
import dashscope
from dashvector import Client
import logging
from typing import List, Dict, Optional
import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from log_config import setup_logger
# 配置日志记录器
logger = setup_logger(
name='dashvector_similar',
log_file='logs/dashvector_similar.log',
error_log_file='logs/dashvector_similar_error.log',
level=logging.INFO,
console_output=True
)
# === 配置 ===
DASHVECTOR_API_KEY = 'sk-55x6oBXypSlPHQ8NvPHfyBABcMIMUE0407A0FCC2A11F0B9C802831A608ABB'
DASHVECTOR_ENDPOINT = 'vrs-cn-2ml4jm42o0001r.dashvector.cn-hangzhou.aliyuncs.com'
dashscope.api_key = 'sk-6d22dd845a624d9c92a821d24a50e2e8'
collection_name = "ai_articles_collection" # 使用包含完整元数据的集合
def get_embedding(text: str) -> List[float]:
"""
生成文本的向量embedding
Args:
text: 输入文本
Returns:
List[float]: 向量embedding
"""
try:
logger.info(f"开始生成embedding文本: {text}")
resp = dashscope.TextEmbedding.call(
model='text-embedding-v2',
input=text
)
if resp.status_code == 200:
embedding = resp.output['embeddings'][0]['embedding']
logger.info(f"生成embedding成功向量维度: {len(embedding)}")
return embedding
else:
error_msg = f"DashScope API 错误: {resp.code} - {resp.message}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
error_msg = f"生成 embedding 失败: {str(e)}"
logger.error(error_msg)
raise Exception(error_msg)
# === 初始化客户端 ===
client = None
collection = None
def init_dashvector_client():
"""
初始化DashVector客户端和集合
"""
global client, collection
try:
logger.info("开始初始化DashVector客户端")
client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_ENDPOINT)
logger.info("DashVector客户端创建成功")
# 获取已存在的集合
collection = client.get(collection_name)
if collection is None:
error_msg = f"集合 '{collection_name}' 不存在!请先确保它已创建并插入数据。"
logger.error(error_msg)
raise Exception(error_msg)
logger.info(f"成功连接到集合: {collection_name}")
return True
except Exception as e:
logger.error(f"初始化DashVector客户端失败: {e}")
return False
# === 查询函数 ===
def search_chinese(query_text: str, topk: int = 3, similarity_threshold: float = 0.5) -> List[Dict]:
"""
检索相似topic
Args:
query_text: 查询文本
topk: 返回top k个结果
similarity_threshold: 相似度阈值默认0.5
Returns:
List[Dict]: 相似topic列表格式为 [{"id": str, "title": str, "created_at": str, "similar": float}]
"""
global collection
# 确保客户端已初始化
if collection is None:
logger.warning("DashVector客户端未初始化尝试初始化")
if not init_dashvector_client():
logger.error("初始化失败,无法执行检索")
return []
logger.info(f"开始检索相似topic查询文本: '{query_text}', topk={topk}, 相似度阈值={similarity_threshold}")
try:
# 生成查询向量
query_vec = get_embedding(query_text)
except Exception as e:
logger.error(f"生成查询向量失败: {e}")
return []
# 执行向量检索增加返回字段id, title, created_at
try:
logger.info(f"开始向量检索topk={topk}")
rets = collection.query(vector=query_vec, topk=topk, output_fields=["id", "title", "created_at"])
except Exception as e:
logger.error(f"向量检索异常: {e}")
return []
if rets.code != 0:
logger.error(f"向量检索失败: {rets.message}")
return []
results = rets.output
if not results:
logger.warning("未找到相关结果")
return []
logger.info(f"检索成功,原始返回 {len(results)} 条结果")
# 过滤相似度低于阈值的结果
filtered_results = []
for i, doc in enumerate(results, 1):
score = doc.score
article_id = doc.fields.get('id', '')
title = doc.fields.get('title', '')
created_at = doc.fields.get('created_at', '')
logger.info(f" 结果[{i}] 相似度: {score:.4f} | ID: {article_id} | 标题: {title} | 创建时间: {created_at}")
# 相似度阈值过滤注意score越小越相似0表示最相似
if score <= similarity_threshold:
filtered_results.append({
"id": article_id,
"title": title,
"created_at": created_at,
"similar": score
})
logger.info(f" 结果[{i}] 通过相似度过滤({score:.4f} <= {similarity_threshold}")
else:
logger.info(f" 结果[{i}] 未通过相似度过滤({score:.4f} > {similarity_threshold}")
# 最多返回2条
if len(filtered_results) > 2:
filtered_results = filtered_results[:2]
logger.info(f"限制返回结果数量为2条")
logger.info(f"最终返回 {len(filtered_results)} 条相似topic")
return filtered_results
# === 测试查询 ===
if __name__ == "__main__":
# 初始化客户端
if not init_dashvector_client():
logger.error("初始化失败,退出")
sys.exit(1)
test_queries = [
"早泄怎么治疗?",
"肾虚吃什么药?",
"前列腺肥大的原因",
"阳痿能治好吗?",
"早泄可以吃哪些中药?",
"前列腺增生肥大能治好吗"
]
for q in test_queries:
print(f"\n{'='*60}")
print(f"🔍 测试查询: {q}")
print(f"{'='*60}")
results = search_chinese(q, topk=3, similarity_threshold=0.5)
if results:
print(f"\n✅ 找到 {len(results)} 条相关结果:")
for i, item in enumerate(results, 1):
print(f" [{i}] 相似度: {item['similar']:.4f} | ID: {item.get('id', 'N/A')} | 标题: {item.get('title', 'N/A')} | 创建时间: {item.get('created_at', 'N/A')}")
else:
print("⚠️ 未找到符合条件的相关结果")
print(f"\n{'='*60}")
print("测试完成")
print(f"{'='*60}")