diff --git a/src/nested_dask/core.py b/src/nested_dask/core.py index c1659b4..bce5bc1 100644 --- a/src/nested_dask/core.py +++ b/src/nested_dask/core.py @@ -1,17 +1,20 @@ from __future__ import annotations import os +from collections.abc import Callable, Mapping +from typing import Any, Literal import dask.dataframe as dd import dask.dataframe.dask_expr as dx import nested_pandas as npd +import numpy as np import pandas as pd import pyarrow as pa from dask.dataframe.dask_expr._collection import new_collection from nested_pandas.series.dtype import NestedDtype from nested_pandas.series.packer import pack, pack_flat, pack_lists from pandas._libs import lib -from pandas._typing import AnyAll, Axis, IndexLabel +from pandas._typing import Axis, IndexLabel from pandas.api.extensions import no_default # need this for the base _Frame class @@ -540,7 +543,7 @@ def dropna( self, *, axis: Axis = 0, - how: AnyAll | lib.NoDefault = no_default, + how: str | lib.NoDefault = no_default, thresh: int | lib.NoDefault = no_default, on_nested: bool = False, subset: IndexLabel | None = None, @@ -616,6 +619,118 @@ def dropna( meta=self._meta, ) + def sort_values( + self, + by: str | list[str], + npartitions: int | None = None, + ascending: bool | list[bool] = True, + na_position: Literal["first"] | Literal["last"] = "last", + partition_size: float = 128e6, + sort_function: Callable[[pd.DataFrame], pd.DataFrame] | None = None, + sort_function_kwargs: Mapping[str, Any] | None = None, + upsample: float = 1.0, + ignore_index: bool | None = False, + shuffle_method: str | None = None, + **options, + ) -> Self: # type: ignore[name-defined] # noqa: F821: + """ + Sort the dataset by a single column. + + Sorting a parallel dataset requires expensive shuffles and is generally + not recommended. See ‘set_index‘ for implementation details. + + Parameters: + ----------- + by: str or list[str] + Column(s) to sort by. + npartitions: int, None, or ‘auto’ + The ideal number of output partitions. If None, use the same as the + input. If ‘auto’ then decide by memory use. Not used when sorting + nested layers. + ascending: bool or list[bool], optional + Sort ascending vs. descending. Defaults to True. Specify list for + multiple sort orders. If this is a list of bools, must match the + length of the by. + na_position: {‘last’, ‘first’}, optional + Puts NaNs at the beginning if ‘first’, puts NaN at the end if + ‘last’. Defaults to ‘last’. + partition_size: float, optional + The desired size of each partition in bytes. Defaults to 128e6 + (128 MB). Not used in nested sorting. + sort_function: function, optional + Sorting function to use when sorting underlying partitions. If + None, defaults to M.sort_values (the partition library’s + implementation of sort_values). Not used when sorting nested + layers. + sort_function_kwargs: dict, optional + Additional keyword arguments to pass to the partition sorting + function. By default, by, ascending, and na_position are provided. + upsample: float, optional + Used to increase the number of samples for quantiles. Not used + in nested sorting + ignore_index: bool, optional + If True, the resulting axis will be labeled 0, 1, …, n - 1. + Defaults to False. + shuffle_method: str, optional + The method to use for shuffling data. Defaults to None. Not used + in nested sorting + **options: keyword arguments, optional + Additional options to pass to the sorting function. + Returns: + -------- + DataFrame + DataFrame with sorted values. + + """ + + # Resolve target layer + targets = [] + if isinstance(by, str): + by = [by] + # Check "by" columns for hierarchical references + for col in by: + if self._is_known_hierarchical_column(col): + targets.append(col.split(".")[0]) + else: + targets.append("base") + + # Ensure one target layer, preventing multi-layer operations + unq_targets = np.unique(targets).tolist() + if len(unq_targets) > 1: + raise ValueError("Queries cannot target multiple structs/layers, write a separate query for each") + target_layer = unq_targets[0] + + # Just use dask's sort_values if the target is the base layer + # Drops divisions, but this is expected behavior of a sorting operation + if target_layer == "base": + return super().sort_values( + by=by, + npartitions=npartitions, + ascending=ascending, + na_position=na_position, + partition_size=partition_size, + sort_function=sort_function, + sort_function_kwargs=sort_function_kwargs, + upsample=upsample, + ignore_index=ignore_index, + shuffle_method=shuffle_method, + **options, + ) + + # If nested target layer, go through nested-pandas API + # apply via map_partitions, meta is propagated + # does preserve divisions + return self.map_partitions( + lambda x: npd.NestedFrame(x).sort_values( + by=by, + ascending=ascending, + na_position=na_position, + ignore_index=ignore_index, + **options, + ), + meta=self._meta, + ) + def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame: """ Takes a function and applies it to each top-level row of the NestedFrame. diff --git a/tests/nested_dask/test_nestedframe.py b/tests/nested_dask/test_nestedframe.py index 5698b89..5c25284 100644 --- a/tests/nested_dask/test_nestedframe.py +++ b/tests/nested_dask/test_nestedframe.py @@ -299,6 +299,26 @@ def test_dropna(test_dataset_with_nans): assert len(flat_nested_nan_free) == len(flat_nested) - 1 +def test_sort_values(test_dataset): + """test the sort_values function""" + + # test sorting on base columns + sorted_base = test_dataset.sort_values(by="a") + assert sorted_base["a"].values.compute().tolist() == sorted(test_dataset["a"].values.compute().tolist()) + + # test sorting on nested columns + sorted_nested = test_dataset.sort_values(by="nested.flux", ascending=False) + assert sorted_nested.compute().iloc[0]["nested"]["flux"].values.tolist() == sorted( + test_dataset.compute().iloc[0]["nested"]["flux"].values.tolist(), + reverse=True, + ) + assert sorted_nested.known_divisions # Divisions should be known + + # Make sure we trigger multi-target exception + with pytest.raises(ValueError): + test_dataset.sort_values(by=["a", "nested.flux"]) + + def test_reduce(test_dataset): """test the reduce function"""