diff --git a/datafusion/physical-plan/src/spill/spill_pool.rs b/datafusion/physical-plan/src/spill/spill_pool.rs index 1b9d82eaf4506..841183d254aec 100644 --- a/datafusion/physical-plan/src/spill/spill_pool.rs +++ b/datafusion/physical-plan/src/spill/spill_pool.rs @@ -18,6 +18,7 @@ use futures::{Stream, StreamExt}; use std::collections::VecDeque; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::Waker; use parking_lot::Mutex; @@ -88,6 +89,33 @@ impl SpillPoolShared { } } +/// Tracks the number of live [`SpillPoolWriter`] clones. +/// +/// Cloning increments the count. [`WriterCount::decrement`] atomically +/// decrements the count and reports whether the caller was the last clone. +struct WriterCount(Arc); + +impl WriterCount { + fn new() -> Self { + Self(Arc::new(AtomicUsize::new(1))) + } + + /// Decrements the count and returns `true` if this was the last clone. + /// + /// This is a single atomic operation, so concurrent drops cannot both + /// observe themselves as "last". + fn decrement(&self) -> bool { + self.0.fetch_sub(1, Ordering::SeqCst) == 1 + } +} + +impl Clone for WriterCount { + fn clone(&self) -> Self { + self.0.fetch_add(1, Ordering::SeqCst); + Self(Arc::clone(&self.0)) + } +} + /// Writer for a spill pool. Provides coordinated write access with FIFO semantics. /// /// Created by [`channel`]. See that function for architecture diagrams and usage examples. @@ -104,6 +132,9 @@ pub struct SpillPoolWriter { max_file_size_bytes: usize, /// Shared state with readers (includes current_write_file for coordination) shared: Arc>, + /// Tracks how many writer clones are alive. The pool is only finalized + /// when the last clone is dropped. + writer_count: WriterCount, } impl SpillPoolWriter { @@ -231,6 +262,12 @@ impl SpillPoolWriter { impl Drop for SpillPoolWriter { fn drop(&mut self) { + if !self.writer_count.decrement() { + // Other writer clones are still active; do not finalize or + // signal EOF to readers. + return; + } + let mut shared = self.shared.lock(); // Finalize the current file when the last writer is dropped @@ -443,6 +480,7 @@ pub fn channel( let writer = SpillPoolWriter { max_file_size_bytes, shared: Arc::clone(&shared), + writer_count: WriterCount::new(), }; let reader = SpillPoolReader::new(shared, schema); @@ -1343,6 +1381,81 @@ mod tests { Ok(()) } + /// Verifies that the reader stays alive as long as any writer clone exists. + /// + /// `SpillPoolWriter` is `Clone`, and in non-preserve-order repartitioning + /// mode multiple input partition tasks share clones of the same writer. + /// The reader must not see EOF until **all** clones have been dropped, + /// even if the queue is temporarily empty between writes from different + /// clones. + /// + /// The test sequence is: + /// + /// 1. writer1 writes a batch, then is dropped. + /// 2. The reader consumes that batch (queue is now empty). + /// 3. writer2 (still alive) writes a batch. + /// 4. The reader must see that batch. + /// 5. EOF is only signalled after writer2 is also dropped. + #[tokio::test] + async fn test_clone_drop_does_not_signal_eof_prematurely() -> Result<()> { + let (writer1, mut reader) = create_spill_channel(1024 * 1024); + let writer2 = writer1.clone(); + + // Synchronization: tell writer2 when it may proceed. + let (proceed_tx, proceed_rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn writer2 — it waits for the signal before writing. + let writer2_handle = SpawnedTask::spawn(async move { + proceed_rx.await.unwrap(); + writer2.push_batch(&create_test_batch(10, 10)).unwrap(); + // writer2 is dropped here (last clone → true EOF) + }); + + // Writer1 writes one batch, then drops. + writer1.push_batch(&create_test_batch(0, 10))?; + drop(writer1); + + // Read writer1's batch. + let batch1 = reader.next().await.unwrap()?; + assert_eq!(batch1.num_rows(), 10); + let col = batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 0); + + // Signal writer2 to write its batch. It will execute when the + // current task yields (i.e. when reader.next() returns Pending). + proceed_tx.send(()).unwrap(); + + // The reader should wait (Pending) for writer2's data, not EOF. + let batch2 = + tokio::time::timeout(std::time::Duration::from_secs(5), reader.next()) + .await + .expect("Reader timed out — should not hang"); + + assert!( + batch2.is_some(), + "Reader must not return EOF while a writer clone is still alive" + ); + let batch2 = batch2.unwrap()?; + assert_eq!(batch2.num_rows(), 10); + let col = batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 10); + + writer2_handle.await.unwrap(); + + // All writers dropped — reader should see real EOF now. + assert!(reader.next().await.is_none()); + + Ok(()) + } + #[tokio::test] async fn test_disk_usage_decreases_as_files_consumed() -> Result<()> { use datafusion_execution::runtime_env::RuntimeEnvBuilder;