Skip to content

Commit

Permalink
Refactor PyWindow to be moved to Expr submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
jdye64 committed Oct 18, 2023
1 parent 9161fc6 commit d130409
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 128 deletions.
3 changes: 3 additions & 0 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,5 +662,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<drop_table::PyDropTable>()?;
m.add_class::<repartition::PyPartitioning>()?;
m.add_class::<repartition::PyRepartition>()?;
m.add_class::<window::PyWindow>()?;
m.add_class::<window::PyWindowFrame>()?;
m.add_class::<window::PyWindowFrameBound>()?;
Ok(())
}
103 changes: 88 additions & 15 deletions src/expr/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{Expr, Window, WindowFrame, WindowFrameBound};
use datafusion_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};

use crate::common::df_schema::PyDFSchema;
use crate::errors::{py_type_err, DataFusionError};
use crate::errors::py_type_err;
use crate::expr::logical_node::LogicalNode;
use crate::expr::PyExpr;
use crate::sql::logical::PyLogicalPlan;

use super::py_expr_list;

use crate::errors::py_datafusion_err;

#[pyclass(name = "Window", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindow {
Expand All @@ -41,6 +43,18 @@ pub struct PyWindowFrame {
window_frame: WindowFrame,
}

impl From<PyWindowFrame> for WindowFrame {
fn from(window_frame: PyWindowFrame) -> Self {
window_frame.window_frame
}
}

impl From<WindowFrame> for PyWindowFrame {
fn from(window_frame: WindowFrame) -> PyWindowFrame {
PyWindowFrame { window_frame }
}
}

#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindowFrameBound {
Expand All @@ -59,12 +73,6 @@ impl From<Window> for PyWindow {
}
}

impl From<WindowFrame> for PyWindowFrame {
fn from(window_frame: WindowFrame) -> Self {
PyWindowFrame { window_frame }
}
}

impl From<WindowFrameBound> for PyWindowFrameBound {
fn from(frame_bound: WindowFrameBound) -> Self {
PyWindowFrameBound { frame_bound }
Expand All @@ -83,6 +91,16 @@ impl Display for PyWindow {
}
}

impl Display for PyWindowFrame {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(
f,
"OVER ({} BETWEEN {} AND {})",
self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
)
}
}

#[pymethods]
impl PyWindow {
/// Returns the schema of the Window
Expand Down Expand Up @@ -130,7 +148,7 @@ impl PyWindow {
}

/// Returns a Pywindow frame for a given window function expression
pub fn get_window_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { window_frame, .. }) => Some(window_frame.into()),
_ => None,
Expand All @@ -148,6 +166,57 @@ fn not_window_function_err(expr: Expr) -> PyErr {

#[pymethods]
impl PyWindowFrame {
#[new(unit, start_bound, end_bound)]
pub fn new(units: &str, start_bound: Option<u64>, end_bound: Option<u64>) -> PyResult<Self> {
let units = units.to_ascii_lowercase();
let units = match units.as_str() {
"rows" => WindowFrameUnits::Rows,
"range" => WindowFrameUnits::Range,
"groups" => WindowFrameUnits::Groups,
_ => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
};
let start_bound = match start_bound {
Some(start_bound) => {
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound)))
}
None => match units {
WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Groups => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
},
};
let end_bound = match end_bound {
Some(end_bound) => WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Groups => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
},
};
Ok(PyWindowFrame {
window_frame: WindowFrame {
units,
start_bound,
end_bound,
},
})
}

/// Returns the window frame units for the bounds
pub fn get_frame_units(&self) -> PyResult<String> {
Ok(self.window_frame.units.to_string())
Expand All @@ -160,6 +229,11 @@ impl PyWindowFrame {
pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
Ok(self.window_frame.end_bound.clone().into())
}

/// Get a String representation of this window frame
fn __repr__(&self) -> String {
format!("{}", self)
}
}

#[pymethods]
Expand Down Expand Up @@ -190,16 +264,15 @@ impl PyWindowFrameBound {
let s = v.clone().unwrap();
match s.parse::<u64>() {
Ok(s) => Ok(Some(s)),
Err(_e) => Err(DataFusionError::Common(format!(
Err(_e) => Err(DataFusionError::Plan(format!(
"Unable to parse u64 from Utf8 value '{s}'"
))
.into()),
}
}
ref x => Err(DataFusionError::Common(format!(
"Unexpected window frame bound: {x}"
))
.into()),
ref x => {
Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
}
},
WindowFrameBound::CurrentRow => Ok(None),
}
Expand Down
2 changes: 1 addition & 1 deletion src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use pyo3::{prelude::*, wrap_pyfunction};
use crate::context::PySessionContext;
use crate::errors::DataFusionError;
use crate::expr::conditional_expr::PyCaseBuilder;
use crate::expr::window::PyWindowFrame;
use crate::expr::PyExpr;
use crate::window_frame::PyWindowFrame;
use datafusion::execution::FunctionRegistry;
use datafusion_common::Column;
use datafusion_expr::expr::Alias;
Expand Down
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ mod udaf;
#[allow(clippy::borrow_deref_ref)]
mod udf;
pub mod utils;
mod window_frame;

#[cfg(feature = "mimalloc")]
#[global_allocator]
Expand Down Expand Up @@ -84,7 +83,6 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<context::PySessionContext>()?;
m.add_class::<dataframe::PyDataFrame>()?;
m.add_class::<udf::PyScalarUDF>()?;
m.add_class::<window_frame::PyWindowFrame>()?;
m.add_class::<udaf::PyAggregateUDF>()?;
m.add_class::<config::PyConfig>()?;
m.add_class::<sql::logical::PyLogicalPlan>()?;
Expand Down
110 changes: 0 additions & 110 deletions src/window_frame.rs

This file was deleted.

0 comments on commit d130409

Please sign in to comment.