This commit is contained in:
“shengyudong”
2026-01-06 14:18:39 +08:00
commit 5a384b694e
10345 changed files with 2050918 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
##
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,414 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple, Type, Union, NewType
import re
import numpy as np
from dashvector.common.constants import *
from dashvector.common.types import *
from dashvector.common.error import DashVectorCode, DashVectorException
from dashvector.core.doc import Doc
from dashvector.core.models.collection_meta_status import CollectionMeta
from dashvector.util.convertor import to_sorted_sparse_vector, to_sorted_sparse_vectors
from dashvector.core.proto import dashvector_pb2
class Validator():
@staticmethod
def validate_dense_vector(vector, dimension: int, dtype: VectorDataType, doc_op: str):
if isinstance(vector, list):
if len(vector) != dimension:
raise DashVectorException(
code=DashVectorCode.MismatchedDimension,
reason=f"DashVectorSDK {doc_op} vector list length({len(vector)}) is invalid and must be same with collection dimension({dimension})",
)
vector_data_type = VectorType.get_vector_data_type(type(vector[0]))
if vector_data_type != dtype:
raise DashVectorException(
code=DashVectorCode.MismatchedDataType,
reason=f"DashVectorSDK {doc_op} vector type({type(vector[0])}) is invalid and must be {VectorType.get_python_type(dtype)}",
)
if vector_data_type == VectorType.INT:
try:
vector = VectorType.convert_to_bytes(vector, dtype, dimension)
except Exception as e:
raise DashVectorException(
code=DashVectorCode.InvalidVectorFormat,
reason=f"DashVectorSDK {doc_op} vector value({vector}) is invalid and int value must be in [-128, 127]",
)
elif isinstance(vector, np.ndarray):
if vector.ndim != 1:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vector numpy dimension({vector.ndim}) is invalid and must be 1",
)
if vector.shape[0] != dimension:
raise DashVectorException(
code=DashVectorCode.MismatchedDimension,
reason=f"DashVectorSDK {doc_op} vector numpy shape[0]({vector.shape[0]}) is invalid and must be same with collection dimension({dimension})",
)
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vector type({type(vector)}) is invalid and must be [list, numpy.ndarray]",
)
return vector
@staticmethod
def validate_collection_name(name: str, doc_op: str):
if not isinstance(name, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} name type({name}) is invalid and must be str",
)
if re.search(COLLECTION_AND_PARTITION_NAME_PATTERN, name) is None:
raise DashVectorException(
code=DashVectorCode.InvalidCollectionName,
reason=f"DashVectorSDK {doc_op} name characters({name}) is invalid and "
+ COLLECTION_AND_PARTITION_NAME_PATTERN_MSG,
)
return name
@staticmethod
def validate_partition_name(partition_name: str, doc_op: str):
if not isinstance(partition_name, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} partition name type({partition_name}) is invalid and must be str",
)
if re.search(COLLECTION_AND_PARTITION_NAME_PATTERN, partition_name) is None:
raise DashVectorException(
code=DashVectorCode.InvalidPartitionName,
reason=f"DashVectorSDK {doc_op} partition characters({partition_name}) is invalid and "
+ COLLECTION_AND_PARTITION_NAME_PATTERN_MSG,
)
return partition_name
@staticmethod
def validate_vector_name(vector_name: str, doc_op: str):
if not isinstance(vector_name, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vector_name type({type(vector_name)}) is invalid and must be str",
)
if re.search(FIELD_NAME_PATTERN, vector_name) is None:
raise DashVectorException(
code=DashVectorCode.InvalidFieldName,
reason=f"DashVectorSDK {doc_op} vector_name characters({vector_name}) is invalid and "
+ FIELD_NAME_PATTERN_MSG,
)
@staticmethod
def validate_sparse_vectors(sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]], doc_op: str):
if sparse_vectors is None:
sparse_vectors = dict()
if isinstance(sparse_vectors, dict):
for vector_name in sparse_vectors.keys():
Validator.validate_vector_name(vector_name, doc_op)
if not isinstance(sparse_vectors, dict):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} sparse_vectors type({type(sparse_vectors)}) is invalid and must be dict"
)
for vector_name, vector_param in sparse_vectors.items():
if not isinstance(vector_param, VectorParam):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vector_param type({type(vector_param)}) is invalid and must be VectorParam",
)
vector_param.validate()
if vector_param.dimension != 0:
raise DashVectorException(
code=DashVectorCode.InvalidDimension,
reason=f"DashVectorSDK VectorParam dimension value({vector_param.dimension}) for sparse vector is invalid and must be 0",
)
if(len(vector_param.quantize_type) > 0):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vector_name({vector_name}), quantize_type({vector_param.quantize_type}) for sparse vector is invalid and must be empty",
)
return sparse_vectors
@staticmethod
def validate_fields_schema(fields_schema: Optional[FieldSchemaDict], doc_op: str):
returned_fields_schema = dict()
if fields_schema is not None:
if not isinstance(fields_schema, dict):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} fields_schema type({type(fields_schema)}) is invalid and must be dict",
)
if len(fields_schema) > 1024:
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK {doc_op} fields_schema length({len(fields_schema)}) is invalid and must be in [0, 1024]",
)
for field_name, field_dtype in fields_schema.items():
if not isinstance(field_name, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} field_name in fields_schema type({type(field_name)}) is invalid and must be str",
)
if re.search(FIELD_NAME_PATTERN, field_name) is None:
raise DashVectorException(
code=DashVectorCode.InvalidFieldName,
reason=f"DashVectorSDK {doc_op} field_name in fields_schema characters({field_name}) is invalid and "
+ FIELD_NAME_PATTERN_MSG,
)
if field_name == DASHVECTOR_VECTOR_NAME:
raise DashVectorException(
code=DashVectorCode.InvalidFieldName,
reason=f"DashVectorSDK {doc_op} field_name in fields_schema value({DASHVECTOR_VECTOR_NAME}) is reserved",
)
ftype = FieldType.get_field_data_type(field_dtype)
returned_fields_schema[field_name] = ftype
return returned_fields_schema
@staticmethod
def validate_extra_params(extra_params: Optional[Dict[str, Any]], doc_op: str):
returned_extra_params = dict()
if extra_params is not None:
if not isinstance(extra_params, dict):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} extra_params type({type(extra_params)}) is invalid and must be dict",
)
extra_params_is_empty = True
for extra_param_key, extra_param_value in extra_params.items():
extra_params_is_empty = False
if not isinstance(extra_param_key, str) or not isinstance(extra_param_value, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} extra_param key/value type is invalid and must be str.",
)
if len(extra_param_key) <= 0:
raise DashVectorException(
code=DashVectorCode.InvalidExtraParam,
reason=f"DashVectorSDK {doc_op} extra_param key is empty",
)
if not extra_params_is_empty:
returned_extra_params = extra_params
return returned_extra_params
@staticmethod
def validate_doc_ids(ids: IdsType, doc_op: str):
returned_ids = list()
returned_ids_is_single = False
if isinstance(ids, list):
if len(ids) < 1 or len(ids) > 1024:
raise DashVectorException(
code=DashVectorCode.ExceedIdsLimit,
reason=f"DashVectorSDK {doc_op} ids list length({len(ids)}) is invalid and must be in [1, 1024]",
)
for id in ids:
if isinstance(id, str):
if re.search(DOC_ID_PATTERN, id) is None:
raise DashVectorException(
code=DashVectorCode.InvalidPrimaryKey,
reason=f"DashVectorSDK {doc_op} id in ids list characters({id}) is invalid and "
+ DOC_ID_PATTERN_MSG,
)
returned_ids.append(id)
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} id in ids list type({type(id)}) is invalid and must be str",
)
elif isinstance(ids, str):
if re.search(DOC_ID_PATTERN, ids) is None:
raise DashVectorException(
code=DashVectorCode.InvalidPrimaryKey,
reason=f"DashVectorSDK {doc_op} ids str characters({ids}) is invalid and "
+ DOC_ID_PATTERN_MSG,
)
returned_ids.append(ids)
returned_ids_is_single = True
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} ids type({type(ids)}) is invalid and must be [str, List[str]]",
)
return returned_ids, returned_ids_is_single
@staticmethod
def validate_id(id: str, doc_op: str):
if isinstance(id, str):
if re.search(DOC_ID_PATTERN, id) is None:
raise DashVectorException(
code=DashVectorCode.InvalidPrimaryKey,
reason=f"DashVectorSDK {doc_op} id str characters({id}) is invalid and "
+ DOC_ID_PATTERN_MSG,
)
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} expect id to be <str> but actual type is ({type(id)})",
)
return id
@staticmethod
def validate_output_fields(output_fields: Optional[List[str]], doc_op: str):
returned_output_fields = list()
if output_fields is not None:
if isinstance(output_fields, list):
for output_field in output_fields:
if not isinstance(output_field, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} output_field in output_fields type({type(output_field)}) is invalid and must be list[str]",
)
if re.search(FIELD_NAME_PATTERN, output_field) is None:
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK {doc_op} output_field in output_fields characters({output_field}) is invalid and "
+ FIELD_NAME_PATTERN_MSG,
)
returned_output_fields = output_fields
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} output_fields type({type(output_fields)}) is invalid and must be List[str]",
)
return returned_output_fields
@staticmethod
def validate_rerank(rerank: BaseRanker, query_request: dashvector_pb2.QueryDocRequest, doc_op: str):
if isinstance(rerank, WeightedRanker):
weight_keys = sorted(rerank.weights.keys())
query_vectors_keys = sorted(sorted(query_request.vectors.keys()) + sorted(query_request.sparse_vectors.keys()))
if weight_keys is not None and len(weight_keys) > 0 and weight_keys != query_vectors_keys:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} expect WeightedRanker.weights({rerank.weights}) to exactly match all vector names({query_vectors_keys})"
)
elif isinstance(rerank, RrfRanker):
rank_constant = rerank.rank_constant
if not isinstance(rank_constant, int) or rank_constant < 0 or rank_constant >= 2**31:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} expect RrfRanker.rank_constant({rank_constant}) to be positive int32"
)
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} expect rerank type to be WeightedRanker or RrfRanker, actual type({type(rerank)})"
)
return rerank
@staticmethod
def validate_include_vector(include_vector: bool, doc_op: str):
if not isinstance(include_vector, bool):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} include_vector type({type(include_vector)}) is invalid and must be bool",
)
return include_vector
@staticmethod
def validate_topk(topk: int, include_vector: bool, doc_op: str):
include_vector = Validator.validate_include_vector(include_vector, doc_op)
if not isinstance(topk, int):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} topk type({type(topk)}) is invalid and must be int",
)
if topk < 1 or (include_vector and topk > 1024):
raise DashVectorException(
code=DashVectorCode.InvalidTopk,
reason=f"DashVectorSDK {doc_op} topk value({topk}) is invalid and must be in [1, 1024] "
f"when include_vector is True",
)
return topk
@staticmethod
def validate_filter(filter: str, doc_op: str):
if filter is not None:
if not isinstance(filter, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} filter type({type(filter)}) is invalid and must be str",
)
if len(filter) > 40960:
raise DashVectorException(
code=DashVectorCode.InvalidFilter,
reason=f"DashVectorSDK {doc_op} filter length({len(filter)}) is invalid and must be in [0, 40960]",
)
if len(filter) > 0:
return filter
else:
return None
else:
return None
@staticmethod
def validate_doc(doc: Doc, meta: CollectionMeta, action: str, doc_op: str):
if doc.id is not None:
if not isinstance(doc.id, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} id type({type(doc.id)}) is invalid and must be str",
)
if re.search(DOC_ID_PATTERN, doc.id) is None:
raise DashVectorException(
code=DashVectorCode.InvalidPrimaryKey,
reason=f"DashVectorSDK {doc_op} id characters({doc.id}) is invalid and "
+ DOC_ID_PATTERN_MSG,
)
else:
if action == "update":
raise DashVectorException(
code=DashVectorCode.InvalidPrimaryKey,
reason=f"DashVectorSDK {doc_op} id({doc.id}) is required when the action is update",
)
returned_id = doc.id
returned_vector = None
if doc.vector is not None:
returned_vector = Validator.validate_dense_vector(doc.vector, meta.get_dimension(), VectorType.get(meta.get_dtype()), doc_op)
returned_vectors = dict()
if doc.vectors is not None:
for key, value in doc.vectors.items():
returned_vectors[key] = Validator.validate_dense_vector(value, meta.get_dimension(vector_name=key), VectorType.get(meta.get_dtype(vector_name=key)), doc_op)
if action != "update" and doc.vector is None and doc.vectors is None and doc.sparse_vectors is None:
raise DashVectorException(
code=DashVectorCode.InvalidVectorFormat,
reason=f"DashVectorSDK {doc_op} vector is required and must be in [list, numpy.ndarray] when request in [insert, upsert]",
)
# check fields
if doc.fields is None:
pass
elif not isinstance(doc.fields, dict):
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK {doc_op} fields type({type(doc.fields)}) is invalid",
)
elif len(doc.fields) > 1024:
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK {doc_op} fields length({len(doc.fields)}) is invalid and must be in [1, 1024]",
)
returned_sparse_vector = to_sorted_sparse_vector(doc.sparse_vector)
returned_sparse_vectors = to_sorted_sparse_vectors(doc.sparse_vectors)
return Doc(id=returned_id, vector=returned_vector, vectors=returned_vectors, fields=doc.fields, sparse_vector=returned_sparse_vector, sparse_vectors=returned_sparse_vectors)

