Skip to content

Commit

Permalink
feat: dask expr name namespace (#735)
Browse files Browse the repository at this point in the history
* feat: dask expr name namespace

* returns_scalar from upstream
  • Loading branch information
FBruzzesi authored Aug 7, 2024
1 parent b36176c commit 935e9f4
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 54 deletions.
158 changes: 158 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ def str(self: Self) -> DaskExprStringNamespace:
def dt(self: Self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)

@property
def name(self: Self) -> DaskExprNameNamespace:
return DaskExprNameNamespace(self)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down Expand Up @@ -591,3 +595,157 @@ def ordinal_day(self) -> DaskExpr:
"ordinal_day",
returns_scalar=False,
)


class DaskExprNameNamespace:
def __init__(self: Self, expr: DaskExpr) -> None:
self._expr = expr

def keep(self: Self) -> DaskExpr:
root_names = self._expr._root_names

if root_names is None:
msg = (
"Anonymous expressions are not supported in `.name.keep`.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)

return self._expr.__class__(
lambda df: [
series.rename(name)
for series, name in zip(self._expr._call(df), root_names)
],
depth=self._expr._depth,
function_name=self._expr._function_name,
root_names=root_names,
output_names=root_names,
returns_scalar=self._expr._returns_scalar,
backend_version=self._expr._backend_version,
)

def map(self: Self, function: Callable[[str], str]) -> DaskExpr:
root_names = self._expr._root_names

if root_names is None:
msg = (
"Anonymous expressions are not supported in `.name.map`.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)

output_names = [function(str(name)) for name in root_names]

return self._expr.__class__(
lambda df: [
series.rename(name)
for series, name in zip(self._expr._call(df), output_names)
],
depth=self._expr._depth,
function_name=self._expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._expr._returns_scalar,
backend_version=self._expr._backend_version,
)

def prefix(self: Self, prefix: str) -> DaskExpr:
root_names = self._expr._root_names
if root_names is None:
msg = (
"Anonymous expressions are not supported in `.name.prefix`.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)

output_names = [prefix + str(name) for name in root_names]
return self._expr.__class__(
lambda df: [
series.rename(name)
for series, name in zip(self._expr._call(df), output_names)
],
depth=self._expr._depth,
function_name=self._expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._expr._returns_scalar,
backend_version=self._expr._backend_version,
)

def suffix(self: Self, suffix: str) -> DaskExpr:
root_names = self._expr._root_names
if root_names is None:
msg = (
"Anonymous expressions are not supported in `.name.suffix`.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)

output_names = [str(name) + suffix for name in root_names]

return self._expr.__class__(
lambda df: [
series.rename(name)
for series, name in zip(self._expr._call(df), output_names)
],
depth=self._expr._depth,
function_name=self._expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._expr._returns_scalar,
backend_version=self._expr._backend_version,
)

def to_lowercase(self: Self) -> DaskExpr:
root_names = self._expr._root_names

if root_names is None:
msg = (
"Anonymous expressions are not supported in `.name.to_lowercase`.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)
output_names = [str(name).lower() for name in root_names]

return self._expr.__class__(
lambda df: [
series.rename(name)
for series, name in zip(self._expr._call(df), output_names)
],
depth=self._expr._depth,
function_name=self._expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._expr._returns_scalar,
backend_version=self._expr._backend_version,
)

def to_uppercase(self: Self) -> DaskExpr:
root_names = self._expr._root_names

if root_names is None:
msg = (
"Anonymous expressions are not supported in `.name.to_uppercase`.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)
output_names = [str(name).upper() for name in root_names]

return self._expr.__class__(
lambda df: [
series.rename(name)
for series, name in zip(self._expr._call(df), output_names)
],
depth=self._expr._depth,
function_name=self._expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._expr._returns_scalar,
backend_version=self._expr._backend_version,
)
12 changes: 3 additions & 9 deletions tests/expr_and_series/name/keep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,21 @@
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}


def test_keep(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_keep(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo", "BAR") * 2).name.keep())
expected = {k: [e * 2 for e in v] for k, v in data.items()}
compare_dicts(result, expected)


def test_keep_after_alias(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_keep_after_alias(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo")).alias("alias_for_foo").name.keep())
expected = {"foo": data["foo"]}
compare_dicts(result, expected)


def test_keep_raise_anonymous(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_keep_raise_anonymous(constructor: Any) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw)

Expand Down
12 changes: 3 additions & 9 deletions tests/expr_and_series/name/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,21 @@ def map_func(s: str | None) -> str:
return str(s)[::-1].lower()


def test_map(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_map(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo", "BAR") * 2).name.map(function=map_func))
expected = {map_func(k): [e * 2 for e in v] for k, v in data.items()}
compare_dicts(result, expected)


def test_map_after_alias(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_map_after_alias(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo")).alias("alias_for_foo").name.map(function=map_func))
expected = {map_func("foo"): data["foo"]}
compare_dicts(result, expected)


def test_map_raise_anonymous(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_map_raise_anonymous(constructor: Any) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw)

Expand Down
12 changes: 3 additions & 9 deletions tests/expr_and_series/name/prefix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,21 @@
prefix = "with_prefix_"


def test_prefix(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_prefix(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo", "BAR") * 2).name.prefix(prefix))
expected = {prefix + str(k): [e * 2 for e in v] for k, v in data.items()}
compare_dicts(result, expected)


def test_suffix_after_alias(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_suffix_after_alias(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo")).alias("alias_for_foo").name.prefix(prefix))
expected = {prefix + "foo": data["foo"]}
compare_dicts(result, expected)


def test_prefix_raise_anonymous(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_prefix_raise_anonymous(constructor: Any) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw)

Expand Down
12 changes: 3 additions & 9 deletions tests/expr_and_series/name/suffix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,21 @@
suffix = "_with_suffix"


def test_suffix(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_suffix(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo", "BAR") * 2).name.suffix(suffix))
expected = {str(k) + suffix: [e * 2 for e in v] for k, v in data.items()}
compare_dicts(result, expected)


def test_suffix_after_alias(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_suffix_after_alias(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo")).alias("alias_for_foo").name.suffix(suffix))
expected = {"foo" + suffix: data["foo"]}
compare_dicts(result, expected)


def test_suffix_raise_anonymous(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_suffix_raise_anonymous(constructor: Any) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw)

Expand Down
12 changes: 3 additions & 9 deletions tests/expr_and_series/name/to_lowercase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,21 @@
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}


def test_to_lowercase(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_to_lowercase(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo", "BAR") * 2).name.to_lowercase())
expected = {k.lower(): [e * 2 for e in v] for k, v in data.items()}
compare_dicts(result, expected)


def test_to_lowercase_after_alias(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_to_lowercase_after_alias(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("BAR")).alias("ALIAS_FOR_BAR").name.to_lowercase())
expected = {"bar": data["BAR"]}
compare_dicts(result, expected)


def test_to_lowercase_raise_anonymous(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_to_lowercase_raise_anonymous(constructor: Any) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw)

Expand Down
12 changes: 3 additions & 9 deletions tests/expr_and_series/name/to_uppercase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,21 @@
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}


def test_to_uppercase(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_to_uppercase(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo", "BAR") * 2).name.to_uppercase())
expected = {k.upper(): [e * 2 for e in v] for k, v in data.items()}
compare_dicts(result, expected)


def test_to_uppercase_after_alias(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_to_uppercase_after_alias(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo")).alias("alias_for_foo").name.to_uppercase())
expected = {"FOO": data["foo"]}
compare_dicts(result, expected)


def test_to_uppercase_raise_anonymous(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_to_uppercase_raise_anonymous(constructor: Any) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw)

Expand Down

0 comments on commit 935e9f4

Please sign in to comment.