diff --git a/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java b/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java index e868a84cc1..2abd998828 100644 --- a/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java +++ b/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java @@ -37,10 +37,17 @@ public class BucketOffsetsRetrieverImpl implements OffsetsInitializer.BucketOffsetsRetriever { private final Admin flussAdmin; private final TablePath tablePath; + private final boolean fetchEarliestOffset; public BucketOffsetsRetrieverImpl(Admin flussAdmin, TablePath tablePath) { + this(flussAdmin, tablePath, false); + } + + public BucketOffsetsRetrieverImpl( + Admin flussAdmin, TablePath tablePath, boolean fetchEarliestOffset) { this.flussAdmin = flussAdmin; this.tablePath = tablePath; + this.fetchEarliestOffset = fetchEarliestOffset; } @Override @@ -52,11 +59,15 @@ public Map latestOffsets( @Override public Map earliestOffsets( @Nullable String partitionName, Collection buckets) { - Map bucketWithOffset = new HashMap<>(buckets.size()); - for (Integer bucket : buckets) { - bucketWithOffset.put(bucket, EARLIEST_OFFSET); + if (!fetchEarliestOffset) { + Map bucketWithOffset = new HashMap<>(buckets.size()); + for (Integer bucket : buckets) { + bucketWithOffset.put(bucket, EARLIEST_OFFSET); + } + return bucketWithOffset; + } else { + return listOffsets(partitionName, buckets, new OffsetSpec.EarliestSpec()); } - return bucketWithOffset; } @Override diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala index 28fb633b52..aac6a698da 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala @@ -50,4 +50,14 @@ object SparkFlussConf { .durationType() .defaultValue(Duration.ofMillis(10000L)) .withDescription("The timeout for log scanner to poll records.") + + val SCAN_MAX_RECORDS_PER_PARTITION: ConfigOption[java.lang.Long] = + ConfigBuilder + .key("scan.maxRecordsPerPartition") + .longType() + .noDefaultValue() + .withDescription( + "The maximum number of records per Spark input partition when reading a log table. " + + "When set, each Fluss bucket whose offset range exceeds this value will be split " + + "into multiple partitions. Disabled by default (one partition per bucket).") } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala index 87f2fdad0f..d876abae59 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala @@ -25,6 +25,7 @@ import org.apache.fluss.client.table.scanner.log.LogScanner import org.apache.fluss.config.Configuration import org.apache.fluss.metadata.{PartitionInfo, TableBucket, TableInfo, TablePath} import org.apache.fluss.predicate.Predicate +import org.apache.fluss.spark.SparkFlussConf import org.apache.fluss.spark.utils.SparkPartitionPredicate import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} @@ -127,26 +128,58 @@ class FlussAppendBatch( } override def planInputPartitions(): Array[InputPartition] = { - val bucketOffsetsRetrieverImpl = new BucketOffsetsRetrieverImpl(admin, tablePath) + val maxRecordsPerPartition: Option[Long] = { + val value = flussConfig.getLong(SparkFlussConf.SCAN_MAX_RECORDS_PER_PARTITION, 0) + if (value > 0) Some(value) else None + } + + val bucketOffsetsRetrieverImpl = maxRecordsPerPartition match { + case Some(_) => new BucketOffsetsRetrieverImpl(admin, tablePath, true) + case _ => new BucketOffsetsRetrieverImpl(admin, tablePath) + } val buckets = (0 until tableInfo.getNumBuckets).toSeq + def splitOffsetRange( + tableBucket: TableBucket, + startOffset: Long, + stopOffset: Long, + maxRecords: Long): Seq[InputPartition] = { + if ( + startOffset < 0 || stopOffset <= startOffset || stopOffset <= (startOffset + maxRecords) + ) { + return Seq(FlussAppendInputPartition(tableBucket, startOffset, stopOffset)) + } + val rangeSize = stopOffset - startOffset + val numSplits = ((rangeSize + maxRecords - 1) / maxRecords).toInt + val step = (rangeSize + numSplits - 1) / numSplits + + Iterator + .from(0) + .take(numSplits) + .map(i => startOffset + i * step) + .map { + from => FlussAppendInputPartition(tableBucket, from, math.min(from + step, stopOffset)) + } + .toSeq + } + def createPartitions( partitionId: Option[Long], startBucketOffsets: Map[Integer, Long], stoppingBucketOffsets: Map[Integer, Long]): Array[InputPartition] = { - buckets.map { + buckets.flatMap { bucketId => - val (startBucketOffset, stoppingBucketOffset) = + val (startOffset, stopOffset) = (startBucketOffsets(bucketId), stoppingBucketOffsets(bucketId)) - partitionId match { - case Some(partitionId) => - val tableBucket = new TableBucket(tableInfo.getTableId, partitionId, bucketId) - FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset) - .asInstanceOf[InputPartition] - case None => - val tableBucket = new TableBucket(tableInfo.getTableId, bucketId) - FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset) - .asInstanceOf[InputPartition] + val tableBucket = partitionId match { + case Some(pid) => new TableBucket(tableInfo.getTableId, pid, bucketId) + case None => new TableBucket(tableInfo.getTableId, bucketId) + } + maxRecordsPerPartition match { + case Some(maxRecs) => + splitOffsetRange(tableBucket, startOffset, stopOffset, maxRecs) + case _ => + Seq(FlussAppendInputPartition(tableBucket, startOffset, stopOffset)) } }.toArray } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala index 42b0aa62d0..578ce8be35 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala @@ -17,12 +17,12 @@ package org.apache.fluss.spark -import org.apache.fluss.spark.read.{FlussMetrics, FlussScan} -import org.apache.fluss.spark.read.FlussAppendScan +import org.apache.fluss.spark.read.{FlussAppendScan, FlussMetrics, FlussScan} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.Row import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.assertj.core.api.Assertions.assertThat @@ -603,4 +603,67 @@ class SparkLogTableReadTest extends FlussSparkTestBase { assert(numRowsRead == 5L, s"Expected 5 rows read, got $numRowsRead") } } + + test("Spark Read: split partition by config") { + withSampleTable { + withSQLConf( + s"${SparkFlussConf.SPARK_FLUSS_CONF_PREFIX}${SparkFlussConf.SCAN_MAX_RECORDS_PER_PARTITION.key()}" + -> "2") { + val df = sql(s"SELECT amount FROM $DEFAULT_DATABASE.t ORDER BY orderId") + checkAnswer(df, Row(601) :: Row(602) :: Row(603) :: Row(604) :: Row(605) :: Nil) + + val partitions = getInputPartitions(df) + assertThat(partitions.length).isEqualTo(3) + } + } + + withTable("t_partition") { + sql( + s""" + |CREATE TABLE $DEFAULT_DATABASE.t_partition (orderId BIGINT, itemId BIGINT, amount INT, address STRING, dt STRING) + |PARTITIONED BY (dt) + |""".stripMargin + ) + + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_partition VALUES + |(600L, 21L, 601, "addr1", "2026-01-01"), (700L, 22L, 602, "addr2", "2026-01-01"), + |(800L, 23L, 603, "addr3", "2026-01-02"), (900L, 24L, 604, "addr4", "2026-01-02"), + |(1000L, 25L, 605, "addr5", "2026-01-03") + |""".stripMargin) + Seq((0, 3), (1, 5), (2, 3)).foreach { + case (maxRecords, expectedPartitions) => + withClue(s"maxRecords = $maxRecords, expectedPartitions = $expectedPartitions") { + withSQLConf( + s"${SparkFlussConf.SPARK_FLUSS_CONF_PREFIX}${SparkFlussConf.SCAN_MAX_RECORDS_PER_PARTITION.key()}" + -> maxRecords.toString) { + val df = sql(s"SELECT * FROM $DEFAULT_DATABASE.t_partition ORDER BY orderId") + checkAnswer( + df, + Row(600L, 21L, 601, "addr1", "2026-01-01") :: + Row(700L, 22L, 602, "addr2", "2026-01-01") :: + Row(800L, 23L, 603, "addr3", "2026-01-02") :: + Row(900L, 24L, 604, "addr4", "2026-01-02") :: + Row(1000L, 25L, 605, "addr5", "2026-01-03") :: Nil + ) + + val partitions = getInputPartitions(df) + assertThat(partitions.length).isEqualTo(expectedPartitions) + } + } + } + } + } + + private def getInputPartitions(df: DataFrame): Seq[InputPartition] = { + df.queryExecution.executedPlan match { + case aeq: AdaptiveSparkPlanExec => + aeq.inputPlan.collect { case b: BatchScanExec => b.inputPartitions }.flatten + case e => + e.collect { + case b: BatchScanExec => b.inputPartitions + case _ => Seq.empty[InputPartition] + }.flatten + } + } }