View File

@@ -0,0 +1,31 @@
##
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##
# -*- coding: utf-8 -*-
DASHVECTOR_VECTOR_NAME = "proxima_vector"
DASHVECTOR_LOGGING_LEVEL_ENV = "DASHVECTOR_LOGGING_LEVEL"
FIELD_NAME_PATTERN = "^[a-zA-Z0-9_-]{1,32}$"
FIELD_NAME_PATTERN_MSG = "character must be in [a-zA-Z0-9] and symbols[_, -] and length must be in [1,32]"
COLLECTION_AND_PARTITION_NAME_PATTERN = "^[a-zA-Z0-9_-]{3,32}$"
COLLECTION_AND_PARTITION_NAME_PATTERN_MSG = (
"character must be in [a-zA-Z0-9] and symbols[_, -] and length must be in [3,32]"
)
DOC_ID_PATTERN = "^[a-zA-Z0-9_\\-!@#$%+=.]{1,64}$"
DOC_ID_PATTERN_MSG = "character must be in [a-zA-Z0-9] and symbols[_-!@#$%+=.] and length must be in [1, 64]"
GRPC_MAX_MSG_SIZE = 128 * 1024 * 1024
MAX_INT_VALUE = 2 ** 31 - 1
MIN_INT_VALUE = -2 ** 31

