#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ DashVector向量检索模块 用于查找相似topic """ import dashscope from dashvector import Client import logging from typing import List, Dict, Optional import sys import os # 添加项目根目录到Python路径 sys.path.append(os.path.dirname(os.path.abspath(__file__))) from log_config import setup_logger # 配置日志记录器 logger = setup_logger( name='dashvector_similar', log_file='logs/dashvector_similar.log', error_log_file='logs/dashvector_similar_error.log', level=logging.INFO, console_output=True ) # === 配置 === DASHVECTOR_API_KEY = 'sk-55x6oBXypSlPHQ8NvPHfyBABcMIMUE0407A0FCC2A11F0B9C802831A608ABB' DASHVECTOR_ENDPOINT = 'vrs-cn-2ml4jm42o0001r.dashvector.cn-hangzhou.aliyuncs.com' dashscope.api_key = 'sk-6d22dd845a624d9c92a821d24a50e2e8' collection_name = "ai_articles_collection" # 使用包含完整元数据的集合 def get_embedding(text: str) -> List[float]: """ 生成文本的向量embedding Args: text: 输入文本 Returns: List[float]: 向量embedding """ try: logger.info(f"开始生成embedding,文本: {text}") resp = dashscope.TextEmbedding.call( model='text-embedding-v2', input=text ) if resp.status_code == 200: embedding = resp.output['embeddings'][0]['embedding'] logger.info(f"生成embedding成功,向量维度: {len(embedding)}") return embedding else: error_msg = f"DashScope API 错误: {resp.code} - {resp.message}" logger.error(error_msg) raise Exception(error_msg) except Exception as e: error_msg = f"生成 embedding 失败: {str(e)}" logger.error(error_msg) raise Exception(error_msg) # === 初始化客户端 === client = None collection = None def init_dashvector_client(): """ 初始化DashVector客户端和集合 """ global client, collection try: logger.info("开始初始化DashVector客户端") client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_ENDPOINT) logger.info("DashVector客户端创建成功") # 获取已存在的集合 collection = client.get(collection_name) if collection is None: error_msg = f"集合 '{collection_name}' 不存在!请先确保它已创建并插入数据。" logger.error(error_msg) raise Exception(error_msg) logger.info(f"成功连接到集合: {collection_name}") return True except Exception as e: logger.error(f"初始化DashVector客户端失败: {e}") return False # === 查询函数 === def search_chinese(query_text: str, topk: int = 3, similarity_threshold: float = 0.5) -> List[Dict]: """ 检索相似topic Args: query_text: 查询文本 topk: 返回top k个结果 similarity_threshold: 相似度阈值,默认0.5 Returns: List[Dict]: 相似topic列表,格式为 [{"id": str, "title": str, "created_at": str, "similar": float}] """ global collection # 确保客户端已初始化 if collection is None: logger.warning("DashVector客户端未初始化,尝试初始化") if not init_dashvector_client(): logger.error("初始化失败,无法执行检索") return [] logger.info(f"开始检索相似topic,查询文本: '{query_text}', topk={topk}, 相似度阈值={similarity_threshold}") try: # 生成查询向量 query_vec = get_embedding(query_text) except Exception as e: logger.error(f"生成查询向量失败: {e}") return [] # 执行向量检索,增加返回字段:id, title, created_at try: logger.info(f"开始向量检索,topk={topk}") rets = collection.query(vector=query_vec, topk=topk, output_fields=["id", "title", "created_at"]) except Exception as e: logger.error(f"向量检索异常: {e}") return [] if rets.code != 0: logger.error(f"向量检索失败: {rets.message}") return [] results = rets.output if not results: logger.warning("未找到相关结果") return [] logger.info(f"检索成功,原始返回 {len(results)} 条结果") # 过滤相似度低于阈值的结果 filtered_results = [] for i, doc in enumerate(results, 1): score = doc.score article_id = doc.fields.get('id', '') title = doc.fields.get('title', '') created_at = doc.fields.get('created_at', '') logger.info(f" 结果[{i}] 相似度: {score:.4f} | ID: {article_id} | 标题: {title} | 创建时间: {created_at}") # 相似度阈值过滤(注意:score越小越相似,0表示最相似) if score <= similarity_threshold: filtered_results.append({ "id": article_id, "title": title, "created_at": created_at, "similar": score }) logger.info(f" 结果[{i}] 通过相似度过滤({score:.4f} <= {similarity_threshold})") else: logger.info(f" 结果[{i}] 未通过相似度过滤({score:.4f} > {similarity_threshold})") # 最多返回2条 if len(filtered_results) > 2: filtered_results = filtered_results[:2] logger.info(f"限制返回结果数量为2条") logger.info(f"最终返回 {len(filtered_results)} 条相似topic") return filtered_results # === 测试查询 === if __name__ == "__main__": # 初始化客户端 if not init_dashvector_client(): logger.error("初始化失败,退出") sys.exit(1) test_queries = [ "早泄怎么治疗?", "肾虚吃什么药?", "前列腺肥大的原因", "阳痿能治好吗?", "早泄可以吃哪些中药?", "前列腺增生肥大能治好吗" ] for q in test_queries: print(f"\n{'='*60}") print(f"🔍 测试查询: {q}") print(f"{'='*60}") results = search_chinese(q, topk=3, similarity_threshold=0.5) if results: print(f"\n✅ 找到 {len(results)} 条相关结果:") for i, item in enumerate(results, 1): print(f" [{i}] 相似度: {item['similar']:.4f} | ID: {item.get('id', 'N/A')} | 标题: {item.get('title', 'N/A')} | 创建时间: {item.get('created_at', 'N/A')}") else: print("⚠️ 未找到符合条件的相关结果") print(f"\n{'='*60}") print("测试完成") print(f"{'='*60}")