From ae3a47ddf6c6e3eda7cec718d6ace462c7be3f49 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 14 Jan 2026 15:59:43 -0600 Subject: [PATCH 1/4] git commit -m "feat: add AstraDB client support Add DataStax AstraDB as a supported vector database client: - Implement AstraDB client with vector search capabilities - Add CLI interface with configuration options for API endpoint, token, and namespace - Support cosine, euclidean, and dot_product distance metrics - Integrate with VectorDBBench UI and configuration system - Add astrapy dependency to project requirements --- README.md | 1 + pyproject.toml | 1 + vectordb_bench/backend/clients/__init__.py | 16 ++ .../backend/clients/astradb/astradb.py | 169 ++++++++++++++++++ vectordb_bench/backend/clients/astradb/cli.py | 86 +++++++++ .../backend/clients/astradb/config.py | 38 ++++ vectordb_bench/cli/vectordbbench.py | 2 + .../frontend/config/dbCaseConfigs.py | 7 + vectordb_bench/frontend/config/styles.py | 1 + 9 files changed, 321 insertions(+) create mode 100644 vectordb_bench/backend/clients/astradb/astradb.py create mode 100644 vectordb_bench/backend/clients/astradb/cli.py create mode 100644 vectordb_bench/backend/clients/astradb/config.py diff --git a/README.md b/README.md index bd7568da1..64701f353 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ All the database client supported | awsopensearch | `pip install vectordb-bench[opensearch]` | | aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` | | mongodb | `pip install vectordb-bench[mongodb]` | +| astradb | `pip install vectordb-bench[astradb]` | | tidb | `pip install vectordb-bench[tidb]` | | vespa | `pip install vectordb-bench[vespa]` | | oceanbase | `pip install vectordb-bench[oceanbase]` | diff --git a/pyproject.toml b/pyproject.toml index 63c585c55..281f8d543 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ chromadb = [ "chromadb" ] opensearch = [ "opensearch-py" ] aliyun_opensearch = [ "alibabacloud_ha3engine_vector" ] mongodb = [ "pymongo" ] +astradb = [ "astrapy" ] mariadb = [ "mariadb" ] tidb = [ "PyMySQL" ] cockroachdb = [ "psycopg[binary,pool]", "pgvector" ] diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index d69c54504..6a3b52d9f 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -44,6 +44,7 @@ class DB(Enum): Test = "test" AliyunOpenSearch = "AliyunOpenSearch" MongoDB = "MongoDB" + AstraDB = "AstraDB" TiDB = "TiDB" CockroachDB = "CockroachDB" Clickhouse = "Clickhouse" @@ -165,6 +166,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 return MongoDB + if self == DB.AstraDB: + from .astradb.astradb import AstraDB + + return AstraDB + if self == DB.OceanBase: from .oceanbase.oceanbase import OceanBase @@ -339,6 +345,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 return MongoDBConfig + if self == DB.AstraDB: + from .astradb.config import AstraDBConfig + + return AstraDBConfig + if self == DB.OceanBase: from .oceanbase.config import OceanBaseConfig @@ -494,6 +505,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915 return MongoDBIndexConfig + if self == DB.AstraDB: + from .astradb.config import AstraDBIndexConfig + + return AstraDBIndexConfig + if self == DB.OceanBase: from .oceanbase.config import _oceanbase_case_config diff --git a/vectordb_bench/backend/clients/astradb/astradb.py b/vectordb_bench/backend/clients/astradb/astradb.py new file mode 100644 index 000000000..df7eb0049 --- /dev/null +++ b/vectordb_bench/backend/clients/astradb/astradb.py @@ -0,0 +1,169 @@ +import logging +import time +from contextlib import contextmanager + +from astrapy import DataAPIClient +from astrapy.constants import VectorMetric +from astrapy.info import CollectionDefinition + +from ..api import VectorDB +from .config import AstraDBIndexConfig + +log = logging.getLogger(__name__) + + +class AstraDBError(Exception): + """Custom exception class for AstraDB client errors.""" + + +class AstraDB(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: AstraDBIndexConfig, + collection_name: str = "vdb_bench_collection", + id_field: str = "id", + vector_field: str = "vector", + drop_old: bool = False, + **kwargs, + ): + self.dim = dim + self.db_config = db_config + self.case_config = db_case_config + self.collection_name = collection_name + self.id_field = id_field + self.vector_field = vector_field + self.drop_old = drop_old + + # Get index parameters + index_params = self.case_config.index_param() + log.info(f"index params: {index_params}") + self.index_params = index_params + + # Initialize client - will be properly set in init() + self.client = None + self.db = None + self.collection = None + + # Initialize and drop collection if needed + temp_client = DataAPIClient(self.db_config["token"]) + temp_db = temp_client.get_database( + api_endpoint=self.db_config["api_endpoint"], + keyspace=self.db_config["namespace"] + ) + + if self.drop_old: + try: + temp_db.drop_collection(self.collection_name) + log.info(f"AstraDB client dropped old collection: {self.collection_name}") + except Exception: + log.info(f"Collection {self.collection_name} does not exist, skipping drop") + + @contextmanager + def init(self): + """Initialize AstraDB client and cleanup when done""" + try: + self.client = DataAPIClient(self.db_config["token"]) + self.db = self.client.get_database( + api_endpoint=self.db_config["api_endpoint"], + keyspace=self.db_config["namespace"] + ) + + # Create or get collection with vector configuration + metric_str = self.case_config.parse_metric() + + # Map metric string to VectorMetric constant + metric_map = { + "euclidean": VectorMetric.EUCLIDEAN, + "dot_product": VectorMetric.DOT_PRODUCT, + "cosine": VectorMetric.COSINE, + } + metric = metric_map.get(metric_str, VectorMetric.COSINE) + + # Create collection with new API + # Note: check_exists is no longer needed - API handles conflicts automatically + self.collection = self.db.create_collection( + name=self.collection_name, + definition=( + CollectionDefinition.builder() + .set_vector_dimension(self.dim) + .set_vector_metric(metric) + .build() + ), + ) + log.info(f"Created/accessed collection: {self.collection_name} with metric: {metric_str}") + + yield + finally: + if self.client is not None: + self.client = None + self.db = None + self.collection = None + + def need_normalize_cosine(self) -> bool: + return False + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> (int, Exception | None): + """Insert embeddings into AstraDB""" + + # Prepare documents in bulk + documents = [ + { + "_id": str(id_), + "$vector": embedding, + } + for id_, embedding in zip(metadata, embeddings, strict=False) + ] + + # Insert documents in batches + try: + result = self.collection.insert_many(documents, ordered=False) + return len(result.inserted_ids), None + except Exception as e: + log.exception("Error inserting embeddings") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + **kwargs, + ) -> list[int]: + """Search for similar vectors""" + + # Build filter if specified + search_filter = None + if filters: + log.info(f"Applying filter: {filters}") + search_filter = { + self.id_field: {"$gte": filters["id"]}, + } + + # Perform vector search + try: + results = self.collection.find( + filter=search_filter, + sort={"$vector": query}, + limit=k, + include_similarity=True, + ) + + # Extract IDs from results + return [int(doc["_id"]) for doc in results] + except Exception: + log.exception("Error searching embeddings") + return [] + + def optimize(self, data_size: int | None = None) -> None: + """AstraDB vector indexes are automatically managed""" + log.info("optimize for search - AstraDB manages indexes automatically") + + def ready_to_load(self) -> None: + """AstraDB is always ready to load""" diff --git a/vectordb_bench/backend/clients/astradb/cli.py b/vectordb_bench/backend/clients/astradb/cli.py new file mode 100644 index 000000000..01ea6a31f --- /dev/null +++ b/vectordb_bench/backend/clients/astradb/cli.py @@ -0,0 +1,86 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from ..api import MetricType +from .config import AstraDBIndexConfig + + +class AstraDBTypedDict(TypedDict): + api_endpoint: Annotated[ + str, + click.option( + "--api-endpoint", + type=str, + help="AstraDB API endpoint (e.g., https://-.apps.astra.datastax.com)", + required=True, + ), + ] + token: Annotated[ + str, + click.option( + "--token", + type=str, + help="AstraDB authentication token", + required=True, + ), + ] + namespace: Annotated[ + str, + click.option( + "--namespace", + type=str, + help="AstraDB namespace (keyspace)", + default="default_keyspace", + show_default=True, + ), + ] + metric: Annotated[ + str, + click.option( + "--metric", + type=click.Choice(["cosine", "euclidean", "dot_product"], case_sensitive=False), + help="Distance metric for vector similarity", + default="cosine", + show_default=True, + ), + ] + + +class AstraDBIndexTypedDict(CommonTypedDict, AstraDBTypedDict): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(AstraDBIndexTypedDict) +def AstraDB(**parameters: Unpack[AstraDBIndexTypedDict]): + from .config import AstraDBConfig + + # Convert metric string to MetricType enum + metric_map = { + "cosine": MetricType.COSINE, + "euclidean": MetricType.L2, + "dot_product": MetricType.IP, + } + metric_type = metric_map.get(parameters["metric"].lower(), MetricType.COSINE) + + run( + db=DB.AstraDB, + db_config=AstraDBConfig( + db_label=parameters["db_label"], + api_endpoint=parameters["api_endpoint"], + token=SecretStr(parameters["token"]), + namespace=parameters["namespace"], + ), + db_case_config=AstraDBIndexConfig( + metric_type=metric_type, + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/astradb/config.py b/vectordb_bench/backend/clients/astradb/config.py new file mode 100644 index 000000000..01efa6dd9 --- /dev/null +++ b/vectordb_bench/backend/clients/astradb/config.py @@ -0,0 +1,38 @@ +from enum import Enum + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class AstraDBConfig(DBConfig, BaseModel): + api_endpoint: str = "https://-.apps.astra.datastax.com" + token: SecretStr = "" + namespace: str = "default_keyspace" + + def to_dict(self) -> dict: + return { + "api_endpoint": self.api_endpoint, + "token": self.token.get_secret_value(), + "namespace": self.namespace, + } + + +class AstraDBIndexConfig(BaseModel, DBCaseConfig): + index: IndexType = IndexType.HNSW # AstraDB uses vector search + metric_type: MetricType = MetricType.COSINE + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + if self.metric_type == MetricType.IP: + return "dot_product" + return "cosine" # Default to cosine similarity + + def index_param(self) -> dict: + return { + "metric": self.parse_metric(), + } + + def search_param(self) -> dict: + return {} diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 76e9534f9..4cd0a2b5e 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,5 +1,6 @@ from ..backend.clients.alisql.cli import AliSQLHNSW from ..backend.clients.alloydb.cli import AlloyDBScaNN +from ..backend.clients.astradb.cli import AstraDB from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.chroma.cli import Chroma from ..backend.clients.clickhouse.cli import Clickhouse @@ -50,6 +51,7 @@ cli.add_command(PgVectorScaleDiskAnn) cli.add_command(PgDiskAnn) cli.add_command(AlloyDBScaNN) +cli.add_command(AstraDB) cli.add_command(OceanBaseHNSW) cli.add_command(OceanBaseIVF) cli.add_command(MariaDBHNSW) diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 6a32e5ff1..c46ba0cef 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -2373,6 +2373,9 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_MongoDBNumCandidatesRatio, ] +AstraDBLoadingConfig = [] +AstraDBPerformanceConfig = [] + CockroachDBLoadingConfig = [ CaseConfigParamInput_IndexType_CockroachDB, CaseConfigParamInput_MinPartitionSize_CockroachDB, @@ -2691,6 +2694,10 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: MongoDBLoadingConfig, CaseLabel.Performance: MongoDBPerformanceConfig, }, + DB.AstraDB: { + CaseLabel.Load: AstraDBLoadingConfig, + CaseLabel.Performance: AstraDBPerformanceConfig, + }, DB.MariaDB: { CaseLabel.Load: MariaDBLoadingConfig, CaseLabel.Performance: MariaDBPerformanceConfig, diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index bce4561fd..8cd0f3c9f 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -60,6 +60,7 @@ def getPatternShape(i): DB.AliyunOpenSearch: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAMAAAAt85rTAAAA51BMVEX/////agD/ZwD/XQD/7ubykWXy8Oz/7Ob/k0r/kUn08vD/h1DVgVL/ZAD/YQD/bQD/8uj/+vX/9+//ei//lV7/eRP/3cn/cgD/jEz/8OT/hDn/s4L/WgD/pW3/gTH/5NP/vIz/poH/g0X/nWH/t5L+j1THuavd1Mr/r4r8xKv/toj/59z/dyb/x573m3T/eAD/z7f/wJP/3cv/up/u5t/uhE7/zKv/chf/l1X/iz7/fyP/w6L/o3f/17zvwaH/q3vVyr//nG7inHLZr5jw08XgpYXaj2D/rXb/omT/nVjjwqnTuKP0roxxvdl9AAAG60lEQVR4nO2d/VuiShiGYaAy9wwIKn5LKm6mVua6fpFZu3uO2+7+/3/PAdu2rVQGZki6ruf+uXi5ZQZeEJ+RJAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAYF8Ypqmq6hEXKkshzhpHqpoxjdB29qA2nc0LVOGhWiwFlyoVq1xFNLkwn3YGthlCz3avy5QSQmQ+aIOlWoNylpEJoXL9zlmw2Y1uWlmd8ro94LIUdIWUIpRMcjenQcWMhaxoguxkkrVZBO2soHqepELHO6ejsRQlt67XZJr7RlNkTVrbMRmNpbhKfrE2i58ktQUKeiy3f6zHQgvJtMUm2OI+yzxntvkYGnZfcCG6YhNcia47sTZUGV3rYkeKLGs3bII3muDCRF8evqqyVARX8QRv2QRvRQvKslJ5WcQUPEzWVUZsgiPxn61MXsxDsy56fHpU2fwkqSq+Nik/MzSuY/CTKaug+CHqGf59tTAqkzhKHLMKzmL4eMnEeSpQ4W6rN6ExdaI+biyHUD5/3H4phjngoTHdDfqoMZxlPKqP07AWwxnU+wTzGVbBTD6OEfSn0TgU2ew+Qmh1ynwLak6rIpv8P/vQTK83v+gL3SqhlGpKvjIMcYttWpW8d5vm32SL3Jf+Q8vmCNsoIUTPtmsnH0/T7HKPpE8/ntTaWZ3/QcIT6xOpcSdmCnpHrp7ruEz3uNux3U6uTgUNWHrnXwttEU0MoVq+dRbqwc9WTPuslRfyWIH0h972BpwH0Jt0mlbv9tTQj+62Y6i9bt1/dsJpSQfrpxQcboTK/WbrnOHpYHhK561mX+Z6uEe9fi0zj/b/3pQr1IsHF44l8Mi9xLCci4NivRB1UhLv5v4oygEk3qCsnVl26TDC6TIc6cOebZ3V6poWxZGqUi9km+RPOWXm9jIxHrjXGJmeO4swKRVLGoRpdL0ply3WxkJOluExz2vFrBxqtNIr6YpZkGh6+c7hvM7xYjt3ZZ19tNKKVGGYg377pZCpa9lvOi43Y9jWp6mssLV1tCbdB/2Z1zZfpr58/abuaWBuwlS/ff2Sugxu0clUOg74G5KtJMjsOWYl6GsNkpJSu/+ETvc86XZjt3fPsGBBwnzXuh8yAbufChii2tW+DYLo7jyEwYLMz1X2xe7nOYFDlCRfcPf+Bx3BdyC4a/chCMH9A0EIJhwIQjDhQBCCCQeCEEw4EIRgwoEgBBMOBCGYcN5e0Gx8iE479Hc9by+YyfmvtkSDXoautw/B6C/uhHjHFIIQhOBWwegnGfIeTjLmv/9E57/Q777t4UKf5iF0NXQyEEw4EIRgwoEgBBMOBCGYcCAIwYQDQQgmHAhCMOFAEIIJB4IQTDgQfPc/r+MTpEf7FghC5fsFqNLbt0AQu6MqAgXpqxDEpLE76CD4V9hKN7G/ovcxu7uzRoIFZVocxBKHI4LSoBiQxMEgKBO96CwSeBjNhVOUA3c+JU2DX/ogcr85vUrU6aZ3NW32A/VkP8viO9NbLYRoynw1TEJgh6kOV3NFY4vOod9DJAgTqhWOP1mWPYo952gz6ZFtWZ+OCyEiyegqZPQloXq+/PmgM37zCA973Dn4XM6HDKnXXGkYNvpy/UbSJPthxZidL4LF6kN2si4ccme9RkWNFl5KCFX0+3Ev9rCSTG98rytRI+S8XjpqrpoP1ejcscQE/m3CtIfO3CsSeQdJ3ZQMvvhUb1KW75xFDFlPxsDPp+KLOPST8aQBb0qiv9JDs706+RG42AMrpz9OVu3mhHAHOJKBxPd+59OW/ARHRZlXhnwZjsb6Iqdo/LmN673KrWdPR1xgqXfmuVwOrIjNa8kaLAuKwDBc8pBE7erCtuhvlNJ8u+GGdiy5jXZeVHLqb/SHIOxboRm/8nq86tl2N0QvYHfb/jozooOM+7frradjSWkmSpk9pbmsxBFlToq/e8pFLDnbMmG+/+jFErMtkz/raizjiGKXtTGr4Die+k+PW9RCLIsV1FgFY4lqJ/O/2siB8OVQ1hVYBeNIgifPFn4xBF4Ln9BYBePI8ifOsxImS8xoWJTXK8ts5DAGQeq8PIenxBvucVGbDUs+jSbCRyn9ySb4U/hnS/obvlOxyqINtRM2wRPRgqS4sYuyZ4KHyr5WztKmW74TM1JVoQdxP2ufUaW2/Y7NbjM8TGWGNNkEhS7Pp+d2doiGWxZ3x8K4wKLVF1eRNAP7Q9udCRsxE6ZudCxs0Sc6PWNp8M1zUatkvugmtiBmSRZCFW338ph/c54r62S9xgwX2jVDReNa4yzj7+ikWQv3IDozdhqNi4sDPn4x3POavziLXFw0Ok60BQbS6fQhF0zf0vDWiPC7WAAAAAAAAAAAAAAAAAAAAAAAAAAAAMD74n+lRrmptwLV8gAAAABJRU5ErkJggg==", DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", DB.OSSOpenSearch: "https://images.seeklogo.com/logo-png/50/1/opensearch-icon-logo-png_seeklogo-500356.png", + DB.AstraDB: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAABGdBTUEAALGPC/xhBQAAACBjSFJNAAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAARGVYSWZNTQAqAAAACAABh2kABAAAAAEAAAAaAAAAAAADoAEAAwAAAAEAAQAAoAIABAAAAAEAAABkoAMABAAAAAEAAABkAAAAAC+73kEAAAHLaVRYdFhNTDpjb20uYWRvYmUueG1wAAAAAAA8eDp4bXBtZXRhIHhtbG5zOng9ImFkb2JlOm5zOm1ldGEvIiB4OnhtcHRrPSJYTVAgQ29yZSA2LjAuMCI+CiAgIDxyZGY6UkRGIHhtbG5zOnJkZj0iaHR0cDovL3d3dy53My5vcmcvMTk5OS8wMi8yMi1yZGYtc3ludGF4LW5zIyI+CiAgICAgIDxyZGY6RGVzY3JpcHRpb24gcmRmOmFib3V0PSIiCiAgICAgICAgICAgIHhtbG5zOmV4aWY9Imh0dHA6Ly9ucy5hZG9iZS5jb20vZXhpZi8xLjAvIj4KICAgICAgICAgPGV4aWY6Q29sb3JTcGFjZT4xPC9leGlmOkNvbG9yU3BhY2U+CiAgICAgICAgIDxleGlmOlBpeGVsWERpbWVuc2lvbj4yNTY8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjU2PC9leGlmOlBpeGVsWURpbWVuc2lvbj4KICAgICAgPC9yZGY6RGVzY3JpcHRpb24+CiAgIDwvcmRmOlJERj4KPC94OnhtcG1ldGE+CuYattQAAAgJSURBVHgB7VxpaBVXFD5Z3GI0SKhLYkBjEpeoqAGN4I+IilLFgKKBKkFxLUYFf0kRRBBE8I8Q6Q/9FTXGihJBRVEk9VcUakEtRatRaGzQmrjhkqCZnu/W1+Q+3iRvlvfu9fUcOG/mvpm558z3zbnLmSWNvojjOOm8+i3ratYK1gJWkcQh0M5Vn2etTUtL+zViJg0rTMZAXnzHWs1ayPoNaxarSOIQ6OSq/2L9nfU0awMT05XJK5Aq1h9Yi1EQSQoCg9jK+C9axMt0Doy6DP4p40INazmriBkEctlsDms7+o3vWUtZRcwiMIHNr07jCHnBK9msCCERcwi8Z9MdIMQx54NYjkYATZaIRQgIIRaRAVeEECHEMgQsc0ciRAixDAHL3JEIEUIsQ8AydyRChBDLELDMHYkQIcQyBCxzRyJECLEMAcvckQgRQixDwDJ3JEKEEMsQsMwdiRAhxDIELHNHIkQIsQwBy9yRCBFCLEPAMnckQoQQyxCwzB2JEMsIibwfErdbDx48oGPHjsW9f5Ads7KyqLy8nBYtWkQZGRlBqqLPnz/T1atXqbm5md6/x3PNiZH09HRas2YNTZ06lfgFHO9G8LC1F7l06RIezk6azp4922loaHAYUC9uavu+fv3aOX78uDNnzpyk+H3q1Cnf/nqOkKBXqtdL5tatW7Rv3z5qbW2lpUuX0qRJkzxV8fjxY2IyqK6ujh49euTpWD87I0IyMz3D2mNKu5TiKFy5ciUpVxl7qNkZOnSos27dOuf27dtxePnvLtw8OZs3b3ZycnK0uqLrDrPMhDhnzpzxHSF44dOTmCIEoPGV5yxZssQ5d+6c09nZ6er327dvnbNnzzoVFRUOt+NJIwM+BiUkQGyx+STLp0+f6PLly9TS0kLcL1B1dTWhiegt3d3ddOLECTp06FBSmqjetrEO+1C/Ehoh6Fvy8/NpxIgR/kYXMc6gq6uLnj17Ru3teKW7RzDSO3jwoBrNRBOC0dThw4djkpGbm0ujRo2igQPxFnhiBCMrYOBXQiNk8ODBtHbtWlqwYAENGDDArz//HQeg29ra1NV+/jzer9cFZLmJ27Z58+YpH8eMGRPoKnazG/m/pKTE90UZGiEgobS0lHDSYV2BT58+paampsh5asu+xvhu2woKCmju3LkqkrXKLCroDXBAx7iXxSAhYC09h9teX4+n4a2FSkh4biWmJkROdJ+TGEv+aw2tyfLvQvKO/PjxI7169Yqys7MDRzKiF+Ty/ChwWqc3AilJiFsfcu/ePaqvryeMtoIMTQEgCEG/OW3aNJo5c2agkVVKEwIypk+frlItiIjecvPmTUIqxo2w3vv2tw5CUA9I2blzJ23ZsoXGj8e3ZIJJykUI8kgHDhxQzci1a9dUEwXwIFhG1oPB1nM05j21tbVqvrRr1y6aOHFioFFmSnbqxcXFdPToUdqzZw+NHTu2B70Erb179444w0urVq2i69ev04cPH3xbSklC0JRwQpFqampUtMyYMSOUyWpfKHNuje7fv0/bt2+nI0eOqEmtn34qJQmJAIcJ6sqVK+nkyZO0YsUKNSLCyAiEhaURW5ElUvz79+9XtwwwsfUqKdeHRAOAlM6UKVMUSIWFhcQ32Ojly5ehzEfQH3FmWdUX6ZuwROIT91+qqqoI2QEvkvKERMAoKiqi3bt307Zt29Tt3Mj/QZbImV24cIH27t1Lb9680apCP+KWU9N2jCr8bwjBeQ8fPlxpFAa+ixhhIXvsNvtHs+hVUroP8QqG1/1xfwakhClCSJhohlBXqIQgRN3C14+vkRGRn2OTcUwi/AutD8E4HDPj58+fB3vq4guSIBeJQOSfYklfY3xsgy8PHz5UTYqftjyWzej/0GTduXOHcO6xJDLyirXN7b9QCcGdPdzzDgsAtM+YBccS3PVzswNC4EdjYyMhn+W2X6x6vfwHwEFGLEJGjx5Nw4YN81Kd2jc0QgBCR0eHZwe8HoB5xfLly2nr1q2uaW8AhejCxMzP0NOrT9H7z5o1S2UJJk+eHL2p33JohPRrKYQdkDZfv369Ukz2+hK078l+qA/+LFu2TGV+Fy9e7Ctd89UQgvv1SHPjhNFc2SZonjZt2qQeTcJzvX4vBs+E+OmogoCHJmr+/Pm0ceNGqqysjPtE0f+g002GlJWVqZzZhg0baOTIkYFMeiYEtz+R3g70/GocLoN4kMEPW6v+Anfl4hV04uPGjVN384Kkwvuzh2YxLy+PduzYQQsXLlT+9ndMf9s9f2ocD67duHEj7iu1PwfctmOQMGTIEPVw9YQJ+E59/IJj7969S0+ePFFRkohRFi4YEILmk5+qD20k55mQ+GGRPf0gEOpM3Y8DcoyOgBCi42G8JIQYp0B3QAjR8TBeEkKMU6A7IIToeBgvCSHGKdAdEEJ0PIyXhBDjFOgOCCE6HsZLQohxCnQHhBAdD+MlIcQ4BboDQoiOh/GSEGKcAt0BIUTHw3hJCDFOge6AEKLjYbwkhBinQHdACNHxMF4SQoxToDsghOh4GC8JIcYp0B0QQnQ8jJeEEOMU6A4IIToexktCiHEKdAdASCtr4j6GrtuTkjsCeFGxHYQ0sb5gFTGLQBubbwQhP7Em/qPoZk/2a7D+Gzv5I17Yuciay5rPWsIqknwE/mCTp/k9ll8y+aebXz6p5z/wfewqVrw6msc6iFUkcQig3/6btYW1jvU0K2lfR2Fi8N5YDWslK6JGJHEI/MlV/8yKLuMiAgOm/gGwGa1bA01QmQAAAABJRU5ErkJggg==", DB.MongoDB: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAOEAAADhCAMAAAAJbSJIAAAAolBMVEUAHisA7WQA8GUAHSsA8mYAACcAHCsAGioA9GYAGCoAACYADikAGSoAFyoADCkAFSoAESkA5GIABigAkEkAzlsA32AAt1QAqFAA1F0AJS0AMTAArFEAUjgA6mQANTEAok4Ax1kAaz4AnE0AXDoAOzIAQDMAKi4AiEcAdUEAzFsAv1cAgEQAYTsATDYAVzkAckAAACEAlksA/2oAf0YATjcARTQMLnchAAAHqklEQVR4nO2dWXuyPBCGIQQIoCzuG261q3b73vb//7WPgFRUYFDbiwlX7oMecZCnk0xmJpOoKBKJRCKRSCQSiUQikUgkEolEIpFIJBKJRCKRXIfma3UP4W/R3K7bbIl6979Hve5B/CXuTiXqnVP3MP6O1mREVDKatOoeyF+h6QOqqiod6E1divojFxhJbOpS9F8TgZHEV7/uwfwF1jxahAnRUrTqHs7vo7kDpqawgWPUPaBfx/6k6gHatese0G/TWanHrDp1D+l38TYBORJIgo1X96B+E9Po0RMb0p5h1j2s38Pwp0w9hQ0bFIPr3XOBkcRuYzZ+/YnkCIx4b4hE/y3IV0iCWSNiG9PsF5hQJf1WA7yN4S5P3WjGob644sc2+jrPy/x4m7XwmVTnvURf7G0Ej228+8JFmC7FTbvuQd6CoZcswv1SXAo9T/O3+pOlKHLG789AfdyK4u6K5sOofBHul2JvLOiuqOlDeI7G8/RT0HTYv6ukLyIUc54abqU5Gs/TkSdiaKN/QhvFAfopoD9tb6paMLbiXLh9X/MH1U3IQ3DhzhXdXTU/msJ2gp1IVdwKM9N0NBbL2VQJ146hYlVtWnMgpcgxYn8j0rGivrjEzSREyXDdw66ONb90jsYo4oSnnenlJoxW4lqYdN+aX7oIE8IHUYzYGV5jQp5jCGJE6z64SqCqBnMx3Onle+GPEcVwp+bkwnDmgCCBjfN03Srk0H8CRKeafrUJIyMuBZim3tvV+jgz/Gff+pVbRQIdojeipl+7VSSE6Cvg/oWZ7ylshbzsprkXFS/OoVPk/Qvm5LZJGsU1E9zBqb+6bZLin6b6VXlTFjpF7U0NuAAF/QeiyA3zNPXuwHimC1rxDfOm7z6Cy/Af9AF7dOuWUYL+Ai7DFfQBeUG8EDUDLiKCx8Kkb+DdEb23EBq/eg9+ESJeiO4W3ivG4Bd0izdJ1NewQh1WuMC7EO0lnPza4BeY0+AqmdMH/EmItnHBGIMmJP2PCmc2D1jrUd4MVjj6gMs45AvrkbcPV9nI88czrPAda3rhdmGFy48XUCHtYq3ud+BTQzr4gIsAdIFWIZwc0ulHhY+GSBVWqdHQhV3F0EhrNZoDZxZ0XUmhg1ShDzsR9ml/gikkwdo+pPlw0Ma6Npwkkxe0NuzBCh+rKHxGa0N4M2dbewsrXGJV6FTYzJ8Kr3llFOKdpbBCdaVDV0wwK6xyZrHTd2ClA+1+qLhwuBK+uXAth06x1hMr9NEEM3cGpsloozalA5dpgnsf7rbB2/7lwKW2/sab90GFaDsy4FML0n9oT8AyBkFbMG2B7fmkr7QUuFCD9pBUA2uhpG8ZPqwQb7eCDVWZyKhjuJBC0kNbTYR7acjIhv8NmHtqwONDbh4bykAo4gNED4rIeMFeh7LI8A6rK+XH+MAao4NIIRC9RjsK1pJ3BGQf3mgBLVayxOto4IXID86guxi4z/G9XbkNea96BypFoY1oOFA/Deu6ilveBk5GmJehArkRPgOBmcydEWa88uZL9uQozlP5J4j3Cg6wX7AnX/HfyxTi3isUXqsprWTwxsPy9kU6xZr9pniz0udodh4wkdk31vPfH+yyaRorLMuTeWiOnfKD4GivK72ugDnqTintgw5e20r7q+wDES7olbUJB6+WYn0XK8ScGh7wZoUpVHybuVXysFL4it7PcIrjGjKam4o5L4zs6KCDtkKTxfsqMiIZTQzFKL69R5DHMymaU2RE0htHCsdFZQxxHsdo3xfs6eTZjBQqRQrJtxgm5HXTAndKlp5WfNyP/KbFEdY8fyWSZeRJtE5RqUOk1+gLmoWTNvyCYo5Y7wwZSq6/LFMoyC3nH9xVnopkpeWvUrbCH5FmiXaMHH+a9KjnFtvY1BZkp0hpnz7gHcuID3fzim0kmAv30r6bU47hpbb8oqpoczTGPu9UZNtY4XlXFHvBn/ieY1hnOQQvtfF28FOFovnRFGd1Phd5i7p3VooKRXvKLEU/LWjwMg1vaDhRKO5PQZjWyVJk8ZGE93askL1YQs5RTuukK4HFKbz1faSQ9DsC1GaK8O7DrER2zzc96yi5inZCISoXRdhP2SyDxY/OtTZZhaHoD3rrjxkjhnM+H8151oRbkTKKPLTsq/PBhLsUc3KwK1uLFo6eYxwk7s+VMudTbK0L60YPGD/Z0j50Mcdp8kiHTRAYSWzva2+kp8Q2TEtRdOA3QmDkPK0kWSRLI1aoJUk+G1gCb4THmH78djnp+bEvTe5lsIXbGIF8LS5Y7FgiUaYbux620BskUOGyCOEd3G7b5XU4QtYN+P2VIwx9ywM4uvB50xcJt83wolk0fRdE2mjy5w5vG/ANOP5+HyQjRdCMFyItI2J++OI20hsneG+M3EpaKE2Kik3E2e7XIdo7MbfirfYpBfJH564n7dIIBXik9Dpam71CoV7QvwRNSXqFArNx8cye9PXWoJEBTUxyIQjzvaZbSZ7ko5gf1buR5DKJGO1519GJLyI0N6RJz0ZZY0MaZd+mz9A+dHU7ycuRmF9+vBXrO6CUBt/CtV1UxngYDgaDIfI7Izeh+Y7jNKXQnY/GqXsQEolEIpFIJBKJRCKRSCQSiUQikUgkEolE0lT+B4h2dnif2MTUAAAAAElFTkSuQmCC", DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", DB.Clickhouse: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAO4AAADUCAMAAACs0e/bAAAANlBMVEX/////zAD/yQD/AAD//fX/8Mf/9uD/ygD/6a7/0C3/CAj/rq7/0QD/hQz/67T/zhv/fwD/za9VhqZUAAABG0lEQVR4nO3auQ3DQBRDQa1un7L7b1YVUJkB4++8nMHkHFpq24fUPKbRuMTRmkdrHC15NMfRVbi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLj/xn0ft1RF7h5HrSR3w8XFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXF/Rm3rytKZ0ejq3BxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXGLcZ+pVpL7Sn2+Fbn3KfXAxcXFxcXFxcXFxcXFxcXFxcXFxcXFxe2ZewKjx49mqHXf2AAAAABJRU5ErkJggg==", From 2b49e56192c55f201ee287be3b868ec95e4bb7f1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 14 Jan 2026 17:07:53 -0600 Subject: [PATCH 2/4] git commit -m "feat: add Apache Cassandra vector database support Adds complete implementation of Apache Cassandra (5.0+) and DataStax Astra DB support with Storage-Attached Indexes (SAI) for vector search. Includes cosine, L2 (Euclidean), and dot product similarity functions. --- pyproject.toml | 1 + vectordb_bench/backend/clients/__init__.py | 18 + .../backend/clients/cassandra/__init__.py | 1 + .../backend/clients/cassandra/cassandra.py | 337 ++++++++++++++++++ .../backend/clients/cassandra/cli.py | 84 +++++ .../backend/clients/cassandra/config.py | 127 +++++++ vectordb_bench/cli/vectordbbench.py | 2 + vectordb_bench/frontend/config/styles.py | 1 + 8 files changed, 571 insertions(+) create mode 100644 vectordb_bench/backend/clients/cassandra/__init__.py create mode 100644 vectordb_bench/backend/clients/cassandra/cassandra.py create mode 100644 vectordb_bench/backend/clients/cassandra/cli.py create mode 100644 vectordb_bench/backend/clients/cassandra/config.py diff --git a/pyproject.toml b/pyproject.toml index 63c585c55..a1d720bc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,7 @@ mongodb = [ "pymongo" ] mariadb = [ "mariadb" ] tidb = [ "PyMySQL" ] cockroachdb = [ "psycopg[binary,pool]", "pgvector" ] +cassandra = [ "cassandra-driver" ] clickhouse = [ "clickhouse-connect" ] vespa = [ "pyvespa" ] lancedb = [ "lancedb" ] diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index d69c54504..4058b134a 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -46,6 +46,7 @@ class DB(Enum): MongoDB = "MongoDB" TiDB = "TiDB" CockroachDB = "CockroachDB" + Cassandra = "Cassandra" Clickhouse = "Clickhouse" Vespa = "Vespa" LanceDB = "LanceDB" @@ -184,6 +185,12 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 from .cockroachdb.cockroachdb import CockroachDB return CockroachDB + + if self == DB.Cassandra: + from .cassandra.cassandra import Cassandra + + return Cassandra + if self == DB.Doris: from .doris.doris import Doris @@ -358,6 +365,12 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 from .cockroachdb.config import CockroachDBConfig return CockroachDBConfig + + if self == DB.Cassandra: + from .cassandra.config import CassandraConfig + + return CassandraConfig + if self == DB.Doris: from .doris.config import DorisConfig @@ -514,6 +527,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915 return _cockroachdb_case_config.get(index_type) + if self == DB.Cassandra: + from .cassandra.config import CassandraIndexConfig + + return CassandraIndexConfig + if self == DB.Vespa: from .vespa.config import VespaHNSWConfig diff --git a/vectordb_bench/backend/clients/cassandra/__init__.py b/vectordb_bench/backend/clients/cassandra/__init__.py new file mode 100644 index 000000000..e3b882ebc --- /dev/null +++ b/vectordb_bench/backend/clients/cassandra/__init__.py @@ -0,0 +1 @@ +# Cassandra vector database client diff --git a/vectordb_bench/backend/clients/cassandra/cassandra.py b/vectordb_bench/backend/clients/cassandra/cassandra.py new file mode 100644 index 000000000..97906da94 --- /dev/null +++ b/vectordb_bench/backend/clients/cassandra/cassandra.py @@ -0,0 +1,337 @@ +import logging +from contextlib import contextmanager + +from cassandra.auth import PlainTextAuthProvider +from cassandra.cluster import Cluster +from cassandra.query import BatchStatement, BatchType + +from ..api import VectorDB +from .config import CassandraIndexConfig + +log = logging.getLogger(__name__) + + +class Cassandra(VectorDB): + """Cassandra vector database client. + + Supports both regular Cassandra (5.0+) and DataStax Astra DB + with vector search capabilities using Storage-Attached Indexes (SAI). + """ + + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: CassandraIndexConfig, + collection_name: str = "vdb_bench_collection", + drop_old: bool = False, + **kwargs, + ): + """Initialize Cassandra client. + + Args: + dim: Vector dimension + db_config: Database configuration dictionary from CassandraConfig.to_dict() + db_case_config: Index configuration + collection_name: Table name for vector storage + drop_old: Whether to drop existing table + """ + self.dim = dim + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.keyspace = db_config["keyspace"] + + # Field names + self.id_field = "id" + self.vector_field = "vector" + + # Initialize connection to setup keyspace/table and drop if needed + cluster, session = self._create_cluster_and_session() + + # Create keyspace if not exists (must be done before dropping/creating tables) + self._create_keyspace(session) + + if drop_old: + log.info(f"Dropping old table: {self.keyspace}.{self.table_name}") + session.execute(f"DROP TABLE IF EXISTS {self.keyspace}.{self.table_name}") + + # Create table + self._create_table(session, dim) + + # Create index immediately after table creation + self._create_index(session) + + # Close initial connection + cluster.shutdown() + self.cluster = None + self.session = None + + def _create_cluster_and_session(self): + """Create Cassandra cluster and session based on configuration. + + Returns: + Tuple of (Cluster, Session) + """ + config = self.db_config + + if "cloud" in config: + # Astra DB with Secure Connect Bundle + cloud_config = config["cloud"] + + # Setup authentication + if "auth_provider_token" in config: + auth_provider = PlainTextAuthProvider("token", config["auth_provider_token"]) + elif "auth_provider_username" in config: + auth_provider = PlainTextAuthProvider( + config["auth_provider_username"], + config["auth_provider_password"] + ) + else: + auth_provider = None + + cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider) + else: + # Regular Cassandra + contact_points = config.get("contact_points", ["localhost"]) + port = config.get("port", 9042) + + if "auth_provider_username" in config: + auth_provider = PlainTextAuthProvider( + config["auth_provider_username"], + config["auth_provider_password"] + ) + else: + auth_provider = None + + cluster = Cluster(contact_points=contact_points, port=port, auth_provider=auth_provider) + + session = cluster.connect() + return cluster, session + + def _create_keyspace(self, session): + """Create keyspace if it doesn't exist. + + Args: + session: Cassandra session + """ + # First try to use the keyspace if it already exists + try: + session.set_keyspace(self.keyspace) + log.info(f"Using existing keyspace: {self.keyspace}") + return + except Exception: + # Keyspace doesn't exist, try to create it + log.info(f"Keyspace {self.keyspace} does not exist, attempting to create it") + + # Try to create the keyspace + try: + replication_strategy = self.db_config.get("replication_strategy", "NetworkTopologyStrategy") + replication_factor = self.db_config.get("replication_factor", 3) + datacenter_name = self.db_config.get("datacenter_name", "datacenter1") + + # Build replication settings based on strategy + if replication_strategy == "NetworkTopologyStrategy": + replication_settings = f"{{'class': '{replication_strategy}', '{datacenter_name}': {replication_factor}}}" + else: + replication_settings = f"{{'class': '{replication_strategy}', 'replication_factor': {replication_factor}}}" + + cql = f""" + CREATE KEYSPACE IF NOT EXISTS {self.keyspace} + WITH REPLICATION = {replication_settings} + """ + print(cql) + session.execute(cql) + session.set_keyspace(self.keyspace) + log.info(f"Created and using keyspace: {self.keyspace} with replication strategy: {replication_strategy}, datacenter: {datacenter_name}, factor: {replication_factor}") + except Exception as e: + log.error(f"Failed to create keyspace {self.keyspace}: {e}") + # Try to use it anyway in case it was created by another process + try: + session.set_keyspace(self.keyspace) + log.info(f"Using keyspace: {self.keyspace}") + except Exception as e2: + log.error(f"Failed to use keyspace {self.keyspace}: {e2}") + raise + + def _create_table(self, session, dim: int): + """Create table with vector column. + + Args: + session: Cassandra session + dim: Vector dimension + """ + cql = f""" + CREATE TABLE IF NOT EXISTS {self.keyspace}.{self.table_name} ( + {self.id_field} bigint PRIMARY KEY, + {self.vector_field} VECTOR + ) + """ + session.execute(cql) + log.info(f"Created table {self.keyspace}.{self.table_name} with vector dimension {dim}") + + def _create_index(self, session): + """Create SAI vector index for optimized vector search. + + Args: + session: Cassandra session + """ + index_name = f"{self.table_name}_vector_idx" + index_params = self.case_config.index_param() + similarity_function = index_params["similarity_function"] + + # Drop existing index if present + cql_drop = f"DROP INDEX IF EXISTS {self.keyspace}.{index_name}" + session.execute(cql_drop) + log.info(f"Dropped existing index {index_name} if present") + + # Create SAI vector index + cql = f""" + CREATE CUSTOM INDEX {index_name} + ON {self.keyspace}.{self.table_name} ({self.vector_field}) + USING 'StorageAttachedIndex' + WITH OPTIONS = {{'similarity_function': '{similarity_function}'}} + """ + session.execute(cql) + log.info(f"Created vector index {index_name} with similarity function {similarity_function}") + + @contextmanager + def init(self): + """Initialize Cassandra client and cleanup when done. + + Yields control to execute operations within the context. + """ + try: + log.debug("Initializing Cassandra connection") + self.cluster, self.session = self._create_cluster_and_session() + self.session.set_keyspace(self.keyspace) + log.debug(f"Successfully connected to keyspace: {self.keyspace}") + yield + except Exception as e: + log.error(f"Failed to initialize Cassandra connection: {e}") + raise + finally: + if self.cluster is not None: + log.debug("Shutting down Cassandra connection") + self.cluster.shutdown() + self.cluster = None + self.session = None + + def need_normalize_cosine(self) -> bool: + """Whether database requires normalized vectors for cosine similarity. + + Cassandra handles cosine normalization internally. + + Returns: + False - Cassandra handles normalization + """ + return False + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> tuple[int, Exception | None]: + """Insert embeddings using batch statements. + + Args: + embeddings: List of vector embeddings + metadata: List of IDs for each embedding + **kwargs: Additional parameters (unused) + + Returns: + Tuple of (count of inserted records, exception if any) + """ + if self.session is None: + log.error("Cannot insert: session is None. Make sure insert is called within init() context manager.") + return 0, RuntimeError("Session not initialized") + + try: + # Cassandra batch statements have size limits, so batch in chunks + batch_size = self.case_config.batch_size + total_inserted = 0 + + insert_cql = f""" + INSERT INTO {self.keyspace}.{self.table_name} ({self.id_field}, {self.vector_field}) + VALUES (?, ?) + """ + prepared = self.session.prepare(insert_cql) + + for i in range(0, len(embeddings), batch_size): + batch = BatchStatement(batch_type=BatchType.UNLOGGED) + end_idx = min(i + batch_size, len(embeddings)) + + for id_, embedding in zip(metadata[i:end_idx], embeddings[i:end_idx], strict=False): + batch.add(prepared, (id_, embedding)) + + self.session.execute(batch) + total_inserted += (end_idx - i) + + if (i // batch_size) % 10 == 0 and i > 0: + log.debug(f"Inserted {total_inserted} embeddings so far...") + + log.info(f"Successfully inserted {total_inserted} embeddings") + return total_inserted, None + except Exception as e: + log.error(f"Error inserting embeddings: {e}") + return total_inserted if 'total_inserted' in locals() else 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + **kwargs, + ) -> list[int]: + """Search for similar vectors using ANN (Approximate Nearest Neighbor). + + Args: + query: Query vector + k: Number of results to return + filters: Optional filters (not implemented for Cassandra) + **kwargs: Additional parameters (unused) + + Returns: + List of IDs ordered by similarity (most similar first) + """ + if self.session is None: + log.error("Cannot search: session is None. Make sure search is called within init() context manager.") + raise RuntimeError("Session not initialized. Call search within init() context manager.") + + try: + # Cassandra uses ANN OF for vector similarity search + # The similarity function is determined by the index + cql = f""" + SELECT {self.id_field} + FROM {self.keyspace}.{self.table_name} + ORDER BY {self.vector_field} ANN OF %s + LIMIT %s + """ + + results = self.session.execute(cql, (query, k)) + result_list = [row[0] for row in results] + log.debug(f"Search returned {len(result_list)} results") + return result_list + except Exception as e: + log.error(f"Search query failed: {e}. This usually indicates the vector index hasn't been created. Query: {cql[:100]}") + raise + + def optimize(self, data_size: int | None = None) -> None: + """Optimize operation - no action needed since index is created during initialization. + + The index is now created immediately after table creation in __init__, + before any data is inserted. This method is kept for API compatibility. + + Args: + data_size: Size of data (unused, kept for API compatibility) + """ + log.info("Index already created during initialization - no optimization needed") + pass + + def ready_to_load(self) -> None: + """Prepare for data loading. + + Cassandra is always ready to load data. + """ + pass diff --git a/vectordb_bench/backend/clients/cassandra/cli.py b/vectordb_bench/backend/clients/cassandra/cli.py new file mode 100644 index 000000000..99d1303ec --- /dev/null +++ b/vectordb_bench/backend/clients/cassandra/cli.py @@ -0,0 +1,84 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB +from vectordb_bench.cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + +DBTYPE = DB.Cassandra + + +class CassandraTypeDict(CommonTypedDict): + """CLI parameters for Cassandra vector database.""" + # Connection parameters for regular Cassandra + host: Annotated[ + str | None, + click.option("--host", type=str, help="Cassandra host (for regular Cassandra)", required=False), + ] + port: Annotated[ + int, + click.option("--port", type=int, help="Cassandra port", default=9042), + ] + + # Connection parameter for Astra DB + secure_connect_bundle: Annotated[ + str | None, + click.option( + "--secure-connect-bundle", + type=str, + help="Path to Secure Connect Bundle zip file (for Astra DB)", + required=False, + ), + ] + + # Authentication parameters + username: Annotated[ + str | None, + click.option("--username", type=str, help="Cassandra username", required=False), + ] + password: Annotated[ + str | None, + click.option("--password", type=str, help="Cassandra password", required=False), + ] + token: Annotated[ + str | None, + click.option("--token", type=str, help="Astra DB token", required=False), + ] + + # Keyspace parameter + keyspace: Annotated[ + str, + click.option("--keyspace", type=str, help="Cassandra keyspace", default="vdb_bench"), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(CassandraTypeDict) +def Cassandra(**parameters: Unpack[CassandraTypeDict]): + """Run VectorDB benchmark with Cassandra. + + Supports both regular Cassandra (use --host and --port) and + DataStax Astra DB (use --secure-connect-bundle). + """ + from .config import CassandraConfig, CassandraIndexConfig + + run( + db=DBTYPE, + db_config=CassandraConfig( + host=parameters["host"], + port=parameters["port"], + secure_connect_bundle=parameters["secure_connect_bundle"], + username=parameters["username"], + password=SecretStr(parameters["password"]) if parameters["password"] else None, + token=SecretStr(parameters["token"]) if parameters["token"] else None, + keyspace=parameters["keyspace"], + ), + db_case_config=CassandraIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/cassandra/config.py b/vectordb_bench/backend/clients/cassandra/config.py new file mode 100644 index 000000000..07a05801c --- /dev/null +++ b/vectordb_bench/backend/clients/cassandra/config.py @@ -0,0 +1,127 @@ +import os +from pathlib import Path + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, MetricType + + +class CassandraConfig(DBConfig, BaseModel): + """Configuration for Cassandra vector database connections. + + Supports two connection modes: + 1. Regular Cassandra: Use host and port parameters + 2. DataStax Astra DB: Use secure_connect_bundle parameter + """ + # Regular Cassandra connection parameters + host: str | None = None + port: int = 9042 + + # DataStax Astra DB connection (mutually exclusive with host/port) + secure_connect_bundle: str | None = None + + # Authentication + username: str | None = None + password: SecretStr | None = None + token: SecretStr | None = None # For Astra DB token authentication + + # Keyspace + keyspace: str = "vdb_bench" + + # Table name + table_name: str | None = None # Custom table name (defaults to collection_name if not specified) + + # Replication settings + replication_strategy: str = "NetworkTopologyStrategy" # SimpleStrategy or NetworkTopologyStrategy + replication_factor: int = 3 # Replication factor (use 1 for single-node) + datacenter_name: str = "datacenter1" # Datacenter name for NetworkTopologyStrategy + + def to_dict(self) -> dict: + """Convert configuration to dictionary for cassandra-driver. + + Returns connection parameters formatted for Cluster initialization. + """ + config = {} + + if self.secure_connect_bundle: + # Resolve relative paths to absolute paths + bundle_path = self.secure_connect_bundle + if not os.path.isabs(bundle_path): + # Convert relative path to absolute + bundle_path = os.path.abspath(bundle_path) + + # Verify the bundle file exists + if not os.path.exists(bundle_path): + raise FileNotFoundError( + f"Secure connect bundle not found: {bundle_path}. " + f"Original path: {self.secure_connect_bundle}" + ) + + # Astra DB mode with Secure Connect Bundle + config["cloud"] = { + "secure_connect_bundle": bundle_path + } + # Astra DB uses token-based auth or username/password + if self.token: + config["auth_provider_token"] = self.token.get_secret_value() + elif self.username and self.password: + config["auth_provider_username"] = self.username + config["auth_provider_password"] = self.password.get_secret_value() + else: + # Regular Cassandra mode + config["contact_points"] = [self.host] if self.host else ["localhost"] + config["port"] = self.port + if self.username and self.password: + config["auth_provider_username"] = self.username + config["auth_provider_password"] = self.password.get_secret_value() + + config["keyspace"] = self.keyspace + config["replication_strategy"] = self.replication_strategy + config["replication_factor"] = self.replication_factor + config["datacenter_name"] = self.datacenter_name + return config + + +class CassandraIndexConfig(BaseModel, DBCaseConfig): + """Index configuration for Cassandra vector search. + + Cassandra 5.0+ uses Storage-Attached Indexes (SAI) for vector search + with support for multiple similarity functions. + """ + metric_type: MetricType = MetricType.COSINE + batch_size: int = 1000 # Batch size for insert operations (default: 1000) + + def parse_metric(self) -> str: + """Map VectorDBBench metric types to Cassandra similarity functions. + + Returns: + Cassandra similarity function name: EUCLIDEAN, DOT_PRODUCT, or COSINE + """ + if self.metric_type == MetricType.L2: + return "EUCLIDEAN" + if self.metric_type == MetricType.IP: + return "DOT_PRODUCT" + if self.metric_type == MetricType.COSINE: + return "COSINE" + raise ValueError(f"Unsupported metric type: {self.metric_type}") + + def index_param(self) -> dict: + """Return parameters for creating the SAI vector index. + + Returns: + Dictionary with similarity_function and index_type + """ + return { + "similarity_function": self.parse_metric(), + "index_type": "SAI" # Storage-Attached Index + } + + def search_param(self) -> dict: + """Return parameters for vector search queries. + + Returns: + Dictionary with metric for search operations + """ + return { + "metric": self.parse_metric() + } diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 76e9534f9..c555b494a 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,6 +1,7 @@ from ..backend.clients.alisql.cli import AliSQLHNSW from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.cassandra.cli import Cassandra from ..backend.clients.chroma.cli import Chroma from ..backend.clients.clickhouse.cli import Clickhouse from ..backend.clients.cockroachdb.cli import CockroachDB as CockroachDBCli @@ -55,6 +56,7 @@ cli.add_command(MariaDBHNSW) cli.add_command(TiDB) cli.add_command(CockroachDBCli) +cli.add_command(Cassandra) cli.add_command(Clickhouse) cli.add_command(Vespa) cli.add_command(LanceDB) diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index bce4561fd..48db60dfc 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -71,6 +71,7 @@ def getPatternShape(i): DB.Doris: "https://doris.apache.org/images/logo.svg", DB.TurboPuffer: "https://turbopuffer.com/logo2.png", DB.CockroachDB: "https://raw.githubusercontent.com/cockroachdb/cockroach/master/docs/media/cockroach_db.png", + DB.Cassandra: "https://upload.wikimedia.org/wikipedia/commons/1/1e/Apache-cassandra-icon.png", } # RedisCloud color: #0D6EFD From ea4b40079e4c68191198640cf37f58398a7f659e Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 13:07:22 -0600 Subject: [PATCH 3/4] - --- .../backend/clients/oss_opensearch/cli.py | 29 ++++++++++ .../backend/clients/oss_opensearch/config.py | 9 ++- .../clients/oss_opensearch/oss_opensearch.py | 56 +++++++++++++++++-- vectordb_bench/backend/runner/mp_runner.py | 8 +++ vectordb_bench/backend/task_runner.py | 8 +++ 5 files changed, 102 insertions(+), 8 deletions(-) diff --git a/vectordb_bench/backend/clients/oss_opensearch/cli.py b/vectordb_bench/backend/clients/oss_opensearch/cli.py index 0c4b694ea..844e3f399 100644 --- a/vectordb_bench/backend/clients/oss_opensearch/cli.py +++ b/vectordb_bench/backend/clients/oss_opensearch/cli.py @@ -22,6 +22,10 @@ class OSSOpenSearchTypedDict(TypedDict): port: Annotated[int, click.option("--port", type=int, default=80, help="Db Port")] user: Annotated[str, click.option("--user", type=str, help="Db User")] password: Annotated[str, click.option("--password", type=str, help="Db password")] + use_ssl: Annotated[ + bool | None, + click.option("--use-ssl", type=bool, default=None, help="Use SSL (defaults to True when port=443)", required=False), + ] number_of_shards: Annotated[ int, click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1), @@ -140,6 +144,26 @@ class OSSOpenSearchTypedDict(TypedDict): ), ] + number_of_indexing_clients: Annotated[ + int, + click.option( + "--number-of-indexing-clients", + type=int, + help="Number of concurrent clients for data insertion", + default=1, + ), + ] + + use_local_preference: Annotated[ + bool, + click.option( + "--use-local-preference", + type=bool, + help="Use _only_local search preference for single-shard indices (disable for managed/cloud deployments)", + default=True, + ), + ] + class OSSOpenSearchHNSWTypedDict(CommonTypedDict, OSSOpenSearchTypedDict, HNSWFlavor1): ... @@ -156,6 +180,8 @@ def OSSOpenSearch(**parameters: Unpack[OSSOpenSearchHNSWTypedDict]): port=parameters["port"], user=parameters["user"], password=SecretStr(parameters["password"]), + use_ssl=parameters.get("use_ssl"), + db_label=parameters["db_label"], ), db_case_config=OSSOpenSearchIndexConfig( number_of_shards=parameters["number_of_shards"], @@ -174,6 +200,9 @@ def OSSOpenSearch(**parameters: Unpack[OSSOpenSearchHNSWTypedDict]): quantization_type=OSSOpenSearchQuantization(parameters["quantization_type"]), confidence_interval=parameters["confidence_interval"], clip=parameters["clip"], + metric_type_name=parameters["metric_type"], + number_of_indexing_clients=parameters["number_of_indexing_clients"], + use_local_preference=parameters["use_local_preference"], ), **parameters, ) diff --git a/vectordb_bench/backend/clients/oss_opensearch/config.py b/vectordb_bench/backend/clients/oss_opensearch/config.py index 83fed3d58..7af606939 100644 --- a/vectordb_bench/backend/clients/oss_opensearch/config.py +++ b/vectordb_bench/backend/clients/oss_opensearch/config.py @@ -13,9 +13,10 @@ class OSSOpenSearchConfig(DBConfig, BaseModel): port: int = 80 user: str | None = None password: SecretStr | None = None + use_ssl: bool | None = None def to_dict(self) -> dict: - use_ssl = self.port == 443 + use_ssl = self.use_ssl if self.use_ssl is not None else (self.port == 443) http_auth = ( (self.user, self.password.get_secret_value()) if self.user is not None and self.password is not None and len(self.user) != 0 and len(self.password) != 0 @@ -110,6 +111,7 @@ class OSSOpenSearchIndexConfig(BaseModel, DBCaseConfig): on_disk: bool = False compression_level: str = CompressionLevel.LEVEL_32X oversample_factor: float = 1.0 + use_local_preference: bool = True @validator("quantization_type", pre=True, always=True) def validate_quantization_type(cls, value: any): @@ -160,6 +162,7 @@ def __eq__(self, obj: any): and self.on_disk == obj.on_disk and self.compression_level == obj.compression_level and self.oversample_factor == obj.oversample_factor + and self.use_local_preference == obj.use_local_preference ) def __hash__(self) -> int: @@ -181,12 +184,14 @@ def __hash__(self) -> int: self.on_disk, self.compression_level, self.oversample_factor, + self.use_local_preference, ) ) def parse_metric(self) -> str: log.info(f"User specified metric_type: {self.metric_type_name}") - self.metric_type = MetricType[self.metric_type_name.upper()] + if self.metric_type_name is not None: + self.metric_type = MetricType[self.metric_type_name.upper()] if self.metric_type == MetricType.IP: return "innerproduct" if self.metric_type == MetricType.COSINE: diff --git a/vectordb_bench/backend/clients/oss_opensearch/oss_opensearch.py b/vectordb_bench/backend/clients/oss_opensearch/oss_opensearch.py index f71850a17..910aebf83 100644 --- a/vectordb_bench/backend/clients/oss_opensearch/oss_opensearch.py +++ b/vectordb_bench/backend/clients/oss_opensearch/oss_opensearch.py @@ -5,6 +5,7 @@ from typing import Any, Final from opensearchpy import OpenSearch +from opensearchpy.exceptions import TransportError from packaging.version import Version from packaging.version import parse as parse_version @@ -168,7 +169,7 @@ def build_search_kwargs( "body": body, "size": k, "_source": False, - "preference": "_only_local" if self.case_config.number_of_shards == 1 else None, + "preference": "_only_local" if (self.case_config.number_of_shards == 1 and self.case_config.use_local_preference) else None, "routing": routing_key, } @@ -448,7 +449,7 @@ def _insert_with_multiple_clients( for i in range(0, len(embeddings_list), chunk_size): end = min(i + chunk_size, len(embeddings_list)) - chunks.append((embeddings_list[i:end], metadata[i:end], labels_data[i:end])) + chunks.append((embeddings_list[i:end], metadata[i:end], labels_data[i:end] if labels_data is not None else None)) clients = [OpenSearch(**self.db_config) for _ in range(min(num_clients, len(chunks)))] log.info(f"OSS_OpenSearch using {len(clients)} parallel clients for data insertion") @@ -485,7 +486,7 @@ def insert_chunk(client_idx: int, chunk_idx: int): time.sleep(10) return self._insert_with_single_client(embeddings, metadata, labels_data) - response = self.client.indices.stats(self.index_name) + response = self.client.indices.stats(index=self.index_name) log.info( f"""Total document count in index after parallel insertion: {response['_all']['primaries']['indexing']['index_total']}""", @@ -516,6 +517,25 @@ def _update_ef_search_before_search(self, client: OpenSearch): except Exception as e: log.warning(f"Failed to update ef_search parameter before search: {e}") + def _build_curl_search(self, body: dict) -> str: + """Build a curl command to reproduce a search request.""" + import json + host_entry = self.db_config.get("hosts", [{}])[0] + host = host_entry.get("host", "localhost") + port = host_entry.get("port", 9200) + use_ssl = self.db_config.get("use_ssl", False) + scheme = "https" if use_ssl else "http" + url = f"{scheme}://{host}:{port}/{self.index_name}/_search" + + auth = self.db_config.get("http_auth", ()) + auth_flag = f" -u '{auth[0]}:{auth[1]}'" if len(auth) == 2 else "" + + body_json = json.dumps(body, indent=2) + return ( + f"curl -s{auth_flag} -H 'Content-Type: application/json' " + f"'{url}' -d '\n{body_json}'" + ) + def search_embedding( self, query: list[float], @@ -542,11 +562,15 @@ def search_embedding( search_kwargs = search_query_builder.build_search_kwargs( self.index_name, body, k, self.id_col_name, self.routing_key ) + if log.isEnabledFor(logging.DEBUG): + import copy + debug_kwargs = copy.deepcopy(search_kwargs) + with suppress(Exception): + debug_kwargs["body"]["query"]["knn"][self.vector_col_name]["vector"] = "[...]" + log.debug(f"Search kwargs (index={self.index_name}, k={k}, id_col={self.id_col_name}, routing={self.routing_key}): {debug_kwargs}") response = self.client.search(**search_kwargs) - log.debug(f"Search took: {response['took']}") - log.debug(f"Search shards: {response['_shards']}") - log.debug(f"Search hits total: {response['hits']['total']}") + log.debug(f"Search response: {response}") try: if self.id_col_name == "_id": # Get _id directly from hit metadata @@ -565,6 +589,26 @@ def search_embedding( return [] else: return result_ids + except TransportError as e: + detail = e.info if isinstance(e.info, dict) else {} + root_causes = detail.get("error", {}).get("root_cause", []) + reason = detail.get("error", {}).get("reason") or e.error + log.error( + f"Failed to search: {self.index_name} " + f"status={e.status_code} error={reason!r} " + f"root_causes={root_causes!r}" + ) + log.error(f"Failed search full response: {e.info!r}") + import copy + redacted_body = copy.deepcopy(body) + with suppress(Exception): + redacted_body["query"]["knn"][self.vector_col_name]["vector"] = "[...]" + log.error(f"Failed search body was: {redacted_body}") + log.error(f"Reproduce with (vector redacted):\n{self._build_curl_search(redacted_body)}") + log.error(f"Reproduce with (full vector):\n{self._build_curl_search(body)}") + raise OpenSearchError( + f"Search failed with status {e.status_code}: {reason}" + ) from e except Exception as e: log.warning(f"Failed to search: {self.index_name} error: {e!s}") raise e from None diff --git a/vectordb_bench/backend/runner/mp_runner.py b/vectordb_bench/backend/runner/mp_runner.py index 9133e407a..1749c1ab9 100644 --- a/vectordb_bench/backend/runner/mp_runner.py +++ b/vectordb_bench/backend/runner/mp_runner.py @@ -73,6 +73,7 @@ def search( start_time = time.perf_counter() count = 0 latencies = [] + first_error = None while time.perf_counter() < start_time + self.duration: s = time.perf_counter() try: @@ -81,6 +82,8 @@ def search( latencies.append(time.perf_counter() - s) except Exception as e: log.warning(f"VectorDB search_embedding error: {e}") + if first_error is None: + first_error = e # loop through the test data idx = idx + 1 if idx < num - 1 else 0 @@ -97,6 +100,9 @@ def search( f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}" ) + if count == 0 and first_error is not None: + raise first_error + return (count, total_dur, latencies) @staticmethod @@ -132,6 +138,8 @@ def _run_all_concurrencies_mem_efficient(self): start = time.perf_counter() all_count = sum([r.result()[0] for r in future_iter]) latencies = sum([r.result()[2] for r in future_iter], start=[]) + if not latencies: + raise RuntimeError(f"No successful searches at concurrency {conc}") latency_p99 = np.percentile(latencies, 99) latency_p95 = np.percentile(latencies, 95) latency_avg = np.mean(latencies) diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 8224a0415..fe59fb14a 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -122,6 +122,13 @@ def init_db(self, drop_old: bool = True) -> None: if "collection_name" in db_config_dict and not collection_name: collection_name = db_config_dict.pop("collection_name") + # For OSSOpenSearch, use db_label as index_name if set + index_name = None + if self.config.db == DB.OSSOpenSearch: + db_label = self.config.db_config.db_label + if db_label: + index_name = re.sub(r"[^a-z0-9_\-]+", "_", db_label.lower()).strip("_-") + self.db = db_cls( dim=self.ca.dataset.data.dim, db_config=db_config_dict, @@ -129,6 +136,7 @@ def init_db(self, drop_old: bool = True) -> None: drop_old=drop_old, with_scalar_labels=self.ca.with_scalar_labels, **({"collection_name": collection_name} if collection_name else {}), + **({"index_name": index_name} if index_name else {}), ) def _pre_run(self, drop_old: bool = True): From 968b8fa74d9ff758f8e7ef7e633718274e9bfd1b Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Mar 2026 15:24:11 -0600 Subject: [PATCH 4/4] - --- .../backend/clients/astradb/astradb.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vectordb_bench/backend/clients/astradb/astradb.py b/vectordb_bench/backend/clients/astradb/astradb.py index df7eb0049..a912c23b3 100644 --- a/vectordb_bench/backend/clients/astradb/astradb.py +++ b/vectordb_bench/backend/clients/astradb/astradb.py @@ -6,8 +6,11 @@ from astrapy.constants import VectorMetric from astrapy.info import CollectionDefinition +from astrapy.exceptions import CollectionInsertManyException + from ..api import VectorDB from .config import AstraDBIndexConfig +from vectordb_bench.backend.filter import FilterOp log = logging.getLogger(__name__) @@ -17,6 +20,11 @@ class AstraDBError(Exception): class AstraDB(VectorDB): + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + ] + def __init__( self, dim: int, @@ -125,6 +133,21 @@ def insert_embeddings( try: result = self.collection.insert_many(documents, ordered=False) return len(result.inserted_ids), None + except CollectionInsertManyException as e: + # Check if all failures are due to already-existing documents + all_duplicate = all( + any( + d.error_code == "DOCUMENT_ALREADY_EXISTS" + for d in getattr(exc, "error_descriptors", []) + ) + for exc in e.exceptions + ) + if all_duplicate or not e.exceptions: + skipped = len(documents) - len(e.inserted_ids) + log.warning(f"Skipping {skipped} already-existing document(s), continuing load") + return len(documents), None + log.warning(f"InsertMany partial failure (non-duplicate errors): {e}") + return 0, e except Exception as e: log.exception("Error inserting embeddings") return 0, e