Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions integration/test_collection_diversity_hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Integration tests for hybrid search + MMR diversity selection.

``DiversitySelection`` passed inside ``HybridVector.near_vector`` /
``HybridVector.near_text`` is applied by the server as a post-fusion MMR pass
(Weaviate >= 1.39.0). These tests assert that ``balance=0`` (pure diversity)
produces a different ordering than ``balance=1`` (pure relevance), and that
``mmr.limit`` caps the result count.

The equivalent ``near_vector`` behaviour is covered in
``test_collection_diversity.py``.
"""

import pytest

from integration.conftest import CollectionFactory
from weaviate.classes.query import Diversity, HybridVector
from weaviate.collections.classes.config import Configure, DataType, Property
from weaviate.collections.classes.data import DataObject

MIN_VERSION = (1, 39, 0)


def _skip_if_unsupported(collection) -> None:
if collection._connection._weaviate_version.is_lower_than(*MIN_VERSION):
pytest.skip("Hybrid diversity selection requires Weaviate >= 1.39.0")


def _create_clustered_collection(collection_factory: CollectionFactory):
"""Create a collection with 3 tight clusters (a, b, c) of vectors in 3D."""
collection = collection_factory(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)
_skip_if_unsupported(collection)
collection.data.insert_many(
[
DataObject(properties={"text": "a1"}, vector=[1.0, 0.0, 0.0]),
DataObject(properties={"text": "a2"}, vector=[0.95, 0.05, 0.0]),
DataObject(properties={"text": "a3"}, vector=[0.9, 0.1, 0.0]),
DataObject(properties={"text": "b1"}, vector=[0.0, 1.0, 0.0]),
DataObject(properties={"text": "b2"}, vector=[0.05, 0.95, 0.0]),
DataObject(properties={"text": "c1"}, vector=[0.0, 0.0, 1.0]),
]
)
return collection


def _create_large_collection(collection_factory: CollectionFactory, n_items: int = 50):
"""Create a collection with enough items (>25) that a small mmr.limit is distinguishable from the server's default limit."""
collection = collection_factory(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)
_skip_if_unsupported(collection)
collection.data.insert_many(
[
DataObject(properties={"text": f"t{i}"}, vector=[1.0 - 0.001 * i, 0.0, 0.0])
for i in range(n_items)
]
)
return collection


def test_hybrid_near_vector_balance_0_differs_from_balance_1(
collection_factory: CollectionFactory,
) -> None:
"""Hybrid near-vector: balance=0 (diversity) must reorder vs balance=1 (relevance)."""
collection = _create_clustered_collection(collection_factory)
balance_0 = collection.query.hybrid(
query=None,
vector=HybridVector.near_vector(
vector=[1.0, 0.0, 0.0],
diversity_selection=Diversity.mmr(limit=3, balance=0.0),
),
limit=3,
).objects
balance_1 = collection.query.hybrid(
query=None,
vector=HybridVector.near_vector(
vector=[1.0, 0.0, 0.0],
diversity_selection=Diversity.mmr(limit=3, balance=1.0),
),
limit=3,
).objects
assert [o.uuid for o in balance_0] != [o.uuid for o in balance_1]


def test_hybrid_near_vector_balance_1_matches_baseline(
collection_factory: CollectionFactory,
) -> None:
"""Hybrid near-vector with MMR balance=1 (pure relevance) matches the plain baseline."""
collection = _create_clustered_collection(collection_factory)
baseline = collection.query.hybrid(
query=None,
vector=HybridVector.near_vector(vector=[1.0, 0.0, 0.0]),
limit=3,
).objects
mmr_balance_1 = collection.query.hybrid(
query=None,
vector=HybridVector.near_vector(
vector=[1.0, 0.0, 0.0],
diversity_selection=Diversity.mmr(limit=3, balance=1.0),
),
limit=3,
).objects
assert [o.uuid for o in baseline] == [o.uuid for o in mmr_balance_1]


def test_hybrid_alpha_1_balance_0_differs_from_balance_1(
collection_factory: CollectionFactory,
) -> None:
"""Hybrid with explicit alpha=1.0 (pure vector) applies MMR like near_vector."""
collection = _create_clustered_collection(collection_factory)
balance_0 = collection.query.hybrid(
query="irrelevant",
alpha=1.0,
vector=HybridVector.near_vector(
vector=[1.0, 0.0, 0.0],
diversity_selection=Diversity.mmr(limit=3, balance=0.0),
),
limit=3,
).objects
balance_1 = collection.query.hybrid(
query="irrelevant",
alpha=1.0,
vector=HybridVector.near_vector(
vector=[1.0, 0.0, 0.0],
diversity_selection=Diversity.mmr(limit=3, balance=1.0),
),
limit=3,
).objects
assert [o.uuid for o in balance_0] != [o.uuid for o in balance_1]


def test_hybrid_respects_mmr_limit(
collection_factory: CollectionFactory,
) -> None:
"""Hybrid respects mmr.limit as the result-count cap when no outer limit is set."""
mmr_limit = 5
collection = _create_large_collection(collection_factory, n_items=50)

result = collection.query.hybrid(
query=None,
vector=HybridVector.near_vector(
vector=[1.0, 0.0, 0.0],
diversity_selection=Diversity.mmr(limit=mmr_limit, balance=0.5),
),
).objects
assert len(result) == mmr_limit
63 changes: 63 additions & 0 deletions test/collection/test_hybrid_diversity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Unit tests: hybrid search wires diversity_selection into the gRPC request.

