Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog/unreleased/SOLR-18074.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# See https://github.com/apache/solr/blob/main/dev-docs/changelog.adoc
title: Introducing support for multi valued dense vector representation in documents through nested vectors
type: added # added, changed, fixed, deprecated, removed, dependency_update, security, other
authors:
- name: Alessandro Benedetti
links:
- name: SOLR-18074
url: https://issues.apache.org/jira/browse/SOLR-18074
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -35,14 +39,17 @@
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrException;
import org.apache.solr.schema.DenseVectorField;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.BitsFilteredPostingsEnum;
import org.apache.solr.search.DocSet;
import org.apache.solr.search.SolrDocumentFetcher;
Expand Down Expand Up @@ -138,6 +145,20 @@ public void transform(SolrDocument rootDoc, int rootDocId) {
final Bits liveDocs = leafReaderContext.reader().getLiveDocs();
final int segBaseId = leafReaderContext.docBase;
final int segRootId = rootDocId - segBaseId;
Set<String> multiValuedFLoatVectorFields =
this.getMultiValuedVectorFields(
searcher.getSchema(), childReturnFields, VectorEncoding.FLOAT32);
Set<String> multiValuedByteVectorFields =
this.getMultiValuedVectorFields(
searcher.getSchema(), childReturnFields, VectorEncoding.BYTE);
if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) > 0
&& (multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size())
!= childReturnFields.getExplicitlyRequestedFieldNames().size()) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"When using the Child transformer to flatten nested vectors, all 'fl' must be "
+ "multivalued vector fields");
}

// can return be -1 and that's okay (happens for very first block)
final int segPrevRootId;
Expand Down Expand Up @@ -219,8 +240,21 @@ public void transform(SolrDocument rootDoc, int rootDocId) {

if (isAncestor) {
// if this path has pending child docs, add them.
addChildrenToParent(
doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending
if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) {
addFlatMultiValuedVectorsToParent(
rootDoc,
pendingParentPathsToChildren.values().iterator().next(),
multiValuedFLoatVectorFields,
VectorEncoding.FLOAT32);
addFlatMultiValuedVectorsToParent(
rootDoc,
pendingParentPathsToChildren.values().iterator().next(),
multiValuedByteVectorFields,
VectorEncoding.BYTE);
} else {
addChildrenToParent(
doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending
}
}

// get parent path
Expand Down Expand Up @@ -248,7 +282,20 @@ public void transform(SolrDocument rootDoc, int rootDocId) {
assert pendingParentPathsToChildren.keySet().size() == 1;

// size == 1, so get the last remaining entry
addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next());
if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) {
addFlatMultiValuedVectorsToParent(
rootDoc,
pendingParentPathsToChildren.values().iterator().next(),
multiValuedFLoatVectorFields,
VectorEncoding.FLOAT32);
addFlatMultiValuedVectorsToParent(
rootDoc,
pendingParentPathsToChildren.values().iterator().next(),
multiValuedByteVectorFields,
VectorEncoding.BYTE);
} else {
addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next());
}

} catch (IOException e) {
// TODO DWS: reconsider this unusual error handling approach; shouldn't we rethrow?
Expand All @@ -257,6 +304,25 @@ public void transform(SolrDocument rootDoc, int rootDocId) {
}
}

private Set<String> getMultiValuedVectorFields(
IndexSchema schema, SolrReturnFields childReturnFields, VectorEncoding encoding) {
Set<String> multiValuedVectorsFields = new HashSet<>();
Set<String> explicitlyRequestedFieldNames =
childReturnFields.getExplicitlyRequestedFieldNames();
if (explicitlyRequestedFieldNames != null) {
for (String fieldName : explicitlyRequestedFieldNames) {
SchemaField sfield = schema.getFieldOrNull(fieldName);
if (sfield != null
&& sfield.getType() instanceof DenseVectorField
&& sfield.multiValued()
&& ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) {
multiValuedVectorsFields.add(fieldName);
}
}
}
return multiValuedVectorsFields;
}