View File

@@ -0,0 +1,133 @@
##
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##
# -*- coding: utf-8 -*-
import http
from enum import IntEnum
import grpc
class DashVectorCode(IntEnum):
Timeout = 408
Success = 0
Closed = -998
Unknown = -999
EmptyCollectionName = -2000
EmptyColumnName = -2001
EmptyPartitionName = -2002
EmptyColumns = -2003
EmptyPrimaryKey = -2004
EmptyDocList = -2005
EmptyDocFields = -2006
EmptyIndexField = -2007
InvalidRecord = -2008
InvalidQuery = -2009
InvalidWriteRequest = -2010
InvalidVectorFormat = -2011
InvalidDataType = -2012
InvalidIndexType = -2013
InvalidFeature = -2014
InvalidFilter = -2015
InvalidPrimaryKey = -2016
InvalidField = -2017
MismatchedIndexColumn = -2018
MismatchedDimension = -2019
MismatchedDataType = -2020
InexistentCollection = -2021
InexistentPartition = -2022
InexistentColumn = -2023
InexistentKey = -2024
DuplicateCollection = -2025
DuplicatePartition = -2026
DuplicateKey = -2027
DuplicateField = -2028
UnreadyPartition = -2029
UnreadyCollection = -2030
UnsupportedCondition = -2031
OrderbyNotInSelectItems = -2032
PbToSqlInfoError = -2033
ExceedRateLimit = -2034
InvalidSparseValues = -2035
InvalidBatchSize = -2036
InvalidDimension = -2037
InvalidExtraParam = -2038
InvalidRadius = -2039
InvalidLinear = -2040
InvalidTopk = -2041
InvalidCollectionName = -2042
InvalidPartitionName = -2043
InvalidFieldName = -2044
InvalidChannelCount = -2045
InvalidReplicaCount = -2046
InvalidJson = -2047
InvalidGroupBy = -2053,
InvalidSparseIndices = -2951
InvalidEndpoint = -2952
ExceedIdsLimit = -2967
InvalidVectorType = -2968
ExceedRequestSize = -2970
ExistVectorAndId = -2973
InvalidArgument = -2999
class DashVectorException(Exception):
"""
DashVector Exception
"""
def __init__(self, code=DashVectorCode.Unknown, reason=None, request_id=None):
self._code = code
self._reason = "DashVectorSDK unknown exception" if reason is None else reason
self._request_id = request_id
super().__init__(f"{self._reason}({self._code})")
@property
def code(self):
return self._code
@property
def message(self):
return self._reason
@property
def request_id(self):
if self._request_id is None:
return ""
return self._request_id
class DashVectorHTTPException(DashVectorException):
def __new__(cls, code, reason=None, request_id=None):
exception_code = code
exception_reason = reason
if isinstance(code, http.HTTPStatus):
exception_code = code.value
exception_reason = f"DashVectorSDK http rpc error: {code.phrase}"
return DashVectorException(code=exception_code, reason=exception_reason, request_id=request_id)
class DashVectorGRPCException(DashVectorException):
def __new__(cls, code, reason=None, request_id=None):
exception_code = code
exception_reason = reason
if isinstance(code, grpc.StatusCode):
exception_code = code.value[0]
exception_reason = f"DashVectorSDK grpc rpc error: {code.value[1]}"
return DashVectorException(code=exception_code, reason=exception_reason, request_id=request_id)

View File

