From e8d769450713f7ee87794e1cdfebc099284e0b5c Mon Sep 17 00:00:00 2001 From: Stefanietry Date: Wed, 6 May 2026 18:09:22 +0800 Subject: [PATCH] [spark] support distributed execution of vector search on spark --- docs/generated/core_configuration.html | 6 + .../java/org/apache/paimon/CoreOptions.java | 10 + .../GlobalIndexResultSerializer.java | 17 ++ .../paimon/utils/RoaringNavigableMap64.java | 5 +- .../org/apache/paimon/table/InnerTable.java | 4 +- .../paimon/table/source/VectorReadImpl.java | 17 +- .../source/VectorSearchBuilderFactory.java | 64 ++++++ .../table/source/VectorSearchBuilderImpl.java | 12 +- .../source/VectorSearchBuilderProvider.java | 30 +++ .../paimon/spark/read/SparkEngineContext.java | 63 ++++++ .../spark/read/SparkVectorReadImpl.java | 182 ++++++++++++++++++ .../read/SparkVectorSearchBuilderImpl.java | 41 ++++ .../SparkVectorSearchBuilderProvider.java | 46 +++++ ...n.table.source.VectorSearchBuilderProvider | 16 ++ .../paimon/spark/SparkMultimodalITCase.java | 66 ++++++- 15 files changed, 555 insertions(+), 24 deletions(-) create mode 100644 paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderFactory.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderProvider.java create mode 100644 paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkEngineContext.java create mode 100644 paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorReadImpl.java create mode 100644 paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderImpl.java create mode 100644 paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderProvider.java create mode 100644 paimon-spark/paimon-spark-common/src/main/resources/META-INF/services/org.apache.paimon.table.source.VectorSearchBuilderProvider diff --git a/docs/generated/core_configuration.html b/docs/generated/core_configuration.html index 8f7ea2eaaea9..e8a6bc4e661c 100644 --- a/docs/generated/core_configuration.html +++ b/docs/generated/core_configuration.html @@ -1620,6 +1620,12 @@ String Specifies column names that should be stored as vector type. This is used when you want to treat a ARRAY column as a VECTOR. + +
vector-search.distribute.enabled
+ false + Boolean + Whether to process distributed vector search. +
vector.file.format
(none) diff --git a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java index 03140e9ecc42..5c2af37a59bc 100644 --- a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java +++ b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java @@ -2519,6 +2519,12 @@ public InlineElement getDescription() { + " Default is the same as TARGET_FILE_SIZE.") .build()); + public static final ConfigOption VECTOR_SEARCH_DISTRIBUTE_ENABLED = + key("vector-search.distribute.enabled") + .booleanType() + .defaultValue(false) + .withDescription("Whether to process distributed vector search."); + @Immutable public static final ConfigOption PK_CLUSTERING_OVERRIDE = key("pk-clustering-override") @@ -3978,6 +3984,10 @@ public long vectorTargetFileSize() { .orElse(targetFileSize(false)); } + public boolean vectorSearchDistributeEnabled() { + return options.get(VECTOR_SEARCH_DISTRIBUTE_ENABLED); + } + /** Specifies the merge engine for table with primary key. */ public enum MergeEngine implements DescribedEnum { DEDUPLICATE("deduplicate", "De-duplicate and keep the last row."), diff --git a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java index 66a43b082d4a..5559c59c4f64 100644 --- a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java +++ b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java @@ -23,6 +23,7 @@ import org.apache.paimon.io.DataInputView; import org.apache.paimon.io.DataOutputSerializer; import org.apache.paimon.io.DataOutputView; +import org.apache.paimon.utils.Preconditions; import org.apache.paimon.utils.RoaringNavigableMap64; import java.io.IOException; @@ -116,4 +117,20 @@ public GlobalIndexResult deserialize(DataInputView dataInput) throws IOException return ScoredGlobalIndexResult.create(roaringNavigableMap64, scoreMap::get); } + + public byte[] serialize(GlobalIndexResult globalIndexResult) throws IOException { + DataOutputSerializer dataOutputSerializer = new DataOutputSerializer(1024); + serialize(globalIndexResult, dataOutputSerializer); + return dataOutputSerializer.getCopyOfBuffer(); + } + + public ScoredGlobalIndexResult deserialize(byte[] data) throws IOException { + DataInputDeserializer dataInputDeserializer = new DataInputDeserializer(data); + GlobalIndexResult globalIndexResult = deserialize(dataInputDeserializer); + Preconditions.checkArgument( + globalIndexResult instanceof ScoredGlobalIndexResult, + "Expected ScoredGlobalIndexResult, but got %s", + globalIndexResult == null ? "null" : globalIndexResult.getClass().getName()); + return (ScoredGlobalIndexResult) globalIndexResult; + } } diff --git a/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java b/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java index bec44f3fb039..c70623817c8a 100644 --- a/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java +++ b/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java @@ -25,12 +25,15 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.io.Serializable; import java.util.Iterator; import java.util.List; import java.util.Objects; /** A compressed bitmap for 64-bit integer aggregated by tree. */ -public class RoaringNavigableMap64 implements Iterable { +public class RoaringNavigableMap64 implements Iterable, Serializable { + + private static final long serialVersionUID = 1L; private final Roaring64NavigableMap roaring64NavigableMap; diff --git a/paimon-core/src/main/java/org/apache/paimon/table/InnerTable.java b/paimon-core/src/main/java/org/apache/paimon/table/InnerTable.java index d360597744ee..2f0826f5eaa2 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/InnerTable.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/InnerTable.java @@ -33,7 +33,7 @@ import org.apache.paimon.table.source.ReadBuilderImpl; import org.apache.paimon.table.source.StreamDataTableScan; import org.apache.paimon.table.source.VectorSearchBuilder; -import org.apache.paimon.table.source.VectorSearchBuilderImpl; +import org.apache.paimon.table.source.VectorSearchBuilderFactory; import java.util.Optional; @@ -59,7 +59,7 @@ default ReadBuilder newReadBuilder() { @Override default VectorSearchBuilder newVectorSearchBuilder() { - return new VectorSearchBuilderImpl(this); + return VectorSearchBuilderFactory.create(this); } @Override diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java index a3402c3f1d66..2eae2d48779d 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java @@ -42,6 +42,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -55,13 +56,15 @@ import static org.apache.paimon.utils.Preconditions.checkNotNull; /** Implementation for {@link VectorRead}. */ -public class VectorReadImpl implements VectorRead { +public class VectorReadImpl implements VectorRead, Serializable { - private final FileStoreTable table; + private static final long serialVersionUID = 1L; + + protected final FileStoreTable table; private final Predicate filter; - private final int limit; - private final DataField vectorColumn; - private final float[] vector; + protected final int limit; + protected final DataField vectorColumn; + protected final float[] vector; public VectorReadImpl( FileStoreTable table, @@ -120,7 +123,7 @@ public GlobalIndexResult read(List splits) { return result.topK(limit); } - private Optional preFilter(List splits) { + protected Optional preFilter(List splits) { Set scalarIndexFiles = new TreeSet<>(Comparator.comparing(IndexFileMeta::fileName)); for (VectorSearchSplit split : splits) { @@ -139,7 +142,7 @@ private Optional preFilter(List splits } } - private CompletableFuture> eval( + protected CompletableFuture> eval( GlobalIndexer globalIndexer, IndexPathFactory indexPathFactory, long rowRangeStart, diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderFactory.java b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderFactory.java new file mode 100644 index 000000000000..19d407f964f8 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.factories.FactoryException; +import org.apache.paimon.factories.FactoryUtil; +import org.apache.paimon.table.InnerTable; + +import java.util.ArrayList; +import java.util.List; + +/** Factory for {@link VectorSearchBuilder}. */ +public class VectorSearchBuilderFactory { + + private VectorSearchBuilderFactory() {} + + public static VectorSearchBuilder create(InnerTable table) { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + if (classLoader == null) { + classLoader = VectorSearchBuilderFactory.class.getClassLoader(); + } + + List providers = + FactoryUtil.discoverFactories(classLoader, VectorSearchBuilderProvider.class); + List builders = new ArrayList<>(); + List matchedProviders = new ArrayList<>(); + for (VectorSearchBuilderProvider provider : providers) { + VectorSearchBuilder builder = provider.create(table); + if (builder != null) { + builders.add(builder); + matchedProviders.add(provider.getClass().getName()); + } + } + + if (builders.size() > 1) { + throw new FactoryException( + String.format( + "Multiple VectorSearchBuilder providers matched table '%s': %s", + table.name(), String.join(", ", matchedProviders))); + } + + if (builders.size() == 1) { + return builders.get(0); + } + + return new VectorSearchBuilderImpl(table); + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java index beb7844e13bc..a0d11ff21f12 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java @@ -32,13 +32,13 @@ public class VectorSearchBuilderImpl implements VectorSearchBuilder { private static final long serialVersionUID = 1L; - private final FileStoreTable table; + protected final FileStoreTable table; - private PartitionPredicate partitionFilter; - private Predicate filter; - private int limit; - private DataField vectorColumn; - private float[] vector; + protected PartitionPredicate partitionFilter; + protected Predicate filter; + protected int limit; + protected DataField vectorColumn; + protected float[] vector; public VectorSearchBuilderImpl(InnerTable table) { this.table = (FileStoreTable) table; diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderProvider.java b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderProvider.java new file mode 100644 index 000000000000..dc8e16ee88d6 --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderProvider.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.table.source; + +import org.apache.paimon.table.InnerTable; + +import javax.annotation.Nullable; + +/** SPI for engine specific {@link VectorSearchBuilder} creation. */ +public interface VectorSearchBuilderProvider { + + @Nullable + VectorSearchBuilder create(InnerTable table); +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkEngineContext.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkEngineContext.java new file mode 100644 index 000000000000..f5a1abd6681b --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkEngineContext.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.read; + +import org.apache.paimon.utils.SerializableFunction; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Stream; + +/** + * Tiny wrapper around the active {@link SparkSession} that exposes RDD style {@code map} / {@code + * flatMap} primitives over a Java {@link List}. Used by Paimon-on-Spark to dispatch + * embarrassingly-parallel work (e.g. per-split vector search) to the cluster without forcing the + * caller to depend on Spark types directly. + */ +public class SparkEngineContext { + + private final JavaSparkContext jsc; + + public SparkEngineContext() { + this.jsc = JavaSparkContext.fromSparkContext(SparkSession.active().sparkContext()); + } + + public Broadcast broadcast(T value) { + return jsc.broadcast(value); + } + + public List map(List data, SerializableFunction func, int parallelism) { + if (data.isEmpty()) { + return Collections.emptyList(); + } + return jsc.parallelize(data, parallelism).map(func::apply).collect(); + } + + public List flatMap( + List data, SerializableFunction> func, int parallelism) { + if (data.isEmpty()) { + return Collections.emptyList(); + } + return jsc.parallelize(data, parallelism).flatMap(x -> func.apply(x).iterator()).collect(); + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorReadImpl.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorReadImpl.java new file mode 100644 index 000000000000..df60afb4892d --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorReadImpl.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.read; + +import org.apache.paimon.globalindex.GlobalIndexReadThreadPool; +import org.apache.paimon.globalindex.GlobalIndexResult; +import org.apache.paimon.globalindex.GlobalIndexResultSerializer; +import org.apache.paimon.globalindex.GlobalIndexer; +import org.apache.paimon.globalindex.GlobalIndexerFactoryUtils; +import org.apache.paimon.globalindex.ScoredGlobalIndexResult; +import org.apache.paimon.index.IndexPathFactory; +import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.source.VectorReadImpl; +import org.apache.paimon.table.source.VectorSearchSplit; +import org.apache.paimon.types.DataField; +import org.apache.paimon.utils.InstantiationUtil; +import org.apache.paimon.utils.RoaringNavigableMap64; +import org.apache.paimon.utils.SerializableFunction; + +import org.apache.spark.broadcast.Broadcast; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; + +import static org.apache.paimon.CoreOptions.GLOBAL_INDEX_THREAD_NUM; + +/** + * Spark-aware {@link VectorReadImpl} that distributes grouped vector index evaluation across the + * Spark cluster instead of evaluating them with the local thread pool. + */ +public class SparkVectorReadImpl extends VectorReadImpl { + + private static final long serialVersionUID = 1L; + + public SparkVectorReadImpl( + FileStoreTable table, + Predicate filter, + int limit, + DataField vectorColumn, + float[] vector) { + super(table, filter, limit, vectorColumn, vector); + } + + @Override + public GlobalIndexResult read(List splits) { + if (splits.isEmpty()) { + return GlobalIndexResult.createEmpty(); + } + + int parallelism = + Math.max(1, table.coreOptions().toConfiguration().get(GLOBAL_INDEX_THREAD_NUM)); + if (splits.size() < parallelism * 2) { + return super.read(splits); + } + + RoaringNavigableMap64 preFilter = preFilter(splits).orElse(null); + String indexType = splits.get(0).vectorIndexFiles().get(0).indexType(); + List splitBytes = new ArrayList<>(splits.size()); + for (VectorSearchSplit split : splits) { + try { + splitBytes.add(InstantiationUtil.serializeObject(split)); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize VectorSearchSplit", e); + } + } + List> splitGroups = splitGroups(splitBytes, parallelism); + SparkEngineContext engineContext = new SparkEngineContext(); + Broadcast preFilterBroadcast = + preFilter == null ? null : engineContext.broadcast(preFilter); + + SerializableFunction, byte[]> task = + group -> { + GlobalIndexer globalIndexer = + GlobalIndexerFactoryUtils.load(indexType) + .create(vectorColumn, table.coreOptions().toConfiguration()); + IndexPathFactory indexPathFactory = + table.store().pathFactory().globalIndexFileFactory(); + + RoaringNavigableMap64 includeRowIds = + preFilterBroadcast == null ? null : preFilterBroadcast.value(); + ExecutorService executor = + GlobalIndexReadThreadPool.getExecutorService( + Math.min(parallelism, group.size())); + List>> futures = + new ArrayList<>(group.size()); + for (byte[] bytes : group) { + VectorSearchSplit split = deserializeSplit(bytes); + futures.add( + eval( + globalIndexer, + indexPathFactory, + split.rowRangeStart(), + split.rowRangeEnd(), + split.vectorIndexFiles(), + includeRowIds, + executor)); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + ScoredGlobalIndexResult result = ScoredGlobalIndexResult.createEmpty(); + for (CompletableFuture> f : futures) { + Optional next = f.join(); + if (next.isPresent()) { + result = result.or(next.get()); + } + } + result = result.topK(limit); + if (result.results().isEmpty()) { + return null; + } + try { + return new GlobalIndexResultSerializer().serialize(result); + } catch (IOException e) { + throw new RuntimeException( + "Failed to serialize ScoredGlobalIndexResult", e); + } + }; + + List remoteResults; + try { + remoteResults = engineContext.map(splitGroups, task, splitGroups.size()); + } finally { + if (preFilterBroadcast != null) { + preFilterBroadcast.unpersist(false); + } + } + + ScoredGlobalIndexResult result = ScoredGlobalIndexResult.createEmpty(); + GlobalIndexResultSerializer serializer = new GlobalIndexResultSerializer(); + for (byte[] bytes : remoteResults) { + if (bytes != null) { + try { + result = result.or(serializer.deserialize(bytes)); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize ScoredGlobalIndexResult", e); + } + } + } + return result.topK(limit); + } + + private VectorSearchSplit deserializeSplit(byte[] bytes) { + try { + return InstantiationUtil.deserializeObject( + bytes, Thread.currentThread().getContextClassLoader()); + } catch (IOException | ClassNotFoundException e) { + throw new RuntimeException("Failed to deserialize VectorSearchSplit", e); + } + } + + private List> splitGroups(List splitBytes, int parallelism) { + List> groups = new ArrayList<>(parallelism); + int groupSize = (splitBytes.size() + parallelism - 1) / parallelism; + for (int start = 0; start < splitBytes.size(); start += groupSize) { + groups.add( + new ArrayList<>( + splitBytes.subList( + start, Math.min(start + groupSize, splitBytes.size())))); + } + return groups; + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderImpl.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderImpl.java new file mode 100644 index 000000000000..0e4e347f6456 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderImpl.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.read; + +import org.apache.paimon.table.InnerTable; +import org.apache.paimon.table.source.VectorRead; +import org.apache.paimon.table.source.VectorSearchBuilderImpl; + +/** + * Spark-aware {@link VectorSearchBuilderImpl} which produces a {@link SparkVectorReadImpl} so the + * per-split vector index evaluation is dispatched through Spark instead of the local thread pool. + */ +public class SparkVectorSearchBuilderImpl extends VectorSearchBuilderImpl { + + private static final long serialVersionUID = 1L; + + public SparkVectorSearchBuilderImpl(InnerTable table) { + super(table); + } + + @Override + public VectorRead newVectorRead() { + return new SparkVectorReadImpl(table, filter, limit, vectorColumn, vector); + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderProvider.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderProvider.java new file mode 100644 index 000000000000..a98085163955 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.read; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.table.InnerTable; +import org.apache.paimon.table.source.VectorSearchBuilder; +import org.apache.paimon.table.source.VectorSearchBuilderProvider; + +import org.apache.spark.sql.SparkSession; + +import javax.annotation.Nullable; + +/** Spark specific {@link VectorSearchBuilderProvider}. */ +public class SparkVectorSearchBuilderProvider implements VectorSearchBuilderProvider { + + @Nullable + @Override + public VectorSearchBuilder create(InnerTable table) { + if (!CoreOptions.fromMap(table.options()).vectorSearchDistributeEnabled()) { + return null; + } + + if (!SparkSession.getActiveSession().isDefined()) { + return null; + } + + return new SparkVectorSearchBuilderImpl(table); + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/resources/META-INF/services/org.apache.paimon.table.source.VectorSearchBuilderProvider b/paimon-spark/paimon-spark-common/src/main/resources/META-INF/services/org.apache.paimon.table.source.VectorSearchBuilderProvider new file mode 100644 index 000000000000..7046142ad433 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/resources/META-INF/services/org.apache.paimon.table.source.VectorSearchBuilderProvider @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.paimon.spark.read.SparkVectorSearchBuilderProvider diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java index 4100f54f61dd..c969c5070a96 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java @@ -18,12 +18,20 @@ package org.apache.paimon.spark; +import org.apache.paimon.data.BinaryString; import org.apache.paimon.fs.Path; import org.apache.paimon.hive.TestHiveMetastore; +import org.apache.paimon.partition.PartitionPredicate; +import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.spark.read.SparkVectorReadImpl; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.utils.Pair; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.connector.catalog.Identifier; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -31,6 +39,7 @@ import java.io.IOException; import java.util.List; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -71,7 +80,8 @@ private SparkSession.Builder createSparkSessionBuilder(Path warehousePath) { } @Test - public void testVector(@TempDir java.nio.file.Path tempDir) throws IOException { + public void testVector(@TempDir java.nio.file.Path tempDir) + throws IOException, NoSuchTableException { Path warehousePath = new Path("file:" + tempDir.toString()); SparkSession.Builder builder = createSparkSessionBuilder(warehousePath); SparkSession spark = builder.getOrCreate(); @@ -95,13 +105,17 @@ public void testVector(@TempDir java.nio.file.Path tempDir) throws IOException { spark = builder.getOrCreate(); spark.sql( - "insert overwrite table my_db1.vector_test\n" - + "VALUES (1, '1', array(cast(1.0 as float), cast(2.0 as float), cast(3.0 as float), cast(4.0 as float)), '20260420'),\n" + "insert overwrite table my_db1.vector_test VALUES \n" + + "(1, '1', array(cast(1.0 as float), cast(2.0 as float), cast(3.0 as float), cast(4.0 as float)), '20260420'),\n" + "(2, '2', array(cast(2.0 as float), cast(3.0 as float), cast(4.0 as float), cast(5.0 as float)), '20260420'),\n" - + "(3, '3', array(cast(3.0 as float), cast(4.0 as float), cast(5.0 as float), cast(6.0 as float)), '20260420'),\n" + + "(3, '3', array(cast(3.0 as float), cast(4.0 as float), cast(5.0 as float), cast(6.0 as float)), '20260420');"); + spark.sql( + "insert overwrite table my_db1.vector_test VALUES\n" + "(4, '4', array(cast(4.0 as float), cast(5.0 as float), cast(6.0 as float), cast(7.0 as float)), '20260420'),\n" + "(5, '5', array(cast(5.0 as float), cast(6.0 as float), cast(7.0 as float), cast(8.0 as float)), '20260420'),\n" - + "(6, '6', array(cast(6.0 as float), cast(7.0 as float), cast(8.0 as float), cast(9.0 as float)), '20260420'),\n" + + "(6, '6', array(cast(6.0 as float), cast(7.0 as float), cast(8.0 as float), cast(9.0 as float)), '20260420');"); + spark.sql( + "insert overwrite table my_db1.vector_test VALUES\n" + "(7, '7', array(cast(7.0 as float), cast(8.0 as float), cast(9.0 as float), cast(10.0 as float)), '20260420'),\n" + "(8, '8', array(cast(8.0 as float), cast(9.0 as float), cast(10.0 as float), cast(11.0 as float)), '20260420');"); spark.close(); @@ -128,12 +142,48 @@ public void testVector(@TempDir java.nio.file.Path tempDir) throws IOException { "select gid, sid, embs from vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f), 5) where date = '20260420'") .collectAsList(); assertThat(rows).hasSize(5); - Dataset df = - spark.sql( - "select gid, sid, embs, __paimon_vector_search_score from vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f), 5) where date = '20260420'"); + String vectorSearchSql = + "select gid, sid, embs, __paimon_vector_search_score " + + "from vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f), 5) " + + "where date = '20260420'"; + Dataset df = spark.sql(vectorSearchSql); assertThat(df.columns()).hasSize(4); rows = df.collectAsList(); assertThat(rows).hasSize(5); + spark.sql("set spark.paimon.vector-search.distribute.enabled = true;"); + spark.sql("set spark.paimon.global-index.thread-num=1"); + SparkCatalog sparkCatalog = + (SparkCatalog) spark.sessionState().catalogManager().currentCatalog(); + SparkTable sparkTable = + (SparkTable) + sparkCatalog.loadTable( + Identifier.of(new String[] {"my_db1"}, "vector_test")); + FileStoreTable table = (FileStoreTable) sparkTable.getTable(); + PredicateBuilder partitionPredicateBuilder = + new PredicateBuilder(table.schema().logicalPartitionType()); + assertThat( + table.newVectorSearchBuilder() + .withPartitionFilter( + PartitionPredicate.fromPredicate( + table.schema().logicalPartitionType(), + partitionPredicateBuilder.equal( + partitionPredicateBuilder.indexOf("date"), + BinaryString.fromString("20260420")))) + .withVector(new float[] {1.0f, 2.0f, 3.0f, 4.0f}) + .withVectorColumn("embs") + .withLimit(5) + .newVectorRead()) + .isInstanceOf(SparkVectorReadImpl.class); + List compareRows = spark.sql(vectorSearchSql).collectAsList(); + assertThat(compareRows).hasSize(5); + assertThat( + compareRows.stream() + .map(row -> Pair.of(row.getLong(0), row.getString(1))) + .collect(Collectors.toList())) + .containsExactlyInAnyOrderElementsOf( + rows.stream() + .map(row -> Pair.of(row.getLong(0), row.getString(1))) + .collect(Collectors.toList())); spark.close(); spark = builder.getOrCreate();