Skip to content

Commit ba37f11

Browse files
authored
feat: Improved support for KeyboardInterrupts (#20961)
1 parent a7b933a commit ba37f11

File tree

45 files changed

+713
-916
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+713
-916
lines changed

Cargo.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/polars-core/src/chunked_array/object/registry.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,6 @@ pub fn register_object_builder(
123123
})
124124
}
125125

126-
pub fn is_object_builder_registered() -> bool {
127-
let reg = GLOBAL_OBJECT_REGISTRY.deref();
128-
let reg = reg.read().unwrap();
129-
reg.is_some()
130-
}
131-
132126
#[cold]
133127
pub fn get_object_physical_type() -> ArrowDataType {
134128
let reg = GLOBAL_OBJECT_REGISTRY.read().unwrap();

crates/polars-core/src/frame/group_by/into_groups.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups};
2-
use polars_error::check_signals;
2+
use polars_error::signals::try_raise_keyboard_interrupt;
33
use polars_utils::total_ord::{ToTotalOrd, TotalHash};
44

55
use super::*;
@@ -235,7 +235,7 @@ where
235235
num_groups_proxy(ca, multithreaded, sorted)
236236
},
237237
};
238-
check_signals()?;
238+
try_raise_keyboard_interrupt();
239239
Ok(out)
240240
}
241241
}
@@ -287,7 +287,7 @@ impl IntoGroupsType for BinaryChunked {
287287
} else {
288288
group_by(bh[0].iter(), sorted)
289289
};
290-
check_signals()?;
290+
try_raise_keyboard_interrupt();
291291
Ok(out)
292292
}
293293
}

crates/polars-core/src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pub use crate::chunked_array::StructChunked;
3737
#[cfg(feature = "dtype-categorical")]
3838
pub use crate::datatypes::string_cache::StringCacheHolder;
3939
pub use crate::datatypes::{ArrayCollectIterExt, *};
40+
pub use crate::error::signals::try_raise_keyboard_interrupt;
4041
pub use crate::error::{
4142
polars_bail, polars_ensure, polars_err, polars_warn, PolarsError, PolarsResult,
4243
};

crates/polars-error/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,8 @@ regex = { workspace = true, optional = true }
1616
simdutf8 = { workspace = true }
1717
thiserror = { workspace = true }
1818

19+
[target.'cfg(not(target_family = "wasm"))'.dependencies]
20+
signal-hook = "0.3"
21+
1922
[features]
2023
python = []

crates/polars-error/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ use std::fmt::{self, Display, Formatter, Write};
99
use std::ops::Deref;
1010
use std::sync::{Arc, LazyLock};
1111
use std::{env, io};
12-
mod signals;
12+
pub mod signals;
1313

14-
pub use signals::{check_signals, set_signals_function};
1514
pub use warning::*;
1615

