Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add few missing SparkLikeExpr methods #1721

Merged
Merged
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ where `YOUR-GITHUB-USERNAME` will be your GitHub user name.

Here's how you can set up your local development environment to contribute.

#### Prerequisites for PySpark tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should narwhals suggest (and maintain :)) a guide to install Java for pyspark?
Or should we just add a note to say that pyspark needs Java and add a link to the pyspark documentation?

There may be different ways one wants to install Java on their machine.
For example, on macOS I prefer using openjdk installed via homebrew.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be simple to just say that pyspark needs java installed and add a link to pyspark documentation.


If you want to run PySpark-related tests, you'll need to have Java installed. Refer to the [Spark documentation](https://spark.apache.org/docs/latest/#downloading) for more information.

#### Option 1: Use UV (recommended)

1. Make sure you have Python3.12 installed, create a virtual environment,
Expand Down
206 changes: 183 additions & 23 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def __gt__(self, other: SparkLikeExpr) -> Self:
returns_scalar=False,
)

def abs(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.abs, "abs", returns_scalar=self._returns_scalar)

def alias(self, name: str) -> Self:
def _alias(df: SparkLikeLazyFrame) -> list[Column]:
return [col.alias(name) for col in self._call(df)]
Expand All @@ -179,44 +184,42 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]:
)

def count(self) -> Self:
def _count(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import functions as F # noqa: N812

return F.count(_input)

return self._from_call(_count, "count", returns_scalar=True)
return self._from_call(F.count, "count", returns_scalar=True)

def max(self) -> Self:
def _max(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import functions as F # noqa: N812

return F.max(_input)

return self._from_call(_max, "max", returns_scalar=True)
return self._from_call(F.max, "max", returns_scalar=True)

def mean(self) -> Self:
def _mean(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.mean, "mean", returns_scalar=True)

def median(self) -> Self:
def _median(_input: Column) -> Column:
import pyspark # ignore-banned-import
from pyspark.sql import functions as F # noqa: N812

return F.mean(_input)
if parse_version(pyspark.__version__) < (3, 4):
# Use percentile_approx with default accuracy parameter (10000)
return F.percentile_approx(_input.cast("double"), 0.5)

return self._from_call(_mean, "mean", returns_scalar=True)
return F.median(_input)

def min(self) -> Self:
def _min(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
return self._from_call(_median, "median", returns_scalar=True)

return F.min(_input)
def min(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(_min, "min", returns_scalar=True)
return self._from_call(F.min, "min", returns_scalar=True)

def sum(self) -> Self:
def _sum(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import functions as F # noqa: N812

return F.sum(_input)

return self._from_call(_sum, "sum", returns_scalar=True)
return self._from_call(F.sum, "sum", returns_scalar=True)

def std(self: Self, ddof: int) -> Self:
from functools import partial
Expand Down Expand Up @@ -249,3 +252,160 @@ def var(self: Self, ddof: int) -> Self:
)

return self._from_call(func, "var", returns_scalar=True, ddof=ddof)

def clip(
self,
lower_bound: Any | None = None,
upper_bound: Any | None = None,
) -> Self:
Comment on lines +246 to +250
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recently introduced support for lower_bound and upper_bound to be other Expr's.
I am ok to keep it as a follow up, but definitly something we would look forward to.

def _clip(_input: Column, lower_bound: Any, upper_bound: Any) -> Column:
from pyspark.sql import functions as F # noqa: N812

result = _input
if lower_bound is not None:
# Convert lower_bound to a literal Column
result = F.when(result < lower_bound, F.lit(lower_bound)).otherwise(
result
)
if upper_bound is not None:
# Convert upper_bound to a literal Column
result = F.when(result > upper_bound, F.lit(upper_bound)).otherwise(
result
)
return result

return self._from_call(
_clip,
"clip",
lower_bound=lower_bound,
upper_bound=upper_bound,
returns_scalar=self._returns_scalar,
)

def is_between(
self,
lower_bound: Any,
upper_bound: Any,
closed: str,
) -> Self:
def _is_between(_input: Column, lower_bound: Any, upper_bound: Any) -> Column:
if closed == "both":
return (_input >= lower_bound) & (_input <= upper_bound)
if closed == "none":
return (_input > lower_bound) & (_input < upper_bound)
if closed == "left":
return (_input >= lower_bound) & (_input < upper_bound)
return (_input > lower_bound) & (_input <= upper_bound)

return self._from_call(
_is_between,
"is_between",
lower_bound=lower_bound,
upper_bound=upper_bound,
returns_scalar=self._returns_scalar,
)

def is_duplicated(self) -> Self:
def _is_duplicated(_input: Column) -> Column:
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

# Create a window spec that treats each value separately
window = Window.partitionBy(
F.when(F.isnull(_input), F.lit("NULL"))
.when(F.isnan(_input), F.lit("NAN"))
.otherwise(_input)
)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

# Count occurrences treating NULL and NaN as unique values
return F.count(F.lit(1)).over(window) > 1

return self._from_call(
_is_duplicated, "is_duplicated", returns_scalar=self._returns_scalar
)

def is_finite(self) -> Self:
def _is_finite(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

# A value is finite if it's not NaN, not NULL, and not infinite
return (
~F.isnan(_input)
& ~F.isnull(_input)
& (_input != float("inf"))
& (_input != float("-inf"))
)

return self._from_call(
_is_finite, "is_finite", returns_scalar=self._returns_scalar
)

def is_in(self, values: Sequence[Any]) -> Self:
def _is_in(_input: Column, values: Sequence[Any]) -> Column:
return _input.isin(values)

return self._from_call(
_is_in,
"is_in",
values=values,
returns_scalar=self._returns_scalar,
)

def is_unique(self) -> Self:
def _is_unique(_input: Column) -> Column:
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

# Create a window spec that treats each value separately
window = Window.partitionBy(
F.when(F.isnull(_input), F.lit("NULL"))
.when(F.isnan(_input), F.lit("NAN"))
.otherwise(_input)
)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

# Count occurrences treating NULL and NaN as unique values
return F.count(F.lit(1)).over(window) == 1

return self._from_call(
_is_unique, "is_unique", returns_scalar=self._returns_scalar
)

def len(self) -> Self:
def _len(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

# Use count(*) to count all rows including nulls
return F.count("*")

return self._from_call(_len, "len", returns_scalar=True)

def n_unique(self) -> Self:
def _n_unique(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

expr = (
F.when(F.isnull(_input), F.lit("NULL"))
.when(F.isnan(_input), F.lit("NaN"))
.otherwise(_input)
)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
return F.countDistinct(expr)

return self._from_call(_n_unique, "n_unique", returns_scalar=True)

def round(self, decimals: int) -> Self:
def _round(_input: Column, decimals: int) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.round(_input, decimals)

return self._from_call(
_round,
"round",
decimals=decimals,
returns_scalar=self._returns_scalar,
)

def skew(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.skewness, "skew", returns_scalar=True)
6 changes: 5 additions & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def is_pandas_like(self) -> bool:
>>> df.implementation.is_pandas_like()
True
"""
return self in {Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF}
return self in {
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
}

def is_polars(self) -> bool:
"""Return whether implementation is Polars.
Expand Down
Loading
Loading