Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust): Fix height validation in hstack_mut was bypassed when adding to empty frame #21335

Merged
merged 25 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 60 additions & 73 deletions crates/polars-core/src/frame/horizontal.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,8 @@
use polars_error::{polars_ensure, polars_err, PolarsResult};
use polars_utils::aliases::PlHashSet;
use polars_error::{polars_err, PolarsResult};

use super::Column;
use crate::datatypes::AnyValue;
use crate::frame::DataFrame;
use crate::prelude::PlSmallStr;

fn check_hstack(
col: &Column,
names: &mut PlHashSet<PlSmallStr>,
height: usize,
is_empty: bool,
) -> PolarsResult<()> {
polars_ensure!(
col.len() == height || is_empty,
ShapeMismatch: "unable to hstack Series of length {} and DataFrame of height {}",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug is here - || (self.)is_empty

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems correct for this function, but not correct on the callsite.

col.len(), height,
);
polars_ensure!(
names.insert(col.name().clone()),
Duplicate: "unable to hstack, column with name {:?} already exists", col.name().as_str(),
);
Ok(())
}

impl DataFrame {
/// Add columns horizontally.
Expand All @@ -31,28 +11,25 @@ impl DataFrame {
/// The caller must ensure:
/// - the length of all [`Column`] is equal to the height of this [`DataFrame`]
/// - the columns names are unique
///
/// Note that on a debug build this will panic on duplicates / height mismatch.
pub unsafe fn hstack_mut_unchecked(&mut self, columns: &[Column]) -> &mut Self {
// If we don't have any columns yet, copy the height from the given columns.
if let Some(fst) = columns.first() {
if self.width() == 0 {
// SAFETY: The functions invariants asks for all columns to be the same length so
// that makes that a valid height.
unsafe { self.set_height(fst.len()) };
}
}
self.clear_schema();
self.columns.extend_from_slice(columns);

if cfg!(debug_assertions) {
// It is an impl error if this fails.
self._validate_hstack(columns).unwrap();
DataFrame::validate_columns_slice(&self.columns).unwrap();
}

if let Some(c) = self.columns.first() {
unsafe { self.set_height(c.len()) };
}

self.clear_schema();
self.columns.extend_from_slice(columns);
self
}

/// Add multiple [`Column`] to a [`DataFrame`].
/// The added `Series` are required to have the same length.
/// Errors if the resulting DataFrame columns have duplicate names or unequal heights.
///
/// # Example
///
Expand All @@ -63,28 +40,19 @@ impl DataFrame {
/// }
/// ```
pub fn hstack_mut(&mut self, columns: &[Column]) -> PolarsResult<&mut Self> {
self._validate_hstack(columns)?;
Ok(unsafe { self.hstack_mut_unchecked(columns) })
}
self.clear_schema();
self.columns.extend_from_slice(columns);

fn _validate_hstack(&self, columns: &[Column]) -> PolarsResult<()> {
let mut names = self
.columns
.iter()
.map(|c| c.name().clone())
.collect::<PlHashSet<_>>();

let height = self.height();
let is_empty = self.is_empty();
// first loop check validity. We don't do this in a single pass otherwise
// this DataFrame is already modified when an error occurs.
for col in columns {
check_hstack(col, &mut names, height, is_empty)?;
DataFrame::validate_columns_slice(&self.columns)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And by that, I mean here. After the first column is added is_empty should be false as it isn't empty anymore.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - we needed to add a column first.


if let Some(c) = self.columns.first() {
unsafe { self.set_height(c.len()) };
}

Ok(())
Ok(self)
}
}

/// Concat [`DataFrame`]s horizontally.
/// Concat horizontally and extend with null values if lengths don't match
pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> PolarsResult<DataFrame> {
Expand All @@ -96,12 +64,23 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars

let owned_df;

let mut out_width = 0;

let all_equal_height = dfs.iter().all(|df| {
out_width += df.width();
df.height() == output_height
});

// if not all equal length, extend the DataFrame with nulls
let dfs = if !dfs.iter().all(|df| df.height() == output_height) {
let dfs = if !all_equal_height {
out_width = 0;

owned_df = dfs
.iter()
.cloned()
.map(|mut df| {
out_width += df.width();

if df.height() != output_height {
let diff = output_height - df.height();

Expand All @@ -123,30 +102,38 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars
dfs
};

let mut first_df = dfs[0].clone();
let height = first_df.height();
let is_empty = first_df.is_empty();
let mut acc_cols = Vec::with_capacity(out_width);

let mut names = if check_duplicates {
first_df
.columns
.iter()
.map(|s| s.name().clone())
.collect::<PlHashSet<_>>()
} else {
Default::default()
};
for df in dfs {
acc_cols.extend(df.get_columns().iter().cloned());
}

if check_duplicates {
DataFrame::validate_columns_slice(&acc_cols)?;
}

for df in &dfs[1..] {
let cols = df.get_columns();
let df = unsafe { DataFrame::new_no_checks_height_from_first(acc_cols) };

if check_duplicates {
for col in cols {
check_hstack(col, &mut names, height, is_empty)?;
}
}
Ok(df)
}

unsafe { first_df.hstack_mut_unchecked(cols) };
#[cfg(test)]
mod tests {
use polars_error::PolarsError;

#[test]
fn test_hstack_mut_empty_frame_height_validation() {
use crate::frame::DataFrame;
use crate::prelude::{Column, DataType};
let mut df = DataFrame::empty();
let result = df.hstack_mut(&[
Column::full_null("a".into(), 1, &DataType::Null),
Column::full_null("b".into(), 3, &DataType::Null),
]);

assert!(
matches!(result, Err(PolarsError::ShapeMismatch(_))),
"expected shape mismatch error"
);
}
Ok(first_df)
}
53 changes: 11 additions & 42 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{mem, ops};

use arrow::datatypes::ArrowSchemaRef;
use polars_row::ArrayRef;
use polars_schema::schema::debug_ensure_matching_schema_names;
use polars_schema::schema::ensure_matching_schema_names;
use polars_utils::itertools::Itertools;
use rayon::prelude::*;

Expand All @@ -31,6 +31,7 @@ pub(crate) mod horizontal;
pub mod row;
mod top_k;
mod upstream_traits;
mod validation;

use arrow::record_batch::{RecordBatch, RecordBatchT};
use polars_utils::pl_str::PlSmallStr;
Expand Down Expand Up @@ -260,6 +261,8 @@ impl DataFrame {

/// Create a DataFrame from a Vector of Series.
///
/// Errors if a column names are not unique, or if heights are not all equal.
///
/// # Example
///
/// ```
Expand All @@ -271,17 +274,9 @@ impl DataFrame {
/// # Ok::<(), PolarsError>(())
/// ```
pub fn new(columns: Vec<Column>) -> PolarsResult<Self> {
ensure_names_unique(&columns, |s| s.name().as_str())?;

let Some(fst) = columns.first() else {
return Ok(DataFrame {
height: 0,
columns,
cached_schema: OnceLock::new(),
});
};

Self::new_with_height(fst.len(), columns)
DataFrame::validate_columns_slice(&columns)
.map_err(|e| e.wrap_msg(|e| format!("could not create a new DataFrame: {}", e)))?;
Ok(unsafe { Self::new_no_checks_height_from_first(columns) })
}

pub fn new_with_height(height: usize, columns: Vec<Column>) -> PolarsResult<Self> {
Expand Down Expand Up @@ -522,11 +517,7 @@ impl DataFrame {
/// having an equal length and a unique name, if not this may panic down the line.
pub unsafe fn new_no_checks(height: usize, columns: Vec<Column>) -> DataFrame {
if cfg!(debug_assertions) {
ensure_names_unique(&columns, |s| s.name().as_str()).unwrap();

for col in &columns {
assert_eq!(col.len(), height);
}
DataFrame::validate_columns_slice(&columns).unwrap();
}

unsafe { Self::_new_no_checks_impl(height, columns) }
Expand All @@ -544,30 +535,6 @@ impl DataFrame {
}
}

/// Create a new `DataFrame` but does not check the length of the `Series`,
/// only check for duplicates.
///
/// It is advised to use [DataFrame::new] in favor of this method.
///
/// # Safety
///
/// It is the callers responsibility to uphold the contract of all `Series`
/// having an equal length, if not this may panic down the line.
pub unsafe fn new_no_length_checks(columns: Vec<Column>) -> PolarsResult<DataFrame> {
ensure_names_unique(&columns, |s| s.name().as_str())?;
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed - new_no_length_checks doesn't really save any perf since it's already traversing the columns for duplicate name checking - doing a height check on top is practically free.


Ok(if cfg!(debug_assertions) {
Self::new(columns).unwrap()
} else {
let height = Self::infer_height(&columns);
DataFrame {
height,
columns,
cached_schema: OnceLock::new(),
}
})
}

/// Shrink the capacity of this DataFrame to fit its length.
pub fn shrink_to_fit(&mut self) {
// Don't parallelize this. Memory overhead
Expand Down Expand Up @@ -1845,7 +1812,9 @@ impl DataFrame {
cols: &[PlSmallStr],
schema: &Schema,
) -> PolarsResult<Vec<Column>> {
debug_ensure_matching_schema_names(schema, self.schema())?;
if cfg!(debug_assertions) {
ensure_matching_schema_names(schema, self.schema())?;
}

cols.iter()
.map(|name| {
Expand Down
67 changes: 67 additions & 0 deletions crates/polars-core/src/frame/validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use polars_error::{polars_bail, PolarsResult};
use polars_utils::aliases::{InitHashMaps, PlHashSet};

use super::column::Column;
use super::DataFrame;

impl DataFrame {
/// Ensure all equal height and names are unique.
///
/// An Ok() result indicates `columns` is a valid state for a DataFrame.
pub fn validate_columns_slice(columns: &[Column]) -> PolarsResult<()> {
if columns.len() <= 1 {
return Ok(());
}

if columns.len() <= 4 {
// Too small to be worth spawning a hashmap for, this is at most 6 comparisons.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied and adjusted from

if items.len() <= 4 {
// Too small to be worth spawning a hashmap for, this is at most 6 comparisons.
for i in 0..items.len() - 1 {
let name = get_name(&items[i]);
for other in items.iter().skip(i + 1) {
if name == get_name(other) {
polars_bail!(duplicate = name);
}
}
}
} else {

for i in 0..columns.len() - 1 {
let name = columns[i].name();
let height = columns[i].len();

for other in columns.iter().skip(i + 1) {
if other.name() == name {
polars_bail!(duplicate = name);
}

if other.len() != height {
polars_bail!(
ShapeMismatch:
"height of column '{}' ({}) does not match height of column '{}' ({})",
other.name(), other.len(), name, height
)
}
}
}
} else {
let first = &columns[0];

let first_len = first.len();
let first_name = first.name();

let mut names = PlHashSet::with_capacity(columns.len());
names.insert(first_name);

for col in &columns[1..] {
let col_name = col.name();
let col_len = col.len();

if col_len != first_len {
polars_bail!(
ShapeMismatch:
"height of column '{}' ({}) does not match height of column '{}' ({})",
col_name, col_len, first_name, first_len
)
}

if names.contains(col_name) {
polars_bail!(duplicate = col_name)
}

names.insert(col_name);
}
}

Ok(())
}
}
3 changes: 1 addition & 2 deletions crates/polars-ops/src/frame/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,5 @@ fn pivot_impl_single_column(
});
out?;

// SAFETY: length has already been checked.
unsafe { DataFrame::new_no_length_checks(final_cols) }
DataFrame::new(final_cols)
}
3 changes: 1 addition & 2 deletions crates/polars-ops/src/series/ops/to_dummies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ impl ToDummies for Series {
})
.collect::<Vec<_>>();

// SAFETY: `dummies_helper` functions preserve `self.len()` length
unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) }
DataFrame::new(sort_columns(columns))
}
}

Expand Down
Loading
Loading