73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
import requests
|
||
from dashvector import Client, Doc
|
||
|
||
# === 配置 ===
|
||
DASHVECTOR_API_KEY = 'sk-55x6oBXypSlPHQ8NvPHfyBABcMIMUE0407A0FCC2A11F0B9C802831A608ABB'
|
||
DASHVECTOR_ENDPOINT = 'vrs-cn-2ml4jm42o0001r.dashvector.cn-hangzhou.aliyuncs.com'
|
||
|
||
# 从 DashScope 控制台获取(不是百炼 Model Studio!)
|
||
DASHSCOPE_API_KEY = 'sk-d3f235925afa4e4e83d707dde04b9e52' # 👈 替换这里!
|
||
|
||
def get_embedding(text):
|
||
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding-v1"
|
||
headers = {
|
||
"Authorization": f"Bearer {DASHSCOPE_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"input": {"texts": [text]},
|
||
"model": "text-embedding-v1"
|
||
}
|
||
resp = requests.post(url, headers=headers, json=data)
|
||
if resp.status_code == 200:
|
||
return resp.json()["output"]["embeddings"][0]["embedding"]
|
||
else:
|
||
raise Exception(f"❌ Embedding API 错误: {resp.status_code} - {resp.text}")
|
||
|
||
# === 初始化 DashVector 客户端 ===
|
||
client = Client(api_key=DASHVECTOR_API_KEY, endpoint=DASHVECTOR_ENDPOINT)
|
||
|
||
# === 创建集合(注意维度是 1536!)===
|
||
collection_name = "medical_topics"
|
||
try:
|
||
client.delete(collection_name)
|
||
except:
|
||
pass
|
||
|
||
client.create(name=collection_name, dimension=1536) # text-embedding-v1 输出 1536 维
|
||
collection = client.get(collection_name)
|
||
print("✅ 集合已创建并获取")
|
||
|
||
# === 插入数据 ===
|
||
topics = [
|
||
"如何治疗阳痿、早泄和肾虚?",
|
||
"早泄可以吃哪些中药?",
|
||
"该如何治疗早泄?",
|
||
"前列腺肥大是什么原因引起的?"
|
||
]
|
||
|
||
docs = []
|
||
for i, text in enumerate(topics, 1):
|
||
emb = get_embedding(text)
|
||
docs.append(Doc(id=f"topic_{i}", vector=emb, fields={"content": text}))
|
||
|
||
resp = collection.insert(docs)
|
||
if resp.success:
|
||
print("✅ 4 条中文话题已成功插入!")
|
||
else:
|
||
print("❌ 插入失败:", resp)
|
||
exit(1)
|
||
|
||
# === 查询测试 ===
|
||
query_text = "早泄的治疗方法有哪些?"
|
||
query_vec = get_embedding(query_text)
|
||
|
||
rets = collection.query(vector=query_vec, topk=3, output_fields=["content"])
|
||
if rets.success:
|
||
print(f"\n🔍 查询 '{query_text}' 的结果:")
|
||
for doc in rets.documents:
|
||
print(f" ID: {doc.id} | 相似度: {doc.score:.4f} | 内容: {doc.fields['content']}")
|
||
else:
|
||
print("❌ 查询失败:", rets)
|
||
|