2026-1-6
This commit is contained in:
18
venv/Lib/site-packages/dashvector/common/__init__.py
Normal file
18
venv/Lib/site-packages/dashvector/common/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
##
|
||||
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
##
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
414
venv/Lib/site-packages/dashvector/common/common_validator.py
Normal file
414
venv/Lib/site-packages/dashvector/common/common_validator.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, NewType
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dashvector.common.constants import *
|
||||
from dashvector.common.types import *
|
||||
from dashvector.common.error import DashVectorCode, DashVectorException
|
||||
from dashvector.core.doc import Doc
|
||||
from dashvector.core.models.collection_meta_status import CollectionMeta
|
||||
from dashvector.util.convertor import to_sorted_sparse_vector, to_sorted_sparse_vectors
|
||||
from dashvector.core.proto import dashvector_pb2
|
||||
|
||||
class Validator():
|
||||
|
||||
@staticmethod
|
||||
def validate_dense_vector(vector, dimension: int, dtype: VectorDataType, doc_op: str):
|
||||
if isinstance(vector, list):
|
||||
if len(vector) != dimension:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.MismatchedDimension,
|
||||
reason=f"DashVectorSDK {doc_op} vector list length({len(vector)}) is invalid and must be same with collection dimension({dimension})",
|
||||
)
|
||||
vector_data_type = VectorType.get_vector_data_type(type(vector[0]))
|
||||
if vector_data_type != dtype:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.MismatchedDataType,
|
||||
reason=f"DashVectorSDK {doc_op} vector type({type(vector[0])}) is invalid and must be {VectorType.get_python_type(dtype)}",
|
||||
)
|
||||
if vector_data_type == VectorType.INT:
|
||||
try:
|
||||
vector = VectorType.convert_to_bytes(vector, dtype, dimension)
|
||||
except Exception as e:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorFormat,
|
||||
reason=f"DashVectorSDK {doc_op} vector value({vector}) is invalid and int value must be in [-128, 127]",
|
||||
)
|
||||
elif isinstance(vector, np.ndarray):
|
||||
if vector.ndim != 1:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vector numpy dimension({vector.ndim}) is invalid and must be 1",
|
||||
)
|
||||
if vector.shape[0] != dimension:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.MismatchedDimension,
|
||||
reason=f"DashVectorSDK {doc_op} vector numpy shape[0]({vector.shape[0]}) is invalid and must be same with collection dimension({dimension})",
|
||||
)
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vector type({type(vector)}) is invalid and must be [list, numpy.ndarray]",
|
||||
)
|
||||
return vector
|
||||
|
||||
@staticmethod
|
||||
def validate_collection_name(name: str, doc_op: str):
|
||||
if not isinstance(name, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} name type({name}) is invalid and must be str",
|
||||
)
|
||||
|
||||
if re.search(COLLECTION_AND_PARTITION_NAME_PATTERN, name) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidCollectionName,
|
||||
reason=f"DashVectorSDK {doc_op} name characters({name}) is invalid and "
|
||||
+ COLLECTION_AND_PARTITION_NAME_PATTERN_MSG,
|
||||
)
|
||||
return name
|
||||
|
||||
@staticmethod
|
||||
def validate_partition_name(partition_name: str, doc_op: str):
|
||||
if not isinstance(partition_name, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} partition name type({partition_name}) is invalid and must be str",
|
||||
)
|
||||
|
||||
if re.search(COLLECTION_AND_PARTITION_NAME_PATTERN, partition_name) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidPartitionName,
|
||||
reason=f"DashVectorSDK {doc_op} partition characters({partition_name}) is invalid and "
|
||||
+ COLLECTION_AND_PARTITION_NAME_PATTERN_MSG,
|
||||
)
|
||||
return partition_name
|
||||
|
||||
@staticmethod
|
||||
def validate_vector_name(vector_name: str, doc_op: str):
|
||||
if not isinstance(vector_name, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vector_name type({type(vector_name)}) is invalid and must be str",
|
||||
)
|
||||
if re.search(FIELD_NAME_PATTERN, vector_name) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidFieldName,
|
||||
reason=f"DashVectorSDK {doc_op} vector_name characters({vector_name}) is invalid and "
|
||||
+ FIELD_NAME_PATTERN_MSG,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_sparse_vectors(sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]], doc_op: str):
|
||||
if sparse_vectors is None:
|
||||
sparse_vectors = dict()
|
||||
if isinstance(sparse_vectors, dict):
|
||||
for vector_name in sparse_vectors.keys():
|
||||
Validator.validate_vector_name(vector_name, doc_op)
|
||||
if not isinstance(sparse_vectors, dict):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} sparse_vectors type({type(sparse_vectors)}) is invalid and must be dict"
|
||||
)
|
||||
for vector_name, vector_param in sparse_vectors.items():
|
||||
if not isinstance(vector_param, VectorParam):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vector_param type({type(vector_param)}) is invalid and must be VectorParam",
|
||||
)
|
||||
vector_param.validate()
|
||||
if vector_param.dimension != 0:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidDimension,
|
||||
reason=f"DashVectorSDK VectorParam dimension value({vector_param.dimension}) for sparse vector is invalid and must be 0",
|
||||
)
|
||||
if(len(vector_param.quantize_type) > 0):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vector_name({vector_name}), quantize_type({vector_param.quantize_type}) for sparse vector is invalid and must be empty",
|
||||
)
|
||||
return sparse_vectors
|
||||
|
||||
@staticmethod
|
||||
def validate_fields_schema(fields_schema: Optional[FieldSchemaDict], doc_op: str):
|
||||
returned_fields_schema = dict()
|
||||
if fields_schema is not None:
|
||||
if not isinstance(fields_schema, dict):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} fields_schema type({type(fields_schema)}) is invalid and must be dict",
|
||||
)
|
||||
|
||||
if len(fields_schema) > 1024:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK {doc_op} fields_schema length({len(fields_schema)}) is invalid and must be in [0, 1024]",
|
||||
)
|
||||
|
||||
|
||||
for field_name, field_dtype in fields_schema.items():
|
||||
if not isinstance(field_name, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} field_name in fields_schema type({type(field_name)}) is invalid and must be str",
|
||||
)
|
||||
|
||||
if re.search(FIELD_NAME_PATTERN, field_name) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidFieldName,
|
||||
reason=f"DashVectorSDK {doc_op} field_name in fields_schema characters({field_name}) is invalid and "
|
||||
+ FIELD_NAME_PATTERN_MSG,
|
||||
)
|
||||
|
||||
if field_name == DASHVECTOR_VECTOR_NAME:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidFieldName,
|
||||
reason=f"DashVectorSDK {doc_op} field_name in fields_schema value({DASHVECTOR_VECTOR_NAME}) is reserved",
|
||||
)
|
||||
|
||||
ftype = FieldType.get_field_data_type(field_dtype)
|
||||
returned_fields_schema[field_name] = ftype
|
||||
return returned_fields_schema
|
||||
|
||||
@staticmethod
|
||||
def validate_extra_params(extra_params: Optional[Dict[str, Any]], doc_op: str):
|
||||
returned_extra_params = dict()
|
||||
if extra_params is not None:
|
||||
if not isinstance(extra_params, dict):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} extra_params type({type(extra_params)}) is invalid and must be dict",
|
||||
)
|
||||
|
||||
extra_params_is_empty = True
|
||||
for extra_param_key, extra_param_value in extra_params.items():
|
||||
extra_params_is_empty = False
|
||||
|
||||
if not isinstance(extra_param_key, str) or not isinstance(extra_param_value, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} extra_param key/value type is invalid and must be str.",
|
||||
)
|
||||
|
||||
if len(extra_param_key) <= 0:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidExtraParam,
|
||||
reason=f"DashVectorSDK {doc_op} extra_param key is empty",
|
||||
)
|
||||
|
||||
if not extra_params_is_empty:
|
||||
returned_extra_params = extra_params
|
||||
return returned_extra_params
|
||||
|
||||
@staticmethod
|
||||
def validate_doc_ids(ids: IdsType, doc_op: str):
|
||||
returned_ids = list()
|
||||
returned_ids_is_single = False
|
||||
if isinstance(ids, list):
|
||||
if len(ids) < 1 or len(ids) > 1024:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.ExceedIdsLimit,
|
||||
reason=f"DashVectorSDK {doc_op} ids list length({len(ids)}) is invalid and must be in [1, 1024]",
|
||||
)
|
||||
for id in ids:
|
||||
if isinstance(id, str):
|
||||
if re.search(DOC_ID_PATTERN, id) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidPrimaryKey,
|
||||
reason=f"DashVectorSDK {doc_op} id in ids list characters({id}) is invalid and "
|
||||
+ DOC_ID_PATTERN_MSG,
|
||||
)
|
||||
returned_ids.append(id)
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} id in ids list type({type(id)}) is invalid and must be str",
|
||||
)
|
||||
|
||||
elif isinstance(ids, str):
|
||||
if re.search(DOC_ID_PATTERN, ids) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidPrimaryKey,
|
||||
reason=f"DashVectorSDK {doc_op} ids str characters({ids}) is invalid and "
|
||||
+ DOC_ID_PATTERN_MSG,
|
||||
)
|
||||
|
||||
returned_ids.append(ids)
|
||||
returned_ids_is_single = True
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} ids type({type(ids)}) is invalid and must be [str, List[str]]",
|
||||
)
|
||||
return returned_ids, returned_ids_is_single
|
||||
|
||||
@staticmethod
|
||||
def validate_id(id: str, doc_op: str):
|
||||
if isinstance(id, str):
|
||||
if re.search(DOC_ID_PATTERN, id) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidPrimaryKey,
|
||||
reason=f"DashVectorSDK {doc_op} id str characters({id}) is invalid and "
|
||||
+ DOC_ID_PATTERN_MSG,
|
||||
)
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} expect id to be <str> but actual type is ({type(id)})",
|
||||
)
|
||||
return id
|
||||
|
||||
@staticmethod
|
||||
def validate_output_fields(output_fields: Optional[List[str]], doc_op: str):
|
||||
returned_output_fields = list()
|
||||
if output_fields is not None:
|
||||
if isinstance(output_fields, list):
|
||||
for output_field in output_fields:
|
||||
if not isinstance(output_field, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} output_field in output_fields type({type(output_field)}) is invalid and must be list[str]",
|
||||
)
|
||||
|
||||
if re.search(FIELD_NAME_PATTERN, output_field) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK {doc_op} output_field in output_fields characters({output_field}) is invalid and "
|
||||
+ FIELD_NAME_PATTERN_MSG,
|
||||
)
|
||||
|
||||
returned_output_fields = output_fields
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} output_fields type({type(output_fields)}) is invalid and must be List[str]",
|
||||
)
|
||||
return returned_output_fields
|
||||
|
||||
@staticmethod
|
||||
def validate_rerank(rerank: BaseRanker, query_request: dashvector_pb2.QueryDocRequest, doc_op: str):
|
||||
if isinstance(rerank, WeightedRanker):
|
||||
weight_keys = sorted(rerank.weights.keys())
|
||||
query_vectors_keys = sorted(sorted(query_request.vectors.keys()) + sorted(query_request.sparse_vectors.keys()))
|
||||
|
||||
if weight_keys is not None and len(weight_keys) > 0 and weight_keys != query_vectors_keys:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} expect WeightedRanker.weights({rerank.weights}) to exactly match all vector names({query_vectors_keys})"
|
||||
)
|
||||
elif isinstance(rerank, RrfRanker):
|
||||
rank_constant = rerank.rank_constant
|
||||
if not isinstance(rank_constant, int) or rank_constant < 0 or rank_constant >= 2**31:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} expect RrfRanker.rank_constant({rank_constant}) to be positive int32"
|
||||
)
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} expect rerank type to be WeightedRanker or RrfRanker, actual type({type(rerank)})"
|
||||
)
|
||||
return rerank
|
||||
|
||||
@staticmethod
|
||||
def validate_include_vector(include_vector: bool, doc_op: str):
|
||||
if not isinstance(include_vector, bool):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} include_vector type({type(include_vector)}) is invalid and must be bool",
|
||||
)
|
||||
return include_vector
|
||||
|
||||
@staticmethod
|
||||
def validate_topk(topk: int, include_vector: bool, doc_op: str):
|
||||
include_vector = Validator.validate_include_vector(include_vector, doc_op)
|
||||
if not isinstance(topk, int):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} topk type({type(topk)}) is invalid and must be int",
|
||||
)
|
||||
if topk < 1 or (include_vector and topk > 1024):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidTopk,
|
||||
reason=f"DashVectorSDK {doc_op} topk value({topk}) is invalid and must be in [1, 1024] "
|
||||
f"when include_vector is True",
|
||||
)
|
||||
return topk
|
||||
|
||||
@staticmethod
|
||||
def validate_filter(filter: str, doc_op: str):
|
||||
if filter is not None:
|
||||
if not isinstance(filter, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} filter type({type(filter)}) is invalid and must be str",
|
||||
)
|
||||
|
||||
if len(filter) > 40960:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidFilter,
|
||||
reason=f"DashVectorSDK {doc_op} filter length({len(filter)}) is invalid and must be in [0, 40960]",
|
||||
)
|
||||
if len(filter) > 0:
|
||||
return filter
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def validate_doc(doc: Doc, meta: CollectionMeta, action: str, doc_op: str):
|
||||
if doc.id is not None:
|
||||
if not isinstance(doc.id, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} id type({type(doc.id)}) is invalid and must be str",
|
||||
)
|
||||
if re.search(DOC_ID_PATTERN, doc.id) is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidPrimaryKey,
|
||||
reason=f"DashVectorSDK {doc_op} id characters({doc.id}) is invalid and "
|
||||
+ DOC_ID_PATTERN_MSG,
|
||||
)
|
||||
else:
|
||||
if action == "update":
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidPrimaryKey,
|
||||
reason=f"DashVectorSDK {doc_op} id({doc.id}) is required when the action is update",
|
||||
)
|
||||
returned_id = doc.id
|
||||
returned_vector = None
|
||||
if doc.vector is not None:
|
||||
returned_vector = Validator.validate_dense_vector(doc.vector, meta.get_dimension(), VectorType.get(meta.get_dtype()), doc_op)
|
||||
returned_vectors = dict()
|
||||
if doc.vectors is not None:
|
||||
for key, value in doc.vectors.items():
|
||||
returned_vectors[key] = Validator.validate_dense_vector(value, meta.get_dimension(vector_name=key), VectorType.get(meta.get_dtype(vector_name=key)), doc_op)
|
||||
if action != "update" and doc.vector is None and doc.vectors is None and doc.sparse_vectors is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorFormat,
|
||||
reason=f"DashVectorSDK {doc_op} vector is required and must be in [list, numpy.ndarray] when request in [insert, upsert]",
|
||||
)
|
||||
|
||||
# check fields
|
||||
if doc.fields is None:
|
||||
pass
|
||||
elif not isinstance(doc.fields, dict):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK {doc_op} fields type({type(doc.fields)}) is invalid",
|
||||
)
|
||||
elif len(doc.fields) > 1024:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK {doc_op} fields length({len(doc.fields)}) is invalid and must be in [1, 1024]",
|
||||
)
|
||||
|
||||
returned_sparse_vector = to_sorted_sparse_vector(doc.sparse_vector)
|
||||
|
||||
returned_sparse_vectors = to_sorted_sparse_vectors(doc.sparse_vectors)
|
||||
|
||||
return Doc(id=returned_id, vector=returned_vector, vectors=returned_vectors, fields=doc.fields, sparse_vector=returned_sparse_vector, sparse_vectors=returned_sparse_vectors)
|
||||
31
venv/Lib/site-packages/dashvector/common/constants.py
Normal file
31
venv/Lib/site-packages/dashvector/common/constants.py
Normal file
@@ -0,0 +1,31 @@
|
||||
##
|
||||
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
##
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
DASHVECTOR_VECTOR_NAME = "proxima_vector"
|
||||
DASHVECTOR_LOGGING_LEVEL_ENV = "DASHVECTOR_LOGGING_LEVEL"
|
||||
FIELD_NAME_PATTERN = "^[a-zA-Z0-9_-]{1,32}$"
|
||||
FIELD_NAME_PATTERN_MSG = "character must be in [a-zA-Z0-9] and symbols[_, -] and length must be in [1,32]"
|
||||
COLLECTION_AND_PARTITION_NAME_PATTERN = "^[a-zA-Z0-9_-]{3,32}$"
|
||||
COLLECTION_AND_PARTITION_NAME_PATTERN_MSG = (
|
||||
"character must be in [a-zA-Z0-9] and symbols[_, -] and length must be in [3,32]"
|
||||
)
|
||||
DOC_ID_PATTERN = "^[a-zA-Z0-9_\\-!@#$%+=.]{1,64}$"
|
||||
DOC_ID_PATTERN_MSG = "character must be in [a-zA-Z0-9] and symbols[_-!@#$%+=.] and length must be in [1, 64]"
|
||||
GRPC_MAX_MSG_SIZE = 128 * 1024 * 1024
|
||||
MAX_INT_VALUE = 2 ** 31 - 1
|
||||
MIN_INT_VALUE = -2 ** 31
|
||||
133
venv/Lib/site-packages/dashvector/common/error.py
Normal file
133
venv/Lib/site-packages/dashvector/common/error.py
Normal file
@@ -0,0 +1,133 @@
|
||||
##
|
||||
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
##
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import http
|
||||
from enum import IntEnum
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class DashVectorCode(IntEnum):
|
||||
Timeout = 408
|
||||
Success = 0
|
||||
Closed = -998
|
||||
Unknown = -999
|
||||
EmptyCollectionName = -2000
|
||||
EmptyColumnName = -2001
|
||||
EmptyPartitionName = -2002
|
||||
EmptyColumns = -2003
|
||||
EmptyPrimaryKey = -2004
|
||||
EmptyDocList = -2005
|
||||
EmptyDocFields = -2006
|
||||
EmptyIndexField = -2007
|
||||
InvalidRecord = -2008
|
||||
InvalidQuery = -2009
|
||||
InvalidWriteRequest = -2010
|
||||
InvalidVectorFormat = -2011
|
||||
InvalidDataType = -2012
|
||||
InvalidIndexType = -2013
|
||||
InvalidFeature = -2014
|
||||
InvalidFilter = -2015
|
||||
InvalidPrimaryKey = -2016
|
||||
InvalidField = -2017
|
||||
MismatchedIndexColumn = -2018
|
||||
MismatchedDimension = -2019
|
||||
MismatchedDataType = -2020
|
||||
InexistentCollection = -2021
|
||||
InexistentPartition = -2022
|
||||
InexistentColumn = -2023
|
||||
InexistentKey = -2024
|
||||
DuplicateCollection = -2025
|
||||
DuplicatePartition = -2026
|
||||
DuplicateKey = -2027
|
||||
DuplicateField = -2028
|
||||
UnreadyPartition = -2029
|
||||
UnreadyCollection = -2030
|
||||
UnsupportedCondition = -2031
|
||||
OrderbyNotInSelectItems = -2032
|
||||
PbToSqlInfoError = -2033
|
||||
ExceedRateLimit = -2034
|
||||
InvalidSparseValues = -2035
|
||||
InvalidBatchSize = -2036
|
||||
InvalidDimension = -2037
|
||||
InvalidExtraParam = -2038
|
||||
InvalidRadius = -2039
|
||||
InvalidLinear = -2040
|
||||
InvalidTopk = -2041
|
||||
InvalidCollectionName = -2042
|
||||
InvalidPartitionName = -2043
|
||||
InvalidFieldName = -2044
|
||||
InvalidChannelCount = -2045
|
||||
InvalidReplicaCount = -2046
|
||||
InvalidJson = -2047
|
||||
InvalidGroupBy = -2053,
|
||||
InvalidSparseIndices = -2951
|
||||
InvalidEndpoint = -2952
|
||||
ExceedIdsLimit = -2967
|
||||
InvalidVectorType = -2968
|
||||
ExceedRequestSize = -2970
|
||||
ExistVectorAndId = -2973
|
||||
InvalidArgument = -2999
|
||||
|
||||
|
||||
class DashVectorException(Exception):
|
||||
"""
|
||||
DashVector Exception
|
||||
"""
|
||||
|
||||
def __init__(self, code=DashVectorCode.Unknown, reason=None, request_id=None):
|
||||
self._code = code
|
||||
self._reason = "DashVectorSDK unknown exception" if reason is None else reason
|
||||
self._request_id = request_id
|
||||
super().__init__(f"{self._reason}({self._code})")
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
return self._code
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._reason
|
||||
|
||||
@property
|
||||
def request_id(self):
|
||||
if self._request_id is None:
|
||||
return ""
|
||||
return self._request_id
|
||||
|
||||
|
||||
class DashVectorHTTPException(DashVectorException):
|
||||
def __new__(cls, code, reason=None, request_id=None):
|
||||
exception_code = code
|
||||
exception_reason = reason
|
||||
if isinstance(code, http.HTTPStatus):
|
||||
exception_code = code.value
|
||||
exception_reason = f"DashVectorSDK http rpc error: {code.phrase}"
|
||||
|
||||
return DashVectorException(code=exception_code, reason=exception_reason, request_id=request_id)
|
||||
|
||||
|
||||
class DashVectorGRPCException(DashVectorException):
|
||||
def __new__(cls, code, reason=None, request_id=None):
|
||||
exception_code = code
|
||||
exception_reason = reason
|
||||
if isinstance(code, grpc.StatusCode):
|
||||
exception_code = code.value[0]
|
||||
exception_reason = f"DashVectorSDK grpc rpc error: {code.value[1]}"
|
||||
return DashVectorException(code=exception_code, reason=exception_reason, request_id=request_id)
|
||||
212
venv/Lib/site-packages/dashvector/common/handler.py
Normal file
212
venv/Lib/site-packages/dashvector/common/handler.py
Normal file
@@ -0,0 +1,212 @@
|
||||
##
|
||||
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
##
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
from google.protobuf.message import Message
|
||||
|
||||
from dashvector.common.error import DashVectorCode, DashVectorException
|
||||
from dashvector.util.validator import verify_endpoint
|
||||
from dashvector.version import __version__
|
||||
|
||||
|
||||
class RPCRequest(object):
|
||||
def __init__(self, *, request: Message):
|
||||
self.request = request
|
||||
self.request_str = request.SerializeToString()
|
||||
self.request_len = len(self.request_str)
|
||||
if self.request_len > (2 * 1024 * 1024):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.ExceedRequestSize,
|
||||
reason=f"DashVectorSDK request length({self.request_len}) exceeds maximum length(2MiB) limit",
|
||||
)
|
||||
|
||||
def to_json(self):
|
||||
return MessageToJson(self.request, always_print_fields_with_no_presence=True, preserving_proto_field_name=True)
|
||||
|
||||
def to_proto(self):
|
||||
return self.request
|
||||
|
||||
def to_string(self):
|
||||
return self.request_str
|
||||
|
||||
|
||||
class RPCResponse(object):
|
||||
def __init__(self, *, async_req):
|
||||
self._async_req = async_req
|
||||
self._request_id = None
|
||||
self._code = DashVectorCode.Unknown
|
||||
self._message = None
|
||||
self._output = None
|
||||
self._usage = None
|
||||
|
||||
@property
|
||||
def async_req(self):
|
||||
return self._async_req
|
||||
|
||||
@property
|
||||
def request_id(self):
|
||||
return self._request_id
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
return self._code
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._message
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
return self._output
|
||||
|
||||
@property
|
||||
def usage(self):
|
||||
return self._usage
|
||||
|
||||
@abstractmethod
|
||||
def get(self):
|
||||
pass
|
||||
|
||||
|
||||
class RPCHandler(object):
|
||||
def __init__(self, *, endpoint: str = "", api_key: str = "", timeout: float = 10.0):
|
||||
"""
|
||||
endpoint: str
|
||||
"""
|
||||
if not isinstance(endpoint, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK RPCHandler endpoint({endpoint}) is invalid and must be str",
|
||||
)
|
||||
|
||||
if not verify_endpoint(endpoint):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidEndpoint,
|
||||
reason=f"DashVectorSDK RPCHandler endpoint({endpoint}) is invalid and cannot contain protocol header and [_]",
|
||||
)
|
||||
|
||||
"""
|
||||
api_key: str
|
||||
"""
|
||||
if not isinstance(api_key, str):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK RPCHandler api_key({api_key}) is invalid"
|
||||
)
|
||||
|
||||
"""
|
||||
timeout: float
|
||||
"""
|
||||
if isinstance(timeout, float):
|
||||
pass
|
||||
elif isinstance(timeout, int):
|
||||
timeout = float(timeout)
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK RPCHandler timeout({timeout}) is invalid"
|
||||
)
|
||||
if timeout <= 0.000001:
|
||||
timeout = 365.5 * 86400
|
||||
|
||||
self._endpoint = endpoint
|
||||
self._timeout = timeout
|
||||
self._insecure_mode = os.getenv("DASHVECTOR_INSECURE_MODE", "False").lower() in ("true", "1")
|
||||
self._headers = {
|
||||
"dashvector-auth-token": api_key,
|
||||
"x-user-agent": f"{__version__};{platform.python_version()};{platform.platform()}",
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(self, create_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, delete_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def describe_collection(self, describe_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_collections(self, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stats_collection(self, stats_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_partition(self, create_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_partition(self, delete_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def describe_partition(self, describe_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_partitions(self, list_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stats_partition(self, stats_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert_doc(self, insert_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_doc(self, update_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upsert_doc(self, upsert_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_doc(self, delete_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_doc(self, query_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_doc_group_by(self, query_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch_doc(self, fetch_request, *, async_req=False) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_version(self, *, async_req) -> RPCResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
pass
|
||||
50
venv/Lib/site-packages/dashvector/common/logging.py
Normal file
50
venv/Lib/site-packages/dashvector/common/logging.py
Normal file
@@ -0,0 +1,50 @@
|
||||
##
|
||||
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
##
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dashvector.common.constants import DASHVECTOR_LOGGING_LEVEL_ENV
|
||||
|
||||
logger = logging.getLogger("dashvector")
|
||||
|
||||
|
||||
def enable_logging():
|
||||
level = os.environ.get(DASHVECTOR_LOGGING_LEVEL_ENV, None)
|
||||
if level is not None: # set logging level.
|
||||
if level not in ["info", "debug"]:
|
||||
# set logging level env, but invalid value, use default.
|
||||
level = "info"
|
||||
if level == "info":
|
||||
logger.setLevel(logging.INFO)
|
||||
else:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# set default logging handler
|
||||
console_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s - %(message)s"
|
||||
# noqa E501
|
||||
)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
# in release disable dashscope log
|
||||
# you can enable dashscope log for debugger.
|
||||
enable_logging()
|
||||
61
venv/Lib/site-packages/dashvector/common/status.py
Normal file
61
venv/Lib/site-packages/dashvector/common/status.py
Normal file
@@ -0,0 +1,61 @@
|
||||
##
|
||||
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
##
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
from typing import Union
|
||||
|
||||
from dashvector.common.error import DashVectorCode, DashVectorException
|
||||
|
||||
|
||||
class StrStatus(str, Enum):
|
||||
INITIALIZED = "INITIALIZED"
|
||||
SERVING = "SERVING"
|
||||
DROPPING = "DROPPING"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class Status(IntEnum):
|
||||
INITIALIZED = 0
|
||||
SERVING = 1
|
||||
DROPPING = 2
|
||||
ERROR = 3
|
||||
|
||||
@staticmethod
|
||||
def get(cs: Union[str, StrStatus]) -> IntEnum:
|
||||
if cs == StrStatus.INITIALIZED:
|
||||
return Status.INITIALIZED
|
||||
elif cs == StrStatus.SERVING:
|
||||
return Status.SERVING
|
||||
elif cs == StrStatus.DROPPING:
|
||||
return Status.DROPPING
|
||||
elif cs == StrStatus.ERROR:
|
||||
return Status.ERROR
|
||||
raise DashVectorException(code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK get invalid status {cs}")
|
||||
|
||||
@staticmethod
|
||||
def str(cs: Union[int, IntEnum]) -> str:
|
||||
if cs == Status.INITIALIZED:
|
||||
return StrStatus.INITIALIZED.value
|
||||
elif cs == Status.SERVING:
|
||||
return StrStatus.SERVING.value
|
||||
elif cs == Status.DROPPING:
|
||||
return StrStatus.DROPPING.value
|
||||
elif cs == Status.ERROR:
|
||||
return StrStatus.ERROR.value
|
||||
raise DashVectorException(code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK get invalid Status {cs}")
|
||||
837
venv/Lib/site-packages/dashvector/common/types.py
Normal file
837
venv/Lib/site-packages/dashvector/common/types.py
Normal file
@@ -0,0 +1,837 @@
|
||||
##
|
||||
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
##
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import struct
|
||||
from enum import Enum, IntEnum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, NewType
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dashvector.common.error import DashVectorCode, DashVectorException
|
||||
from dashvector.common.handler import RPCResponse
|
||||
from dashvector.common.status import Status
|
||||
from dashvector.util.convertor import to_json_without_ascii
|
||||
from dashvector.core.proto import dashvector_pb2
|
||||
|
||||
long = NewType("long", int)
|
||||
|
||||
VectorDataType = Union[
|
||||
Type[int],
|
||||
Type[float],
|
||||
Type[bool],
|
||||
Type[np.int8],
|
||||
Type[np.int16],
|
||||
Type[np.float16],
|
||||
Type[np.bool_],
|
||||
Type[np.float32],
|
||||
Type[np.float64],
|
||||
]
|
||||
VectorValueType = Union[List[int], List[float], np.ndarray]
|
||||
SparseValueType = Dict[int, float]
|
||||
|
||||
supported_type_msg = ("bool | str | int | float | long | "
|
||||
"typing.List[str] | typing.List[int] | typing.List[float] | typing.List[long]")
|
||||
# used to define schema
|
||||
FieldSchemaType = Union[
|
||||
Type[long], Type[str], Type[bool], Type[int], Type[float],
|
||||
Type[List[long]], Type[List[str]], Type[List[int]], Type[List[float]]
|
||||
]
|
||||
# used to insert field data
|
||||
FieldDataType = Union[long, str, int, float, bool, List[long], List[str], List[int], List[float]]
|
||||
|
||||
FieldSchemaDict = Dict[str, FieldSchemaType]
|
||||
FieldDataDict = Dict[str, FieldDataType]
|
||||
IdsType = Union[str, List[str]]
|
||||
|
||||
|
||||
class DashVectorProtocol(IntEnum):
|
||||
GRPC = 0
|
||||
HTTP = 1
|
||||
|
||||
|
||||
class DocOp(IntEnum):
|
||||
insert = 0
|
||||
update = 1
|
||||
upsert = 2
|
||||
delete = 3
|
||||
|
||||
|
||||
class MetricStrType(str, Enum):
|
||||
EUCLIDEAN = "euclidean"
|
||||
DOTPRODUCT = "dotproduct"
|
||||
COSINE = "cosine"
|
||||
|
||||
|
||||
class MetricType(IntEnum):
|
||||
EUCLIDEAN = 0
|
||||
DOTPRODUCT = 1
|
||||
COSINE = 2
|
||||
|
||||
@staticmethod
|
||||
def get(mtype: Union[str, MetricStrType]) -> IntEnum:
|
||||
if mtype == MetricStrType.EUCLIDEAN:
|
||||
return MetricType.EUCLIDEAN
|
||||
elif mtype == MetricStrType.DOTPRODUCT:
|
||||
return MetricType.DOTPRODUCT
|
||||
elif mtype == MetricStrType.COSINE:
|
||||
return MetricType.COSINE
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK get invalid metrictype {mtype} and must be in [cosine, dotproduct, euclidean]",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def str(mtype: Union[int, IntEnum]) -> str:
|
||||
if mtype == MetricType.EUCLIDEAN:
|
||||
return MetricStrType.EUCLIDEAN.value
|
||||
elif mtype == MetricType.DOTPRODUCT:
|
||||
return MetricStrType.DOTPRODUCT.value
|
||||
elif mtype == MetricType.COSINE:
|
||||
return MetricStrType.COSINE.value
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK get invalid metrictype {mtype} and must be in [cosine, dotproduct, euclidean]",
|
||||
)
|
||||
|
||||
|
||||
class VectorStrType(str, Enum):
|
||||
FLOAT = "FLOAT"
|
||||
INT = "INT"
|
||||
|
||||
|
||||
class VectorType(IntEnum):
|
||||
FLOAT = 0
|
||||
INT = 1
|
||||
|
||||
@staticmethod
|
||||
def get(vtype: Union[str, VectorStrType]) -> 'VectorType':
|
||||
if vtype == VectorStrType.FLOAT:
|
||||
return VectorType.FLOAT
|
||||
elif vtype == VectorStrType.INT:
|
||||
return VectorType.INT
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorType,
|
||||
reason=f"DashVectorSDK get invalid vectortype {vtype} and must be in [int, float]",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def str(vtype: Union[int, IntEnum]) -> str:
|
||||
if vtype == VectorType.FLOAT:
|
||||
return VectorStrType.FLOAT.value
|
||||
elif vtype == VectorType.INT:
|
||||
return VectorStrType.INT.value
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorType,
|
||||
reason=f"DashVectorSDK get invalid vectortype {vtype} and must be in [int, float]]",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_vector_data_type(vtype: Type):
|
||||
if not isinstance(vtype, type):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorType,
|
||||
reason=f"DashVectorSDK does not support vector data type {vtype} and must be in [int, float]",
|
||||
)
|
||||
if vtype not in _vector_dtype_map:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorType,
|
||||
reason=f"DashVectorSDK does not support vector data type {vtype} and must be in [int, float]",
|
||||
)
|
||||
return _vector_dtype_map[vtype]
|
||||
|
||||
@staticmethod
|
||||
def get_vector_data_format(data_type):
|
||||
if data_type not in (VectorType.INT, VectorType.FLOAT):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorType,
|
||||
reason=f"DashVectorSDK does not support vector({data_type}) to convert bytes",
|
||||
)
|
||||
return _vector_type_to_format[data_type]
|
||||
|
||||
@staticmethod
|
||||
def convert_to_bytes(feature, data_type, dimension):
|
||||
if data_type not in (VectorType.INT, VectorType.FLOAT):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorType,
|
||||
reason=f"DashVectorSDK does not support auto pack feature type({data_type})",
|
||||
)
|
||||
return struct.pack(f"<{dimension}{_vector_type_to_format[data_type]}", *feature)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_dtype(feature, data_type, dimension):
|
||||
if data_type not in (VectorType.INT, VectorType.FLOAT):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidVectorType,
|
||||
reason=f"DashVectorSDK does not support auto unpack feature type({data_type})",
|
||||
)
|
||||
return struct.unpack(f"<{dimension}{_vector_type_to_format[data_type]}", feature)
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
return self._indices
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return self._values
|
||||
|
||||
def __dict__(self):
|
||||
return {"indices": self.indices, "values": self.values}
|
||||
|
||||
def get_python_type(self):
|
||||
return _reverse_vector_dtype_map[self]
|
||||
|
||||
|
||||
class FieldStrType(str, Enum):
|
||||
BOOL = "BOOL"
|
||||
STRING = "STRING"
|
||||
INT = "INT"
|
||||
FLOAT = "FLOAT"
|
||||
LONG = "LONG"
|
||||
|
||||
ARRAY_STRING = "ARRAY_STRING"
|
||||
ARRAY_INT = "ARRAY_INT"
|
||||
ARRAY_FLOAT = "ARRAY_FLOAT"
|
||||
ARRAY_LONG = "ARRAY_LONG"
|
||||
|
||||
class FieldType(IntEnum):
|
||||
BOOL = 0
|
||||
STRING = 1
|
||||
INT = 2
|
||||
FLOAT = 3
|
||||
LONG = 4
|
||||
|
||||
ARRAY_STRING = 11
|
||||
ARRAY_INT = 12
|
||||
ARRAY_FLOAT = 13
|
||||
ARRAY_LONG = 14
|
||||
|
||||
@staticmethod
|
||||
def get(ftype: Union[str, FieldStrType]) -> IntEnum:
|
||||
if ftype == FieldStrType.BOOL:
|
||||
return FieldType.BOOL
|
||||
elif ftype == FieldStrType.STRING:
|
||||
return FieldType.STRING
|
||||
elif ftype == FieldStrType.INT:
|
||||
return FieldType.INT
|
||||
elif ftype == FieldStrType.FLOAT:
|
||||
return FieldType.FLOAT
|
||||
elif ftype == FieldStrType.LONG:
|
||||
return FieldType.LONG
|
||||
elif ftype == FieldStrType.ARRAY_STRING:
|
||||
return FieldType.ARRAY_STRING
|
||||
elif ftype == FieldStrType.ARRAY_INT:
|
||||
return FieldType.ARRAY_INT
|
||||
elif ftype == FieldStrType.ARRAY_FLOAT:
|
||||
return FieldType.ARRAY_FLOAT
|
||||
elif ftype == FieldStrType.ARRAY_LONG:
|
||||
return FieldType.ARRAY_LONG
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK does not support field value type {ftype} and must be in {supported_type_msg}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def str(ftype: Union[int, IntEnum]) -> str:
|
||||
if ftype == FieldType.BOOL:
|
||||
return FieldStrType.BOOL.value
|
||||
elif ftype == FieldType.STRING:
|
||||
return FieldStrType.STRING.value
|
||||
elif ftype == FieldType.INT:
|
||||
return FieldStrType.INT.value
|
||||
elif ftype == FieldType.FLOAT:
|
||||
return FieldStrType.FLOAT.value
|
||||
elif ftype == FieldType.LONG:
|
||||
return FieldStrType.LONG.value
|
||||
elif ftype == FieldType.ARRAY_STRING:
|
||||
return FieldStrType.ARRAY_STRING.value
|
||||
elif ftype == FieldType.ARRAY_INT:
|
||||
return FieldStrType.ARRAY_INT.value
|
||||
elif ftype == FieldType.ARRAY_FLOAT:
|
||||
return FieldStrType.ARRAY_FLOAT.value
|
||||
elif ftype == FieldType.ARRAY_LONG:
|
||||
return FieldStrType.ARRAY_LONG.value
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK does not support field value type {ftype} and must be in {supported_type_msg}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_field_data_type(dtype: FieldSchemaType):
|
||||
if dtype not in _attr_dtype_map:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK does not support field value type {dtype} and must be in {supported_type_msg}"
|
||||
)
|
||||
return _attr_dtype_map[dtype]
|
||||
|
||||
|
||||
class IndexStrType(str, Enum):
|
||||
UNDEFINED = "IT_UNDEFINED"
|
||||
HNSW = "IT_HNSW"
|
||||
INVERT = "IT_INVERT"
|
||||
|
||||
|
||||
class IndexType(IntEnum):
|
||||
UNDEFINED = 0
|
||||
HNSW = 1
|
||||
INVERT = 10
|
||||
|
||||
@staticmethod
|
||||
def get(itype: Union[str, IndexStrType]):
|
||||
if itype == IndexStrType.UNDEFINED:
|
||||
return IndexType.UNDEFINED
|
||||
elif itype == IndexStrType.HNSW:
|
||||
return IndexType.HNSW
|
||||
elif itype == IndexStrType.INVERT:
|
||||
return IndexType.INVERT
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidIndexType, reason=f"DashVectorSDK does not support indextype {itype}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def str(itype: Union[int, IntEnum]) -> str:
|
||||
if itype == IndexType.UNDEFINED:
|
||||
return IndexStrType.UNDEFINED.value
|
||||
elif itype == IndexType.HNSW:
|
||||
return IndexStrType.HNSW.value
|
||||
elif itype == IndexType.INVERT:
|
||||
return IndexStrType.INVERT.value
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidIndexType, reason=f"DashVectorSDK does not support indextype {itype}"
|
||||
)
|
||||
|
||||
_vector_dtype_map = {
|
||||
float: VectorType.FLOAT,
|
||||
int: VectorType.INT,
|
||||
}
|
||||
_reverse_vector_dtype_map = {v: k for k,v in _vector_dtype_map.items()}
|
||||
|
||||
_vector_type_to_format = {
|
||||
VectorType.FLOAT: "f",
|
||||
VectorType.INT: "b",
|
||||
}
|
||||
|
||||
_attr_dtype_map = {
|
||||
str: FieldType.STRING,
|
||||
bool: FieldType.BOOL,
|
||||
int: FieldType.INT,
|
||||
float: FieldType.FLOAT,
|
||||
long: FieldType.LONG,
|
||||
"long": FieldType.LONG,
|
||||
List[str]: FieldType.ARRAY_STRING,
|
||||
List[int]: FieldType.ARRAY_INT,
|
||||
List[float]: FieldType.ARRAY_FLOAT,
|
||||
List[long]: FieldType.ARRAY_LONG,
|
||||
}
|
||||
|
||||
|
||||
class DashVectorResponse(object):
|
||||
def __init__(self, response: Optional[RPCResponse] = None, *, exception: Optional[DashVectorException] = None):
|
||||
self._code = DashVectorCode.Unknown
|
||||
self._message = ""
|
||||
self._request_id = ""
|
||||
self._output = None
|
||||
self._usage = None
|
||||
|
||||
self.__response = response
|
||||
self.__exception = exception
|
||||
|
||||
if self.__response is None:
|
||||
self._code = DashVectorCode.Success
|
||||
|
||||
if self.__response is not None and not self.__response.async_req:
|
||||
self.get()
|
||||
|
||||
if self.__exception is not None:
|
||||
self._code = self.__exception.code
|
||||
self._message = self.__exception.message
|
||||
self._request_id = self.__exception.request_id
|
||||
|
||||
def get(self):
|
||||
if self._code != DashVectorCode.Unknown:
|
||||
return self
|
||||
|
||||
if self.__response is None:
|
||||
return self
|
||||
|
||||
try:
|
||||
result = self.__response.get()
|
||||
self._request_id = result.request_id
|
||||
self._code = result.code
|
||||
self._message = result.message
|
||||
self._output = result.output
|
||||
self._usage = result.usage
|
||||
except DashVectorException as e:
|
||||
self._code = e.code
|
||||
self._message = e.message
|
||||
self._request_id = e.request_id
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
return self._code
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._message
|
||||
|
||||
@property
|
||||
def request_id(self):
|
||||
return self._request_id
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
return self._output
|
||||
|
||||
@output.setter
|
||||
def output(self, value: Any):
|
||||
self._output = value
|
||||
|
||||
@property
|
||||
def usage(self):
|
||||
return self._usage
|
||||
|
||||
@property
|
||||
def response(self):
|
||||
return self.__response
|
||||
|
||||
def _decorate_output(self):
|
||||
if self._output is None:
|
||||
return {"code": self.code, "message": self.message, "requests_id": self.request_id}
|
||||
elif isinstance(self._output, Status):
|
||||
return {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
"requests_id": self.request_id,
|
||||
"output": Status.str(self._output),
|
||||
}
|
||||
elif isinstance(self._output, (str, int, float)):
|
||||
return {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
"requests_id": self.request_id,
|
||||
"output": str(self._output),
|
||||
}
|
||||
elif isinstance(self._output, list):
|
||||
output_list = []
|
||||
for output_value in self._output:
|
||||
if isinstance(output_value, (str, int, float)):
|
||||
output_list.append(str(output_value))
|
||||
elif hasattr(output_value, "__dict__"):
|
||||
output_list.append(output_value.__dict__())
|
||||
elif hasattr(output_value, "__str__"):
|
||||
output_list.append(output_value.__str__())
|
||||
else:
|
||||
output_list.append(str(type(output_value)))
|
||||
return {"code": self.code, "message": self.message, "requests_id": self.request_id, "output": output_list}
|
||||
elif isinstance(self._output, dict):
|
||||
output_dict = {}
|
||||
for output_key, output_value in self._output.items():
|
||||
if isinstance(output_value, (str, int, float)):
|
||||
output_dict[output_key] = str(output_value)
|
||||
elif hasattr(output_value, "__dict__"):
|
||||
output_dict[output_key] = output_value.__dict__()
|
||||
elif hasattr(output_value, "__str__"):
|
||||
output_dict[output_key] = output_value.__str__()
|
||||
else:
|
||||
output_dict[output_key] = str(type(output_value))
|
||||
return {"code": self.code, "message": self.message, "requests_id": self.request_id, "output": output_dict}
|
||||
elif hasattr(self._output, "__dict__"):
|
||||
return {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
"requests_id": self.request_id,
|
||||
"output": self._output.__dict__(),
|
||||
}
|
||||
elif hasattr(self._output, "__str__"):
|
||||
return {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
"requests_id": self.request_id,
|
||||
"output": self._output.__str__(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
"requests_id": self.request_id,
|
||||
"output": str(type(self._output)),
|
||||
}
|
||||
|
||||
def __dict__(self):
|
||||
obj = self._decorate_output()
|
||||
if self._usage is not None:
|
||||
obj["usage"] = self._usage.__dict__()
|
||||
return obj
|
||||
|
||||
def __str__(self):
|
||||
return to_json_without_ascii(self.__dict__())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __bool__(self):
|
||||
return self.code == DashVectorCode.Success
|
||||
|
||||
def __len__(self):
|
||||
return len(self._output)
|
||||
|
||||
def __iter__(self):
|
||||
return self._output.__iter__()
|
||||
|
||||
def __contains__(self, item):
|
||||
if hasattr(self._output, "__contains__"):
|
||||
return self.output.__contains__(item)
|
||||
else:
|
||||
raise TypeError(f"DashVectorSDK Get argument of type '{type(self.output)}' is not iterable")
|
||||
|
||||
def __getitem__(self, item):
|
||||
if hasattr(self._output, "__getitem__"):
|
||||
return self.output.__getitem__(item)
|
||||
else:
|
||||
raise TypeError(f"DashVectorSDK Get '{type(self.output)}' object is not subscriptable")
|
||||
|
||||
|
||||
class RequestUsage(object):
|
||||
read_units: int
|
||||
write_units: int
|
||||
|
||||
def __init__(self, *, read_units=None, write_units=None):
|
||||
self.read_units = read_units
|
||||
self.write_units = write_units
|
||||
|
||||
@staticmethod
|
||||
def from_pb(usage: dashvector_pb2.RequestUsage):
|
||||
if usage.HasField("read_units"):
|
||||
return RequestUsage(read_units=usage.read_units)
|
||||
elif usage.HasField("write_units"):
|
||||
return RequestUsage(write_units=usage.write_units)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(usage: dict):
|
||||
if "read_units" in usage:
|
||||
return RequestUsage(read_units=usage["read_units"])
|
||||
elif "write_units" in usage:
|
||||
return RequestUsage(write_units=usage["write_units"])
|
||||
|
||||
def __dict__(self):
|
||||
if self.read_units is None:
|
||||
if self.write_units is None:
|
||||
return {}
|
||||
else:
|
||||
return {"write_units": self.write_units}
|
||||
else:
|
||||
if self.write_units is None:
|
||||
return {"read_units": self.read_units}
|
||||
else:
|
||||
return {"read_units": self.read_units, "write_units": self.write_units}
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.__dict__())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class VectorParam:
|
||||
def __init__(self,
|
||||
dimension: int = 0,
|
||||
dtype: Union[Type[int], Type[float]] = float,
|
||||
metric: str = "cosine",
|
||||
quantize_type: str = "",
|
||||
):
|
||||
"""
|
||||
Vector param.
|
||||
|
||||
Args:
|
||||
dimension (int): vector dimension in collection
|
||||
dtype (Union[Type[int], Type[float]]): vector data type in collection
|
||||
metric (str): vector metric in collection, support 'cosine', 'dotproduct' and 'euclidean', default to 'cosine'
|
||||
quantize_type (str): vector quantize type in collection, refer to https://help.aliyun.com/document_detail/2663745.html for latest support types
|
||||
"""
|
||||
|
||||
self._exception = None
|
||||
|
||||
"""
|
||||
dim: int
|
||||
"""
|
||||
if not isinstance(dimension, int):
|
||||
self._exception = DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK VectorParam dimension type({type(dimension)}) is invalid and must be int",
|
||||
)
|
||||
return
|
||||
self.dimension = dimension
|
||||
|
||||
"""
|
||||
metric: MetricType
|
||||
"""
|
||||
if not isinstance(metric, str):
|
||||
self._exception = DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK VectorParam metric Type({type(metric)}) is invalid and must be str",
|
||||
)
|
||||
return
|
||||
try:
|
||||
self._metric = MetricType.get(metric)
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
return
|
||||
|
||||
"""
|
||||
dtype: VectorType
|
||||
"""
|
||||
if dtype is not float and metric == "cosine":
|
||||
self._exception = DashVectorException(
|
||||
code=DashVectorCode.MismatchedDataType,
|
||||
reason=f"DashVectorSDK VectorParam dtype value({dtype}) is invalid and must be [float] when metric is cosine",
|
||||
)
|
||||
return
|
||||
try:
|
||||
self._dtype = VectorType.get_vector_data_type(dtype)
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
return
|
||||
|
||||
"""
|
||||
quantize_type: str
|
||||
"""
|
||||
if not isinstance(quantize_type, str):
|
||||
self._exception = DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK VectorParam quantize_type Type({type(quantize_type)}) is invalid and must be str",
|
||||
)
|
||||
return
|
||||
self.quantize_type = quantize_type
|
||||
|
||||
def validate(self):
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
@staticmethod
|
||||
def from_pb(pb: dashvector_pb2.CollectionInfo.VectorParam):
|
||||
if pb.dtype == VectorType.FLOAT:
|
||||
dtype = float
|
||||
elif pb.dtype == VectorType.INT:
|
||||
dtype = int
|
||||
else:
|
||||
raise DashVectorException(f"DashVectorSDK VectorParam dtype value({pb.dtype}) is invalid")
|
||||
return VectorParam(
|
||||
dimension=pb.dimension,
|
||||
dtype=dtype,
|
||||
metric=MetricType.str(pb.metric),
|
||||
quantize_type=pb.quantize_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]):
|
||||
dimension = d.get("dimension")
|
||||
dtype = VectorType.get(d.get("dtype")).get_python_type()
|
||||
metric = MetricType.str(MetricType.get(d.get("metric")))
|
||||
quantize_type = d.get("quantize_type")
|
||||
return VectorParam(
|
||||
dimension=dimension,
|
||||
dtype=dtype,
|
||||
metric=metric,
|
||||
quantize_type=quantize_type
|
||||
)
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return MetricType.str(self._metric)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return VectorType.str(self._dtype)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"dimension": self.dimension,
|
||||
"dtype": self.dtype,
|
||||
"metric": self.metric,
|
||||
'quantize_type': self.quantize_type
|
||||
}
|
||||
|
||||
|
||||
class VectorQuery:
|
||||
def __init__(self,
|
||||
vector: VectorValueType,
|
||||
num_candidates: int = 0,
|
||||
is_linear: bool = False,
|
||||
ef: int = 0,
|
||||
radius: float = 0.0):
|
||||
"""
|
||||
A vector query.
|
||||
|
||||
vector (Optional[Union[List[Union[int, float, bool]], np.ndarray]]): query vector
|
||||
num_candidate (int): number of candidates for this vector query, default to collection.query.topk
|
||||
is_linear (bool): whether perform linear(brute-force) search, default to False
|
||||
ef (int): ef_search for HNSW-like algorithm, default to adaptive ef
|
||||
radius (float): perform radius nearest neighbor if radius is not 0.0,
|
||||
i.e. return docs with score <= radius for euclidean/cosine and score >= radius for dotproduct
|
||||
"""
|
||||
self.vector = vector
|
||||
self.num_candidates = num_candidates
|
||||
self.is_linear = is_linear
|
||||
self.ef = ef
|
||||
self.radius = radius
|
||||
|
||||
def validate(self):
|
||||
num_candidates = self.num_candidates
|
||||
is_linear = self.is_linear
|
||||
ef = self.ef
|
||||
radius = self.radius
|
||||
if not isinstance(num_candidates, int):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest topk type({type(num_candidates)}) is invalid and must be int",
|
||||
)
|
||||
if num_candidates < 0 or num_candidates > 1024:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidTopk,
|
||||
reason=f"DashVectorSDK GetDocRequest topk value({num_candidates}) is invalid and must be in [1, 1024]",
|
||||
)
|
||||
if not isinstance(is_linear, bool):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest ls_linear type({type(is_linear)}) is invalid and must be bool",
|
||||
)
|
||||
if not isinstance(ef, int):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest ef type({type(ef)}) is invalid and must be int",
|
||||
)
|
||||
if not (0 <= ef <= 4294967295):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest ef value({ef}) is invalid and must be in [0, 4294967295]",
|
||||
)
|
||||
if not isinstance(radius, float):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest radius type({type(radius)}) is invalid and must be float",
|
||||
)
|
||||
|
||||
class SparseVectorQuery:
|
||||
def __init__(self,
|
||||
sparse_vector: SparseValueType,
|
||||
num_candidates: int = 0,
|
||||
is_linear: bool = False,
|
||||
ef: int = 0,
|
||||
radius: float = 0.0):
|
||||
"""
|
||||
A sparse_vector query.
|
||||
|
||||
sparse_vector (Dict[int, float]): query sparse_vector
|
||||
num_candidate (int): number of candidates for this sparse_vector query, default to collection.query.topk
|
||||
is_linear (bool): whether perform linear(brute-force) search, default to False
|
||||
ef (int): ef_search for HNSW-like algorithm, default to adaptive ef
|
||||
radius (float): perform radius nearest neighbor if radius is not 0.0,
|
||||
i.e. return docs with score <= radius for euclidean/cosine and score >= radius for dotproduct
|
||||
"""
|
||||
self.vector = sparse_vector
|
||||
self.num_candidates = num_candidates
|
||||
self.is_linear = is_linear
|
||||
self.ef = ef
|
||||
self.radius = radius
|
||||
|
||||
def validate(self):
|
||||
num_candidates = self.num_candidates
|
||||
is_linear = self.is_linear
|
||||
ef = self.ef
|
||||
radius = self.radius
|
||||
if not isinstance(num_candidates, int):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest topk type({type(num_candidates)}) is invalid and must be int",
|
||||
)
|
||||
if num_candidates < 0 or num_candidates > 1024:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidTopk,
|
||||
reason=f"DashVectorSDK GetDocRequest topk value({num_candidates}) is invalid and must be in [1, 1024]",
|
||||
)
|
||||
if not isinstance(is_linear, bool):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest ls_linear type({type(is_linear)}) is invalid and must be bool",
|
||||
)
|
||||
if not isinstance(ef, int):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest ef type({type(ef)}) is invalid and must be int",
|
||||
)
|
||||
if not (0 <= ef <= 4294967295):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest ef value({ef}) is invalid and must be in [0, 4294967295]",
|
||||
)
|
||||
if not isinstance(radius, float):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK QueryDocRequest radius type({type(radius)}) is invalid and must be float",
|
||||
)
|
||||
|
||||
class BaseRanker:
|
||||
pass
|
||||
|
||||
class RrfRanker(BaseRanker):
|
||||
def __init__(self, rank_constant: int = 60):
|
||||
self.rank_constant = rank_constant
|
||||
|
||||
def to_pb(self):
|
||||
ranker = dashvector_pb2.Ranker()
|
||||
ranker.ranker_name = "rrf"
|
||||
ranker.ranker_params["rank_constant"] = str(self.rank_constant)
|
||||
return ranker
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'ranker_name': "rrf",
|
||||
'ranker_params': {
|
||||
'rank_constant': str(self.rank_constant)
|
||||
}
|
||||
}
|
||||
|
||||
class WeightedRanker(BaseRanker):
|
||||
def __init__(self, weights: Optional[Dict[str, float]] = None):
|
||||
self.weights = weights
|
||||
|
||||
def to_pb(self):
|
||||
ranker = dashvector_pb2.Ranker()
|
||||
ranker.ranker_name = "weighted"
|
||||
if self.weights is not None:
|
||||
ranker.ranker_params["weights"] = json.dumps(self.weights)
|
||||
return ranker
|
||||
|
||||
def to_dict(self):
|
||||
d = {
|
||||
'ranker_name': "weighted",
|
||||
}
|
||||
if self.weights is not None:
|
||||
d['ranker_params'] = {
|
||||
'weights': json.dumps(self.weights)
|
||||
}
|
||||
return d
|
||||
235
venv/Lib/site-packages/dashvector/common/vector_validator.py
Normal file
235
venv/Lib/site-packages/dashvector/common/vector_validator.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dashvector.common.common_validator import *
|
||||
from dashvector.common.types import *
|
||||
from dashvector.common.error import DashVectorCode, DashVectorException
|
||||
from dashvector.core.models.collection_meta_status import CollectionMeta
|
||||
from dashvector.util.convertor import to_sorted_sparse_vector
|
||||
from dashvector.core.proto import dashvector_pb2
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
def convert_vector_query(vector_query, type: VectorType):
|
||||
returned_vector_query = dashvector_pb2.VectorQuery()
|
||||
if isinstance(vector_query.vector, list):
|
||||
returned_vector_query.vector.float_vector.values.extend(vector_query.vector)
|
||||
elif isinstance(vector_query.vector, bytes):
|
||||
returned_vector_query.vector.byte_vector = vector_query.vector
|
||||
elif isinstance(vector_query.vector, np.ndarray):
|
||||
if type == VectorType.INT:
|
||||
data_format_type = VectorType.get_vector_data_format(type)
|
||||
vector_query.vector = np.ascontiguousarray(vector_query.vector, dtype=f"<{data_format_type}").tobytes()
|
||||
returned_vector_query.vector.byte_vector = vector_query.vector
|
||||
else:
|
||||
vector_query.vector = list(vector_query.vector)
|
||||
returned_vector_query.vector.float_vector.values.extend(vector_query.vector)
|
||||
return returned_vector_query
|
||||
|
||||
def convert_vector_query_from_pb(vector_query: dashvector_pb2.VectorQuery):
|
||||
if vector_query.vector.HasField("float_vector"):
|
||||
vector = np.array(vector_query.vector.float_vector.values)
|
||||
returned_vector_query = VectorQuery(vector=vector)
|
||||
return returned_vector_query
|
||||
elif vector_query.vector.HasField("byte_vector"):
|
||||
data_format_type = "b"
|
||||
vector = np.frombuffer(vector_query.vector.byte_vector, dtype=f"<{data_format_type}")
|
||||
returned_vector_query = VectorQuery(vector=vector.tolist())
|
||||
return returned_vector_query
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK vector_query.vector type is invalid.")
|
||||
|
||||
class BasicVectorValidator(ABC):
|
||||
def __init__(self, collection_meta: CollectionMeta):
|
||||
self._collection_meta = collection_meta
|
||||
|
||||
@abstractmethod
|
||||
def validate_collection_vectors(
|
||||
self,
|
||||
vectors: Union[None, VectorParam, Dict[str, VectorParam]],
|
||||
sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]],
|
||||
*,
|
||||
dimension: int = 0,
|
||||
dtype: VectorDataType = None,
|
||||
metric: str = "cosine",
|
||||
doc_op: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_query_vectors(self, vector, top_k: int, query_request, doc_op: str):
|
||||
pass
|
||||
|
||||
class ReserveVectorValidator(BasicVectorValidator):
|
||||
|
||||
def __init__(self, collection_meta: CollectionMeta, vector_name: str = DASHVECTOR_VECTOR_NAME):
|
||||
if collection_meta is not None:
|
||||
super().__init__(collection_meta)
|
||||
if(vector_name != DASHVECTOR_VECTOR_NAME):
|
||||
self._dimension = self._collection_meta.get_dimension(vector_name)
|
||||
self._dtype = VectorType.get(self._collection_meta.get_dtype(vector_name))
|
||||
else:
|
||||
self._dimension = self._collection_meta.dimension
|
||||
self._dtype = VectorType.get(self._collection_meta.dtype)
|
||||
|
||||
def validate_collection_vectors(
|
||||
self,
|
||||
vectors: Union[None, VectorParam, Dict[str, VectorParam]],
|
||||
sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]],
|
||||
*,
|
||||
dimension: int = 0,
|
||||
dtype: VectorDataType = None,
|
||||
metric: str = "cosine",
|
||||
doc_op: str):
|
||||
if sparse_vectors is not None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} single vector sparse_vectors type({type(sparse_vectors)}) is invalid and must be None",
|
||||
)
|
||||
if vectors is None:
|
||||
vectors = {"": VectorParam(dimension=dimension, dtype=dtype, metric=metric)}
|
||||
elif isinstance(vectors, VectorParam):
|
||||
vectors = {"": vectors}
|
||||
|
||||
if not isinstance(vectors, dict):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vectors type({type(vectors)}) is invalid and must be dict"
|
||||
)
|
||||
vectors[""].validate()
|
||||
if vectors[""].dimension <= 1 or vectors[""].dimension > 20000:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidDimension,
|
||||
reason=f"DashVectorSDK VectorParam dimension value({vectors[''].dimension}) is invalid and must be in (1, 20000]",
|
||||
)
|
||||
return vectors, dict()
|
||||
|
||||
def validate_query_vectors(self, vector, top_k: int, query_request, doc_op: str):
|
||||
if isinstance(vector, VectorQuery):
|
||||
vector.validate()
|
||||
elif isinstance(vector, list) or isinstance(vector, np.ndarray):
|
||||
vector = VectorQuery(vector=vector, num_candidates=top_k)
|
||||
vector.validate()
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} single vector type({type(vector)}) is invalid and must be [VectorQuery, VectorValueType]",
|
||||
)
|
||||
vector.vector = Validator.validate_dense_vector(vector.vector, self._dimension, self._dtype, doc_op)
|
||||
converted_vector_query = convert_vector_query(vector, self._dtype)
|
||||
returned_query_request = query_request
|
||||
if isinstance(returned_query_request, dashvector_pb2.QueryDocRequest):
|
||||
returned_query_request.vectors[DASHVECTOR_VECTOR_NAME].CopyFrom(converted_vector_query)
|
||||
elif isinstance(returned_query_request, dashvector_pb2.QueryDocGroupByRequest):
|
||||
returned_query_request.vector.CopyFrom(converted_vector_query.vector)
|
||||
return returned_query_request
|
||||
|
||||
class MultiVectorValidator(BasicVectorValidator):
|
||||
|
||||
def validate_collection_vectors(
|
||||
self,
|
||||
vectors: Union[None, VectorParam, Dict[str, VectorParam]],
|
||||
sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]],
|
||||
*,
|
||||
dimension: int = 0,
|
||||
dtype: VectorDataType = None,
|
||||
metric: str = "cosine",
|
||||
doc_op: str):
|
||||
if vectors is None and sparse_vectors is None:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vectors and sparse_vectors are all empty",
|
||||
)
|
||||
if vectors is None:
|
||||
vectors = dict()
|
||||
|
||||
if isinstance(vectors, dict):
|
||||
for vector_name in vectors.keys():
|
||||
Validator.validate_vector_name(vector_name, doc_op)
|
||||
|
||||
if not isinstance(vectors, dict):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vectors type({type(vectors)}) is invalid and must be dict"
|
||||
)
|
||||
for vector_name, vector_param in vectors.items():
|
||||
if not isinstance(vector_param, VectorParam):
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} vector_param type({type(vector_param)}) is invalid and must be VectorParam",
|
||||
)
|
||||
vector_param.validate()
|
||||
if vector_param.dimension <= 1 or vector_param.dimension > 20000:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidDimension,
|
||||
reason=f"DashVectorSDK VectorParam dimension value({vector_param.dimension}) is invalid and must be in (1, 20000]",
|
||||
)
|
||||
sparse_vectors = Validator.validate_sparse_vectors(sparse_vectors, doc_op)
|
||||
if len(vectors) + len(sparse_vectors) > 4:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidField,
|
||||
reason=f"DashVectorSDK {doc_op} vectors length({len(vectors) + len(sparse_vectors)}) is invalid and must be in [0, 4]",
|
||||
)
|
||||
return vectors, sparse_vectors
|
||||
|
||||
def validate_query_vectors(self, vector, top_k: int, query_request, doc_op: str):
|
||||
vector_queries = None
|
||||
if isinstance(vector, dict):
|
||||
vector_queries = vector
|
||||
else:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidArgument,
|
||||
reason=f"DashVectorSDK {doc_op} multi vector type({type(vector)}) is invalid and must be dict",
|
||||
)
|
||||
|
||||
returned_query_request = query_request
|
||||
for vector_name, vector_query in vector_queries.items():
|
||||
if isinstance(vector_query, list) or isinstance(vector_query, np.ndarray):
|
||||
vector_query = VectorQuery(vector=vector_query)
|
||||
vector_query.validate()
|
||||
vector_query.vector = Validator.validate_dense_vector(vector_query.vector, self._collection_meta.get_dimension(vector_name),
|
||||
VectorType.get(self._collection_meta.get_dtype(vector_name)), doc_op)
|
||||
converted_vector_query = convert_vector_query(vector_query, VectorType.get(self._collection_meta.get_dtype(vector_name)))
|
||||
returned_query_request.vectors[vector_name].CopyFrom(converted_vector_query)
|
||||
|
||||
return returned_query_request
|
||||
|
||||
class ValidatorFactory:
|
||||
@staticmethod
|
||||
def meta_create(collection_meta) -> BasicVectorValidator:
|
||||
if len(collection_meta.vectors_schema) == 1 and DASHVECTOR_VECTOR_NAME in collection_meta.vectors_schema.keys():
|
||||
return ReserveVectorValidator(collection_meta)
|
||||
else:
|
||||
return MultiVectorValidator(collection_meta)
|
||||
|
||||
@staticmethod
|
||||
def input_type_create(vectors, sparse_vectors) -> BasicVectorValidator:
|
||||
if isinstance(vectors, dict) or isinstance(sparse_vectors, dict):
|
||||
return MultiVectorValidator(collection_meta=None)
|
||||
else:
|
||||
return ReserveVectorValidator(collection_meta=None)
|
||||
|
||||
class SparseVectorChecker:
|
||||
def __init__(self, collection_meta: CollectionMeta, vector_name: str, vector_query: SparseVectorQuery, doc_op: str):
|
||||
if not vector_name:
|
||||
self._dtype = VectorType.get(collection_meta.dtype)
|
||||
self._dimension = collection_meta.dimension
|
||||
else:
|
||||
self._dtype = VectorType.get(collection_meta.get_dtype(vector_name))
|
||||
self._dimension = collection_meta.get_dimension(vector_name)
|
||||
|
||||
vector_query.validate()
|
||||
self._sparse_vector = vector_query
|
||||
self.vector_query = dashvector_pb2.SparseVectorQuery()
|
||||
for key, value in to_sorted_sparse_vector(self._sparse_vector.vector).items():
|
||||
self.vector_query.sparse_vector.sparse_vector[key] = value
|
||||
if len(self.vector_query.sparse_vector.sparse_vector) == 0:
|
||||
raise DashVectorException(
|
||||
code=DashVectorCode.InvalidSparseValues,
|
||||
reason=f"DashVectorSDK {doc_op} not supports query with empty sparse_vector",
|
||||
)
|
||||
|
||||
param = self.vector_query.param
|
||||
param.num_candidates = vector_query.num_candidates
|
||||
param.ef = vector_query.ef
|
||||
param.is_linear = vector_query.is_linear
|
||||
param.radius = vector_query.radius
|
||||
Reference in New Issue
Block a user