1716
enum ErrorStrategy {

crates/polars-error/src/signals.rs

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,95 @@
1-
use crate::PolarsResult;
2-
3-
type SignalsFunction = fn() -> PolarsResult<()>;
4-
static mut SIGNALS_FUNCTION: Option<SignalsFunction> = None;
5-
6-
/// Set the function that will be called check_signals.
7-
/// This can be set on startup to enable stopping a query when user input like `ctrl-c` is called.
8-
///
9-
/// # Safety
10-
/// The caller must ensure there is no other thread accessing this function
11-
/// or calling `check_signals`.
12-
pub unsafe fn set_signals_function(function: SignalsFunction) {
13-
SIGNALS_FUNCTION = Some(function)
1+
use std::panic::{catch_unwind, UnwindSafe};
2+
use std::sync::atomic::{AtomicU64, Ordering};
3+
4+
/// Python hooks SIGINT to instead generate a KeyboardInterrupt exception.
5+
/// So we do the same to try and abort long-running computations and return to
6+
/// Python so that the Python exception can be generated.
7+
pub struct KeyboardInterrupt;
8+
9+
// Bottom bit: interrupt flag.
10+
// Top 63 bits: number of alive interrupt catchers.
11+
static INTERRUPT_STATE: AtomicU64 = AtomicU64::new(0);
12+
13+
pub fn register_polars_keyboard_interrupt_hook() {
14+
let default_hook = std::panic::take_hook();
15+
std::panic::set_hook(Box::new(move |p| {
16+
// Suppress panic output on KeyboardInterrupt.
17+
if p.payload().downcast_ref::<KeyboardInterrupt>().is_none() {
18+
default_hook(p);
19+
}
20+
}));
21+
22+
// WASM doesn't support signals, so we just skip installing the hook there.
23+
#[cfg(not(target_family = "wasm"))]
24+
unsafe {
25+
// SAFETY: we only do an atomic op in the signal handler, which is allowed.
26+
signal_hook::low_level::register(signal_hook::consts::signal::SIGINT, move || {
27+
// Set the interrupt flag, but only if there are active catchers.
28+
INTERRUPT_STATE
29+
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
30+
let num_catchers = state >> 1;
31+
if num_catchers > 0 {
32+
Some(state | 1)
33+
} else {
34+
None
35+
}
36+
})
37+
.ok();
38+
})
39+
.unwrap();
40+
}
41+
}
42+
43+
/// Checks if the keyboard interrupt flag is set, and if yes panics with a
44+
/// KeyboardInterrupt. This function is very cheap.
45+
#[inline(always)]
46+
pub fn try_raise_keyboard_interrupt() {
47+
if INTERRUPT_STATE.load(Ordering::Relaxed) & 1 != 0 {
48+
try_raise_keyboard_interrupt_slow()
49+
}
50+
}
51+
52+
#[inline(never)]
53+
#[cold]
54+
fn try_raise_keyboard_interrupt_slow() {
55+
std::panic::panic_any(KeyboardInterrupt);
56+
}
57+
58+
/// Runs the passed function, catching any KeyboardInterrupts if they occur
59+
/// while running the function.
60+
pub fn catch_keyboard_interrupt<R, F: FnOnce() -> R + UnwindSafe>(
61+
try_fn: F,
62+
) -> Result<R, KeyboardInterrupt> {
63+
// Try to register this catcher (or immediately return if there is an
64+
// uncaught interrupt).
65+
try_register_catcher()?;
66+
let ret = catch_unwind(try_fn);
67+
unregister_catcher();
68+
ret.map_err(|p| match p.downcast::<KeyboardInterrupt>() {
69+
Ok(_) => KeyboardInterrupt,
70+
Err(p) => std::panic::resume_unwind(p),
71+
})
1472
}
1573

16-
fn default() -> PolarsResult<()> {
74+
fn try_register_catcher() -> Result<(), KeyboardInterrupt> {
75+
let old_state = INTERRUPT_STATE.fetch_add(2, Ordering::Relaxed);
76+
if old_state & 1 != 0 {
77+
unregister_catcher();
78+
return Err(KeyboardInterrupt);
79+
}
1780
Ok(())
1881
}
1982

20-
pub fn check_signals() -> PolarsResult<()> {
21-
let f = unsafe { SIGNALS_FUNCTION.unwrap_or(default) };
22-
f()
83+
fn unregister_catcher() {
84+
INTERRUPT_STATE
85+
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
86+
let num_catchers = state >> 1;
87+
if num_catchers > 1 {
88+
Some(state - 2)
89+
} else {
90+
// Last catcher, clear interrupt flag.
91+
Some(0)
92+
}
93+
})
94+
.ok();
2395
}

crates/polars-expr/src/state/execution_state.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use std::sync::{Mutex, RwLock};
55
use bitflags::bitflags;
66
use once_cell::sync::OnceCell;
77
use polars_core::config::verbose;
8-
use polars_core::error::check_signals;
98
use polars_core::prelude::*;
109
use polars_ops::prelude::ChunkJoinOptIds;
1110

@@ -150,7 +149,7 @@ impl ExecutionState {
150149

151150
// This is wrong when the U64 overflows which will never happen.
152151
pub fn should_stop(&self) -> PolarsResult<()> {
153-
check_signals()?;
152+
try_raise_keyboard_interrupt();
154153
polars_ensure!(!self.stop.load(Ordering::Relaxed), ComputeError: "query interrupted");
155154
Ok(())
156155
}

crates/polars-ops/src/frame/join/asof/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use std::cmp::Ordering;
66
use default::*;
77
pub use groups::AsofJoinBy;
88
use polars_core::prelude::*;
9-
use polars_error::check_signals;
109
use polars_utils::pl_str::PlSmallStr;
1110
#[cfg(feature = "serde")]
1211
use serde::{Deserialize, Serialize};
@@ -334,7 +333,7 @@ pub trait AsofJoin: IntoDf {
334333
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
335334
},
336335
}?;
337-
check_signals()?;
336+
try_raise_keyboard_interrupt();
338337

