Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Feb 19, 2025
1 parent 40330c5 commit 81e76c6
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 43 deletions.
3 changes: 2 additions & 1 deletion crates/polars-io/src/hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ pub(crate) fn materialize_hive_partitions<D>(
df: &mut DataFrame,
reader_schema: &polars_schema::Schema<D>,
hive_partition_columns: Option<&[Series]>,
num_rows: usize,
) {
let num_rows = df.height();

if let Some(hive_columns) = hive_partition_columns {
// Insert these hive columns in the order they are stored in the file.
if hive_columns.is_empty() {
Expand Down
22 changes: 7 additions & 15 deletions crates/polars-io/src/ipc/ipc_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,24 +252,22 @@ impl<R: MmapBytesReader> SerReader<R> for IpcReader<R> {

// In case only hive columns are projected, the df would be empty, but we need the row count
// of the file in order to project the correct number of rows for the hive columns.
let (mut df, row_count) = (|| {
let mut df = (|| {
if self.projection.as_ref().is_some_and(|x| x.is_empty()) {
let row_count = if let Some(v) = self.n_rows {
v
} else {
get_row_count(&mut self.reader)? as usize
};
let mut df = DataFrame::empty();
unsafe { df.set_height(row_count) };
let df = DataFrame::empty_with_height(row_count);

return PolarsResult::Ok((df, row_count));
return PolarsResult::Ok(df);
}

if self.memory_map.is_some() && self.reader.to_file().is_some() {
match self.finish_memmapped(None) {
Ok(df) => {
let n = df.height();
return Ok((df, n));
return Ok(df);
},
Err(err) => check_mmap_err(err)?,
}
Expand All @@ -293,17 +291,11 @@ impl<R: MmapBytesReader> SerReader<R> for IpcReader<R> {
let ipc_reader =
read::FileReader::new(self.reader, metadata, self.projection, self.n_rows);
let df = finish_reader(ipc_reader, rechunk, None, None, &schema, self.row_index)?;
let n = df.height();
Ok((df, n))
Ok(df)
})()?;

if let Some(hive_cols) = hive_partition_columns {
materialize_hive_partitions(
&mut df,
reader_schema,
Some(hive_cols.as_slice()),
row_count,
);
materialize_hive_partitions(&mut df, reader_schema, Some(hive_cols.as_slice()));
};

if let Some((col, value)) = include_file_path {
Expand All @@ -314,7 +306,7 @@ impl<R: MmapBytesReader> SerReader<R> for IpcReader<R> {
DataType::String,
AnyValue::StringOwned(value.as_ref().into()),
),
row_count,
df.height(),
))
};
}
Expand Down
30 changes: 5 additions & 25 deletions crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,7 @@ fn rg_to_dfs_prefiltered(
} else {
df = unsafe { DataFrame::new_no_checks(md.num_rows(), live_columns.clone()) };

materialize_hive_partitions(
&mut df,
schema.as_ref(),
hive_partition_columns,
md.num_rows(),
);
materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns);
let s = predicate.predicate.evaluate_io(&df)?;
let mask = s.bool().expect("filter predicates was not of type boolean");

Expand Down Expand Up @@ -489,12 +484,7 @@ fn rg_to_dfs_prefiltered(

// We don't need to do any further work if there are no dead columns
if dead_idx_to_col_idx.is_empty() {
materialize_hive_partitions(
&mut df,
schema.as_ref(),
hive_partition_columns,
md.num_rows(),
);
materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns);

return Ok(Some(df));
}
Expand Down Expand Up @@ -606,12 +596,7 @@ fn rg_to_dfs_prefiltered(
// and the length is given by the parquet file which should always be the same.
let mut df = unsafe { DataFrame::new_no_checks(height, merged) };

materialize_hive_partitions(
&mut df,
schema.as_ref(),
hive_partition_columns,
md.num_rows(),
);
materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns);

PolarsResult::Ok(Some(df))
})
Expand Down Expand Up @@ -713,7 +698,7 @@ fn rg_to_dfs_optionally_par_over_columns(
);
}

materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns, rg_slice.1);
materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns);
apply_predicate(
&mut df,
predicate.as_ref().map(|p| p.predicate.as_ref()),
Expand Down Expand Up @@ -850,12 +835,7 @@ fn rg_to_dfs_par_over_rg(
);
}

materialize_hive_partitions(
&mut df,
schema.as_ref(),
hive_partition_columns,
slice.1,
);
materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns);
apply_predicate(
&mut df,
predicate.as_ref().map(|p| p.predicate.as_ref()),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/parquet/read/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn materialize_empty_df(
.unwrap();
}

materialize_hive_partitions(&mut df, reader_schema, hive_partition_columns, 0);
materialize_hive_partitions(&mut df, reader_schema, hive_partition_columns);

df
}
Expand Down
49 changes: 48 additions & 1 deletion py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import urllib.parse
import warnings
from collections import OrderedDict
from datetime import datetime
from datetime import date, datetime
from functools import partial
from pathlib import Path
from typing import Any, Callable
Expand Down Expand Up @@ -888,3 +888,50 @@ def test_hive_auto_enables_when_unspecified_and_hive_schema_passed(
pl.Series("a", [1], dtype=pl.UInt8),
),
)


@pytest.mark.write_disk
def test_hive_parquet_prefiltered_20894_21327(tmp_path: Path) -> None:
file_path = tmp_path / "date=2025-01-01/00000000.parquet"
file_path.parent.mkdir(exist_ok=True, parents=True)

data = pl.DataFrame(
{
"date": [date(2025, 1, 1), date(2025, 1, 1)],
"value": ["1", "2"],
}
)

data.write_parquet(file_path)

import base64
import subprocess

# For security
scan_path_b64 = base64.b64encode(str(file_path).encode()).decode()

# This is, the easiest way to control the threadpool size in the test suite.
out = subprocess.check_output(
[
sys.executable,
"-c",
f"""\
import os
os.environ["POLARS_MAX_THREADS"] = "1"
import polars as pl
import base64
assert pl.thread_pool_size() == 1
tmp_path = base64.b64decode("{scan_path_b64}").decode()
# We need the str() to trigger panic on invalid state
str(pl.scan_parquet(tmp_path, hive_partitioning=True).filter(pl.col("value") == "1").collect())
print("OK", end="")
""",
]
)

assert out == b"OK"

0 comments on commit 81e76c6

Please sign in to comment.