@@ -0,0 +1,212 @@
##
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##
# -*- coding: utf-8 -*-
import os
import platform
import re
from abc import abstractmethod
from google.protobuf.json_format import MessageToJson
from google.protobuf.message import Message
from dashvector.common.error import DashVectorCode, DashVectorException
from dashvector.util.validator import verify_endpoint
from dashvector.version import __version__
class RPCRequest(object):
def __init__(self, *, request: Message):
self.request = request
self.request_str = request.SerializeToString()
self.request_len = len(self.request_str)
if self.request_len > (2 * 1024 * 1024):
raise DashVectorException(
code=DashVectorCode.ExceedRequestSize,
reason=f"DashVectorSDK request length({self.request_len}) exceeds maximum length(2MiB) limit",
)
def to_json(self):
return MessageToJson(self.request, always_print_fields_with_no_presence=True, preserving_proto_field_name=True)
def to_proto(self):
return self.request
def to_string(self):
return self.request_str
class RPCResponse(object):
def __init__(self, *, async_req):
self._async_req = async_req
self._request_id = None
self._code = DashVectorCode.Unknown
self._message = None
self._output = None
self._usage = None
@property
def async_req(self):
return self._async_req
@property
def request_id(self):
return self._request_id
@property
def code(self):
return self._code
@property
def message(self):
return self._message
@property
def output(self):
return self._output
@property
def usage(self):
return self._usage
@abstractmethod
def get(self):
pass
class RPCHandler(object):
def __init__(self, *, endpoint: str = "", api_key: str = "", timeout: float = 10.0):
"""
endpoint: str
"""
if not isinstance(endpoint, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK RPCHandler endpoint({endpoint}) is invalid and must be str",
)
if not verify_endpoint(endpoint):
raise DashVectorException(
code=DashVectorCode.InvalidEndpoint,
reason=f"DashVectorSDK RPCHandler endpoint({endpoint}) is invalid and cannot contain protocol header and [_]",
)
"""
api_key: str
"""
if not isinstance(api_key, str):
raise DashVectorException(
code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK RPCHandler api_key({api_key}) is invalid"
)
"""
timeout: float
"""
if isinstance(timeout, float):
pass
elif isinstance(timeout, int):
timeout = float(timeout)
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK RPCHandler timeout({timeout}) is invalid"
)
if timeout <= 0.000001:
timeout = 365.5 * 86400
self._endpoint = endpoint
self._timeout = timeout
self._insecure_mode = os.getenv("DASHVECTOR_INSECURE_MODE", "False").lower() in ("true", "1")
self._headers = {
"dashvector-auth-token": api_key,
"x-user-agent": f"{__version__};{platform.python_version()};{platform.platform()}",
}
@abstractmethod
def create_collection(self, create_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def delete_collection(self, delete_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def describe_collection(self, describe_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def list_collections(self, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def stats_collection(self, stats_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def create_partition(self, create_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def delete_partition(self, delete_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def describe_partition(self, describe_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def list_partitions(self, list_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def stats_partition(self, stats_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def insert_doc(self, insert_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def update_doc(self, update_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def upsert_doc(self, upsert_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def delete_doc(self, delete_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def query_doc(self, query_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def query_doc_group_by(self, query_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def fetch_doc(self, fetch_request, *, async_req=False) -> RPCResponse:
pass
@abstractmethod
def get_version(self, *, async_req) -> RPCResponse:
pass
@abstractmethod
def close(self) -> None:
pass

View File

@@ -0,0 +1,50 @@
##
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##
# -*- coding: utf-8 -*-
import logging
import os
from dashvector.common.constants import DASHVECTOR_LOGGING_LEVEL_ENV
logger = logging.getLogger("dashvector")
def enable_logging():
level = os.environ.get(DASHVECTOR_LOGGING_LEVEL_ENV, None)
if level is not None: # set logging level.
if level not in ["info", "debug"]:
# set logging level env, but invalid value, use default.
level = "info"
if level == "info":
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.DEBUG)
# set default logging handler
console_handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s - %(message)s"
# noqa E501
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# in release disable dashscope log
# you can enable dashscope log for debugger.
enable_logging()

View File

@@ -0,0 +1,61 @@
##
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##
# -*- coding: utf-8 -*-
from enum import Enum, IntEnum
from typing import Union
from dashvector.common.error import DashVectorCode, DashVectorException
class StrStatus(str, Enum):
INITIALIZED = "INITIALIZED"
SERVING = "SERVING"
DROPPING = "DROPPING"
ERROR = "ERROR"
class Status(IntEnum):
INITIALIZED = 0
SERVING = 1
DROPPING = 2
ERROR = 3
@staticmethod
def get(cs: Union[str, StrStatus]) -> IntEnum:
if cs == StrStatus.INITIALIZED:
return Status.INITIALIZED
elif cs == StrStatus.SERVING:
return Status.SERVING
elif cs == StrStatus.DROPPING:
return Status.DROPPING
elif cs == StrStatus.ERROR:
return Status.ERROR
raise DashVectorException(code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK get invalid status {cs}")
@staticmethod
def str(cs: Union[int, IntEnum]) -> str:
if cs == Status.INITIALIZED:
return StrStatus.INITIALIZED.value
elif cs == Status.SERVING:
return StrStatus.SERVING.value
elif cs == Status.DROPPING:
return StrStatus.DROPPING.value
elif cs == Status.ERROR:
return StrStatus.ERROR.value
raise DashVectorException(code=DashVectorCode.InvalidArgument, reason=f"DashVectorSDK get invalid Status {cs}")

View File

@@ -0,0 +1,837 @@
##
# Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##
# -*- coding: utf-8 -*-
import json
import struct
from enum import Enum, IntEnum
from typing import Any, Dict, List, Optional, Tuple, Type, Union, NewType
import warnings
import numpy as np
from dashvector.common.error import DashVectorCode, DashVectorException
from dashvector.common.handler import RPCResponse
from dashvector.common.status import Status
from dashvector.util.convertor import to_json_without_ascii
from dashvector.core.proto import dashvector_pb2
long = NewType("long", int)
VectorDataType = Union[
Type[int],
Type[float],
Type[bool],
Type[np.int8],
Type[np.int16],
Type[np.float16],
Type[np.bool_],
Type[np.float32],
Type[np.float64],
]
VectorValueType = Union[List[int], List[float], np.ndarray]
SparseValueType = Dict[int, float]
supported_type_msg = ("bool | str | int | float | long | "
"typing.List[str] | typing.List[int] | typing.List[float] | typing.List[long]")
# used to define schema
FieldSchemaType = Union[
Type[long], Type[str], Type[bool], Type[int], Type[float],
Type[List[long]], Type[List[str]], Type[List[int]], Type[List[float]]
]
# used to insert field data
FieldDataType = Union[long, str, int, float, bool, List[long], List[str], List[int], List[float]]
FieldSchemaDict = Dict[str, FieldSchemaType]
FieldDataDict = Dict[str, FieldDataType]
IdsType = Union[str, List[str]]
class DashVectorProtocol(IntEnum):
GRPC = 0
HTTP = 1
class DocOp(IntEnum):
insert = 0
update = 1
upsert = 2
delete = 3
class MetricStrType(str, Enum):
EUCLIDEAN = "euclidean"
DOTPRODUCT = "dotproduct"
COSINE = "cosine"
class MetricType(IntEnum):
EUCLIDEAN = 0
DOTPRODUCT = 1
COSINE = 2
@staticmethod
def get(mtype: Union[str, MetricStrType]) -> IntEnum:
if mtype == MetricStrType.EUCLIDEAN:
return MetricType.EUCLIDEAN
elif mtype == MetricStrType.DOTPRODUCT:
return MetricType.DOTPRODUCT
elif mtype == MetricStrType.COSINE:
return MetricType.COSINE
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK get invalid metrictype {mtype} and must be in [cosine, dotproduct, euclidean]",
)
@staticmethod
def str(mtype: Union[int, IntEnum]) -> str:
if mtype == MetricType.EUCLIDEAN:
return MetricStrType.EUCLIDEAN.value
elif mtype == MetricType.DOTPRODUCT:
return MetricStrType.DOTPRODUCT.value
elif mtype == MetricType.COSINE:
return MetricStrType.COSINE.value
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK get invalid metrictype {mtype} and must be in [cosine, dotproduct, euclidean]",
)
class VectorStrType(str, Enum):
FLOAT = "FLOAT"
INT = "INT"
class VectorType(IntEnum):
FLOAT = 0
INT = 1
@staticmethod
def get(vtype: Union[str, VectorStrType]) -> 'VectorType':
if vtype == VectorStrType.FLOAT:
return VectorType.FLOAT
elif vtype == VectorStrType.INT:
return VectorType.INT
else:
raise DashVectorException(
code=DashVectorCode.InvalidVectorType,
reason=f"DashVectorSDK get invalid vectortype {vtype} and must be in [int, float]",
)
@staticmethod
def str(vtype: Union[int, IntEnum]) -> str:
if vtype == VectorType.FLOAT:
return VectorStrType.FLOAT.value
elif vtype == VectorType.INT:
return VectorStrType.INT.value
raise DashVectorException(
code=DashVectorCode.InvalidVectorType,
reason=f"DashVectorSDK get invalid vectortype {vtype} and must be in [int, float]]",
)
@staticmethod
def get_vector_data_type(vtype: Type):
if not isinstance(vtype, type):
raise DashVectorException(
code=DashVectorCode.InvalidVectorType,
reason=f"DashVectorSDK does not support vector data type {vtype} and must be in [int, float]",
)
if vtype not in _vector_dtype_map:
raise DashVectorException(
code=DashVectorCode.InvalidVectorType,
reason=f"DashVectorSDK does not support vector data type {vtype} and must be in [int, float]",
)
return _vector_dtype_map[vtype]
@staticmethod
def get_vector_data_format(data_type):
if data_type not in (VectorType.INT, VectorType.FLOAT):
raise DashVectorException(
code=DashVectorCode.InvalidVectorType,
reason=f"DashVectorSDK does not support vector({data_type}) to convert bytes",
)
return _vector_type_to_format[data_type]
@staticmethod
def convert_to_bytes(feature, data_type, dimension):
if data_type not in (VectorType.INT, VectorType.FLOAT):
raise DashVectorException(
code=DashVectorCode.InvalidVectorType,
reason=f"DashVectorSDK does not support auto pack feature type({data_type})",
)
return struct.pack(f"<{dimension}{_vector_type_to_format[data_type]}", *feature)
@staticmethod
def convert_to_dtype(feature, data_type, dimension):
if data_type not in (VectorType.INT, VectorType.FLOAT):
raise DashVectorException(
code=DashVectorCode.InvalidVectorType,
reason=f"DashVectorSDK does not support auto unpack feature type({data_type})",
)
return struct.unpack(f"<{dimension}{_vector_type_to_format[data_type]}", feature)
@property
def indices(self):
return self._indices
@property
def values(self):
return self._values
def __dict__(self):
return {"indices": self.indices, "values": self.values}
def get_python_type(self):
return _reverse_vector_dtype_map[self]
class FieldStrType(str, Enum):
BOOL = "BOOL"
STRING = "STRING"
INT = "INT"
FLOAT = "FLOAT"
LONG = "LONG"
ARRAY_STRING = "ARRAY_STRING"
ARRAY_INT = "ARRAY_INT"
ARRAY_FLOAT = "ARRAY_FLOAT"
ARRAY_LONG = "ARRAY_LONG"
class FieldType(IntEnum):
BOOL = 0
STRING = 1
INT = 2
FLOAT = 3
LONG = 4
ARRAY_STRING = 11
ARRAY_INT = 12
ARRAY_FLOAT = 13
ARRAY_LONG = 14
@staticmethod
def get(ftype: Union[str, FieldStrType]) -> IntEnum:
if ftype == FieldStrType.BOOL:
return FieldType.BOOL
elif ftype == FieldStrType.STRING:
return FieldType.STRING
elif ftype == FieldStrType.INT:
return FieldType.INT
elif ftype == FieldStrType.FLOAT:
return FieldType.FLOAT
elif ftype == FieldStrType.LONG:
return FieldType.LONG
elif ftype == FieldStrType.ARRAY_STRING:
return FieldType.ARRAY_STRING
elif ftype == FieldStrType.ARRAY_INT:
return FieldType.ARRAY_INT
elif ftype == FieldStrType.ARRAY_FLOAT:
return FieldType.ARRAY_FLOAT
elif ftype == FieldStrType.ARRAY_LONG:
return FieldType.ARRAY_LONG
else:
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK does not support field value type {ftype} and must be in {supported_type_msg}"
)
@staticmethod
def str(ftype: Union[int, IntEnum]) -> str:
if ftype == FieldType.BOOL:
return FieldStrType.BOOL.value
elif ftype == FieldType.STRING:
return FieldStrType.STRING.value
elif ftype == FieldType.INT:
return FieldStrType.INT.value
elif ftype == FieldType.FLOAT:
return FieldStrType.FLOAT.value
elif ftype == FieldType.LONG:
return FieldStrType.LONG.value
elif ftype == FieldType.ARRAY_STRING:
return FieldStrType.ARRAY_STRING.value
elif ftype == FieldType.ARRAY_INT:
return FieldStrType.ARRAY_INT.value
elif ftype == FieldType.ARRAY_FLOAT:
return FieldStrType.ARRAY_FLOAT.value
elif ftype == FieldType.ARRAY_LONG:
return FieldStrType.ARRAY_LONG.value
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK does not support field value type {ftype} and must be in {supported_type_msg}"
)
@staticmethod
def get_field_data_type(dtype: FieldSchemaType):
if dtype not in _attr_dtype_map:
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK does not support field value type {dtype} and must be in {supported_type_msg}"
)
return _attr_dtype_map[dtype]
class IndexStrType(str, Enum):
UNDEFINED = "IT_UNDEFINED"
HNSW = "IT_HNSW"
INVERT = "IT_INVERT"
class IndexType(IntEnum):
UNDEFINED = 0
HNSW = 1
INVERT = 10
@staticmethod
def get(itype: Union[str, IndexStrType]):
if itype == IndexStrType.UNDEFINED:
return IndexType.UNDEFINED
elif itype == IndexStrType.HNSW:
return IndexType.HNSW
elif itype == IndexStrType.INVERT:
return IndexType.INVERT
else:
raise DashVectorException(
code=DashVectorCode.InvalidIndexType, reason=f"DashVectorSDK does not support indextype {itype}"
)
@staticmethod
def str(itype: Union[int, IntEnum]) -> str:
if itype == IndexType.UNDEFINED:
return IndexStrType.UNDEFINED.value
elif itype == IndexType.HNSW:
return IndexStrType.HNSW.value
elif itype == IndexType.INVERT:
return IndexStrType.INVERT.value
raise DashVectorException(
code=DashVectorCode.InvalidIndexType, reason=f"DashVectorSDK does not support indextype {itype}"
)
_vector_dtype_map = {
float: VectorType.FLOAT,
int: VectorType.INT,
}
_reverse_vector_dtype_map = {v: k for k,v in _vector_dtype_map.items()}
_vector_type_to_format = {
VectorType.FLOAT: "f",
VectorType.INT: "b",
}
_attr_dtype_map = {
str: FieldType.STRING,
bool: FieldType.BOOL,
int: FieldType.INT,
float: FieldType.FLOAT,
long: FieldType.LONG,
"long": FieldType.LONG,
List[str]: FieldType.ARRAY_STRING,
List[int]: FieldType.ARRAY_INT,
List[float]: FieldType.ARRAY_FLOAT,
List[long]: FieldType.ARRAY_LONG,
}
class DashVectorResponse(object):
def __init__(self, response: Optional[RPCResponse] = None, *, exception: Optional[DashVectorException] = None):
self._code = DashVectorCode.Unknown
self._message = ""
self._request_id = ""
self._output = None
self._usage = None
self.__response = response
self.__exception = exception
if self.__response is None:
self._code = DashVectorCode.Success
if self.__response is not None and not self.__response.async_req:
self.get()
if self.__exception is not None:
self._code = self.__exception.code
self._message = self.__exception.message
self._request_id = self.__exception.request_id
def get(self):
if self._code != DashVectorCode.Unknown:
return self
if self.__response is None:
return self
try:
result = self.__response.get()
self._request_id = result.request_id
self._code = result.code
self._message = result.message
self._output = result.output
self._usage = result.usage
except DashVectorException as e:
self._code = e.code
self._message = e.message
self._request_id = e.request_id
return self
@property
def code(self):
return self._code
@property
def message(self):
return self._message
@property
def request_id(self):
return self._request_id
@property
def output(self):
return self._output
@output.setter
def output(self, value: Any):
self._output = value
@property
def usage(self):
return self._usage
@property
def response(self):
return self.__response
def _decorate_output(self):
if self._output is None:
return {"code": self.code, "message": self.message, "requests_id": self.request_id}
elif isinstance(self._output, Status):
return {
"code": self.code,
"message": self.message,
"requests_id": self.request_id,
"output": Status.str(self._output),
}
elif isinstance(self._output, (str, int, float)):
return {
"code": self.code,
"message": self.message,
"requests_id": self.request_id,
"output": str(self._output),
}
elif isinstance(self._output, list):
output_list = []
for output_value in self._output:
if isinstance(output_value, (str, int, float)):
output_list.append(str(output_value))
elif hasattr(output_value, "__dict__"):
output_list.append(output_value.__dict__())
elif hasattr(output_value, "__str__"):
output_list.append(output_value.__str__())
else:
output_list.append(str(type(output_value)))
return {"code": self.code, "message": self.message, "requests_id": self.request_id, "output": output_list}
elif isinstance(self._output, dict):
output_dict = {}
for output_key, output_value in self._output.items():
if isinstance(output_value, (str, int, float)):
output_dict[output_key] = str(output_value)
elif hasattr(output_value, "__dict__"):
output_dict[output_key] = output_value.__dict__()
elif hasattr(output_value, "__str__"):
output_dict[output_key] = output_value.__str__()
else:
output_dict[output_key] = str(type(output_value))
return {"code": self.code, "message": self.message, "requests_id": self.request_id, "output": output_dict}
elif hasattr(self._output, "__dict__"):
return {
"code": self.code,
"message": self.message,
"requests_id": self.request_id,
"output": self._output.__dict__(),
}
elif hasattr(self._output, "__str__"):
return {
"code": self.code,
"message": self.message,
"requests_id": self.request_id,
"output": self._output.__str__(),
}
else:
return {
"code": self.code,
"message": self.message,
"requests_id": self.request_id,
"output": str(type(self._output)),
}
def __dict__(self):
obj = self._decorate_output()
if self._usage is not None:
obj["usage"] = self._usage.__dict__()
return obj
def __str__(self):
return to_json_without_ascii(self.__dict__())
def __repr__(self):
return self.__str__()
def __bool__(self):
return self.code == DashVectorCode.Success
def __len__(self):
return len(self._output)
def __iter__(self):
return self._output.__iter__()
def __contains__(self, item):
if hasattr(self._output, "__contains__"):
return self.output.__contains__(item)
else:
raise TypeError(f"DashVectorSDK Get argument of type '{type(self.output)}' is not iterable")
def __getitem__(self, item):
if hasattr(self._output, "__getitem__"):
return self.output.__getitem__(item)
else:
raise TypeError(f"DashVectorSDK Get '{type(self.output)}' object is not subscriptable")
class RequestUsage(object):
read_units: int
write_units: int
def __init__(self, *, read_units=None, write_units=None):
self.read_units = read_units
self.write_units = write_units
@staticmethod
def from_pb(usage: dashvector_pb2.RequestUsage):
if usage.HasField("read_units"):
return RequestUsage(read_units=usage.read_units)
elif usage.HasField("write_units"):
return RequestUsage(write_units=usage.write_units)
@staticmethod
def from_dict(usage: dict):
if "read_units" in usage:
return RequestUsage(read_units=usage["read_units"])
elif "write_units" in usage:
return RequestUsage(write_units=usage["write_units"])
def __dict__(self):
if self.read_units is None:
if self.write_units is None:
return {}
else:
return {"write_units": self.write_units}
else:
if self.write_units is None:
return {"read_units": self.read_units}
else:
return {"read_units": self.read_units, "write_units": self.write_units}
def __str__(self):
return json.dumps(self.__dict__())
def __repr__(self):
return self.__str__()
class VectorParam:
def __init__(self,
dimension: int = 0,
dtype: Union[Type[int], Type[float]] = float,
metric: str = "cosine",
quantize_type: str = "",
):
"""
Vector param.
Args:
dimension (int): vector dimension in collection
dtype (Union[Type[int], Type[float]]): vector data type in collection
metric (str): vector metric in collection, support 'cosine', 'dotproduct' and 'euclidean', default to 'cosine'
quantize_type (str): vector quantize type in collection, refer to https://help.aliyun.com/document_detail/2663745.html for latest support types
"""
self._exception = None
"""
dim: int
"""
if not isinstance(dimension, int):
self._exception = DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK VectorParam dimension type({type(dimension)}) is invalid and must be int",
)
return
self.dimension = dimension
"""
metric: MetricType
"""
if not isinstance(metric, str):
self._exception = DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK VectorParam metric Type({type(metric)}) is invalid and must be str",
)
return
try:
self._metric = MetricType.get(metric)
except Exception as e:
self._exception = e
return
"""
dtype: VectorType
"""
if dtype is not float and metric == "cosine":
self._exception = DashVectorException(
code=DashVectorCode.MismatchedDataType,
reason=f"DashVectorSDK VectorParam dtype value({dtype}) is invalid and must be [float] when metric is cosine",
)
return
try:
self._dtype = VectorType.get_vector_data_type(dtype)
except Exception as e:
self._exception = e
return
"""
quantize_type: str
"""
if not isinstance(quantize_type, str):
self._exception = DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK VectorParam quantize_type Type({type(quantize_type)}) is invalid and must be str",
)
return
self.quantize_type = quantize_type
def validate(self):
if self._exception is not None:
raise self._exception
@staticmethod
def from_pb(pb: dashvector_pb2.CollectionInfo.VectorParam):
if pb.dtype == VectorType.FLOAT:
dtype = float
elif pb.dtype == VectorType.INT:
dtype = int
else:
raise DashVectorException(f"DashVectorSDK VectorParam dtype value({pb.dtype}) is invalid")
return VectorParam(
dimension=pb.dimension,
dtype=dtype,
metric=MetricType.str(pb.metric),
quantize_type=pb.quantize_type,
)
@staticmethod
def from_dict(d: Dict[str, Any]):
dimension = d.get("dimension")
dtype = VectorType.get(d.get("dtype")).get_python_type()
metric = MetricType.str(MetricType.get(d.get("metric")))
quantize_type = d.get("quantize_type")
return VectorParam(
dimension=dimension,
dtype=dtype,
metric=metric,
quantize_type=quantize_type
)
@property
def metric(self):
return MetricType.str(self._metric)
@property
def dtype(self):
return VectorType.str(self._dtype)
def to_dict(self):
return {
"dimension": self.dimension,
"dtype": self.dtype,
"metric": self.metric,
'quantize_type': self.quantize_type
}
class VectorQuery:
def __init__(self,
vector: VectorValueType,
num_candidates: int = 0,
is_linear: bool = False,
ef: int = 0,
radius: float = 0.0):
"""
A vector query.
vector (Optional[Union[List[Union[int, float, bool]], np.ndarray]]): query vector
num_candidate (int): number of candidates for this vector query, default to collection.query.topk
is_linear (bool): whether perform linear(brute-force) search, default to False
ef (int): ef_search for HNSW-like algorithm, default to adaptive ef
radius (float): perform radius nearest neighbor if radius is not 0.0,
i.e. return docs with score <= radius for euclidean/cosine and score >= radius for dotproduct
"""
self.vector = vector
self.num_candidates = num_candidates
self.is_linear = is_linear
self.ef = ef
self.radius = radius
def validate(self):
num_candidates = self.num_candidates
is_linear = self.is_linear
ef = self.ef
radius = self.radius
if not isinstance(num_candidates, int):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest topk type({type(num_candidates)}) is invalid and must be int",
)
if num_candidates < 0 or num_candidates > 1024:
raise DashVectorException(
code=DashVectorCode.InvalidTopk,
reason=f"DashVectorSDK GetDocRequest topk value({num_candidates}) is invalid and must be in [1, 1024]",
)
if not isinstance(is_linear, bool):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest ls_linear type({type(is_linear)}) is invalid and must be bool",
)
if not isinstance(ef, int):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest ef type({type(ef)}) is invalid and must be int",
)
if not (0 <= ef <= 4294967295):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest ef value({ef}) is invalid and must be in [0, 4294967295]",
)
if not isinstance(radius, float):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest radius type({type(radius)}) is invalid and must be float",
)
class SparseVectorQuery:
def __init__(self,
sparse_vector: SparseValueType,
num_candidates: int = 0,
is_linear: bool = False,
ef: int = 0,
radius: float = 0.0):
"""
A sparse_vector query.
sparse_vector (Dict[int, float]): query sparse_vector
num_candidate (int): number of candidates for this sparse_vector query, default to collection.query.topk
is_linear (bool): whether perform linear(brute-force) search, default to False
ef (int): ef_search for HNSW-like algorithm, default to adaptive ef
radius (float): perform radius nearest neighbor if radius is not 0.0,
i.e. return docs with score <= radius for euclidean/cosine and score >= radius for dotproduct
"""
self.vector = sparse_vector
self.num_candidates = num_candidates
self.is_linear = is_linear
self.ef = ef
self.radius = radius
def validate(self):
num_candidates = self.num_candidates
is_linear = self.is_linear
ef = self.ef
radius = self.radius
if not isinstance(num_candidates, int):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest topk type({type(num_candidates)}) is invalid and must be int",
)
if num_candidates < 0 or num_candidates > 1024:
raise DashVectorException(
code=DashVectorCode.InvalidTopk,
reason=f"DashVectorSDK GetDocRequest topk value({num_candidates}) is invalid and must be in [1, 1024]",
)
if not isinstance(is_linear, bool):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest ls_linear type({type(is_linear)}) is invalid and must be bool",
)
if not isinstance(ef, int):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest ef type({type(ef)}) is invalid and must be int",
)
if not (0 <= ef <= 4294967295):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest ef value({ef}) is invalid and must be in [0, 4294967295]",
)
if not isinstance(radius, float):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK QueryDocRequest radius type({type(radius)}) is invalid and must be float",
)
class BaseRanker:
pass
class RrfRanker(BaseRanker):
def __init__(self, rank_constant: int = 60):
self.rank_constant = rank_constant
def to_pb(self):
ranker = dashvector_pb2.Ranker()
ranker.ranker_name = "rrf"
ranker.ranker_params["rank_constant"] = str(self.rank_constant)
return ranker
def to_dict(self):
return {
'ranker_name': "rrf",
'ranker_params': {
'rank_constant': str(self.rank_constant)
}
}
class WeightedRanker(BaseRanker):
def __init__(self, weights: Optional[Dict[str, float]] = None):
self.weights = weights
def to_pb(self):
ranker = dashvector_pb2.Ranker()
ranker.ranker_name = "weighted"
if self.weights is not None:
ranker.ranker_params["weights"] = json.dumps(self.weights)
return ranker
def to_dict(self):
d = {
'ranker_name': "weighted",
}
if self.weights is not None:
d['ranker_params'] = {
'weights': json.dumps(self.weights)
}
return d