339338
// Drop right join column.
340339
let other = if coalesce && left_key.name() == right_key.name() {

crates/polars-ops/src/frame/join/cross_join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ fn cross_join_dfs(
133133
}
134134
};
135135
let (l_df, r_df) = if parallel {
136-
check_signals()?;
136+
try_raise_keyboard_interrupt();
137137
POOL.install(|| rayon::join(create_left_df, create_right_df))
138138
} else {
139139
(create_left_df(), create_right_df())

crates/polars-ops/src/frame/join/dispatch_left_right.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ pub fn materialize_left_join_from_series(
8787
} else {
8888
right.drop(s_right.name()).unwrap()
8989
};
90-
check_signals()?;
90+
try_raise_keyboard_interrupt();
9191

9292
#[cfg(feature = "chunked_ids")]
9393
match (left_idx, right_idx) {

crates/polars-ops/src/frame/join/hash_join/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ pub trait JoinDispatch: IntoDf {
155155
let (mut join_idx_l, mut join_idx_r) =
156156
s_left.hash_join_outer(s_right, args.validation, args.join_nulls)?;
157157

158-
check_signals()?;
158+
try_raise_keyboard_interrupt();
159159
if let Some((offset, len)) = args.slice {
160160
let (offset, len) = slice_offsets(offset, len, join_idx_l.len());
161161
join_idx_l.slice(offset, len);

crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ where
6969
} else {
7070
build_tables(build, join_nulls)
7171
};
72-
check_signals()?;
72+
try_raise_keyboard_interrupt();
7373

7474
let n_tables = hash_tbls.len();
7575
let offsets = probe_to_offsets(&probe);

crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ where
136136
} else {
137137
build_tables(build, join_nulls)
138138
};
139-
check_signals()?;
139+
try_raise_keyboard_interrupt();
140140
let n_tables = hash_tbls.len();
141141

142142
// we determine the offset so that we later know which index to store in the join tuples

crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ where
231231
let (probe_hashes, _) = create_hash_and_keys_threaded_vectorized(probe, Some(random_state));
232232

233233
let n_tables = hash_tbls.len();
234-
check_signals()?;
234+
try_raise_keyboard_interrupt();
235235

236236
// probe the hash table.
237237
// Note: indexes from b that are not matched will be None, Some(idx_b)

crates/polars-ops/src/frame/join/iejoin/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use polars_core::prelude::*;
1212
use polars_core::series::IsSorted;
1313
use polars_core::utils::{_set_partition_size, split};
1414
use polars_core::{with_match_physical_numeric_polars_type, POOL};
15-
use polars_error::{check_signals, polars_err, PolarsResult};
15+
use polars_error::{polars_err, PolarsResult};
1616
use polars_utils::binary_search::ExponentialSearch;
1717
use polars_utils::itertools::Itertools;
1818
use polars_utils::total_ord::{TotalEq, TotalOrd};
@@ -362,7 +362,7 @@ unsafe fn materialize_join(
362362
right_row_idx: &IdxCa,
363363
suffix: Option<PlSmallStr>,
364364
) -> PolarsResult<DataFrame> {
365-
check_signals()?;
365+
try_raise_keyboard_interrupt();
366366
let (join_left, join_right) = {
367367
POOL.join(
368368
|| left.take_unchecked(left_row_idx),

crates/polars-ops/src/frame/join/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ use polars_core::utils::slice_offsets;
4545
#[allow(unused_imports)]
4646
use polars_core::utils::slice_slice;
4747
use polars_core::POOL;
48-
use polars_error::check_signals;
4948
use polars_utils::hashing::BytesHash;
5049
use rayon::prelude::*;
5150

@@ -565,7 +564,7 @@ trait DataFrameJoinOpsPrivate: IntoDf {
565564
args.maintain_order,
566565
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight
567566
);
568-
check_signals()?;
567+
try_raise_keyboard_interrupt();
569568
let (df_left, df_right) =
570569
if args.maintain_order != MaintainOrderJoin::None && !already_left_sorted {
571570
let mut df =

crates/polars-python/src/batched_csv.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use pyo3::prelude::*;
1010
use pyo3::pybacked::PyBackedStr;
1111

1212
use crate::error::PyPolarsErr;
13+
use crate::utils::EnterPolarsExt;
1314
use crate::{PyDataFrame, Wrap};
1415

1516
#[pyclass]
@@ -136,13 +137,7 @@ impl PyBatchedCsv {
136137

137138
fn next_batches(&self, py: Python, n: usize) -> PyResult<Option<Vec<PyDataFrame>>> {
138139
let reader = &self.reader;
139-
let batches = py.allow_threads(move || {
140-
reader
141-
.lock()
142-
.map_err(|e| PyPolarsErr::Other(e.to_string()))?
143-
.next_batches(n)
144-
.map_err(PyPolarsErr::from)
145-
})?;
140+
let batches = py.enter_polars(move || reader.lock().unwrap().next_batches(n))?;
146141

147142
// SAFETY: same memory layout
148143
let batches = unsafe {

0 commit comments

Comments
 (0)