Hybrid diversity is a post-fusion, hybrid-level operation, so the
``HybridVector.near_vector`` / ``HybridVector.near_text`` ``diversity_selection``
argument must populate the top-level ``Hybrid.selection.mmr`` in the
SearchRequest proto (not the nested ``near_vector`` / ``near_text`` selection).
"""

from weaviate.collections.grpc.query import _QueryGRPC
from weaviate.classes.query import Diversity, HybridVector
from weaviate.util import _ServerVersion


def _builder() -> _QueryGRPC:
return _QueryGRPC(
weaviate_version=_ServerVersion(1, 39, 0),
name="Dummy",
tenant=None,
consistency_level=None,
validate_arguments=True,
uses_125_api=True,
uses_127_api=True,
)


def test_hybrid_near_vector_sets_top_level_selection() -> None:
req = _builder().hybrid(
query=None,
vector=HybridVector.near_vector(
vector=[1.0, 0.0, 0.0],
diversity_selection=Diversity.mmr(limit=7, balance=0.0),
),
limit=7,
)
# Canonical location: top-level Hybrid.selection, not the nested near_vector.
mmr = req.hybrid_search.selection.mmr
assert mmr.limit == 7
assert mmr.balance == 0.0
assert not req.hybrid_search.near_vector.HasField("selection")


def test_hybrid_near_text_sets_top_level_selection() -> None:
req = _builder().hybrid(
query=None,
vector=HybridVector.near_text(
query="cats",
diversity_selection=Diversity.mmr(limit=3, balance=0.5),
),
limit=3,
)
mmr = req.hybrid_search.selection.mmr
assert mmr.limit == 3
assert mmr.balance == 0.5
assert not req.hybrid_search.near_text.HasField("selection")


def test_hybrid_without_selection_leaves_it_unset() -> None:
req = _builder().hybrid(
query=None,
vector=HybridVector.near_vector(vector=[1.0, 0.0, 0.0]),
limit=5,
)
assert not req.hybrid_search.HasField("selection")
10 changes: 10 additions & 0 deletions weaviate/collections/classes/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ class _HybridNearBase(_WeaviateInput):

distance: Optional[float] = None
certainty: Optional[float] = None
diversity_selection: Optional[MMR] = None


class _HybridNearText(_HybridNearBase):
Expand All @@ -772,17 +773,20 @@ class _HybridNearVector: # can't be a Pydantic model because of validation issu
vector: NearVectorInputType
distance: Optional[float]
certainty: Optional[float]
diversity_selection: Optional[MMR]

def __init__(
self,
*,
vector: NearVectorInputType,
distance: Optional[float] = None,
certainty: Optional[float] = None,
diversity_selection: Optional[MMR] = None,
) -> None:
self.vector = vector
self.distance = distance
self.certainty = certainty
self.diversity_selection = diversity_selection


HybridVectorType = Union[NearVectorInputType, _HybridNearText, _HybridNearVector]
Expand Down Expand Up @@ -897,6 +901,7 @@ def near_text(
distance: Optional[float] = None,
move_to: Optional[Move] = None,
move_away: Optional[Move] = None,
diversity_selection: Optional[MMR] = None,
) -> _HybridNearText:
"""Define a near text search to be used within a hybrid query.

Expand All @@ -906,6 +911,7 @@ def near_text(
distance: The maximum distance to search. If not specified, the default distance specified by the server is used.
move_to: Define the concepts that should be moved towards in the vector space during the search.
move_away: Define the concepts that should be moved away from in the vector space during the search.
diversity_selection: Apply diversity selection (e.g. MMR) to the hybrid results. Requires Weaviate >= 1.39.0.

Returns:
A `_HybridNearText` object to be used in the `vector` parameter of the `query.hybrid` and `generate.hybrid` search methods.
Expand All @@ -916,6 +922,7 @@ def near_text(
certainty=certainty,
move_to=move_to,
move_away=move_away,
diversity_selection=diversity_selection,
)

@staticmethod
Expand All @@ -924,12 +931,14 @@ def near_vector(
*,
certainty: Optional[float] = None,
distance: Optional[float] = None,
diversity_selection: Optional[MMR] = None,
) -> _HybridNearVector:
"""Define a near vector search to be used within a hybrid query.

Args:
certainty: The minimum similarity score to return. If not specified, the default certainty specified by the server is used.
distance: The maximum distance to search. If not specified, the default distance specified by the server is used.
diversity_selection: Apply diversity selection (e.g. MMR) to the hybrid results. Requires Weaviate >= 1.39.0.

Returns:
A `_HybridNearVector` object to be used in the `vector` parameter of the `query.hybrid` and `generate.hybrid` search methods.
Expand All @@ -938,6 +947,7 @@ def near_vector(
vector=vector,
distance=distance,
certainty=certainty,
diversity_selection=diversity_selection,
)


Expand Down
10 changes: 10 additions & 0 deletions weaviate/collections/grpc/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,15 @@ def _parse_hybrid(

near_text, near_vector, vector_bytes, vectors = None, None, None, None

# Hybrid diversity selection is a post-fusion, hybrid-level operation, so
# it is carried on the top-level Hybrid.selection field rather than on the
# near_text / near_vector sub-query.
hybrid_selection = (
vector.diversity_selection
if isinstance(vector, (_HybridNearText, _HybridNearVector))
else None
)

if vector is None:
pass
elif isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], float):
Expand Down Expand Up @@ -739,6 +748,7 @@ def _parse_hybrid(
vector_bytes=vector_bytes,
vector_distance=distance,
vectors=vectors,
selection=self._diversity_selection_to_grpc(hybrid_selection),
bm25_search_operator=base_search_pb2.SearchOperatorOptions(
operator=bm25_operator.operator,
minimum_or_tokens_match=bm25_operator.minimum_should_match
Expand Down
Loading