View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
from dashvector.common.common_validator import *
from dashvector.common.types import *
from dashvector.common.error import DashVectorCode, DashVectorException
from dashvector.core.models.collection_meta_status import CollectionMeta
from dashvector.util.convertor import to_sorted_sparse_vector
from dashvector.core.proto import dashvector_pb2
from abc import abstractmethod, ABC
def convert_vector_query(vector_query, type: VectorType):
returned_vector_query = dashvector_pb2.VectorQuery()
if isinstance(vector_query.vector, list):
returned_vector_query.vector.float_vector.values.extend(vector_query.vector)
elif isinstance(vector_query.vector, bytes):
returned_vector_query.vector.byte_vector = vector_query.vector
elif isinstance(vector_query.vector, np.ndarray):
if type == VectorType.INT:
data_format_type = VectorType.get_vector_data_format(type)
vector_query.vector = np.ascontiguousarray(vector_query.vector, dtype=f"<{data_format_type}").tobytes()
returned_vector_query.vector.byte_vector = vector_query.vector
else:
vector_query.vector = list(vector_query.vector)
returned_vector_query.vector.float_vector.values.extend(vector_query.vector)
return returned_vector_query
def convert_vector_query_from_pb(vector_query: dashvector_pb2.VectorQuery):
if vector_query.vector.HasField("float_vector"):
vector = np.array(vector_query.vector.float_vector.values)
returned_vector_query = VectorQuery(vector=vector)
return returned_vector_query
elif vector_query.vector.HasField("byte_vector"):
data_format_type = "b"
vector = np.frombuffer(vector_query.vector.byte_vector, dtype=f"<{data_format_type}")
returned_vector_query = VectorQuery(vector=vector.tolist())
return returned_vector_query
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK vector_query.vector type is invalid.")
class BasicVectorValidator(ABC):
def __init__(self, collection_meta: CollectionMeta):
self._collection_meta = collection_meta
@abstractmethod
def validate_collection_vectors(
self,
vectors: Union[None, VectorParam, Dict[str, VectorParam]],
sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]],
*,
dimension: int = 0,
dtype: VectorDataType = None,
metric: str = "cosine",
doc_op: str):
pass
@abstractmethod
def validate_query_vectors(self, vector, top_k: int, query_request, doc_op: str):
pass
class ReserveVectorValidator(BasicVectorValidator):
def __init__(self, collection_meta: CollectionMeta, vector_name: str = DASHVECTOR_VECTOR_NAME):
if collection_meta is not None:
super().__init__(collection_meta)
if(vector_name != DASHVECTOR_VECTOR_NAME):
self._dimension = self._collection_meta.get_dimension(vector_name)
self._dtype = VectorType.get(self._collection_meta.get_dtype(vector_name))
else:
self._dimension = self._collection_meta.dimension
self._dtype = VectorType.get(self._collection_meta.dtype)
def validate_collection_vectors(
self,
vectors: Union[None, VectorParam, Dict[str, VectorParam]],
sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]],
*,
dimension: int = 0,
dtype: VectorDataType = None,
metric: str = "cosine",
doc_op: str):
if sparse_vectors is not None:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} single vector sparse_vectors type({type(sparse_vectors)}) is invalid and must be None",
)
if vectors is None:
vectors = {"": VectorParam(dimension=dimension, dtype=dtype, metric=metric)}
elif isinstance(vectors, VectorParam):
vectors = {"": vectors}
if not isinstance(vectors, dict):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vectors type({type(vectors)}) is invalid and must be dict"
)
vectors[""].validate()
if vectors[""].dimension <= 1 or vectors[""].dimension > 20000:
raise DashVectorException(
code=DashVectorCode.InvalidDimension,
reason=f"DashVectorSDK VectorParam dimension value({vectors[''].dimension}) is invalid and must be in (1, 20000]",
)
return vectors, dict()
def validate_query_vectors(self, vector, top_k: int, query_request, doc_op: str):
if isinstance(vector, VectorQuery):
vector.validate()
elif isinstance(vector, list) or isinstance(vector, np.ndarray):
vector = VectorQuery(vector=vector, num_candidates=top_k)
vector.validate()
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} single vector type({type(vector)}) is invalid and must be [VectorQuery, VectorValueType]",
)
vector.vector = Validator.validate_dense_vector(vector.vector, self._dimension, self._dtype, doc_op)
converted_vector_query = convert_vector_query(vector, self._dtype)
returned_query_request = query_request
if isinstance(returned_query_request, dashvector_pb2.QueryDocRequest):
returned_query_request.vectors[DASHVECTOR_VECTOR_NAME].CopyFrom(converted_vector_query)
elif isinstance(returned_query_request, dashvector_pb2.QueryDocGroupByRequest):
returned_query_request.vector.CopyFrom(converted_vector_query.vector)
return returned_query_request
class MultiVectorValidator(BasicVectorValidator):
def validate_collection_vectors(
self,
vectors: Union[None, VectorParam, Dict[str, VectorParam]],
sparse_vectors: Union[None, VectorParam, Dict[str, VectorParam]],
*,
dimension: int = 0,
dtype: VectorDataType = None,
metric: str = "cosine",
doc_op: str):
if vectors is None and sparse_vectors is None:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vectors and sparse_vectors are all empty",
)
if vectors is None:
vectors = dict()
if isinstance(vectors, dict):
for vector_name in vectors.keys():
Validator.validate_vector_name(vector_name, doc_op)
if not isinstance(vectors, dict):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vectors type({type(vectors)}) is invalid and must be dict"
)
for vector_name, vector_param in vectors.items():
if not isinstance(vector_param, VectorParam):
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} vector_param type({type(vector_param)}) is invalid and must be VectorParam",
)
vector_param.validate()
if vector_param.dimension <= 1 or vector_param.dimension > 20000:
raise DashVectorException(
code=DashVectorCode.InvalidDimension,
reason=f"DashVectorSDK VectorParam dimension value({vector_param.dimension}) is invalid and must be in (1, 20000]",
)
sparse_vectors = Validator.validate_sparse_vectors(sparse_vectors, doc_op)
if len(vectors) + len(sparse_vectors) > 4:
raise DashVectorException(
code=DashVectorCode.InvalidField,
reason=f"DashVectorSDK {doc_op} vectors length({len(vectors) + len(sparse_vectors)}) is invalid and must be in [0, 4]",
)
return vectors, sparse_vectors
def validate_query_vectors(self, vector, top_k: int, query_request, doc_op: str):
vector_queries = None
if isinstance(vector, dict):
vector_queries = vector
else:
raise DashVectorException(
code=DashVectorCode.InvalidArgument,
reason=f"DashVectorSDK {doc_op} multi vector type({type(vector)}) is invalid and must be dict",
)
returned_query_request = query_request
for vector_name, vector_query in vector_queries.items():
if isinstance(vector_query, list) or isinstance(vector_query, np.ndarray):
vector_query = VectorQuery(vector=vector_query)
vector_query.validate()
vector_query.vector = Validator.validate_dense_vector(vector_query.vector, self._collection_meta.get_dimension(vector_name),
VectorType.get(self._collection_meta.get_dtype(vector_name)), doc_op)
converted_vector_query = convert_vector_query(vector_query, VectorType.get(self._collection_meta.get_dtype(vector_name)))
returned_query_request.vectors[vector_name].CopyFrom(converted_vector_query)
return returned_query_request
class ValidatorFactory:
@staticmethod
def meta_create(collection_meta) -> BasicVectorValidator:
if len(collection_meta.vectors_schema) == 1 and DASHVECTOR_VECTOR_NAME in collection_meta.vectors_schema.keys():
return ReserveVectorValidator(collection_meta)
else:
return MultiVectorValidator(collection_meta)
@staticmethod
def input_type_create(vectors, sparse_vectors) -> BasicVectorValidator:
if isinstance(vectors, dict) or isinstance(sparse_vectors, dict):
return MultiVectorValidator(collection_meta=None)
else:
return ReserveVectorValidator(collection_meta=None)
class SparseVectorChecker:
def __init__(self, collection_meta: CollectionMeta, vector_name: str, vector_query: SparseVectorQuery, doc_op: str):
if not vector_name:
self._dtype = VectorType.get(collection_meta.dtype)
self._dimension = collection_meta.dimension
else:
self._dtype = VectorType.get(collection_meta.get_dtype(vector_name))
self._dimension = collection_meta.get_dimension(vector_name)
vector_query.validate()
self._sparse_vector = vector_query
self.vector_query = dashvector_pb2.SparseVectorQuery()
for key, value in to_sorted_sparse_vector(self._sparse_vector.vector).items():
self.vector_query.sparse_vector.sparse_vector[key] = value
if len(self.vector_query.sparse_vector.sparse_vector) == 0:
raise DashVectorException(
code=DashVectorCode.InvalidSparseValues,
reason=f"DashVectorSDK {doc_op} not supports query with empty sparse_vector",
)
param = self.vector_query.param
param.num_candidates = vector_query.num_candidates
param.ef = vector_query.ef
param.is_linear = vector_query.is_linear
param.radius = vector_query.radius