Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions datafusion/physical-plan/src/spill/spill_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use futures::{Stream, StreamExt};
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Waker;

use parking_lot::Mutex;
Expand Down Expand Up @@ -88,6 +89,33 @@ impl SpillPoolShared {
}
}

/// Tracks the number of live [`SpillPoolWriter`] clones.
///
/// Cloning increments the count. [`WriterCount::decrement`] atomically
/// decrements the count and reports whether the caller was the last clone.
struct WriterCount(Arc<AtomicUsize>);

impl WriterCount {
fn new() -> Self {
Self(Arc::new(AtomicUsize::new(1)))
}

/// Decrements the count and returns `true` if this was the last clone.
///
/// This is a single atomic operation, so concurrent drops cannot both
/// observe themselves as "last".
fn decrement(&self) -> bool {
self.0.fetch_sub(1, Ordering::SeqCst) == 1
}
}

impl Clone for WriterCount {
fn clone(&self) -> Self {
self.0.fetch_add(1, Ordering::SeqCst);
Self(Arc::clone(&self.0))
}
}

/// Writer for a spill pool. Provides coordinated write access with FIFO semantics.
///
/// Created by [`channel`]. See that function for architecture diagrams and usage examples.
Expand All @@ -104,6 +132,9 @@ pub struct SpillPoolWriter {
max_file_size_bytes: usize,
/// Shared state with readers (includes current_write_file for coordination)
shared: Arc<Mutex<SpillPoolShared>>,
/// Tracks how many writer clones are alive. The pool is only finalized
/// when the last clone is dropped.
writer_count: WriterCount,
}

impl SpillPoolWriter {
Expand Down Expand Up @@ -231,6 +262,12 @@ impl SpillPoolWriter {

impl Drop for SpillPoolWriter {
fn drop(&mut self) {
if !self.writer_count.decrement() {
// Other writer clones are still active; do not finalize or
// signal EOF to readers.
return;
}

let mut shared = self.shared.lock();

// Finalize the current file when the last writer is dropped
Expand Down Expand Up @@ -443,6 +480,7 @@ pub fn channel(
let writer = SpillPoolWriter {
max_file_size_bytes,
shared: Arc::clone(&shared),
writer_count: WriterCount::new(),
};

let reader = SpillPoolReader::new(shared, schema);
Expand Down Expand Up @@ -1343,6 +1381,81 @@ mod tests {
Ok(())
}

/// Verifies that the reader stays alive as long as any writer clone exists.
///
/// `SpillPoolWriter` is `Clone`, and in non-preserve-order repartitioning
/// mode multiple input partition tasks share clones of the same writer.
/// The reader must not see EOF until **all** clones have been dropped,
/// even if the queue is temporarily empty between writes from different
/// clones.
///
/// The test sequence is:
///
/// 1. writer1 writes a batch, then is dropped.
/// 2. The reader consumes that batch (queue is now empty).
/// 3. writer2 (still alive) writes a batch.
/// 4. The reader must see that batch.
/// 5. EOF is only signalled after writer2 is also dropped.
#[tokio::test]
async fn test_clone_drop_does_not_signal_eof_prematurely() -> Result<()> {
let (writer1, mut reader) = create_spill_channel(1024 * 1024);
let writer2 = writer1.clone();

// Synchronization: tell writer2 when it may proceed.
let (proceed_tx, proceed_rx) = tokio::sync::oneshot::channel::<()>();

// Spawn writer2 — it waits for the signal before writing.
let writer2_handle = SpawnedTask::spawn(async move {
proceed_rx.await.unwrap();
writer2.push_batch(&create_test_batch(10, 10)).unwrap();
// writer2 is dropped here (last clone → true EOF)
});

// Writer1 writes one batch, then drops.
writer1.push_batch(&create_test_batch(0, 10))?;
drop(writer1);

// Read writer1's batch.
let batch1 = reader.next().await.unwrap()?;
assert_eq!(batch1.num_rows(), 10);
let col = batch1
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), 0);

// Signal writer2 to write its batch. It will execute when the
// current task yields (i.e. when reader.next() returns Pending).
proceed_tx.send(()).unwrap();

// The reader should wait (Pending) for writer2's data, not EOF.
let batch2 =
tokio::time::timeout(std::time::Duration::from_secs(5), reader.next())
.await
.expect("Reader timed out — should not hang");

assert!(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this fix we fail here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also verified that this test fails without the code fix

andrewlamb@Andrews-MacBook-Pro-3:~/Software/datafusion2$ cargo test -p datafusion-physical-plan test_clone_drop_does_not_signal_eof_prematurely
    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.13s
     Running unittests src/lib.rs (target/debug/deps/datafusion_physical_plan-33977765615826e4)

running 1 test
test spill::spill_pool::tests::test_clone_drop_does_not_signal_eof_prematurely ... FAILED

failures:

---- spill::spill_pool::tests::test_clone_drop_does_not_signal_eof_prematurely stdout ----

thread 'spill::spill_pool::tests::test_clone_drop_does_not_signal_eof_prematurely' (2314602) panicked at datafusion/physical-plan/src/spill/spill_pool.rs:1400:9:
Reader must not return EOF while a writer clone is still alive
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace


failures:
    spill::spill_pool::tests::test_clone_drop_does_not_signal_eof_prematurely

test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 1266 filtered out; finished in 0.00s

error: test failed, to rerun pass `-p datafusion-physical-plan --lib`

batch2.is_some(),
"Reader must not return EOF while a writer clone is still alive"
);
let batch2 = batch2.unwrap()?;
assert_eq!(batch2.num_rows(), 10);
let col = batch2
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.value(0), 10);

writer2_handle.await.unwrap();

// All writers dropped — reader should see real EOF now.
assert!(reader.next().await.is_none());

Ok(())
}

#[tokio::test]
async fn test_disk_usage_decreases_as_files_consumed() -> Result<()> {
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
Expand Down
Loading