private static void addChildrenToParent(
SolrDocument parent, Map<String, List<SolrDocument>> children) {
for (Map.Entry<String, List<SolrDocument>> entry : children.entrySet()) {
Expand Down Expand Up @@ -285,6 +351,54 @@ private static void addChildrenToParent(
parent.setField(trimmedPath, children.get(0));
}

private void addFlatMultiValuedVectorsToParent(
SolrDocument parent,
Map<String, List<SolrDocument>> children,
Set<String> multiValuedVectorFields,
VectorEncoding encoding) {
for (String multiValuedVectorField : multiValuedVectorFields) {
List<SolrDocument> solrDocuments = children.get(multiValuedVectorField);
List<List<Number>> multiValuedVectors = new ArrayList<>(solrDocuments.size());
for (SolrDocument singleVector : solrDocuments) {
List<Number> extractedVectors;
switch (encoding) {
case FLOAT32:
extractedVectors =
this.extractFloatVector(singleVector.getFieldValues(multiValuedVectorField));
break;
case BYTE:
extractedVectors =
this.extractByteVector(singleVector.getFieldValues(multiValuedVectorField));
break;
default:
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST, "Unsupported vector encoding: " + encoding);
}
multiValuedVectors.add(extractedVectors);
}
parent.setField(multiValuedVectorField, multiValuedVectors);
}
}

private List<Number> extractFloatVector(Collection<Object> fieldValues) {
List<Number> vector = new ArrayList<>(fieldValues.size());
for (Object fieldValue : fieldValues) {
StoredField storedVectorValue = (StoredField) fieldValue;
vector.add(storedVectorValue.numericValue());
}
return vector;
}

private List<Number> extractByteVector(Collection<Object> singleVector) {
StoredField vector = (StoredField) singleVector.iterator().next();
BytesRef byteVector = vector.binaryValue();
List<Number> extractedVector = new ArrayList<>(byteVector.length);
for (int i = byteVector.offset; i < byteVector.offset + byteVector.length; i++) {
extractedVector.add(byteVector.bytes[i]);
}
return extractedVector;
}

private static String getLastPath(String path) {
int lastIndexOfPathSepChar = path.lastIndexOf(PATH_SEP_CHAR);
if (lastIndexOfPathSepChar == -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ protected boolean enableDocValuesByDefault() {
@Override
public void checkSchemaField(final SchemaField field) throws SolrException {
super.checkSchemaField(field);
if (field.multiValued()) {
throw new SolrException(
SolrException.ErrorCode.SERVER_ERROR,
getClass().getSimpleName() + " fields can not be multiValued: " + field.getName());
}

if (field.hasDocValues()) {
throw new SolrException(
Expand Down
1 change: 1 addition & 0 deletions solr/core/src/java/org/apache/solr/schema/IndexSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public class IndexSchema {
public static final String NAME = "name";
public static final String NEST_PARENT_FIELD_NAME = "_nest_parent_";
public static final String NEST_PATH_FIELD_NAME = "_nest_path_";
public static final String NESTED_VECTORS_PSEUDO_FIELD_NAME = "_nested_vectors_";
public static final String REQUIRED = "required";
public static final String SCHEMA = "schema";
public static final String SIMILARITY = "similarity";
Expand Down
91 changes: 90 additions & 1 deletion solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,120 @@
*/
package org.apache.solr.search.neural;

import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.schema.DenseVectorField;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.QParser;
import org.apache.solr.search.SyntaxError;
import org.apache.solr.search.join.BlockJoinParentQParser;
import org.apache.solr.util.vector.DenseVectorParser;

public class KnnQParser extends AbstractVectorQParserBase {

// retrieve the top K results based on the distance similarity function
protected static final String TOP_K = "topK";
protected static final int DEFAULT_TOP_K = 10;

public static final String PARENTS_PRE_FILTER = "parents.preFilter";
public static final String CHILDREN_OF = "childrenOf";

public KnnQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(qstr, localParams, params, req);
}

@Override
public Query parse() throws SyntaxError {
final SchemaField schemaField = req.getCore().getLatestSchema().getField(getFieldName());
final String vectorField = getFieldName();
final SchemaField schemaField = req.getCore().getLatestSchema().getField(vectorField);
final DenseVectorField denseVectorType = getCheckedFieldType(schemaField);
final String vectorToSearch = getVectorToSearch();
final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K);

// check for parent diversification logic...
final String[] parentsFilterQueries = localParams.getParams(PARENTS_PRE_FILTER);
final String allParentsQuery = localParams.get(CHILDREN_OF);

boolean isDiversifyingChildrenKnnQuery =
null != parentsFilterQueries || null != allParentsQuery;
if (isDiversifyingChildrenKnnQuery) {
if (null == allParentsQuery) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"When running a diversifying children KNN query, 'childrenOf' parameter is required");
}
final DenseVectorParser vectorBuilder =
denseVectorType.getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY);
final VectorEncoding vectorEncoding = denseVectorType.getVectorEncoding();

final BitSetProducer allParentsBitSet =
BlockJoinParentQParser.getCachedBitSetProducer(
req, subQuery(allParentsQuery, null).getQuery());
final BooleanQuery acceptedParents = getParentsFilter(parentsFilterQueries);

Query acceptedChildren =
getChildrenFilter(getFilterQuery(), acceptedParents, allParentsBitSet);
switch (vectorEncoding) {
case FLOAT32:
return new DiversifyingChildrenFloatKnnVectorQuery(
vectorField,
vectorBuilder.getFloatVector(),
acceptedChildren,
topK,
allParentsBitSet);
case BYTE:
return new DiversifyingChildrenByteKnnVectorQuery(
vectorField, vectorBuilder.getByteVector(), acceptedChildren, topK, allParentsBitSet);
default:
throw new SolrException(
SolrException.ErrorCode.SERVER_ERROR,
"Unexpected encoding. Vector Encoding: " + vectorEncoding);
}
}

return denseVectorType.getKnnVectorQuery(
schemaField.getName(), vectorToSearch, topK, getFilterQuery());
}

private BooleanQuery getParentsFilter(String[] parentsFilterQueries) throws SyntaxError {
BooleanQuery.Builder acceptedParentsBuilder = new BooleanQuery.Builder();
if (parentsFilterQueries != null) {
for (String parentsFilterQuery : parentsFilterQueries) {
final QParser parser = subQuery(parentsFilterQuery, null);
parser.setIsFilter(true);
final Query parentsFilter = parser.getQuery();
if (parentsFilter != null) {
acceptedParentsBuilder.add(parentsFilter, BooleanClause.Occur.FILTER);
}
}
}
return acceptedParentsBuilder.build();
}

private Query getChildrenFilter(
Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) {
Query childrenFilter = childrenKnnPreFilter;

if (!parentsFilter.clauses().isEmpty()) {
Query acceptedChildrenBasedOnParentsFilter =
new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet);
BooleanQuery.Builder acceptedChildrenBuilder = new BooleanQuery.Builder();
if (childrenFilter != null) {
acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.FILTER);
}
acceptedChildrenBuilder.add(acceptedChildrenBasedOnParentsFilter, BooleanClause.Occur.FILTER);

childrenFilter = acceptedChildrenBuilder.build();
}
return childrenFilter;
}
}
Loading
Loading