From b75ba500e4ea9f4ab59b64ee9dc9c855f8dd54f2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 02:01:08 +0000 Subject: [PATCH 1/5] Initial plan From 1aa74e76268f388723f8ca87930e3f4caad59138 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 3 Feb 2026 19:04:49 +0100 Subject: [PATCH 2/5] Support for multi valued dense vector fields (through nested vectors and diversifying children query) (#4051) --- changelog/unreleased/SOLR-18074.yml | 8 + .../transform/ChildDocTransformer.java | 120 +++- .../apache/solr/schema/DenseVectorField.java | 5 - .../org/apache/solr/schema/IndexSchema.java | 1 + .../NestedUpdateProcessorFactory.java | 124 ++-- .../collection1/conf/schema-densevector.xml | 20 +- .../solr/schema/DenseVectorFieldTest.java | 14 +- .../join/BlockJoinMultiValuedVectorsTest.java | 330 +++++++++++ ...ockJoinNestedVectorsParentQParserTest.java | 556 ++++++++++++++++++ .../join/BlockJoinNestedVectorsTest.java | 254 ++++++++ .../pages/dense-vector-search.adoc | 228 ++++++- 11 files changed, 1597 insertions(+), 63 deletions(-) create mode 100644 changelog/unreleased/SOLR-18074.yml create mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java create mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java create mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java diff --git a/changelog/unreleased/SOLR-18074.yml b/changelog/unreleased/SOLR-18074.yml new file mode 100644 index 000000000000..69dba9c2966b --- /dev/null +++ b/changelog/unreleased/SOLR-18074.yml @@ -0,0 +1,8 @@ +# See https://github.com/apache/solr/blob/main/dev-docs/changelog.adoc +title: Introducing support for multi valued dense vector representation in documents through nested vectors +type: added # added, changed, fixed, deprecated, removed, dependency_update, security, other +authors: + - name: Alessandro Benedetti +links: + - name: SOLR-18074 + url: https://issues.apache.org/jira/browse/SOLR-18074 diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index d8a4f0842264..789314a905b7 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -24,9 +24,13 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -35,6 +39,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BitSet; @@ -42,7 +47,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.solr.common.SolrDocument; import org.apache.solr.common.SolrException; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; import org.apache.solr.search.BitsFilteredPostingsEnum; import org.apache.solr.search.DocSet; import org.apache.solr.search.SolrDocumentFetcher; @@ -138,6 +145,20 @@ public void transform(SolrDocument rootDoc, int rootDocId) { final Bits liveDocs = leafReaderContext.reader().getLiveDocs(); final int segBaseId = leafReaderContext.docBase; final int segRootId = rootDocId - segBaseId; + Set multiValuedFLoatVectorFields = + this.getMultiValuedVectorFields( + searcher.getSchema(), childReturnFields, VectorEncoding.FLOAT32); + Set multiValuedByteVectorFields = + this.getMultiValuedVectorFields( + searcher.getSchema(), childReturnFields, VectorEncoding.BYTE); + if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) > 0 + && (multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) + != childReturnFields.getExplicitlyRequestedFieldNames().size()) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "When using the Child transformer to flatten nested vectors, all 'fl' must be " + + "multivalued vector fields"); + } // can return be -1 and that's okay (happens for very first block) final int segPrevRootId; @@ -219,8 +240,21 @@ public void transform(SolrDocument rootDoc, int rootDocId) { if (isAncestor) { // if this path has pending child docs, add them. - addChildrenToParent( - doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending + if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) { + addFlatMultiValuedVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedFLoatVectorFields, + VectorEncoding.FLOAT32); + addFlatMultiValuedVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedByteVectorFields, + VectorEncoding.BYTE); + } else { + addChildrenToParent( + doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending + } } // get parent path @@ -248,7 +282,20 @@ public void transform(SolrDocument rootDoc, int rootDocId) { assert pendingParentPathsToChildren.keySet().size() == 1; // size == 1, so get the last remaining entry - addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) { + addFlatMultiValuedVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedFLoatVectorFields, + VectorEncoding.FLOAT32); + addFlatMultiValuedVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedByteVectorFields, + VectorEncoding.BYTE); + } else { + addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + } } catch (IOException e) { // TODO DWS: reconsider this unusual error handling approach; shouldn't we rethrow? @@ -257,6 +304,25 @@ public void transform(SolrDocument rootDoc, int rootDocId) { } } + private Set getMultiValuedVectorFields( + IndexSchema schema, SolrReturnFields childReturnFields, VectorEncoding encoding) { + Set multiValuedVectorsFields = new HashSet<>(); + Set explicitlyRequestedFieldNames = + childReturnFields.getExplicitlyRequestedFieldNames(); + if (explicitlyRequestedFieldNames != null) { + for (String fieldName : explicitlyRequestedFieldNames) { + SchemaField sfield = schema.getFieldOrNull(fieldName); + if (sfield != null + && sfield.getType() instanceof DenseVectorField + && sfield.multiValued() + && ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) { + multiValuedVectorsFields.add(fieldName); + } + } + } + return multiValuedVectorsFields; + } + private static void addChildrenToParent( SolrDocument parent, Map> children) { for (Map.Entry> entry : children.entrySet()) { @@ -285,6 +351,54 @@ private static void addChildrenToParent( parent.setField(trimmedPath, children.get(0)); } + private void addFlatMultiValuedVectorsToParent( + SolrDocument parent, + Map> children, + Set multiValuedVectorFields, + VectorEncoding encoding) { + for (String multiValuedVectorField : multiValuedVectorFields) { + List solrDocuments = children.get(multiValuedVectorField); + List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); + for (SolrDocument singleVector : solrDocuments) { + List extractedVectors; + switch (encoding) { + case FLOAT32: + extractedVectors = + this.extractFloatVector(singleVector.getFieldValues(multiValuedVectorField)); + break; + case BYTE: + extractedVectors = + this.extractByteVector(singleVector.getFieldValues(multiValuedVectorField)); + break; + default: + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, "Unsupported vector encoding: " + encoding); + } + multiValuedVectors.add(extractedVectors); + } + parent.setField(multiValuedVectorField, multiValuedVectors); + } + } + + private List extractFloatVector(Collection fieldValues) { + List vector = new ArrayList<>(fieldValues.size()); + for (Object fieldValue : fieldValues) { + StoredField storedVectorValue = (StoredField) fieldValue; + vector.add(storedVectorValue.numericValue()); + } + return vector; + } + + private List extractByteVector(Collection singleVector) { + StoredField vector = (StoredField) singleVector.iterator().next(); + BytesRef byteVector = vector.binaryValue(); + List extractedVector = new ArrayList<>(byteVector.length); + for (Byte element : byteVector.bytes) { + extractedVector.add(element.byteValue()); + } + return extractedVector; + } + private static String getLastPath(String path) { int lastIndexOfPathSepChar = path.lastIndexOf(PATH_SEP_CHAR); if (lastIndexOfPathSepChar == -1) { diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 773c1e6337d1..f331d391cdfe 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -188,11 +188,6 @@ protected boolean enableDocValuesByDefault() { @Override public void checkSchemaField(final SchemaField field) throws SolrException { super.checkSchemaField(field); - if (field.multiValued()) { - throw new SolrException( - SolrException.ErrorCode.SERVER_ERROR, - getClass().getSimpleName() + " fields can not be multiValued: " + field.getName()); - } if (field.hasDocValues()) { throw new SolrException( diff --git a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java index ecaa9d24e7c2..a20c17867b1f 100644 --- a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java +++ b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java @@ -106,6 +106,7 @@ public class IndexSchema { public static final String NAME = "name"; public static final String NEST_PARENT_FIELD_NAME = "_nest_parent_"; public static final String NEST_PATH_FIELD_NAME = "_nest_path_"; + public static final String NESTED_VECTORS_PSEUDO_FIELD_NAME = "_nested_vectors_"; public static final String REQUIRED = "required"; public static final String SCHEMA = "schema"; public static final String SIMILARITY = "similarity"; diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 243890b80fb3..14fb44009f9f 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -17,14 +17,20 @@ package org.apache.solr.update.processor; +import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME; + import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.List; import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.SolrInputField; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; import org.apache.solr.update.AddUpdateCommand; /** @@ -63,6 +69,7 @@ private static class NestedUpdateProcessor extends UpdateRequestProcessor { private boolean storePath; private boolean storeParent; private String uniqueKeyFieldName; + private IndexSchema schema; NestedUpdateProcessor( SolrQueryRequest req, boolean storeParent, boolean storePath, UpdateRequestProcessor next) { @@ -70,6 +77,7 @@ private static class NestedUpdateProcessor extends UpdateRequestProcessor { this.storeParent = storeParent; this.storePath = storePath; this.uniqueKeyFieldName = req.getSchema().getUniqueKeyField().getName(); + this.schema = req.getSchema(); } @Override @@ -81,54 +89,98 @@ public void processAdd(AddUpdateCommand cmd) throws IOException { private boolean processDocChildren(SolrInputDocument doc, String fullPath) { boolean isNested = false; + List originalVectorFieldsToRemove = new ArrayList<>(); + ArrayList vectors = new ArrayList<>(); for (SolrInputField field : doc.values()) { + SchemaField sfield = schema.getFieldOrNull(field.getName()); int childNum = 0; boolean isSingleVal = !(field.getValue() instanceof Collection); - for (Object val : field) { - if (!(val instanceof SolrInputDocument)) { - // either all collection items are child docs or none are. - break; - } - final String fieldName = field.getName(); - - if (fieldName.contains(PATH_SEP_CHAR)) { - throw new SolrException( - SolrException.ErrorCode.BAD_REQUEST, - "Field name: '" - + fieldName - + "' contains: '" - + PATH_SEP_CHAR - + "' , which is reserved for the nested URP"); - } - final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); - SolrInputDocument cDoc = (SolrInputDocument) val; - if (!cDoc.containsKey(uniqueKeyFieldName)) { + boolean firstLevelChildren = fullPath == null; + if (firstLevelChildren && sfield != null && isMultiValuedVectorField(sfield)) { + for (Object vectorValue : field.getValues()) { + SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); + singleVectorNestedDoc.setField(field.getName(), vectorValue); + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); - cDoc.setField( - uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); + singleVectorNestedDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, field.getName(), sChildNum)); + + if (!isNested) { + isNested = true; + } + final String lastKeyPath = PATH_SEP_CHAR + field.getName() + NUM_SEP_CHAR + sChildNum; + final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath; + if (storePath) { + setPathField(singleVectorNestedDoc, childDocPath); + } + if (storeParent) { + setParentKey(singleVectorNestedDoc, doc); + } + ++childNum; + vectors.add(singleVectorNestedDoc); } - if (!isNested) { - isNested = true; + originalVectorFieldsToRemove.add(field.getName()); + } else { + for (Object val : field) { + if (!(val instanceof SolrInputDocument cDoc)) { + // either all collection items are child docs or none are. + break; + } + final String fieldName = field.getName(); + + if (fieldName.contains(PATH_SEP_CHAR)) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "Field name: '" + + fieldName + + "' contains: '" + + PATH_SEP_CHAR + + "' , which is reserved for the nested URP"); + } + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); + if (!cDoc.containsKey(uniqueKeyFieldName)) { + String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); + cDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); + } + if (!isNested) { + isNested = true; + } + final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; + // concat of all paths children.grandChild => /children#1/grandChild# + final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath; + processChildDoc(cDoc, doc, childDocPath); + ++childNum; } - final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; - // concat of all paths children.grandChild => /children#1/grandChild# - final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; - processChildDoc(cDoc, doc, childDocPath); - ++childNum; } } + this.cleanOriginalVectorFields(doc, originalVectorFieldsToRemove); + if (vectors.size() > 0) { + doc.setField(NESTED_VECTORS_PSEUDO_FIELD_NAME, vectors); + } return isNested; } + private void cleanOriginalVectorFields( + SolrInputDocument doc, List originalVectorFieldsToRemove) { + for (String fieldName : originalVectorFieldsToRemove) { + doc.removeField(fieldName); + } + } + + private static boolean isMultiValuedVectorField(SchemaField sfield) { + return sfield.getType() instanceof DenseVectorField && sfield.multiValued(); + } + private void processChildDoc( - SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) { + SolrInputDocument child, SolrInputDocument parent, String fullPath) { if (storePath) { - setPathField(sdoc, fullPath); + setPathField(child, fullPath); } if (storeParent) { - setParentKey(sdoc, parent); + setParentKey(child, parent); } - processDocChildren(sdoc, fullPath); + processDocChildren(child, fullPath); } private String generateChildUniqueId(String parentId, String childKey, String childNum) { @@ -136,12 +188,12 @@ private String generateChildUniqueId(String parentId, String childKey, String ch return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum; } - private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) { - sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); + private void setParentKey(SolrInputDocument child, SolrInputDocument parent) { + child.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); } - private void setPathField(SolrInputDocument sdoc, String fullPath) { - sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); + private void setPathField(SolrInputDocument child, String fullPath) { + child.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); } } } diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index 42db078a6e20..fd7702ea3b9b 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -18,22 +18,36 @@ - - + + + - + + + + + + + + + + + + + + diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java index 4b7533985213..c777ff43fc5d 100644 --- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java +++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java @@ -74,14 +74,6 @@ public void fieldDefinition_docValues_shouldThrowException() throws Exception { "DenseVectorField fields can not have docValues: vector"); } - @Test - public void fieldDefinition_multiValued_shouldThrowException() throws Exception { - assertConfigs( - "solrconfig-basic.xml", - "bad-schema-densevector-multivalued.xml", - "DenseVectorField fields can not be multiValued: vector"); - } - @Test public void fieldTypeDefinition_nullSimilarityDistance_shouldUseDefaultSimilarityEuclidean() throws Exception { @@ -699,7 +691,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithValuesOutsideBoundar assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[128, 6, 7, 8]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[128, 6, 7, 8]'")); assertThat( thrown.getCause().getCause().getMessage(), @@ -721,7 +713,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithValuesOutsideBoundar assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[1, -129, 7, 8]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[1, -129, 7, 8]'")); assertThat( thrown.getCause().getCause().getMessage(), is( @@ -750,7 +742,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithFloatValues() throws assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[14.3, 6.2, 7.2, 8.1]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[14.3, 6.2, 7.2, 8.1]'")); assertThat( thrown.getCause().getCause().getMessage(), diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java new file mode 100644 index 000000000000..dcf8d9f34a98 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class BlockJoinMultiValuedVectorsTest extends BlockJoinNestedVectorsParentQParserTest { + + protected static String VECTOR_FIELD = "vector_multivalued"; + protected static String VECTOR_BYTE_FIELD = "vector_byte_multivalued"; + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); + + @BeforeClass + public static void beforeClass() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + prepareIndex(); + } + + protected static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + updateJ(jsonAdd(doc), null); + } + assertU(commit()); + } + + /** + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * + * @return a list of documents to index + */ + protected static List prepareDocs() { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", i); + doc.setField("parent_b", true); + doc.setField("parent_s", abcdef[i % abcdef.length]); + List> floatVectors = new ArrayList<>(perParentChildren); + List> byteVectors = new ArrayList<>(perParentChildren); + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { + floatVectors.add(outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + byteVectors.add(outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + } + doc.setField(VECTOR_FIELD, floatVectors); + doc.setField(VECTOR_BYTE_FIELD, byteVectors); + + docs.add(doc); + } + + return docs; + } + + @Test + public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + VECTOR_FIELD); + } + + @Test + public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + VECTOR_FIELD + " topK=5}" + FLOAT_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='8/vector_multivalued#2']", + "//result/doc[2]/str[@name='id'][.='8/vector_multivalued#1']", + "//result/doc[3]/str[@name='id'][.='8/vector_multivalued#0']", + "//result/doc[4]/str[@name='id'][.='6/vector_multivalued#2']", + "//result/doc[5]/str[@name='id'][.='6/vector_multivalued#1']"); + } + + @Test + public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + VECTOR_BYTE_FIELD + " topK=5}" + BYTE_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='8/vector_byte_multivalued#2']", + "//result/doc[2]/str[@name='id'][.='8/vector_byte_multivalued#1']", + "//result/doc[3]/str[@name='id'][.='8/vector_byte_multivalued#0']", + "//result/doc[4]/str[@name='id'][.='6/vector_byte_multivalued#2']", + "//result/doc[5]/str[@name='id'][.='6/vector_byte_multivalued#1']"); + } + + @Test + public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { + super.parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalFloat_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " childFilter=$children.q]", + "children.q", + "{!knn f=" + VECTOR_FIELD + " topK=3 childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='2.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='5.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='8.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']"); + } + + @Test + public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChildren() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " ]", + "children.q", + "{!knn f=" + VECTOR_FIELD + " topK=3 childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='4.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='3.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='2.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='7.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='6.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='5.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='10.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='9.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='8.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']"); + } + + @Test + public void + parentRetrievalByte_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + + VECTOR_BYTE_FIELD + + ", [child fl=" + + VECTOR_BYTE_FIELD + + " childFilter=$children.q]", + "children.q", + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 childrenOf=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='2']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='5']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']"); + } + + @Test + public void parentRetrievalByte_ChildTransformer_shouldFlattenAndReturnAllChildren() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_BYTE_FIELD + ", [child fl=" + VECTOR_BYTE_FIELD + " ]", + "children.q", + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 childrenOf=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='4']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='3']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='2']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='7']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='6']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='5']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='10']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='9']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']"); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java new file mode 100644 index 000000000000..8e374ba01c9f --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.solr.SolrTestCaseJ4; + +public class BlockJoinNestedVectorsParentQParserTest extends SolrTestCaseJ4 { + protected static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); + protected static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); + + protected static String VECTORS_PSEUDOFIELD = "vectors"; + + /** + * Generate a resulting float vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + protected static List outDistanceFloat(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + /** + * Generate a resulting byte vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + protected static List outDistanceByte(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + String vectorField) { + assertQEx( + "When running a diversifying children KNN query, 'allParents' parameter is required", + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score", + "children.q", + "{!knn f=" + + vectorField + + " topK=3 parents.preFilter=$someParents}" + + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(a c)"), + 400); + } + + protected void childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren( + String vectorField) { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + vectorField + " topK=5}" + BYTE_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); + } + + protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent( + String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=" + vectorField + " topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='10']"); + } + + protected void parentRetrieval_knnChildren_shouldReturnKnnParents(String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); + } + + protected void parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents( + String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); + } + + protected void + parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 preFilter=child_s:m parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); + } + + protected void parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren( + String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorField + + ",[child limit=2 fl=vector]", + "children.q", + "{!knn f=" + + vectorField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='16.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='15.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']"); + } + + protected void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren( + String vectorByteField) { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorByteField + + ",[child limit=2 fl=" + + vectorByteField + + "]", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='10']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[1][.='9']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='13']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[1][.='12']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='28']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[1][.='27']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']"); + } + + protected void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + String vectorByteField) { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorByteField + + ",[child fl=" + + vectorByteField + + " childFilter=$children.q]", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='11']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='26']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']"); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java new file mode 100644 index 000000000000..81303600e7da --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class BlockJoinNestedVectorsTest extends BlockJoinNestedVectorsParentQParserTest { + protected static String VECTOR_FIELD = "vector"; + protected static String VECTOR_BYTE_FIELD = "vector_byte_encoding"; + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); + + @BeforeClass + public static void beforeClass() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + prepareIndex(); + } + + protected static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + assertU(adoc(doc)); + } + assertU(commit()); + } + + /** + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * + * @return a list of documents to index + */ + private static List prepareDocs() { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] klm = new String[] {"k", "l", "m"}; + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", i); + doc.setField("parent_b", true); + + doc.setField("parent_s", abcdef[i % abcdef.length]); + List children = new ArrayList<>(perParentChildren); + + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { + SolrInputDocument child = new SolrInputDocument(); + child.setField("id", i + "" + j); + child.setField("child_s", klm[i % klm.length]); + child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + child.setField( + "vector_byte_encoding", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + // the query vector + children.add(child); + } + doc.setField("vectors", children); + docs.add(doc); + } + + return docs; + } + + @Test + public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + VECTOR_FIELD); + } + + @Test + public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_FIELD); + } + + @Test + public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { + super.parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void + parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnAllChildren() { + super.parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + VECTOR_FIELD + + ",[child fl=vector childFilter=$children.q]", + "children.q", + "{!knn f=" + + VECTOR_FIELD + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']"); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { + super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + VECTOR_BYTE_FIELD); + } +} diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 439d464b9c17..9e7c42eb8747 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -171,12 +171,9 @@ It has the same meaning as `efConstruction` from the 2018 paper. Accepted values: Any integer. -`DenseVectorField` supports the attributes: `indexed`, `stored`. +`DenseVectorField` supports the attributes: `indexed`, `stored`, `multivalued`. -[NOTE] -currently multivalue is not supported - -Here's how a `DenseVectorField` should be indexed: +Here's how a `DenseVectorField` should be indexed when single valued: [tabs#densevectorfield-index] ====== @@ -240,6 +237,154 @@ client.add(Arrays.asList(d1, d2)); ==== ====== +Here's how a `DenseVectorField` should be indexed when multi-valued: + +[tabs#densevectorfield-index] +====== +JSON:: ++ +==== +[source,json] +---- +[{ "id": "1", + "vector_multivalued": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] +}, +{ "id": "2", + "vector_multivalued": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] +} +] +---- +==== + +SolrJ:: ++ +==== +[source,java,indent=0] +---- +final SolrClient client = getSolrClient(); + +final SolrInputDocument d1 = new SolrInputDocument(); +d1.setField("id", "1"); +List> floatVectors1 = new ArrayList<>(2); +floatVectors1.add(Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f)); +floatVectors1.add(Arrays.asList(5.0f, 6.0f, 7.0f, 8.0f)); +d1.setField("vector_multivalued", floatVectors1); + + +final SolrInputDocument d2 = new SolrInputDocument(); +d2.setField("id", "2"); +List> floatVectors2 = new ArrayList<>(2); +floatVectors2.add(Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f)); +floatVectors2.add(Arrays.asList(5.0f, 6.0f, 7.0f, 8.0f)); +d2.setField("vector_multivalued", floatVectors2); + +client.add(Arrays.asList(d1, d2)); + +---- +==== +====== + +=== ScalarQuantizedDenseVectorField +Because dense vectors can have a costly size, it may be worthwhile to use a technique called "quantization" +which creates a compressed representation of the original vectors. This allows more of the index to be stored in faster memory +at the cost of some precision. + +This dense vector type uses a conversion that projects a 32 bit float precision feature down to an 8 bit int (or smaller) +by linearly mapping the float range of each dimension down to evenly sized "buckets" of values that fit into an int. For example: +with 8 bits we can store up to 256 discrete values, so a float dimension with values from 0.0 to 1.0 may be mapped as + +[0.0, 0.0039) => 0, [0.0039, 0.0078) => 1 ... etc + +As a specific type of DenseVectorField, this field type supports all the same configurable properties outlined above as well +as some additional ones. + +Here is how a ScalarQuantizedDenseVectorField can be defined in the schema: + +[source,xml] + + + +`bits`:: ++ +[%autowidth,frame=none] +|=== +s|Optional |Default: `7` +|=== ++ +The number of bits to use for each quantized dimension value ++ +Accepted values: 4 (half byte) or 7 (unsigned byte). + +`confidenceInterval`:: ++ +[%autowidth,frame=none] +|=== +s|Optional |Default: `dimension-scaled` +|=== ++ +Statistically, outlier values are rarely meaningfully relevant to searches, so to increase the size of each bucket for +quantization (and therefore information gain) we can scale the quantization intervals to the middle n % of values and place the remaining +outliers in the outermost intervals. ++ +For example: 0.9 means scale interval sizes to the middle 90% of values ++ +If this param is omitted a default is used; scaled to the number of dimensions according to `1-1/(vector_dimensions + 1)` ++ +Accepted values: `FLOAT32` (within 0.9 and 1.0) + +`dynamicConfidenceInterval`:: ++ +[%autowidth,frame=none] +|=== +s|Optional |Default: `false` +|=== ++ +If set to true, enables dynamically determining confidence interval (per dimension) by sampling values each time a merge occurs. ++ +`NOTE: when this is enabled, it will take precedence over any value configured for confidenceInterval` ++ +Accepted values: `BOOLEAN` + +`compress`:: ++ +[%autowidth,frame=none] +|=== +s|Optional |Default: `false` +|=== ++ +If set to true, this will further pack multiple dimension values within a one byte alignment. This further decreases the +quantized vector disk storage size by 50% at some decode penalty. This does not affect the raw vector which is always +preserved when `stored` is true. ++ +`NOTE: this can only be enabled when bits=4` ++ +Accepted values: `BOOLEAN` + +=== BinaryQuantizedDenseVectorField + +Binary quantization is a quantization technique that extends scalar quantization, and is even more aggressive in its compression; +able to reduce in-memory representation of each vector dimension from a 32 bit float down to a single bit. +This is done by normalizing each dimension of a vector relative to a centroid (mid-point pre-calculated against all vectors in the index) +with the stored bit representing whether the actual value is "above" or "below" the centroid's value. A further "corrective factor" is also computed +and stored to help compensate accuracy in the estimated distance. At query time asymmetric quantization is applied to the query +vector (reducing its dimension values down to 4 bits each), but allowing comparison with the stored binary quantized vector via bit arithmetic. + +This implementation comprises of LVQ, proposed in https://arxiv.org/abs/2304.04759[Similarity Search in the Blink of an Eye With Compressed Indices] +by Cecilia Aguerrebere et al., previous work on globally optimized scalar quantization in Apache Lucene, and ideas from +https://arxiv.org/abs/1908.10396[Accelerating Large-Scale Inference with Anisotropic Vector Quantization] by Ruiqi Guo et al. + +This vector type is best utilized for data sets consisting of large amounts of high dimensionality vectors. + +Here is how a BinaryQuantizedDenseVectorField can be defined in the schema: + +[source,xml] + + + +BinaryQuantizedDenseVectorField accepts the same parameters as `DenseVectorField` with the only notable exception being +`similarityFunction`. Bit quantization uses its own distance calculation and so does not require nor use the `similarityFunction` +param. + == Query Time Apache Solr provides three query parsers that work with dense vector fields, that each support different ways of matching documents based on vector similarity: The `knn` query parser, the `vectorSimilarity` query parser and the `knn_text_to_vector` query parser. @@ -342,6 +487,79 @@ The search results retrieved are the k=10 nearest documents to the vector encode For more details on how to work with vectorise text in Apache Solr, please refer to the dedicated page: xref:text-to-vector.adoc[Text to Vector] +=== Handle multivalued vector fields at query time +Behind the scenes a multivalued vector field is handled by Solr as nested documents with a single vector each (see the parameters for the knn query parser that deal with nested vectors 'parents.preFilter' and 'childrenOf'). + +So you should query a multivalued vector fields following the same syntax: +[source,text] +?q={!parent which=$allParents score=max v=$children.q} +&children.q={!knn f=vector_multivalued topK=3 parents.preFilter=$someParents childrenOf=$allParents}[1.0, 2.0, 3.0, 4.0] +&allParents=*:* -_nest_path_:* +&someParents=color_s:RED + +In terms of rendering the results, you need the child transformer if you want to output them flat (you can choose to only return the best vector per result or all vectors): + +All Children +[source,text] +fl=id,vector_multivalued,[child fl="vector_multivalued"] + +==== +[source,json] +---- +"docs":[ + { + "id":"1", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ], + [ + 5.0,6.0, 7.0, 8.0 + ] + ] + }, + { + "id":"2", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ], + [ + 5.0,6.0, 7.0, 8.0 + ] + ] + }] +---- +==== + +Best Child +[source,text] +fl=id,vector_multivalued,[child fl="vector_multivalued" childFilter=$children.q] + +==== +[source,json] +---- +"docs":[ + { + "id":"1", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ] + ] + }, + { + "id":"2", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ] + ] + }] +---- +==== + + === vectorSimilarity Query Parser The `vectorSimilarity` vector similarity query parser matches documents whose similarity with the target vector is a above a minimum threshold. From 01871d69ac8848f2a8c52594e943c9a01142d92e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 02:12:31 +0000 Subject: [PATCH 3/5] Fix Java 11 compatibility: replace instanceof pattern matching in NestedUpdateProcessorFactory --- .../solr/update/processor/NestedUpdateProcessorFactory.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 14fb44009f9f..714bed470eb1 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -122,10 +122,11 @@ private boolean processDocChildren(SolrInputDocument doc, String fullPath) { originalVectorFieldsToRemove.add(field.getName()); } else { for (Object val : field) { - if (!(val instanceof SolrInputDocument cDoc)) { + if (!(val instanceof SolrInputDocument)) { // either all collection items are child docs or none are. break; } + SolrInputDocument cDoc = (SolrInputDocument) val; final String fieldName = field.getName(); if (fieldName.contains(PATH_SEP_CHAR)) { From 759983f47f65ed04d8fd19cda1fdbc22a332b6fd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 02:25:01 +0000 Subject: [PATCH 4/5] Add childrenOf and parents.preFilter support to KnnQParser for SOLR-18074 (diversifying children KNN queries) --- .../apache/solr/search/neural/KnnQParser.java | 91 ++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index b6d9f2541cd0..fe6a84f18d75 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -16,12 +16,23 @@ */ package org.apache.solr.search.neural; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.apache.solr.common.SolrException; import org.apache.solr.common.params.SolrParams; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.SchemaField; +import org.apache.solr.search.QParser; import org.apache.solr.search.SyntaxError; +import org.apache.solr.search.join.BlockJoinParentQParser; +import org.apache.solr.util.vector.DenseVectorParser; public class KnnQParser extends AbstractVectorQParserBase { @@ -29,18 +40,96 @@ public class KnnQParser extends AbstractVectorQParserBase { protected static final String TOP_K = "topK"; protected static final int DEFAULT_TOP_K = 10; + public static final String PARENTS_PRE_FILTER = "parents.preFilter"; + public static final String CHILDREN_OF = "childrenOf"; + public KnnQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) { super(qstr, localParams, params, req); } @Override public Query parse() throws SyntaxError { - final SchemaField schemaField = req.getCore().getLatestSchema().getField(getFieldName()); + final String vectorField = getFieldName(); + final SchemaField schemaField = req.getCore().getLatestSchema().getField(vectorField); final DenseVectorField denseVectorType = getCheckedFieldType(schemaField); final String vectorToSearch = getVectorToSearch(); final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); + // check for parent diversification logic... + final String[] parentsFilterQueries = localParams.getParams(PARENTS_PRE_FILTER); + final String allParentsQuery = localParams.get(CHILDREN_OF); + + boolean isDiversifyingChildrenKnnQuery = + null != parentsFilterQueries || null != allParentsQuery; + if (isDiversifyingChildrenKnnQuery) { + if (null == allParentsQuery) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "When running a diversifying children KNN query, 'childrenOf' parameter is required"); + } + final DenseVectorParser vectorBuilder = + denseVectorType.getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); + final VectorEncoding vectorEncoding = denseVectorType.getVectorEncoding(); + + final BitSetProducer allParentsBitSet = + BlockJoinParentQParser.getCachedBitSetProducer( + req, subQuery(allParentsQuery, null).getQuery()); + final BooleanQuery acceptedParents = getParentsFilter(parentsFilterQueries); + + Query acceptedChildren = + getChildrenFilter(getFilterQuery(), acceptedParents, allParentsBitSet); + switch (vectorEncoding) { + case FLOAT32: + return new DiversifyingChildrenFloatKnnVectorQuery( + vectorField, + vectorBuilder.getFloatVector(), + acceptedChildren, + topK, + allParentsBitSet); + case BYTE: + return new DiversifyingChildrenByteKnnVectorQuery( + vectorField, vectorBuilder.getByteVector(), acceptedChildren, topK, allParentsBitSet); + default: + throw new SolrException( + SolrException.ErrorCode.SERVER_ERROR, + "Unexpected encoding. Vector Encoding: " + vectorEncoding); + } + } + return denseVectorType.getKnnVectorQuery( schemaField.getName(), vectorToSearch, topK, getFilterQuery()); } + + private BooleanQuery getParentsFilter(String[] parentsFilterQueries) throws SyntaxError { + BooleanQuery.Builder acceptedParentsBuilder = new BooleanQuery.Builder(); + if (parentsFilterQueries != null) { + for (String parentsFilterQuery : parentsFilterQueries) { + final QParser parser = subQuery(parentsFilterQuery, null); + parser.setIsFilter(true); + final Query parentsFilter = parser.getQuery(); + if (parentsFilter != null) { + acceptedParentsBuilder.add(parentsFilter, BooleanClause.Occur.FILTER); + } + } + } + return acceptedParentsBuilder.build(); + } + + private Query getChildrenFilter( + Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { + Query childrenFilter = childrenKnnPreFilter; + + if (!parentsFilter.clauses().isEmpty()) { + Query acceptedChildrenBasedOnParentsFilter = + new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); + BooleanQuery.Builder acceptedChildrenBuilder = new BooleanQuery.Builder(); + if (childrenFilter != null) { + acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.FILTER); + } + acceptedChildrenBuilder.add(acceptedChildrenBasedOnParentsFilter, BooleanClause.Occur.FILTER); + + childrenFilter = acceptedChildrenBuilder.build(); + } + return childrenFilter; + } } From 9e8867185cf586f2f56cb0ff9da3a3203b46480a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 02:30:23 +0000 Subject: [PATCH 5/5] Fix byte vector extraction to correctly handle BytesRef offset and length --- .../apache/solr/response/transform/ChildDocTransformer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 789314a905b7..9f9ca0fb457c 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -393,8 +393,8 @@ private List extractByteVector(Collection singleVector) { StoredField vector = (StoredField) singleVector.iterator().next(); BytesRef byteVector = vector.binaryValue(); List extractedVector = new ArrayList<>(byteVector.length); - for (Byte element : byteVector.bytes) { - extractedVector.add(element.byteValue()); + for (int i = byteVector.offset; i < byteVector.offset + byteVector.length; i++) { + extractedVector.add(byteVector.bytes[i]); } return extractedVector; }