Skip to content

Commit

Permalink
feat: Add negative slicing to new streaming multiscan (#21219)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Feb 20, 2025
1 parent 9164d2c commit d2a6b8b
Show file tree
Hide file tree
Showing 8 changed files with 681 additions and 516 deletions.
10 changes: 5 additions & 5 deletions crates/polars-stream/src/nodes/io_sources/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use polars_utils::mmap::MemSlice;
use polars_utils::pl_str::PlSmallStr;
use polars_utils::IdxSize;

use super::multi_scan::{MultiScanable, RowRestrication};
use super::{SourceNode, SourceOutput};
use super::multi_scan::MultiScanable;
use super::{RowRestriction, SourceNode, SourceOutput};
use crate::async_executor::{self, spawn};
use crate::async_primitives::connector::{connector, Receiver};
use crate::async_primitives::distributor_channel::distributor_channel;
Expand Down Expand Up @@ -590,14 +590,14 @@ impl MultiScanable for CsvSourceNode {
.collect()
});
}
fn with_row_restriction(&mut self, row_restriction: Option<RowRestrication>) {
fn with_row_restriction(&mut self, row_restriction: Option<RowRestriction>) {
self.file_options.slice = None;
match row_restriction {
None => {},
Some(RowRestrication::Slice(rng)) => {
Some(RowRestriction::Slice(rng)) => {
self.file_options.slice = Some((rng.start as i64, rng.end - rng.start))
},
Some(RowRestrication::Predicate(_)) => unreachable!(),
Some(RowRestriction::Predicate(_)) => unreachable!(),
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/polars-stream/src/nodes/io_sources/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use polars_utils::pl_str::PlSmallStr;
use polars_utils::priority::Priority;
use polars_utils::IdxSize;

use super::multi_scan::{MultiScanable, RowRestrication};
use super::{SourceNode, SourceOutput};
use super::multi_scan::MultiScanable;
use super::{RowRestriction, SourceNode, SourceOutput};
use crate::async_executor::spawn;
use crate::async_primitives::connector::Receiver;
use crate::async_primitives::distributor_channel::distributor_channel;
Expand Down Expand Up @@ -535,12 +535,12 @@ impl MultiScanable for IpcSourceNode {
prepare_projection(&self.metadata.schema, p)
});
}
fn with_row_restriction(&mut self, row_restriction: Option<RowRestrication>) {
fn with_row_restriction(&mut self, row_restriction: Option<RowRestriction>) {
self.slice = 0..usize::MAX;
if let Some(row_restriction) = row_restriction {
match row_restriction {
RowRestrication::Slice(slice) => self.slice = slice,
RowRestrication::Predicate(_) => unreachable!(),
RowRestriction::Slice(slice) => self.slice = slice,
RowRestriction::Predicate(_) => unreachable!(),
}
}
}
Expand Down
42 changes: 25 additions & 17 deletions crates/polars-stream/src/nodes/io_sources/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ops::Range;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand All @@ -6,6 +7,7 @@ use futures::StreamExt;
use polars_core::config;
use polars_error::PolarsResult;
use polars_expr::state::ExecutionState;
use polars_mem_engine::ScanPredicate;
use polars_utils::index::AtomicIdxSize;

use super::{ComputeNode, JoinHandle, Morsel, PortState, RecvPort, SendPort, TaskPriority};
Expand All @@ -21,10 +23,17 @@ pub mod multi_scan;
#[cfg(feature = "parquet")]
pub mod parquet;

#[derive(Clone, Debug)]
pub enum RowRestriction {
Slice(Range<usize>),
#[expect(dead_code)]
Predicate(ScanPredicate),
}

/// The state needed to manage a spawned [`SourceNode`].
struct StartedSourceComputeNode {
output_send: Sender<SourceOutput>,
join_handles: Vec<JoinHandle<PolarsResult<()>>>,
join_handles: FuturesUnordered<AbortOnDropHandle<PolarsResult<()>>>,
}

/// A [`ComputeNode`] to wrap a [`SourceNode`].
Expand Down Expand Up @@ -94,6 +103,10 @@ impl<T: SourceNode> ComputeNode for SourceComputeNode<T> {

self.source
.spawn_source(self.num_pipelines, rx, state, &mut join_handles, None);
// One of the tasks might throw an error. In which case, we need to cancel all
// handles and find the error.
let join_handles: FuturesUnordered<_> =
join_handles.drain(..).map(AbortOnDropHandle::new).collect();

StartedSourceComputeNode {
output_send: tx,
Expand All @@ -112,27 +125,22 @@ impl<T: SourceNode> ComputeNode for SourceComputeNode<T> {
};
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
let (outcome, wait_group, source_output) = SourceOutput::from_port(source_output);
if started.output_send.send(source_output).await.is_err() {
return Ok(());
};

// Wait for the phase to finish.
wait_group.wait().await;
if outcome.did_finish() {
if started.output_send.send(source_output).await.is_ok() {
// Wait for the phase to finish.
wait_group.wait().await;
if !outcome.did_finish() {
return Ok(());
}

if config::verbose() {
eprintln!("[{name}]: Last data received.");
}
};

// One of the tasks might throw an error. In which case, we need to cancel all
// handles and find the error.
let mut join_handles: FuturesUnordered<_> = started
.join_handles
.drain(..)
.map(AbortOnDropHandle::new)
.collect();
while let Some(ret) = join_handles.next().await {
ret?;
}
// Either the task finished or some error occurred.
while let Some(ret) = started.join_handles.next().await {
ret?;
}

Ok(())
Expand Down
Loading

0 comments on commit d2a6b8b

Please sign in to comment.