diff --git a/.gitignore b/.gitignore index d28e4e3ea8..fd5b1bf8cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ studies/**/sample-* +studies/named_axis.* docs/demos/countries.geojson docs/demos/test-program docs/demos/test-program.cpp diff --git a/docs/_toc.yml b/docs/_toc.yml index 2f4ff6663f..12c406ab15 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -4,12 +4,11 @@ title: "Awkward Array" defaults: titlesonly: True - subtrees: - entries: - file: getting-started/index subtrees: - - entries: + - entries: - file: getting-started/what-is-an-awkward-array - file: getting-started/10-minutes-to-awkward-array - file: getting-started/uproot-awkward-columnar-hats @@ -18,7 +17,7 @@ subtrees: - file: getting-started/papers-and-talks - file: user-guide/index subtrees: - - entries: + - entries: - file: user-guide/how-to-convert title: "Converting arrays" subtrees: @@ -74,6 +73,13 @@ subtrees: - file: user-guide/how-to-examine-checking-validity title: "Checking validity" + - file: user-guide/how-to-array-properties + title: "Array properties" + subtrees: + - entries: + - file: user-guide/how-to-array-properties-named-axis + title: "Named axes" + - file: user-guide/how-to-math title: "Numerical math" subtrees: diff --git a/docs/conf.py b/docs/conf.py index 37faac9e39..f6ea6f5e64 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -140,7 +140,7 @@ html_js_files = ["js/awkward.js"] # MyST settings -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "deflist"] nb_execution_mode = "cache" nb_execution_raise_on_error = True diff --git a/docs/user-guide/how-to-array-properties-named-axis.md b/docs/user-guide/how-to-array-properties-named-axis.md new file mode 100644 index 0000000000..9d5321b67f --- /dev/null +++ b/docs/user-guide/how-to-array-properties-named-axis.md @@ -0,0 +1,304 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.14.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +Named axes +========== + +Named axes are a feature in Awkward Array that allows you to give names to the axes of an array. +This can be useful for documentation, debugging, and for writing code that is more robust to changes in the structure of the data. +As argumented at [PyHEP.dev 2023](https://indico.cern.ch/event/1234156/) and by the Harvard NLP group in their ["Tensor Considered Harmful"](https://nlp.seas.harvard.edu/NamedTensor.html) write-up, named axes can be a powerful tool to make code more readable and less error-prone. + +Awkward array ensures that named axes are properly propagated to the result. +All highlevel, indexing, and broadcasting operations in awkward array support named axes. + +Other libraries that support named axes include: +- [hist](https://hist.readthedocs.io/en/latest/) +- [haliax](https://github.com/stanford-crfm/haliax) +- [Tensor Considered Harmful](https://nlp.seas.harvard.edu/NamedTensor.html) +- [PyTorch Named Tensors](https://pytorch.org/docs/stable/name_inference.html#name-inference-reference-doc) +- [Penzai Named Axis](https://penzai.readthedocs.io/en/stable/notebooks/named_axes.html) +- [xarray Named Axis](https://docs.xarray.dev/en/stable/user-guide/indexing.html#) + +Named axes in Awkward Array are inspired primarily by `hist` and `PyTorch Named Tensors`. + ++++ + +How to (de-)attach named axes? +------------------------- + +Named axes can be attached to an array using the high-level {func}`ak.with_named_axis` function. +Awkward Array allows strings as named axes and integers as positional axes. + +The `named_axis` argument of {func}`ak.with_named_axis` accepts either a `tuple` or `dict`: +- `tuple`: + - `named axis`: item + - `positional axis`: index of the item + - _additional_: `None` represents a wildcard for not specifying a name, e.g.: `("x", None)` means that the first axis is named "x" and the second is not named. +- `dict`: + - `named axis`: key + - `positional axis`: value + - _additional_: not specifying a name is not allowed, e.g.: `{"x": 0}` means that the first axis is named "x", all other existing dimensions are unnamed. The `dict` option also allows for renaming negative axes, e.g.: `{"x": -1}` means that the last axis is named "x". + + +```{code-cell} +import awkward as ak +import numpy as np +``` + +The axis names of an array can be attached through the constructor: +```{code-cell} +named_array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")) +# or +named_array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis={"x": 0, "y": 1}) +``` + +... or through `ak.with_named_axis`: +```{code-cell} +array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y")) +# or +named_array = ak.with_named_axis(array, named_axis={"x": 0, "y": 1}) +``` + +After attaching named axes, you can see the named axes comma-separated in the arrays representation and in `.show(named_axis=True)`: + +```{code-cell} +ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")) +``` + +```{code-cell} +ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")).show(named_axis=True) +``` + +Accessing the named axis mapping to positional axis can be done using the `named_axis` and `positional_axis` properties: + +```{code-cell} +named_array.named_axis +``` + +```{code-cell} +named_array.positional_axis +``` + +If you want to remove the named axes from an array, you can use the {func}`ak.without_named_axis` function: + +```{code-cell} +array = ak.without_named_axis(named_array) +array.named_axis +``` + + +Indexing with Named Axes +------------------------ + +Named axes can be used for indexing operations. +This is enabled throuhg a special syntax that allows you to index with a dictionary where keys refer to named (or positional) axes and the values to the slice or index. + +Simple examples: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +# named axes +named_array[{"x": 0}] # array[0, :, :] +named_array[{"z": 0}] # array[:, :, 0] + +named_array[{"x": 0, "y": 0}] # array[0, 0, :] +named_array[{"x": slice(0, 1), "y": 0}] # array[0:1, 0, :] + +named_array[named_array > 3] # array[array > 3] + + +# positional axes +named_array[{0: 0}] # array[0, :, :] +named_array[{2: 0}] # array[:, :, 0] + +named_array[{-3: 0}] # array[0, :, :] +named_array[{-1: 0}] # array[:, :, 0] +None +``` + +If multiple keys that point to the same positional axis are used, the last key will be used and all others will be ignored: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +assert ak.all(named_array[{0: 0, "x": slice(0, 2)}] == named_array[0:2]) +assert ak.all(named_array[{"x": slice(0, 2), 0: 0}] == named_array[0]) +``` + + +More detailed example: + +```{code-cell} +# create a Record Array that represents four events with a variable number of jets +events = ak.zip({ + "event_no": np.arange(4), + "jetpt": ak.Array([[50, 60], [45], [], [80, 30, 50]]), +}) +named_events = ak.with_named_axis(events, ("events", "jets")) + +print("classic indexing:", named_events[0, 0:1]) +print("named indexing :", named_events[{"events": 0, "jets": slice(0, 1)}]) +``` + +For syntatic suger, use `np.s_` to define slices more easily: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +assert ak.all(named_array[{"x": np.s_[0:2]}] == named_array[{"x": slice(0, 2)}]) +``` + +Highlevel Operations with Named Axes +------------------------------------ + +Named axes can be used for specifying the axis of a highlevel operation given that the operation is performed on an array that supports this named axis. + +For example, the `ak.sum` operation can be performed on an array with named axes: + +```{code-cell} +array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + +print("Sum over axis 'x':", ak.sum(named_array, axis="x")) # ak.sum(array, axis=0) +print("Sum over axis 'y':", ak.sum(named_array, axis="y")) # ak.sum(array, axis=1) +print("Sum over axis 'z':", ak.sum(named_array, axis="z")) # ak.sum(array, axis=2) +``` + + +Named Axes Propagation Strategies +--------------------------------- + + +Named axes are propagated through all operations in Awkward Array. +For this, specific strategies are defined for each operation to ensure that the named axes are properly propagated to the result. + +The possible strategies are: +- `keep all`: keep all named axes +- `keep one`: keep one named axis +- `keep up to`: keep all named axes up to a certain positional axis +- `remove all`: remove all named axis +- `remove one`: remove one named axis +- `add one`: add a new axis +- `unify`: unify named axes of two arrays. The named axes are unifiable if the have the same name (or `None`) and point to the same positional axis. + +Indexing operations +: The following table shows the strategy for indexing operations: + +| Operation | Strategy | +|----------------------|--------------| +| `array[:]` | `keep all` | +| `array[...]` | `keep all` | +| `array[()]` | `keep all` | +| `array[0:1]` | `keep all` | +| `array[[0, 1]]` | `keep all` | +| `array[array % 2]` | `keep all` | +| `array[0]` | `remove one` | +| `array[np.array(0)]` | `remove one` | +| `array[None]` | `add one` | +| `array[np.newaxis]` | `add one` | + +Universal functions (`ufuncs`) +: `ufuncs` with single argument signatures (i.e. unary operations, such as `__abs__`, `__neg__`, `__invert__`, ...) do not modify named axes (strategy: `keep all`). +: `ufuncs` with two argument signatures (i.e. binary operations, such as `__add__`, `__sub__`, `__mul__`, ...) try to merge named axis of the given arrays (strategy: `unify`). + This means that the named axes of the two arrays are merged if they have the same name (or either is `None`) and point to the same positional axis. + If there's a mismatch of named axes, e.g., the same named axis has different names or point to different positional axes, an exception is raised. + +```{code-cell} +array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) +named_array = ak.with_named_axis(array, named_axis=("x", "y")) + +# unary operations with named axes +assert (-named_array).named_axis == {"x": 0, "y": 1} +assert (+named_array).named_axis == {"x": 0, "y": 1} +assert (~named_array).named_axis == {"x": 0, "y": 1} +assert abs(named_array).named_axis == {"x": 0, "y": 1} + +# binary operations with named axes +named_array1 = ak.with_named_axis(array, named_axis=(None, "y")) +named_array2 = ak.with_named_axis(array, named_axis=("x", None)) +named_array3 = ak.with_named_axis(array, named_axis=("x", "y")) + +assert (array + array).named_axis == {} +assert (named_array1 + array).named_axis == {"y": 1} +assert (named_array2 + array).named_axis == {"x": 0} +assert (named_array3 + array).named_axis == {"x": 0, "y": 1} + +assert (named_array1 + named_array2).named_axis == {"x": 0, "y": 1} +assert (named_array3 + named_array3).named_axis == {"x": 0, "y": 1} +``` + +Reducers (`ak.sum`, `ak.any`, ...) +: If `axis=int` and `keepdims=False` (typical use-case) removes the named axis that is reduced (strategy: `remove one`). +: If `keepdims=True` is set, the named axis is kept (strategy: `keep all`). +: If `axis=None` is set, all named axes are removed (strategy: `remove all`). + +```{code-cell} +array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) +named_array = ak.with_named_axis(array, ("x", "y")) + +assert ak.sum(named_array, axis="x", keepdims=False).named_axis == {"y": 0} +assert ak.sum(named_array, axis="x", keepdims=True).named_axis == {"x": 0, "y": 1} +``` + +--- +A full list of operations and their strategies can be found in the following table. +If an operation is not listed, the strategy is either `keep all` or automatically inferred from the below listed operations. + + +| Operation | Strategy | +|-----------------------------------------------------|--------------------| +| `ak.all(..., axis=None)` | `remove all` | +| `ak.all(..., axis=int, keepdims=False)` | `remove one` | +| `ak.all(..., axis=int, keepdims=True)` | `keep all` | +| `ak.any(..., axis=None)` | `remove all` | +| `ak.any(..., axis=int, keepdims=False)` | `remove one` | +| `ak.any(..., axis=int, keepdims=True)` | `keep all` | +| `ak.[arg]cartesian` | `unify` | +| `ak.[arg]combinations` | `keep all` | +| `ak.[arg]max(..., axis=None)` | `remove all` | +| `ak.[arg]max(..., axis=int, keepdims=False)` | `remove one` | +| `ak.[arg]max(..., axis=int, keepdims=True)` | `keep all` | +| `ak.[arg]min(..., axis=None)` | `remove all` | +| `ak.[arg]min(..., axis=int, keepdims=False)` | `remove one` | +| `ak.[arg]min(..., axis=int, keepdims=True)` | `keep all` | +| `ak.[arg]sort` | `keep all` | +| `ak.broadcast_arrays` | `unify`, `add one` | +| `ak.broadcast_fields` | `unify`, `add one` | +| `ak.categories` | `remove all` | +| `ak.concatenate` | `unify` | +| `ak.count[_nonzero](..., axis=None)` | `remove all` | +| `ak.count[_nonzero](..., axis=int, keepdims=False)` | `remove one` | +| `ak.count[_nonzero](..., axis=int, keepdims=True)` | `keep all` | +| `ak.firsts` | `remove one` | +| `ak.flatten(..., axis=None)` | `remove all` | +| `ak.flatten(..., axis=0)` | `keep all` | +| `ak.flatten(..., axis=(!=0), keepdims=True)` | `remove one` | +| `ak.local_index` | `keep up to` | +| `ak.num` | `keep one` | +| `ak.prod(..., axis=None)` | `remove all` | +| `ak.prod(..., axis=int, keepdims=False)` | `remove one` | +| `ak.prod(..., axis=int, keepdims=True)` | `keep all` | +| `ak.ravel` | `remove all` | +| `ak.singletons` | `add one` | +| `ak.sum(..., axis=None)` | `remove all` | +| `ak.sum(..., axis=int, keepdims=False)` | `remove one` | +| `ak.sum(..., axis=int, keepdims=True)` | `keep all` | +| `ak.unflatten` | `remove all` | +| `ak.where` | `unify`, `add one` | +| `ak.with_field` | `unify`, `add one` | +| `ak.zip` | `unify`, `add one` | diff --git a/docs/user-guide/how-to-array-properties.md b/docs/user-guide/how-to-array-properties.md new file mode 100644 index 0000000000..be811e888e --- /dev/null +++ b/docs/user-guide/how-to-array-properties.md @@ -0,0 +1,23 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.10.3 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +Array properties +================ + +The user guide is a collection of "how to..." guides for common tasks. See the left side-bar (or bring it into view by clicking on the upper-left `≡`) to access the guides, grouped by topic. + +If you're looking for documentation on a specific function, see the API reference instead. + +You can test any examples in a new window/tab by clicking on [![Try It! ⭷](https://img.shields.io/badge/-Try%20It%21%20%E2%86%97-orange?style=for-the-badge)](https://awkward-array.org/doc/main/_static/try-it.html). + +




diff --git a/pyproject.toml b/pyproject.toml index f5d1424359..f10e745671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -232,6 +232,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ 'awkward._nplikes.*', + 'awkward._namedaxis', 'awkward._behavior.*', 'awkward._backends.*', 'awkward._meta.*', diff --git a/src/awkward/__init__.py b/src/awkward/__init__.py index c82e83777f..c84b655ccb 100644 --- a/src/awkward/__init__.py +++ b/src/awkward/__init__.py @@ -24,6 +24,7 @@ import awkward._errors import awkward._lookup import awkward._ext # strictly for unpickling from Awkward 1 +import awkward._namedaxis # third-party connectors import awkward._connect.numpy diff --git a/src/awkward/_broadcasting.py b/src/awkward/_broadcasting.py index 7eb2300372..8dab0af30e 100644 --- a/src/awkward/_broadcasting.py +++ b/src/awkward/_broadcasting.py @@ -11,6 +11,11 @@ import awkward as ak from awkward._backends.backend import Backend from awkward._backends.dispatch import backend_of +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + _add_named_axis, + _unify_named_axis, +) from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import ShapeItem, unknown_length @@ -319,10 +324,18 @@ def is_string_like(obj) -> bool: } -def left_broadcast_to(content: Content, depth: int) -> Content: - for _ in range(content.purelist_depth, depth): - content = RegularArray(content, 1, content.length) - return content +def _export_named_axis_from_depth_to_lateral( + idx: int, + depth_context: dict[str, Any], + lateral_context: dict[str, Any], +) -> None: + # set adjusted named axes to lateral (inplace) + named_axis, ndim = depth_context[NAMED_AXIS_KEY][idx] + seen_named_axis, _ = lateral_context[NAMED_AXIS_KEY][idx] + lateral_context[NAMED_AXIS_KEY][idx] = ( + _unify_named_axis(named_axis, seen_named_axis), + ndim, + ) def broadcast_regular_dim_size(contents: Sequence[ak.contents.Content]) -> ShapeItem: @@ -433,10 +446,32 @@ def apply_step( max_depth = max(x.purelist_depth for x in contents) if max_depth > 0 and all(x.purelist_isregular for x in contents): - nextinputs = [ - left_broadcast_to(o, max_depth) if isinstance(o, Content) else o - for o in inputs - ] + nextinputs = [] + + named_axes_with_ndims = depth_context[NAMED_AXIS_KEY] + seen_named_axes = lateral_context[NAMED_AXIS_KEY] + for i, ((named_axis, ndim), o) in enumerate( + zip(named_axes_with_ndims, inputs) + ): + if isinstance(o, Content): + # rightbroadcast + for _ in range(o.purelist_depth, max_depth): + o = RegularArray(o, 1, o.length) + # track new dimensions for named axis + # rightbroadcasting adds a new first(!) dimension at depth + seen_named_axis, seen_ndim = seen_named_axes[i] + named_axis = _add_named_axis(named_axis, depth, ndim) + depth_context[NAMED_AXIS_KEY][i] = ( + _unify_named_axis(named_axis, seen_named_axis), + ndim + 1, + ) + if o.is_leaf: + _export_named_axis_from_depth_to_lateral( + i, depth_context, lateral_context + ) + nextinputs.append(o) + else: + nextinputs.append(o) # Did a broadcast take place? if any(x is not y for x, y in zip(inputs, nextinputs)): return apply_step( @@ -538,6 +573,7 @@ def broadcast_any_list(): # Under the category of "is_list", we have both strings and non-strings # The strings should behave like non-lists within these routines. + named_axes_with_ndims = depth_context[NAMED_AXIS_KEY] # Are the non-string list types exclusively regular? if all(x.is_regular or (is_string_like(x) or not x.is_list) for x in contents): # Compute the expected dim size @@ -586,7 +622,9 @@ def broadcast_any_list(): # we don't left-broadcast nextinputs = [] nextparameters = [] - for x, x_is_string in zip(inputs, inputs_are_strings): + for i, ((named_axis, ndim), x, x_is_string) in enumerate( + zip(named_axes_with_ndims, inputs, inputs_are_strings) + ): if isinstance(x, RegularArray) and not x_is_string: content_size_maybe_one = ( x.size is not unknown_length and x.size == 1 @@ -603,6 +641,16 @@ def broadcast_any_list(): ) ) nextparameters.append(x._parameters) + # track new dimensions for named axis + # rightbroadcasting adds a new first(!) dimension as depth + depth_context[NAMED_AXIS_KEY][i] = ( + _add_named_axis(named_axis, depth, ndim), + ndim + 1, + ) + if x.is_leaf: + _export_named_axis_from_depth_to_lateral( + i, depth_context, lateral_context + ) # Any unknown values or sizes are assumed to be correct as-is elif ( dim_size is unknown_length @@ -667,7 +715,9 @@ def broadcast_any_list(): nextinputs = [] nextparameters = [] - for x, x_is_string in zip(inputs, input_is_string): + for i, ((named_axis, ndim), x, x_is_string) in enumerate( + zip(named_axes_with_ndims, inputs, input_is_string) + ): if isinstance(x, listtypes) and not x_is_string: next_content = broadcast_to_offsets_avoiding_carry(x, offsets) nextinputs.append(next_content) @@ -680,6 +730,16 @@ def broadcast_any_list(): .content ) nextparameters.append(NO_PARAMETERS) + # track new dimensions for named axis + # leftbroadcasting adds a new last dimension at depth + 1 + depth_context[NAMED_AXIS_KEY][i] = ( + _add_named_axis(named_axis, depth + 1, ndim), + ndim + 1, + ) + if x.is_leaf: + _export_named_axis_from_depth_to_lateral( + i, depth_context, lateral_context + ) else: nextinputs.append(x) nextparameters.append(NO_PARAMETERS) @@ -889,7 +949,7 @@ def action_logical_or(inputs, backend, **kwargs): (xy_mask, cond_mask), action_logical_or, 0, - None, + depth_context, lateral_context, simple_options, )[0] @@ -917,7 +977,7 @@ def apply_mask_action(inputs, backend, **kwargs): (xy_unmasked, mask), apply_mask_action, 0, - None, + depth_context, lateral_context, simple_options, ) diff --git a/src/awkward/_connect/numexpr.py b/src/awkward/_connect/numexpr.py index 85ab566c8c..50c3e0485f 100644 --- a/src/awkward/_connect/numexpr.py +++ b/src/awkward/_connect/numexpr.py @@ -4,12 +4,15 @@ import sys import warnings +from functools import reduce from packaging.version import parse as parse_version import awkward as ak -from awkward._behavior import behavior_of +from awkward._attrs import attrs_of_obj +from awkward._behavior import behavior_of, behavior_of_obj from awkward._layout import wrap_layout +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis _has_checked_version = False @@ -110,9 +113,26 @@ def action(inputs, **ignore): return None behavior = behavior_of(*arrays) - out = ak._broadcasting.broadcast_and_apply(arrays, action, allow_records=False) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arguments) + out = ak._broadcasting.broadcast_and_apply( + arrays, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=False, + ) assert isinstance(out, tuple) and len(out) == 1 - return wrap_layout(out[0], behavior) + wrapped = wrap_layout(out[0], behavior) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) evaluate.evaluate = evaluate @@ -148,6 +168,24 @@ def action(inputs, **ignore): return None behavior = behavior_of(*arrays) - out = ak._broadcasting.broadcast_and_apply(arrays, action, allow_records=False) + + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arguments) + out = ak._broadcasting.broadcast_and_apply( + arrays, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=False, + ) assert isinstance(out, tuple) and len(out) == 1 - return wrap_layout(out[0], behavior) + wrapped = wrap_layout(out[0], behavior) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) diff --git a/src/awkward/_connect/numpy.py b/src/awkward/_connect/numpy.py index f17ee98b36..7f5a7cdb08 100644 --- a/src/awkward/_connect/numpy.py +++ b/src/awkward/_connect/numpy.py @@ -22,6 +22,7 @@ ) from awkward._categorical import as_hashable from awkward._layout import wrap_layout +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes import to_nplike from awkward._parameters import parameters_intersect from awkward._regularize import is_non_string_like_iterable @@ -363,6 +364,8 @@ def array_ufunc(ufunc, method: str, inputs, kwargs: dict[str, Any]): attrs = attrs_of(*inputs) backend = backend_of(*inputs, coerce_to_common=True) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(inputs) + inputs = _array_ufunc_custom_cast(inputs, behavior, backend) def action(inputs, **ignore): @@ -464,13 +467,40 @@ def action(inputs, **ignore): return None out = ak._broadcasting.broadcast_and_apply( - inputs, action, allow_records=False, function_name=ufunc.__name__ + inputs, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=False, + function_name=ufunc.__name__, ) + out_named_axis = functools.reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) if len(out) == 1: - return wrap_layout(out[0], behavior=behavior, attrs=attrs) + wrapped = wrap_layout(out[0], behavior=behavior, attrs=attrs) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=None, + attrs=None, + ) else: - return tuple(wrap_layout(o, behavior=behavior, attrs=attrs) for o in out) + wrapped_out = [] + for o in out: + wrapped = wrap_layout(o, behavior=behavior, attrs=attrs) + wrapped_out.append( + ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=True, + behavior=None, + attrs=None, + ) + ) + return tuple(wrapped_out) def action_for_matmul(inputs): diff --git a/src/awkward/_layout.py b/src/awkward/_layout.py index 11cb4bcbe5..f642dd6ed4 100644 --- a/src/awkward/_layout.py +++ b/src/awkward/_layout.py @@ -56,9 +56,7 @@ def merge_mappings( class HighLevelContext: - def __init__( - self, behavior: Mapping | None = None, attrs: Mapping[str, Any] | None = None - ): + def __init__(self, behavior: Mapping | None = None, attrs: Mapping | None = None): self._behavior = behavior self._attrs = attrs self._is_finalized = False @@ -66,6 +64,22 @@ def __init__( self._attrs_from_objects = [] self._behavior_from_objects = [] + def with_attr(self, key, value) -> Self: + self._ensure_finalized() + return type(self)( + behavior=self.behavior, + attrs={**self.attrs, key: value}, + ).finalize() + + def without_attr(self, key) -> Self: + self._ensure_finalized() + attrs = dict(self.attrs) + attrs.pop(key, None) + return type(self)( + behavior=self.behavior, + attrs=attrs, + ).finalize() + def __enter__(self): return self @@ -81,8 +95,10 @@ def _ensure_not_finalized(self): raise RuntimeError("HighLevelContext has already been finalized") @property - def attrs(self) -> Mapping[str, Any] | None: + def attrs(self) -> Mapping: self._ensure_finalized() + if self._attrs is None: + self._attrs = {} return self._attrs @property @@ -154,7 +170,11 @@ def unwrap( ) def wrap( - self, obj: Any, *, highlevel: bool = True, allow_other: bool = False + self, + obj: Any, + *, + highlevel: bool = True, + allow_other: bool = False, ) -> Any: self._ensure_finalized() @@ -230,7 +250,7 @@ def maybe_highlevel_to_lowlevel(obj): Args: obj: an object - Calls #ak.to_layout and returns the result iff. the object is a high-level + Calls #ak.to_layout and returns the result if the object is a high-level Awkward object, otherwise the object is returned as-is. This function should be removed once scalars are properly handled by `to_layout`. @@ -372,6 +392,34 @@ def attach(x): return layout +def _neg2pos_axis( + axis: int, + total: int, +) -> int: + """ + Converts a negative axis index to a positive one. + + This function takes a negative axis index and the total number of axes and returns the corresponding positive axis index. + If the input axis index is already positive, it is returned as is. + + Args: + axis (int): The axis index to convert. Can be negative. + total (int): The total number of axes. + + Returns: + int: The positive axis index corresponding to the input axis index. + + Examples: + >>> _neg2pos_axis(-1, 3) + 2 + >>> _neg2pos_axis(1, 3) + 1 + """ + if axis < 0: + return total + axis + return axis + + def maybe_posaxis(layout: Content, axis: int, depth: int) -> int | None: from awkward.record import Record @@ -386,6 +434,6 @@ def maybe_posaxis(layout: Content, axis: int, depth: int) -> int | None: else: is_branching, additional_depth = layout.branch_depth if not is_branching: - return axis + depth + additional_depth - 1 + return _neg2pos_axis(axis, additional_depth) + depth - 1 else: return None diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py new file mode 100644 index 0000000000..9f0b8d36cc --- /dev/null +++ b/src/awkward/_namedaxis.py @@ -0,0 +1,764 @@ +from __future__ import annotations + +import json +import re +from dataclasses import dataclass + +import awkward._typing as tp +from awkward._layout import _neg2pos_axis +from awkward._regularize import is_integer + +# axis names are hashables, mostly strings, +# except for integers, which are reserved for positional axis. +AxisName: tp.TypeAlias = tp.Hashable + +# e.g.: {"x": 0, "y": 1, "z": 2} +AxisMapping: tp.TypeAlias = tp.Mapping[AxisName, int] + +# e.g.: ("x", "y", None) where None is a wildcard +AxisTuple: tp.TypeAlias = tp.Tuple[AxisName, ...] + + +NAMED_AXIS_KEY: tp.Literal["__named_axis__"] = ( + "__named_axis__" # reserved for named axis +) + + +# just a class for inplace mutation +class NamedAxis: + mapping: AxisMapping + + +NamedAxis.mapping = {} + + +def _prettify_named_axes( + named_axis: AxisMapping, + delimiter: str = ", ", + maxlen: None | int = None, +) -> str: + """ + This function takes a named axis mapping and returns a string representation of the mapping. + The axis names are sorted in ascending order of their corresponding integer values. + If the axis name is a valid Python identifier, it is represented as is. + Otherwise, it is represented as a JSON string. + + Args: + named_axis (AxisMapping): The named axis mapping to prettify. + delimiter (str, optional): The delimiter to use between items in the output string. Defaults to ", ". + maxlen (None | int, optional): The maximum length of the output string. If the string exceeds this length, it is truncated and ends with "...". Defaults to None. + + Returns: + str: The prettified string representation of the named axis mapping. + + Examples: + >>> _prettify_named_axes({"x": 0, "y": 1, "z": 2}) + 'x:0, y:1, z:2' + >>> _prettify_named_axes({"x": 0, "y": 1, "$": 2}) + 'x:0, y:1, "$":2' + >>> _prettify_named_axes({"x": 0, "y": 1, "z": 2}, delimiter="; ") + 'x:0; y:1; z:2' + >>> _prettify_named_axes({"foo": 0, "bar": 1, "baz": 2}, maxlen=17) + 'foo:0, bar:1, ...' + """ + + def _prettify(ax: AxisName) -> str: + repr_ax = str(ax) + if re.match("[A-Za-z_][A-Za-z_0-9]*", repr_ax): + return repr_ax + return json.dumps(repr_ax) + + sorted_named_axis = sorted(named_axis.items(), key=lambda x: x[1]) + items = [ + f"{_prettify(named_ax)}:{pos_ax}" for named_ax, pos_ax in sorted_named_axis + ] + if maxlen is not None: + if len(delimiter.join(items)) > maxlen: + while ( + len(delimiter.join(items)) > maxlen - len(delimiter + "...") + ) and items: + items.pop(-1) + items.append("...") + return delimiter.join(items) + + +def _get_named_axis(ctx: tp.Any) -> AxisMapping: + """ + Retrieves the named axis from the provided context. + + Args: + ctx (Any): The context from which the named axis is to be retrieved. + + Returns: + AxisMapping: The named axis retrieved from the context. If the context does not include a named axis, + an empty dictionary is returned. + + Examples: + >>> _get_named_axis(ak.Array([1, 2, 3], named_axis={"x": 0})) + {"x": 0} + >>> _get_named_axis(np.array([1, 2, 3])) + {} + >>> _get_named_axis({NAMED_AXIS_KEY: {"x": 0, "y": 1, "z": 2}}) + {"x": 0, "y": 1, "z": 2} + >>> _get_named_axis({"other_key": "other_value"}) + {} + """ + if hasattr(ctx, "attrs"): + return _get_named_axis(ctx.attrs) + elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx: + return dict(ctx[NAMED_AXIS_KEY]) + else: + return {} + + +def _make_positional_axis_tuple(n: int) -> tuple[int, ...]: + """ + Generates a positional axis tuple of length n. + + Args: + n (int): The length of the positional axis tuple to generate. + + Returns: + tuple[int, ...]: The generated positional axis tuple. + + Examples: + >>> _make_positional_axis_tuple(3) + (0, 1, 2) + """ + return tuple(range(n)) + + +def _is_valid_named_axis(axis: AxisName) -> bool: + """ + Checks if the given axis is a valid named axis. A valid named axis is a hashable object that is not an integer or None. Currently it is restricted to strings. + + Args: + axis (AxisName): The axis to check. + + Returns: + bool: True if the axis is a valid named axis, False otherwise. + + Examples: + >>> _is_valid_named_axis("x") + True + >>> _is_valid_named_axis(1) + False + """ + return ( + # axis must be hashable + isinstance(axis, AxisName) + # ... but not an integer, otherwise we would confuse it with positional axis + and not is_integer(axis) + # we also prohibit None, which is reserved for wildcard + and axis is not None + # Let's only allow strings for now, in the future we can open up to more types + # by removing the isinstance(axis, str) check. + and isinstance(axis, str) + ) + + +def _check_valid_axis(axis: AxisName) -> AxisName: + """ + Checks if the given axis is a valid named axis. If not, raises a ValueError. + + Args: + axis (AxisName): The axis to check. + + Returns: + AxisName: The axis if it is a valid named axis. + + Raises: + ValueError: If the axis is not a valid named axis. + + Examples: + >>> _check_valid_axis("x") + "x" + >>> _check_valid_axis(1) + Traceback (most recent call last): + ... + ValueError: Axis names must be hashable and not int, got 1 [type(axis)=] + """ + if not _is_valid_named_axis(axis): + raise ValueError( + f"Axis names must be hashable and not int, got {axis!r} [{type(axis)=}]" + ) + return axis + + +def _check_valid_named_axis_mapping(named_axis: AxisMapping) -> AxisMapping: + """ + Checks if the given named axis mapping is valid. A valid named axis mapping is a dictionary where the keys are valid named axes + (hashable objects that are not integers) and the values are integers. + + Args: + named_axis (AxisMapping): The named axis mapping to check. + + Raises: + ValueError: If any of the keys in the named axis mapping is not a valid named axis or if any of the values is not an integer. + + Examples: + >>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": 2}) # No exception is raised + >>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": "2"}) + Traceback (most recent call last): + ... + ValueError: Named axis mapping values must be integers, got '2' [type(axis)=] + >>> _check_valid_named_axis_mapping({"x": 0, 1: 1, "z": 2}) + Traceback (most recent call last): + ... + ValueError: Axis names must be hashable and not int, got 1 [type(axis)=] + """ + for name, axis in named_axis.items(): + _check_valid_axis(name) + if not is_integer(axis): + raise ValueError( + f"Named axis mapping values must be integers, got {axis!r} [{type(axis)=}]" + ) + return named_axis + + +def _axis_tuple_to_mapping(axis_tuple: AxisTuple) -> AxisMapping: + """ + Converts a tuple of axis names to a dictionary mapping axis names to their positions. + + Args: + axis_tuple (AxisTuple): A tuple of axis names. Can include None as a wildcard. + + Returns: + AxisMapping: A dictionary mapping axis names to their positions. + + Examples: + >>> _axis_tuple_to_mapping(("x", None, "y")) + {"x": 0, "y": 2} + """ + return {axis: i for i, axis in enumerate(axis_tuple) if axis is not None} + + +def _prepare_named_axis_for_attrs( + named_axis: AxisMapping | AxisTuple, + ndim: int, +) -> AxisMapping: + """ + Prepares the named axis for attribute assignment. + + This function takes a named axis, which can either be a mapping or a tuple, and returns a dictionary mapping axis names to their positions. + The function checks if the named axis is valid and if the positional axes match the number of dimensions. If not, an error is raised. + + Args: + named_axis (AxisMapping | AxisTuple): The named axis to prepare. Can either be a mapping or a tuple. + ndim (int): The number of dimensions. + + Returns: + AxisMapping: The prepared named axis. + + Raises: + TypeError: If the named axis is not a mapping or a tuple. + ValueError: If the named axes do not point to positional axes matching the number of dimensions. + + Examples: + >>> _prepare_named_axis_for_attrs({"x": 0, "y": 1, "z": 2}, 3) + {"x": 0, "y": 1, "z": 2} + >>> _prepare_named_axis_for_attrs(("x", "y", "z"), 3) + {"x": 0, "y": 1, "z": 2} + >>> _prepare_named_axis_for_attrs({"x": 0, "y": 1, "z": 2}, 2) + Traceback (most recent call last): + ... + ValueError: Named axes must point to positional axes matching 2 dimensions, got named_axis={"x": 0, "y": 1, "z": 2}, ndim=2 + """ + if isinstance(named_axis, tuple): + _named_axis = _axis_tuple_to_mapping(named_axis) + elif isinstance(named_axis, dict): + _named_axis = named_axis + else: + raise TypeError( + f"named_axis must be a mapping or a tuple, got {named_axis=} [{type(named_axis)=}]" + ) + _check_valid_named_axis_mapping(_named_axis) + pos_axes = set(_named_axis.values()) + if max(pos_axes, default=0) >= ndim or min(pos_axes, default=0) < -ndim: + raise ValueError( + f"Named axes must point to positional axes matching {ndim} dimensions, got {named_axis=}, {ndim=}" + ) + return _named_axis + + +def _make_named_int_class(name: tp.Any) -> type[int]: + class NamedInt(int): + def __repr__(self): + value_repr = super().__repr__() + return f"{name!r} (named axis) -> {value_repr} (pos. axis)" + + __str__ = __repr__ + + return NamedInt + + +def _named_axis_to_positional_axis( + named_axis: AxisMapping, + axis: AxisName, +) -> int | None: + """ + Converts a single named axis to a positional axis. + + Args: + axis (AxisName): The named axis to convert. + named_axis (AxisMapping): The mapping from named axes to positional axes. + + Returns: + int | None: The positional axis corresponding to the given named axis. If the named axis is not found in the mapping, returns None. + + Raises: + ValueError: If the named axis is not found in the named axis mapping. + + Examples: + >>> _named_axis_to_positional_axis({"x": 0, "y": 1, "z": 2}, "x") + 0 + """ + if _is_valid_named_axis(axis): + if axis not in named_axis: + raise ValueError(f"{axis=} not found in {named_axis=} mapping.") + + # we wrap it to preserve the original name in its __repr__ and __str__ + # in order to properly display it in error messages. This is useful for cases + # where the positional axis is pointing to a non-existing axis. The error message + # will then show the original (named) axis together with the positional axis. + cls = _make_named_int_class(axis) + return cls(named_axis[axis]) + + if is_integer(axis): + # TODO: is_integer is an external helper function that doesn't specify types + return int(tp.cast(tp.Any, axis)) + elif axis is None: + return None + else: + raise ValueError(f"Invalid {axis=} [{type(axis)=}]") + + +# These are the strategies to handle named axis for the +# output array when performing operations along an axis. +# See studies/named_axis.md#named-axis-in-high-level-functions and +# https://pytorch.org/docs/stable/name_inference.html. +# +# The possible strategies are: +# - "keep all" (_keep_named_axis(..., None)): Keep all named axes in the output array, e.g.: `ak.drop_none` +# - "keep one" (_keep_named_axis(..., int)): Keep one named axes in the output array, e.g.: `ak.firsts` +# - "keep up to" (_keep_named_axis_up_to(..., int)): Keep all named axes up to a certain positional axis in the output array, e.g.: `ak.local_index` +# - "remove all" (_remove_all_named_axis): Removes all named axis, e.g.: `ak.categories` +# - "remove one" (_remove_named_axis): Remove the named axis from the output array, e.g.: `ak.sum` +# - "add one" (_add_named_axis): Add a new named axis to the output array, e.g.: `ak.concatenate` +# - "unify" (_unify_named_axis): Unify the named axis in the output array given two input arrays, e.g.: `ak.broadcast_arrays` + + +def _keep_named_axis( + named_axis: AxisMapping, + axis: int | None = None, +) -> AxisMapping: + """ + Determines the new named axis after keeping the specified axis. This function is useful when an operation + is applied that retains only one axis. + + Args: + named_axis (AxisMapping): The current named axis. + axis (int | None, optional): The index of the axis to keep. If None, all axes are kept. Default is None. + + Returns: + AxisMapping: The new named axis after keeping the specified axis. + + Examples: + >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"y": 0} + >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, None) + {"x": 0, "y": 1, "z": 2} + """ + if axis is None: + return named_axis + return {k: 0 for k, v in named_axis.items() if v == axis} + + +def _keep_named_axis_up_to( + named_axis: AxisMapping, + axis: int, + total: int, +) -> AxisMapping: + """ + Determines the new named axis after keeping all axes up to the specified axis. This function is useful when an operation + is applied that retains all axes up to a certain axis. + + Args: + named_axis (AxisMapping): The current named axis. + axis (int): The index of the axis up to which to keep. + total (int): The total number of axes. + + Returns: + AxisMapping: The new named axis after keeping all axes up to the specified axis. + + Examples: + >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 1, 3) + {"x": 0, "y": 1} + >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, -1, 3) + {"x": 0, "y": 1, "z": 2} + >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 0, 3) + {"x": 0} + """ + axis = _neg2pos_axis(axis, total) + out = {} + for k, v in named_axis.items(): + if v >= 0 and v <= axis: + out[k] = v + elif v < 0 and v >= -axis - 1: + out[k] = v + return out + + +def _remove_all_named_axis( + named_axis: AxisMapping, +) -> AxisMapping: + """ + Returns an empty named axis mapping after removing all axes from the given named axis mapping. + This function is typically used when an operation that eliminates all axes is applied. + + Args: + named_axis (AxisMapping): The current named axis mapping. + + Returns: + AxisMapping: An empty named axis mapping. + + Examples: + >>> _remove_all_named_axis({"x": 0, "y": 1, "z": 2}) + {} + """ + return _remove_named_axis(named_axis=named_axis, axis=None) + + +def _remove_named_axis( + named_axis: AxisMapping, + axis: int | None = None, + total: int | None = None, +) -> AxisMapping: + """ + Determines the new named axis after removing the specified axis. This is useful, for example, + when applying an operation that removes one axis. + + Args: + named_axis (AxisMapping): The current named axis. + axis (int | None, optional): The index of the axis to remove. If None, no axes are removed. Default is None. + total (int | None, optional): The total number of axes. If None, it is calculated as the length of the named axis. Default is None. + + Returns: + AxisMapping: The new named axis after removing the specified axis. + + Examples: + >>> _remove_named_axis({"x": 0, "y": 1}, None) + {} + >>> _remove_named_axis({"x": 0, "y": 1}, 0) + {"y": 0} + >>> _remove_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"x": 0, "z": 1} + >>> _remove_named_axis({"x": 0, "y": 1, "z": -1}, 1) + {"x": 0, "z": -1} + >>> _remove_named_axis({"x": 0, "y": 1, "z": -3}, 1) + {"x": 0, "z": -2} + """ + if axis is None: + return {} + + if total is None: + total = len(named_axis) + + # remove the specified axis + out = { + ax: pos + for ax, pos in named_axis.items() + if _neg2pos_axis(pos, total) != _neg2pos_axis(axis, total) + } + + return _adjust_pos_axis(out, axis, total, direction=-1) + + +def _adjust_pos_axis( + named_axis: AxisMapping, + axis: int, + total: int, + direction: int, +) -> AxisMapping: + """ + Adjusts the positions of the axes in the named axis mapping after an axis has been removed or added. + + Args: + named_axis (AxisMapping): The current named axis mapping. + axis (int): The position of the removed/added axis. + total (int): The total number of axes. + direction (int): The direction of the adjustment. -1 means axis is removed; +1 means axis is added. Default is +1. + + Returns: + AxisMapping: The adjusted named axis mapping. + + Examples: + # axis=1 removed + >>> _adjust_pos_axis({"x": 0, "z": 2}, 1, 3, -1) + {"x": 0, "z": 1} + # axis=1 added + >>> _adjust_pos_axis({"x": 0, "z": 2}, 1, 3, +1) + {"x": 0, "z": 3} + # axis=1 removed + >>> _adjust_pos_axis({"x": 0, "z": -1}, 1, 3, -1) + {"x": 0, "z": -1} + # axis=1 added + >>> _adjust_pos_axis({"x": 0, "z": -1}, 1, 3, +1) + {"x": 0, "z": -1} + """ + assert direction in (-1, +1), f"Invalid direction: {direction}" + + def _adjust(pos: int, axis: int, direction: int) -> int: + # positive axis + if axis >= 0: + # positive axis and position greater than or equal to the removed/added (positive) axis + # -> change position by direction + if pos >= axis: + return pos + direction + # positive axis and negative position + # -> change position by direction + elif pos < 0 and pos + total < axis: + return pos - direction + # positive axis and position smaller than the removed/added (positive) axis, but greater than 0 + # -> keep position + else: + return pos + # negative axis + else: + # negative axis and position smaller than the removed/added (negative) axis + # -> change position by inverse direction + if pos <= axis: + return pos - direction + # negative axis and positive position + # -> change position by inverse direction + elif pos > 0 and pos > axis + total: + return pos + direction + # negative axis and position greater than the removed/added (negative) axis, but smaller than 0 + # -> keep position + else: + return pos + + return {k: _adjust(v, axis, direction) for k, v in named_axis.items()} + + +def _add_named_axis( + named_axis: AxisMapping, + axis: int, + total: int | None = None, +) -> AxisMapping: + """ + Adds a new axis to the named_axis at the specified position. + + Args: + named_axis (AxisMapping): The current named axis mapping. + axis (int): The position at which to add the new axis. + total (int | None): The total number of axes. + + Returns: + AxisMapping: The updated named axis mapping after adding the new axis. + + Examples: + >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 0) + {"x": 1, "y": 2, "z": 3} + >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"x": 0, "y": 2, "z": 3} + """ + if total is None: + total = len(named_axis) + + return _adjust_pos_axis(named_axis, axis, total, direction=+1) + + +def _unify_named_axis( + named_axis1: AxisMapping, + named_axis2: AxisMapping, +) -> AxisMapping: + """ + Unifies two named axes into a single named axis. The function iterates over all positional axes present in either of the input named axes. + For each positional axis, it checks the corresponding axis names in both input named axes. If the axis names are the same or if one of them is None, + the unified axis will be the non-None axis. If the axis names are different and neither of them is None, a ValueError is raised. + + Args: + named_axis1 (AxisMapping): The first named axis to unify. + named_axis2 (AxisMapping): The second named axis to unify. + + Returns: + AxisMapping: The unified named axis. + + Raises: + ValueError: If the axes are different and neither of them is None. + + Examples: + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"a": 0, "b": 1, "c": 2}) + Traceback (most recent call last): + ... + ValueError: The named axes are different. Got: 'x' and 'a' for positional axis 0 + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 3}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({}, {}) + {} + """ + + def _get_axis_name( + axis_mapping: AxisMapping, positional_axis: int + ) -> AxisName | None: + for name, position in axis_mapping.items(): + if position == positional_axis: + return name + return None + + unified_named_axis = {} + all_positional_axes = set(named_axis1.values()) | set(named_axis2.values()) + for position in all_positional_axes: + axis_name1 = _get_axis_name(named_axis1, position) + axis_name2 = _get_axis_name(named_axis2, position) + if axis_name1 is not None and axis_name2 is not None: + if axis_name1 != axis_name2: + raise ValueError( + f"The named axes are incompatible. Got: {axis_name1} and {axis_name2} for positional axis {position}" + ) + unified_named_axis[axis_name1] = position + elif axis_name1 is not None: # axis_name2 is None + unified_named_axis[axis_name1] = position + elif axis_name2 is not None: # axis_name1 is None + unified_named_axis[axis_name2] = position + return unified_named_axis + + +@dataclass +class NamedAxesWithDims: + """ + A dataclass that stores the named axis and their corresponding dimensions. + This is a helper class to store the named axis mapping and the number of + dimensions of each named axis, which is useful for broadcasting. + + Attributes: + named_axis (AxisMapping): The named axis mapping. + ndims (Tuple[int]): The number of dimensions of the named axis. + """ + + named_axis: list[AxisMapping] + ndims: list[int] + + def __post_init__(self): + if len(self.named_axis) != len(self.ndims): + raise ValueError( + "The number of dimensions must match the number of named axis mappings." + ) + + def __iter__(self) -> tp.Iterator[tuple[AxisMapping, int]]: + yield from zip(self.named_axis, self.ndims) + + @classmethod + def prepare_contexts( + cls, arrays: tp.Sequence, unwrap_kwargs: dict | None = None + ) -> tuple[dict, dict]: + from awkward._layout import HighLevelContext + from awkward._typetracer import MaybeNone + + # unwrap options + arrays = [x.content if isinstance(x, MaybeNone) else x for x in arrays] + + _unwrap_kwargs = {"allow_unknown": True} + if unwrap_kwargs is not None: + _unwrap_kwargs.update(unwrap_kwargs) + + _named_axes = [] + _ndims = [] + for array in arrays: + with HighLevelContext() as ctx: + layout = ctx.unwrap(array, **_unwrap_kwargs) + _named_axes.append(_get_named_axis(array)) + _ndims.append(layout.minmax_depth[1]) + + depth_context = {NAMED_AXIS_KEY: cls(_named_axes, _ndims)} + lateral_context = {NAMED_AXIS_KEY: cls(_named_axes, _ndims)} + return depth_context, lateral_context + + def __setitem__(self, index: int, named_axis_with_ndim: tuple[AxisMapping, int]): + named_axis, ndim = named_axis_with_ndim + self.named_axis[index] = named_axis + self.ndims[index] = ndim + + def __getitem__(self, index: int) -> tuple[AxisMapping, int]: + return self.named_axis[index], self.ndims[index] + + def __len__(self) -> int: + return len(self.named_axis) + + +# Define a type alias for a slice or int (can be a single axis or a sequence of axes) +AxisSlice: tp.TypeAlias = tp.Union[tuple, slice, int, tp.EllipsisType, None] +NamedAxisSlice: tp.TypeAlias = tp.Dict[AxisName, AxisSlice] + + +def _normalize_named_slice( + named_axis: AxisMapping, + where: AxisSlice | NamedAxisSlice, + total: int, +) -> AxisSlice: + """ + Normalizes a named slice into a positional slice. + + This function takes a named slice (a dictionary mapping axis names to slices) and converts it into a positional slice + (a tuple of slices). The positional slice can then be used to index an array. + + Args: + named_axis (AxisMapping): The current named axis mapping. + where (AxisSlice | NamedAxisSlice): The slice to normalize. Can be a single slice, a tuple of slices, or a dictionary mapping axis names to slices. + total (int): The total number of axes. + + Returns: + AxisSlice: The normalized slice. + + Raises: + ValueError: If an invalid axis name is provided in the slice. + + Examples: + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {0: 0}, 3) + (0, slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {-1: 0}, 3) + (slice(None), slice(None), 0) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0}, 3) + (0, slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1}, 3) + (0, 1, slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": ...}, 3) + (0, 1, ...) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": slice(0, 1)}, 3) + (0, 1, slice(0, 1)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": (0, 1)}, 3) + ((0, 1), slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": [0, 1]}, 3) + ([0, 1], slice(None), slice(None)) + """ + if isinstance(where, dict): + out_where: list[AxisSlice] = [slice(None)] * total + for ax_name, ax_where in where.items(): + slice_ = ax_where if ax_where is not ... else slice(None) + if is_integer(ax_name): + # it's an integer, pyright doesn't get this + idx = tp.cast(int, ax_name) + out_where[idx] = slice_ + elif _is_valid_named_axis(ax_name): + # it's an integer, pyright doesn't get this + idx = tp.cast(int, _named_axis_to_positional_axis(named_axis, ax_name)) + out_where[idx] = slice_ + else: + raise ValueError(f"Invalid axis name: {ax_name} in slice {where}") + where = tuple(out_where) + return where diff --git a/src/awkward/_nplikes/array_like.py b/src/awkward/_nplikes/array_like.py index d82611ae5d..d75fe6cfcb 100644 --- a/src/awkward/_nplikes/array_like.py +++ b/src/awkward/_nplikes/array_like.py @@ -8,6 +8,7 @@ from awkward._typing import ( TYPE_CHECKING, DType, + EllipsisType, Protocol, Self, SupportsIndex, @@ -15,8 +16,6 @@ ) if TYPE_CHECKING: - from types import EllipsisType - from numpy.typing import DTypeLike diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index 64b99abf01..548071c0ee 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -26,6 +26,7 @@ TYPE_CHECKING, Any, DType, + EllipsisType, Final, Literal, Self, @@ -36,8 +37,6 @@ ) if TYPE_CHECKING: - from types import EllipsisType - from numpy.typing import DTypeLike from awkward.contents.content import Content diff --git a/src/awkward/_operators.py b/src/awkward/_operators.py index 2c58330492..f5a2cf90da 100644 --- a/src/awkward/_operators.py +++ b/src/awkward/_operators.py @@ -50,6 +50,7 @@ def _binary_method(ufunc, name): def func(self, other): if _disables_array_ufunc(other): return NotImplemented + return ufunc(self, other) func.__name__ = f"__{name}__" diff --git a/src/awkward/_regularize.py b/src/awkward/_regularize.py index 663d9eb01a..6f78a18409 100644 --- a/src/awkward/_regularize.py +++ b/src/awkward/_regularize.py @@ -7,7 +7,7 @@ from collections.abc import Iterable, Sequence, Sized from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._typing import AxisMaybeNone, SupportsInt +from awkward._typing import Any np = NumpyMetadata.instance() @@ -51,8 +51,19 @@ def is_non_string_like_sequence(obj) -> bool: return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence) -def regularize_axis(axis: SupportsInt | None) -> AxisMaybeNone: - if axis is None: - return None +def regularize_axis(axis: Any, none_allowed: bool = True) -> int | None: + """ + This function's main purpose is to convert [np,cp,...].array(0) to 0. + """ + if is_integer_like(axis): + regularized_axis = int(axis) else: - return int(axis) + regularized_axis = axis + cond = is_integer(regularized_axis) + msg = f"'axis' must be an integer, not {axis!r}" + if none_allowed: + cond = cond or regularized_axis is None + msg = f"'axis' must be an integer or None, not {axis!r}" + if not cond: + raise TypeError(msg) + return regularized_axis diff --git a/src/awkward/_typing.py b/src/awkward/_typing.py index 0e987b4399..be474a37d9 100644 --- a/src/awkward/_typing.py +++ b/src/awkward/_typing.py @@ -26,6 +26,7 @@ "Literal", "SupportsIndex", "ParamSpec", + "EllipsisType", *typing.__all__, } ) @@ -46,7 +47,10 @@ TypeGuard, Unpack, ) + + EllipsisType = type(...) else: + from types import EllipsisType from typing import ( ClassVar, Final, diff --git a/src/awkward/contents/content.py b/src/awkward/contents/content.py index d0169ee2eb..f324d9dac5 100644 --- a/src/awkward/contents/content.py +++ b/src/awkward/contents/content.py @@ -16,8 +16,14 @@ ) from awkward._behavior import get_array_class, get_record_class from awkward._kernels import KernelError -from awkward._layout import wrap_layout +from awkward._layout import maybe_posaxis, wrap_layout from awkward._meta.meta import Meta +from awkward._namedaxis import ( + NamedAxis, + _add_named_axis, + _keep_named_axis, + _remove_named_axis, +) from awkward._nplikes import to_nplike from awkward._nplikes.dispatch import nplike_of_obj from awkward._nplikes.numpy import Numpy @@ -27,7 +33,12 @@ parameters_are_equal, type_parameters_equal, ) -from awkward._regularize import is_integer_like, is_sized_iterable +from awkward._regularize import ( + is_array_like, + is_integer, + is_integer_like, + is_sized_iterable, +) from awkward._slicing import normalize_slice from awkward._typing import ( TYPE_CHECKING, @@ -38,6 +49,7 @@ Protocol, Self, SupportsIndex, + Type, TypeAlias, TypedDict, ) @@ -509,10 +521,14 @@ def _getitem_next_missing( ) def __getitem__(self, where): - return self._getitem(where) + return self._getitem(where, NamedAxis) - def _getitem(self, where): + def _getitem(self, where, named_axis: Type[NamedAxis] = NamedAxis): if is_integer_like(where): + # propagate named_axis to output + named_axis.mapping = _remove_named_axis( + named_axis.mapping, where, self.purelist_depth + ) return self._getitem_at(ak._slicing.normalize_integer_like(where)) elif isinstance(where, slice) and where.step is None: @@ -523,21 +539,35 @@ def _getitem(self, where): return self._getitem_range(start, stop) elif isinstance(where, slice): - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif isinstance(where, str): return self._getitem_field(where) elif where is np.newaxis: - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif where is Ellipsis: - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif isinstance(where, tuple): if len(where) == 0: return self + # count number of ellipsis + # Need to use a little trick here: + # where.count(Ellipsis) does not work, because it will do a == comparison against Ellipsis, + # and this will fail in the case of typetracers where == is dispatched to np.equal ufunc. + # In this dispatch we encounter an assertion that the type of the Ellipsis is not allowed. + # ...but luckily we can use the fact that Ellipsis is a singleton and use the 'is' operator + n_ellipsis = 0 + for w in where: + if w is ...: + n_ellipsis += 1 + + if n_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Backend may change if index contains typetracers backend = backend_of(self, *where, coerce_to_common=True) this = self.to_backend(backend) @@ -547,6 +577,62 @@ def _getitem(self, where): # Prepare items for advanced indexing (e.g. via broadcasting) nextwhere = ak._slicing.prepare_advanced_indexing(items, backend) + # Handle named axis + # first expand the ellipsis to colons in nextwhere, + # copy nextwhere to not pollute the original + _nextwhere = tuple(nextwhere) + if n_ellipsis == 1: + # collect the ellipsis index + # same little trick as above for `nextwhere.index(...)` + (ellipsis_at,) = tuple(i for i, x in enumerate(nextwhere) if x is ...) + # calculate how many slice(None) we need to add + # same little trick as above for `nextwhere.count(None)` + n_newaxis = 0 + for x in nextwhere: + if x is np.newaxis or x is None: + n_newaxis += 1 + n_total = self.minmax_depth[1] + n_slice_none = n_total - (len(_nextwhere) - n_newaxis - 1) + # expand `[...]` to `[:]*n_slice_none` + _nextwhere = ( + _nextwhere[:ellipsis_at] + + (slice(None),) * n_slice_none + + _nextwhere[ellipsis_at + 1 :] + ) + + # now propagate named axis + _named_axis = _keep_named_axis(named_axis.mapping, None) + _adjust_dim = 0 + # this loop does the following: + # - remove a named axis for integer indices, e.g. `a[1, 2]` + # - add a named axis for None (or np.newaxis) indices, e.g. `a[..., None]` + # - keep named axis for any other index, e.g. `a[:]`, `a[0:1]`, or `a[a>0]` + # (these may only remove elements, but not dimensions) + for dim, nw in enumerate(_nextwhere): + dim_adjusted = dim + _adjust_dim + total_adjusted = self.minmax_depth[1] + _adjust_dim + for _, pos in _named_axis.items(): + if maybe_posaxis(self, pos, 0) == dim_adjusted: + break + + if is_integer(nw) or (is_array_like(nw) and nw.ndim == 0): + _named_axis = _remove_named_axis( + named_axis=_named_axis, + axis=dim_adjusted, + total=total_adjusted, + ) + _adjust_dim -= 1 + elif nw is None: + _named_axis = _add_named_axis( + named_axis=_named_axis, + axis=dim_adjusted, + total=total_adjusted, + ) + _adjust_dim += 1 + + # set propagated named axis + named_axis.mapping = _named_axis + next = ak.contents.RegularArray( this, this.length, @@ -562,7 +648,7 @@ def _getitem(self, where): return out._getitem_at(0) elif isinstance(where, ak.highlevel.Array): - return self._getitem(where.layout) + return self._getitem(where.layout, named_axis) # Convert between nplikes of different backends elif ( @@ -570,7 +656,9 @@ def _getitem(self, where): and where.backend is not self._backend ): backend = backend_of(self, where, coerce_to_common=True) - return self.to_backend(backend)._getitem(where.to_backend(backend)) + return self.to_backend(backend)._getitem( + where.to_backend(backend), named_axis + ) elif isinstance(where, ak.contents.NumpyArray): data_as_index = to_nplike( @@ -602,7 +690,7 @@ def _getitem(self, where): allow_lazy = "copied" # True, but also can be modified in-place else: wheres = self._backend.index_nplike.nonzero(data_as_index) - return self._getitem(wheres) + return self._getitem(wheres, named_axis) else: raise TypeError( "array slice must be an array of integers or booleans, not\n\n {}".format( @@ -621,9 +709,9 @@ def _getitem(self, where): elif isinstance(where, ak.contents.RegularArray): maybe_numpy = where.maybe_to_NumpyArray() if maybe_numpy is None: - return self._getitem((where,)) + return self._getitem((where,), named_axis) else: - return self._getitem(maybe_numpy) + return self._getitem(maybe_numpy, named_axis) # Awkward Array of strings elif ( @@ -637,7 +725,7 @@ def _getitem(self, where): return where.to_NumpyArray(np.int64) elif isinstance(where, Content): - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif is_sized_iterable(where): # Do we have an array @@ -654,7 +742,7 @@ def _getitem(self, where): primitive_policy="error", string_policy="as-characters", ) - return self._getitem(layout) + return self._getitem(layout, named_axis) elif len(where) == 0: return self._carry( @@ -682,7 +770,7 @@ def _getitem(self, where): ), self._backend, ) - return self._getitem(layout) + return self._getitem(layout, named_axis) else: raise TypeError( diff --git a/src/awkward/contents/numpyarray.py b/src/awkward/contents/numpyarray.py index 315d9383b7..5c90ca0141 100644 --- a/src/awkward/contents/numpyarray.py +++ b/src/awkward/contents/numpyarray.py @@ -175,7 +175,11 @@ def shape(self) -> tuple[ShapeItem, ...]: @property def inner_shape(self) -> tuple[ShapeItem, ...]: - return self._data.shape[1:] + if hasattr(self._data, "inner_shape"): + inner_shape = self._data.inner_shape + else: + inner_shape = self._data.shape[1:] + return inner_shape @property def strides(self) -> tuple[ShapeItem, ...]: @@ -189,14 +193,9 @@ def _raw(self, nplike=None): return to_nplike(self.data, nplike, from_nplike=self._backend.nplike) def _form_with_key(self, getkey: Callable[[Content], str | None]) -> NumpyForm: - if hasattr(self._data, "inner_shape"): - inner_shape = self._data.inner_shape - else: - inner_shape = self._data.shape[1:] - return self.form_cls( ak.types.numpytype.dtype_to_primitive(self._data.dtype), - inner_shape, + self.inner_shape, parameters=self._parameters, form_key=getkey(self), ) diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index f315945511..6d1d6649aa 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -23,6 +23,16 @@ from awkward._backends.numpy import NumpyBackend from awkward._behavior import behavior_of, get_array_class, get_record_class from awkward._layout import wrap_layout +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + AxisMapping, + NamedAxis, + _get_named_axis, + _make_positional_axis_tuple, + _normalize_named_slice, + _prepare_named_axis_for_attrs, + _prettify_named_axes, +) from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import NumpyMetadata from awkward._operators import NDArrayOperatorsMixin @@ -32,7 +42,7 @@ unpickle_record_schema_1, ) from awkward._regularize import is_non_string_like_iterable -from awkward._typing import Any, TypeVar +from awkward._typing import Any, MutableMapping, TypeVar from awkward._util import STDOUT from awkward.prettyprint import Formatter from awkward.prettyprint import valuestr as prettyprint_valuestr @@ -278,6 +288,7 @@ def __init__( check_valid=False, backend=None, attrs=None, + named_axis=None, ): self._cpp_type = None if isinstance(data, ak.contents.Content): @@ -326,9 +337,20 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or a mapping") - if attrs is not None and not isinstance(attrs, Mapping): + if attrs is not None and not isinstance(attrs, MutableMapping): raise TypeError("attrs must be None or a mapping") + if named_axis: + _named_axis = _prepare_named_axis_for_attrs( + named_axis=named_axis, + ndim=layout.minmax_depth[1], + ) + # now we're good, set the named axis + if attrs is None: + attrs = {} + # if NAMED_AXIS_KEY is already in attrs, it will be overwritten + attrs[NAMED_AXIS_KEY] = _named_axis + self._layout = layout self._behavior = behavior self._attrs = attrs @@ -357,7 +379,7 @@ def _update_class(self): self.__class__ = get_array_class(self._layout, self._behavior) @property - def attrs(self) -> Mapping[str, Any]: + def attrs(self) -> Mapping: """ The mutable mapping containing top-level metadata, which is serialised with the array during pickling. @@ -455,6 +477,15 @@ def behavior(self, behavior): else: raise TypeError("behavior must be None or a dict") + @property + def positional_axis(self) -> tuple[int, ...]: + (_, ndim) = self._layout.minmax_depth + return _make_positional_axis_tuple(ndim) + + @property + def named_axis(self) -> AxisMapping: + return _get_named_axis(self) + class Mask: def __init__(self, array): self._array = array @@ -1062,12 +1093,30 @@ def __getitem__(self, where): have the same dimension as the array being indexed. """ with ak._errors.SlicingErrorContext(self, where): - return wrap_layout( - prepare_layout(self._layout[where]), - self._behavior, - allow_other=True, - attrs=self._attrs, - ) + # Handle named axis + (_, ndim) = self._layout.minmax_depth + named_axis = _get_named_axis(self) + where = _normalize_named_slice(named_axis, where, ndim) + + NamedAxis.mapping = named_axis + + indexed_layout = prepare_layout(self._layout._getitem(where, NamedAxis)) + + if NamedAxis.mapping: + return ak.operations.ak_with_named_axis._impl( + indexed_layout, + named_axis=NamedAxis.mapping, + highlevel=True, + behavior=self._behavior, + attrs=self._attrs, + ) + else: + return wrap_layout( + indexed_layout, + self._behavior, + allow_other=True, + attrs=self._attrs, + ) def __bytes__(self) -> bytes: if isinstance(self._layout, ak.contents.NumpyArray) and self._layout.parameter( @@ -1309,6 +1358,15 @@ def _repr(self, limit_cols): else: valuestr = "-typetracer" + # prepare named_axis str for repr + axisstr = "" + if self.named_axis: + # we reserve at maximum 20 characters for the named axis string + axisstr = _prettify_named_axes(self.named_axis, delimiter=",", maxlen=20) + axisstr = f" {axisstr}" + # subtract the reserved space from the limit_cols + limit_cols -= len(axisstr) + if len(typestr) + len(pytype) + len(" type=''") + 3 < limit_cols // 2: strwidth = limit_cols - (len(typestr) + len(pytype) + len(" type=''") + 3) else: @@ -1327,13 +1385,14 @@ def _repr(self, limit_cols): else: typestr = "'" + typestr + "'" - return f"<{pytype}{valuestr} type={typestr}>" + return f"<{pytype}{valuestr}{axisstr} type={typestr}>" def show( self, limit_rows=20, limit_cols=80, type=False, + named_axis=False, stream=STDOUT, *, formatter=None, @@ -1365,25 +1424,41 @@ def show( valuestr = prettyprint_valuestr( self, limit_rows, limit_cols, formatter=formatter_impl ) + + out_io = io.StringIO() if type: - tmp = io.StringIO() - self.type.show(stream=tmp) - out = "type: " + tmp.getvalue() + valuestr - else: - out = valuestr + out_io.write("type: ") + self.type.show(stream=out_io) + if named_axis and self.named_axis: + out_io.write("axes: ") + out_io.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + out_io.write("\n") + out_io.write(valuestr) if stream is None: - return out + return out_io else: if stream is STDOUT: stream = STDOUT.stream - stream.write(out + "\n") + stream.write(out_io.getvalue() + "\n") def _repr_mimebundle_(self, include=None, exclude=None): + # order: 1. array, 2. named_axis, 3. type value_buff = io.StringIO() - self.show(type=False, stream=value_buff) + self.show(type=False, named_axis=False, stream=value_buff) header_lines = value_buff.getvalue().splitlines() + named_axis_line = "" + if self.named_axis: + named_axis_buff = io.StringIO() + named_axis_buff.write("axes: ") + named_axis_buff.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + named_axis_line = named_axis_buff.getvalue() + type_buff = io.StringIO() self.type.show(stream=type_buff) footer_lines = type_buff.getvalue().splitlines() @@ -1393,8 +1468,16 @@ def _repr_mimebundle_(self, include=None, exclude=None): if header_lines[-1] == "": del header_lines[-1] - n_cols = max(len(line) for line in itertools.chain(header_lines, footer_lines)) - body = "\n".join([*header_lines, "-" * n_cols, *footer_lines]) + n_cols = max( + len(line) + for line in itertools.chain(header_lines, [named_axis_line], footer_lines) + ) + body_lines = header_lines + body_lines.append("-" * n_cols) + if named_axis_line: + body_lines.append(named_axis_line) + body_lines.extend(footer_lines) + body = "\n".join(body_lines) return { "text/html": f"
{html.escape(body)}
", @@ -1719,6 +1802,7 @@ def __init__( check_valid=False, backend=None, attrs=None, + named_axis=None, ): if isinstance(data, ak.record.Record): layout = data @@ -1762,6 +1846,20 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or mapping") + if attrs is not None and not isinstance(attrs, MutableMapping): + raise TypeError("attrs must be None or a mapping") + + if named_axis: + _named_axis = _prepare_named_axis_for_attrs( + named_axis=named_axis, + ndim=layout.minmax_depth[1], + ) + # now we're good, set the named axis + if attrs is None: + attrs = {} + # if NAMED_AXIS_KEY is already in attrs, it will be overwritten + attrs[NAMED_AXIS_KEY] = _named_axis + self._layout = layout self._behavior = behavior self._attrs = attrs @@ -1877,6 +1975,15 @@ def behavior(self, behavior): else: raise TypeError("behavior must be None or a dict") + @property + def positional_axis(self) -> tuple[int, ...]: + (_, ndim) = self._layout.minmax_depth + return _make_positional_axis_tuple(ndim) + + @property + def named_axis(self) -> AxisMapping: + return _get_named_axis(self) + def tolist(self): """ Converts this Record into Python objects; same as #ak.to_list @@ -2170,6 +2277,15 @@ def _repr(self, limit_cols): else: valuestr = "-typetracer" + # prepare named_axis str for repr + axisstr = "" + if self.named_axis: + # we reserve at maximum 20 characters for the named axis string + axisstr = _prettify_named_axes(self.named_axis, delimiter=",", maxlen=20) + axisstr = f" {axisstr}" + # subtract the reserved space from the limit_cols + limit_cols -= len(axisstr) + if len(typestr) + len(pytype) + len(" type=''") + 3 < limit_cols // 2: strwidth = limit_cols - (len(typestr) + len(pytype) + len(" type=''") + 3) else: @@ -2188,13 +2304,14 @@ def _repr(self, limit_cols): else: typestr = "'" + typestr + "'" - return f"<{pytype}{valuestr} type={typestr}>" + return f"<{pytype}{valuestr}{axisstr} type={typestr}>" def show( self, limit_rows=20, limit_cols=80, type=False, + named_axis=False, stream=STDOUT, *, formatter=None, @@ -2224,25 +2341,41 @@ def show( valuestr = prettyprint_valuestr( self, limit_rows, limit_cols, formatter=formatter_impl ) + + out_io = io.StringIO() if type: - tmp = io.StringIO() - self.type.show(stream=tmp) - out = "type: " + tmp.getvalue() + valuestr - else: - out = valuestr + out_io.write("type: ") + self.type.show(stream=out_io) + if named_axis and self.named_axis: + out_io.write("axes: ") + out_io.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + out_io.write("\n") + out_io.write(valuestr) if stream is None: - return out + return out_io.getvalue() else: if stream is STDOUT: stream = STDOUT.stream - stream.write(out + "\n") + stream.write(out_io.getvalue() + "\n") def _repr_mimebundle_(self, include=None, exclude=None): + # order: 1. array, 2. named_axis, 3. type value_buff = io.StringIO() - self.show(type=False, stream=value_buff) + self.show(type=False, named_axis=False, stream=value_buff) header_lines = value_buff.getvalue().splitlines() + named_axis_line = "" + if self.named_axis: + named_axis_buff = io.StringIO() + named_axis_buff.write("axes: ") + named_axis_buff.write( + _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None) + ) + named_axis_line = named_axis_buff.getvalue() + type_buff = io.StringIO() self.type.show(stream=type_buff) footer_lines = type_buff.getvalue().splitlines() @@ -2252,8 +2385,16 @@ def _repr_mimebundle_(self, include=None, exclude=None): if header_lines[-1] == "": del header_lines[-1] - n_cols = max(len(line) for line in itertools.chain(header_lines, footer_lines)) - body = "\n".join([*header_lines, "-" * n_cols, *footer_lines]) + n_cols = max( + len(line) + for line in itertools.chain(header_lines, [named_axis_line], footer_lines) + ) + body_lines = header_lines + body_lines.append("-" * n_cols) + if named_axis_line: + body_lines.append(named_axis_line) + body_lines.extend(footer_lines) + body = "\n".join(body_lines) return { "text/html": f"
{html.escape(body)}
", diff --git a/src/awkward/operations/__init__.py b/src/awkward/operations/__init__.py index d76d8e2688..94dbd9ffac 100644 --- a/src/awkward/operations/__init__.py +++ b/src/awkward/operations/__init__.py @@ -114,8 +114,10 @@ from awkward.operations.ak_where import * from awkward.operations.ak_with_field import * from awkward.operations.ak_with_name import * +from awkward.operations.ak_with_named_axis import * from awkward.operations.ak_with_parameter import * from awkward.operations.ak_without_field import * +from awkward.operations.ak_without_named_axis import * from awkward.operations.ak_without_parameters import * from awkward.operations.ak_zeros_like import * from awkward.operations.ak_zip import * diff --git a/src/awkward/operations/ak_all.py b/src/awkward/operations/ak_all.py index 859bfd98cb..98a22520ba 100644 --- a/src/awkward/operations/ak_all.py +++ b/src/awkward/operations/ak_all.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -67,9 +73,26 @@ def all( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.All() out = ak._do.reduce( @@ -80,7 +103,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("all") diff --git a/src/awkward/operations/ak_almost_equal.py b/src/awkward/operations/ak_almost_equal.py index 66f67e4d8a..949f955a45 100644 --- a/src/awkward/operations/ak_almost_equal.py +++ b/src/awkward/operations/ak_almost_equal.py @@ -7,6 +7,7 @@ from awkward._behavior import behavior_of, get_array_class, get_record_class from awkward._dispatch import high_level_function from awkward._layout import ensure_same_backend +from awkward._namedaxis import _get_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._parameters import parameters_are_equal from awkward.operations.ak_to_layout import to_layout @@ -27,6 +28,7 @@ def almost_equal( dtype_exact: bool = True, check_parameters: bool = True, check_regular: bool = True, + check_named_axis: bool = True, ): """ Args: @@ -39,6 +41,7 @@ def almost_equal( check_parameters: whether to compare parameters. check_regular: whether to consider ragged and regular dimensions as unequal. + check_named_axis: bool (default=True) whether to consider named axes as unequal. Return True if the two array-like arguments are considered equal for the given options. Otherwise, return False. @@ -61,6 +64,7 @@ def almost_equal( dtype_exact=dtype_exact, check_parameters=check_parameters, check_regular=check_regular, + check_named_axis=check_named_axis, exact_eq=False, same_content_types=False, equal_nan=False, @@ -75,6 +79,7 @@ def _impl( dtype_exact: bool, check_parameters: bool, check_regular: bool, + check_named_axis: bool, exact_eq: bool, same_content_types: bool, equal_nan: bool, @@ -91,6 +96,10 @@ def _impl( right_layout = layouts[1].to_packed() backend = backend_of(left_layout) + if check_named_axis and _get_named_axis(left) and _get_named_axis(right): + if left.named_axis != right.named_axis: + return False + if not backend.nplike.known_data: raise NotImplementedError( "Awkward Arrays with typetracer backends cannot yet be compared with `ak.almost_equal`." diff --git a/src/awkward/operations/ak_any.py b/src/awkward/operations/ak_any.py index 79c9cc6b83..e99065d97c 100644 --- a/src/awkward/operations/ak_any.py +++ b/src/awkward/operations/ak_any.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -67,9 +73,26 @@ def any( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Any() out = ak._do.reduce( @@ -80,7 +103,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("any") diff --git a/src/awkward/operations/ak_argcartesian.py b/src/awkward/operations/ak_argcartesian.py index 12deed5749..f012290cbe 100644 --- a/src/awkward/operations/ak_argcartesian.py +++ b/src/awkward/operations/ak_argcartesian.py @@ -7,7 +7,6 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis __all__ = ("argcartesian",) @@ -107,8 +106,6 @@ def argcartesian( def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs): - axis = regularize_axis(axis) - if isinstance(arrays, Mapping): index_arrays = {n: ak.local_index(x, axis) for n, x in arrays.items()} else: diff --git a/src/awkward/operations/ak_argcombinations.py b/src/awkward/operations/ak_argcombinations.py index 98a2643855..337f77cec1 100644 --- a/src/awkward/operations/ak_argcombinations.py +++ b/src/awkward/operations/ak_argcombinations.py @@ -5,6 +5,7 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import _get_named_axis, _named_axis_to_positional_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -93,7 +94,6 @@ def _impl( behavior, attrs, ): - axis = regularize_axis(axis) if parameters is None: parameters = {} else: @@ -101,6 +101,13 @@ def _impl( if with_name is not None: parameters["__record__"] = with_name + # Handle named axis + named_axis = _get_named_axis(array) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + if axis < 0: raise ValueError("the 'axis' for argcombinations must be non-negative") else: diff --git a/src/awkward/operations/ak_argmax.py b/src/awkward/operations/ak_argmax.py index a4dbe947bd..ef9b37e57c 100644 --- a/src/awkward/operations/ak_argmax.py +++ b/src/awkward/operations/ak_argmax.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -132,9 +138,26 @@ def nanargmax( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.ArgMax() out = ak._do.reduce( @@ -145,7 +168,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("argmax") diff --git a/src/awkward/operations/ak_argmin.py b/src/awkward/operations/ak_argmin.py index 7f21fb3aa8..6982a4d407 100644 --- a/src/awkward/operations/ak_argmin.py +++ b/src/awkward/operations/ak_argmin.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -129,10 +135,26 @@ def nanargmin( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.ArgMin() out = ak._do.reduce( @@ -143,7 +165,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("argmin") diff --git a/src/awkward/operations/ak_argsort.py b/src/awkward/operations/ak_argsort.py index bade378b20..7c92d6a645 100644 --- a/src/awkward/operations/ak_argsort.py +++ b/src/awkward/operations/ak_argsort.py @@ -6,6 +6,10 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -70,11 +74,22 @@ def argsort( def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + out = ak._do.argsort(layout, axis, ascending, stable) - return ctx.wrap(out, highlevel=highlevel) + + return ctx.wrap( + out, + highlevel=highlevel, + ) @ak._connect.numpy.implements("argsort") diff --git a/src/awkward/operations/ak_array_equal.py b/src/awkward/operations/ak_array_equal.py index 398db6b2a6..1dabd60f31 100644 --- a/src/awkward/operations/ak_array_equal.py +++ b/src/awkward/operations/ak_array_equal.py @@ -18,6 +18,7 @@ def array_equal( same_content_types: bool = True, check_parameters: bool = True, check_regular: bool = True, + check_named_axis: bool = True, ): """ True if two arrays have the same shape and elements, False otherwise. @@ -34,6 +35,7 @@ def array_equal( check_parameters: bool (default=True) whether to compare parameters. check_regular: bool (default=True) whether to consider ragged and regular dimensions as unequal. + check_named_axis: bool (default=True) whether to consider named axes as unequal. TypeTracer arrays are not supported, as there is very little information to be compared. @@ -49,6 +51,7 @@ def array_equal( dtype_exact=dtype_exact, check_parameters=check_parameters, check_regular=check_regular, + check_named_axis=check_named_axis, exact_eq=True, same_content_types=same_content_types and check_regular, equal_nan=equal_nan, diff --git a/src/awkward/operations/ak_broadcast_arrays.py b/src/awkward/operations/ak_broadcast_arrays.py index 877c69f9c0..feef9b5138 100644 --- a/src/awkward/operations/ak_broadcast_arrays.py +++ b/src/awkward/operations/ak_broadcast_arrays.py @@ -2,6 +2,8 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._attrs import attrs_of_obj from awkward._backends.dispatch import backend_of @@ -10,6 +12,11 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import wrap_layout +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + NamedAxesWithDims, + _unify_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("broadcast_arrays",) @@ -243,24 +250,43 @@ def action(inputs, depth, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arrays) out = ak._broadcasting.broadcast_and_apply( inputs, action, + depth_context=depth_context, + lateral_context=lateral_context, left_broadcast=left_broadcast, right_broadcast=right_broadcast, broadcast_parameters_rule=broadcast_parameters_rule, numpy_to_regular=True, ) assert isinstance(out, tuple) - return [ - wrap_layout( + + # unify named axes + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = [] + for layout_out, array_in in zip(out, arrays): + _behavior = behavior_of_obj(array_in, behavior=behavior) + _attrs = attrs_of_obj(array_in, attrs=attrs) + wrapped = wrap_layout( layout_out, - behavior=behavior_of_obj(array_in, behavior=behavior), + behavior=_behavior, highlevel=highlevel, - attrs=attrs_of_obj(array_in, attrs=attrs), + attrs=_attrs, ) - for layout_out, array_in in zip(out, arrays) - ] + wrapped_out.append( + ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=_behavior, + attrs=_attrs, + ) + ) + return wrapped_out @ak._connect.numpy.implements("broadcast_arrays") diff --git a/src/awkward/operations/ak_cartesian.py b/src/awkward/operations/ak_cartesian.py index 91767d27d8..0f46f449c9 100644 --- a/src/awkward/operations/ak_cartesian.py +++ b/src/awkward/operations/ak_cartesian.py @@ -3,11 +3,20 @@ from __future__ import annotations from collections.abc import Mapping +from functools import reduce import awkward as ak from awkward._backends.numpy import NumpyBackend from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + NamedAxesWithDims, + _add_named_axis, + _get_named_axis, + _named_axis_to_positional_axis, + _unify_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -214,7 +223,6 @@ def cartesian( def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: if isinstance(arrays, Mapping): layouts = ensure_same_backend( @@ -226,6 +234,11 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr fields = list(arrays.keys()) array_layouts = dict(zip(fields, layouts)) + # propagate named axis from input to output, + # use strategy "unify" (see: awkward._namedaxis) + out_named_axis = reduce( + _unify_named_axis, map(_get_named_axis, arrays.values()) + ) else: layouts = array_layouts = ensure_same_backend( *( @@ -234,6 +247,15 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr ) ) fields = None + # propagate named axis from input to output, + # use strategy "unify" (see: awkward._namedaxis) + out_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays)) + + # Handle named axis + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(out_named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) + max_ndim = max(layout.minmax_depth[1] for layout in layouts) if with_name is not None: if parameters is None: @@ -262,6 +284,7 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr if nested is None or nested is False: nested = [] elif nested is True: + out_named_axis = _add_named_axis(out_named_axis, 0, max_ndim) if fields is not None: nested = list(fields)[:-1] else: @@ -287,6 +310,8 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr "the 'nested' parameter of cartesian must be integers in " "[0, len(arrays) - 1) for an iterable of arrays" ) + for n in nested: + out_named_axis = _add_named_axis(out_named_axis, n, max_ndim) backend = next((layout.backend for layout in layouts), cpu) if posaxis == 0: @@ -398,16 +423,48 @@ def apply_build_record(inputs, depth, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + list(arrays.values()) if isinstance(arrays, Mapping) else list(arrays) + ) out = ak._broadcasting.broadcast_and_apply( - new_layouts, apply_build_record, right_broadcast=False + new_layouts, + apply_build_record, + depth_context=depth_context, + lateral_context=lateral_context, + right_broadcast=False, ) assert isinstance(out, tuple) and len(out) == 1 result = out[0] + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(result, highlevel=highlevel) + # propagate named axis to output + result = ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) + # Remove surplus dimensions, iterating from smallest to greatest for axis_to_flatten in axes_to_flatten: result = ak.operations.flatten( - result, axis=axis_to_flatten, highlevel=False, behavior=behavior + result, axis=axis_to_flatten, highlevel=highlevel, behavior=behavior ) - return ctx.wrap(result, highlevel=highlevel) + return result + + wrapped_out = ctx.wrap(result, highlevel=highlevel) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_categories.py b/src/awkward/operations/ak_categories.py index cd7f6ccf4c..e723d098da 100644 --- a/src/awkward/operations/ak_categories.py +++ b/src/awkward/operations/ak_categories.py @@ -49,6 +49,16 @@ def action(layout, **kwargs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + ak._do.recursively_apply(layout, action) - return ctx.wrap(output, highlevel=highlevel) + wrapped_out = ctx.wrap(output, highlevel=highlevel) + + # propagate named axis from input to output, + # use strategy "drop all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_combinations.py b/src/awkward/operations/ak_combinations.py index d22708cb4a..284023f2cd 100644 --- a/src/awkward/operations/ak_combinations.py +++ b/src/awkward/operations/ak_combinations.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -214,7 +218,15 @@ def _impl( behavior, attrs, ): - axis = regularize_axis(axis) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) if with_name is None: pass @@ -223,8 +235,6 @@ def _impl( else: parameters = {**parameters, "__record__": with_name} - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: - layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") out = ak._do.combinations( layout, n, diff --git a/src/awkward/operations/ak_concatenate.py b/src/awkward/operations/ak_concatenate.py index fb8fcf94ae..3e086f7e8c 100644 --- a/src/awkward/operations/ak_concatenate.py +++ b/src/awkward/operations/ak_concatenate.py @@ -2,6 +2,7 @@ from __future__ import annotations +from functools import reduce from itertools import permutations import awkward as ak @@ -9,6 +10,13 @@ from awkward._dispatch import high_level_function from awkward._do import mergeable from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + NamedAxesWithDims, + _get_named_axis, + _named_axis_to_positional_axis, + _unify_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length from awkward._parameters import type_parameters_equal @@ -92,7 +100,6 @@ def _merge_as_union( def _impl(arrays, axis, mergebool, highlevel, behavior, attrs): - axis = regularize_axis(axis) # Simple single-array, axis=0 fast-path if ( # Is an array with a known backend @@ -121,6 +128,15 @@ def _impl(arrays, axis, mergebool, highlevel, behavior, attrs): ) ) + # Handle named axis + merged_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays)) + # Step 1: normalize named axis to positional axis + axis = _named_axis_to_positional_axis(merged_named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) + # Step 2: propagate named axis from input to output, + # use strategy "unify" (see: awkward._namedaxis) + out_named_axis = merged_named_axis + contents = [x for x in content_or_others if isinstance(x, ak.contents.Content)] if len(contents) == 0: raise ValueError("need at least one array to concatenate") @@ -342,11 +358,35 @@ def action(inputs, depth, backend, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + list(arrays) + ) out = ak._broadcasting.broadcast_and_apply( - content_or_others, action, allow_records=True, right_broadcast=False + content_or_others, + action, + depth_context=depth_context, + lateral_context=lateral_context, + allow_records=True, + right_broadcast=False, )[0] + # Unify named axes + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) def _form_has_type(form, type_): diff --git a/src/awkward/operations/ak_corr.py b/src/awkward/operations/ak_corr.py index 74d148831d..e646a43b0f 100644 --- a/src/awkward/operations/ak_corr.py +++ b/src/awkward/operations/ak_corr.py @@ -3,12 +3,14 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import _get_named_axis, _is_valid_named_axis from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -86,7 +88,10 @@ def corr( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + if _is_valid_named_axis(axis): + raise NotImplementedError("named axis not yet supported for ak.corr") + + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( @@ -110,7 +115,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr x, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=ctx.behavior, @@ -120,7 +125,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr y, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=ctx.behavior, @@ -184,8 +189,19 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr behavior=ctx.behavior, attrs=ctx.attrs, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(sumwxy / ufuncs.sqrt(sumwxx * sumwyy)), + + out = sumwxy / ufuncs.sqrt(sumwxx * sumwyy) + + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_count.py b/src/awkward/operations/ak_count.py index 85f43a27ee..f9b8c48481 100644 --- a/src/awkward/operations/ak_count.py +++ b/src/awkward/operations/ak_count.py @@ -5,6 +5,12 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -109,9 +115,26 @@ def count( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Count() out = ak._do.reduce( @@ -122,4 +145,18 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_count_nonzero.py b/src/awkward/operations/ak_count_nonzero.py index 919a6abf22..74a8b23033 100644 --- a/src/awkward/operations/ak_count_nonzero.py +++ b/src/awkward/operations/ak_count_nonzero.py @@ -5,6 +5,12 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -68,7 +74,26 @@ def count_nonzero( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") reducer = ak._reducers.CountNonzero() @@ -81,7 +106,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("count_nonzero") diff --git a/src/awkward/operations/ak_covar.py b/src/awkward/operations/ak_covar.py index a070ac6895..7c8fe930fe 100644 --- a/src/awkward/operations/ak_covar.py +++ b/src/awkward/operations/ak_covar.py @@ -3,12 +3,14 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import _get_named_axis, _is_valid_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -83,7 +85,9 @@ def covar( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + if _is_valid_named_axis(axis): + raise NotImplementedError("named axis not yet supported for ak.covar") + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( @@ -107,7 +111,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr x, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=None, @@ -117,7 +121,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr y, weight, axis, - False, + True, mask_identity, highlevel=True, behavior=None, @@ -161,8 +165,18 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr behavior=None, attrs=None, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(sumwxy / sumw), + + out = sumwxy / sumw + + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_drop_none.py b/src/awkward/operations/ak_drop_none.py index c6c06014db..d81770f78f 100644 --- a/src/awkward/operations/ak_drop_none.py +++ b/src/awkward/operations/ak_drop_none.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -65,10 +69,16 @@ def _drop_none_if_list(layout): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + if axis is None: # if the outer layout is_option, drop_nones without affecting offsets if layout.is_option: @@ -120,4 +130,7 @@ def action(layout, depth, **kwargs): if len(options["none_indexes"]) > 0: out = ak._do.recursively_apply(out, recompute_offsets, depth_context=options) - return ctx.wrap(out, highlevel=highlevel) + return ctx.wrap( + out, + highlevel=highlevel, + ) diff --git a/src/awkward/operations/ak_fill_none.py b/src/awkward/operations/ak_fill_none.py index 89834689cd..fb3dbfd019 100644 --- a/src/awkward/operations/ak_fill_none.py +++ b/src/awkward/operations/ak_fill_none.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -69,8 +73,6 @@ def fill_none(array, value, axis=-1, *, highlevel=True, behavior=None, attrs=Non def _impl(array, value, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: array_layout, value_layout = ensure_same_backend( ctx.unwrap(array, allow_record=True, allow_unknown=False), @@ -84,6 +86,13 @@ def _impl(array, value, axis, highlevel, behavior, attrs): ), ) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + if isinstance(value_layout, ak.record.Record): value_layout = value_layout.array[value_layout.at : value_layout.at + 1] elif isinstance(value_layout, ak.contents.Content): diff --git a/src/awkward/operations/ak_firsts.py b/src/awkward/operations/ak_firsts.py index f67da6dde1..79fba6eb51 100644 --- a/src/awkward/operations/ak_firsts.py +++ b/src/awkward/operations/ak_firsts.py @@ -5,8 +5,13 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("firsts",) @@ -58,10 +63,20 @@ def firsts(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False) - axis = regularize_axis(axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _remove_named_axis( + named_axis=named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=False) if maybe_posaxis(layout, axis, 1) == 0: # specialized logic; it's tested in test_0582-propagate-context-in-broadcast_and_apply.py @@ -103,4 +118,17 @@ def action(layout, depth, backend, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_flatten.py b/src/awkward/operations/ak_flatten.py index b246870463..3805d28e71 100644 --- a/src/awkward/operations/ak_flatten.py +++ b/src/awkward/operations/ak_flatten.py @@ -5,6 +5,12 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -173,10 +179,25 @@ def flatten(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) + # Step 2: propagate named axis from input to output, + # if axis == None: use strategy "remove all" (see: awkward._namedaxis) + # if axis == 0: use strategy "keep all" (see: awkward._namedaxis) + # if axis != 0: use strategy "remove one" (see: awkward._namedaxis) + if axis is None: + pass + elif axis == 0 or maybe_posaxis(layout, axis, 1) == 0: + out_named_axis = _keep_named_axis(named_axis, None) + else: + out_named_axis = _remove_named_axis(named_axis, axis, layout.minmax_depth[1]) + if axis is None: out = ak._do.remove_structure(layout, function_name="ak.flatten") assert isinstance(out, tuple) and all( @@ -234,4 +255,27 @@ def apply(layout): out = apply(layout) else: out = ak._do.flatten(layout, axis) - return ctx.wrap(out, highlevel=highlevel) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + # if axis == None: use strategy "remove all" (see: awkward._namedaxis) + if axis is None: + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) + # if axis == 0: use strategy "keep all" (see: awkward._namedaxis) + # if axis != 0: use strategy "remove one" (see: awkward._namedaxis) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_from_regular.py b/src/awkward/operations/ak_from_regular.py index b3f840ef31..9fe2800a2b 100644 --- a/src/awkward/operations/ak_from_regular.py +++ b/src/awkward/operations/ak_from_regular.py @@ -55,7 +55,8 @@ def from_regular(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False) - axis = regularize_axis(axis) + + axis = regularize_axis(axis, none_allowed=True) if axis is None: diff --git a/src/awkward/operations/ak_is_none.py b/src/awkward/operations/ak_is_none.py index 078c86bde6..d9a58a5478 100644 --- a/src/awkward/operations/ak_is_none.py +++ b/src/awkward/operations/ak_is_none.py @@ -5,8 +5,13 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis_up_to, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("is_none",) @@ -41,12 +46,19 @@ def is_none(array, axis=0, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + + # Step 2: propagate named axis from input to output, + # use strategy "keep up to" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1]) def action(layout, depth, backend, lateral_context, **kwargs): posaxis = maybe_posaxis(layout, axis, depth) @@ -68,4 +80,16 @@ def action(layout, depth, backend, lateral_context, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_isclose.py b/src/awkward/operations/ak_isclose.py index 8797c36752..d5ff825c61 100644 --- a/src/awkward/operations/ak_isclose.py +++ b/src/awkward/operations/ak_isclose.py @@ -2,9 +2,12 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("isclose",) @@ -70,10 +73,26 @@ def action(inputs, backend, **kwargs): ), ) - out = ak._broadcasting.broadcast_and_apply(layouts, action) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts([a, b]) + out = ak._broadcasting.broadcast_and_apply( + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) assert isinstance(out, tuple) and len(out) == 1 - return ctx.wrap(out[0], highlevel=highlevel) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("isclose") diff --git a/src/awkward/operations/ak_linear_fit.py b/src/awkward/operations/ak_linear_fit.py index 971fea64fe..01ac0f3297 100644 --- a/src/awkward/operations/ak_linear_fit.py +++ b/src/awkward/operations/ak_linear_fit.py @@ -7,7 +7,6 @@ from awkward._layout import HighLevelContext, ensure_same_backend from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis __all__ = ("linear_fit",) @@ -95,8 +94,6 @@ def linear_fit( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -231,4 +228,13 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr if is_scalar: out = out[0] - return ctx.wrap(out, highlevel=highlevel, allow_other=is_scalar) + wrapped_out = ctx.wrap(out, highlevel=highlevel, allow_other=is_scalar) + + # propagate named axis + # use strategy "remove all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_local_index.py b/src/awkward/operations/ak_local_index.py index 2231ac229f..d5e7089dbc 100644 --- a/src/awkward/operations/ak_local_index.py +++ b/src/awkward/operations/ak_local_index.py @@ -5,6 +5,11 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis_up_to, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -88,8 +93,32 @@ def local_index(array, axis=-1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + + # Step 2: propagate named axis from input to output, + # use strategy "keep up to" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1]) + out = ak._do.local_index(layout, axis) - return ctx.wrap(out, highlevel=highlevel) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_mask.py b/src/awkward/operations/ak_mask.py index 54d9a5e04b..b18273047a 100644 --- a/src/awkward/operations/ak_mask.py +++ b/src/awkward/operations/ak_mask.py @@ -2,9 +2,12 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("mask",) @@ -124,8 +127,26 @@ def action(inputs, backend, **kwargs): ctx.unwrap(mask, allow_record=False, primitive_policy="error"), ) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts([array, mask]) out = ak._broadcasting.broadcast_and_apply( - layouts, action, numpy_to_regular=True, right_broadcast=False + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + numpy_to_regular=True, + right_broadcast=False, ) assert isinstance(out, tuple) and len(out) == 1 - return ctx.wrap(out[0], highlevel=highlevel) + + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_max.py b/src/awkward/operations/ak_max.py index a01a0d64c5..319b2c7bed 100644 --- a/src/awkward/operations/ak_max.py +++ b/src/awkward/operations/ak_max.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -142,9 +148,26 @@ def nanmax( def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Max(initial) out = ak._do.reduce( @@ -155,7 +178,21 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("amax") diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py index fa74a89b61..a9b38ce1f0 100644 --- a/src/awkward/operations/ak_mean.py +++ b/src/awkward/operations/ak_mean.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -11,6 +12,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -174,8 +179,6 @@ def nanmean( def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -191,6 +194,13 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): if weight is None: sumw = ak.operations.ak_count._impl( @@ -245,14 +255,25 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): if axis is None: if not keepdims: + # remove all dimensions out = out[(0,) * out.ndim] else: if not keepdims: + # remove reduced dimension posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out) or {}), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_merge_option_of_records.py b/src/awkward/operations/ak_merge_option_of_records.py index c3e1095ba4..17402e77a6 100644 --- a/src/awkward/operations/ak_merge_option_of_records.py +++ b/src/awkward/operations/ak_merge_option_of_records.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -49,10 +53,15 @@ def merge_option_of_records( def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + # First, normalise type-invsible "index-of-records" to "record-of-index" def apply_displace_index(layout, backend, **kwargs): if (layout.is_indexed and not layout.is_option) and layout.content.is_record: diff --git a/src/awkward/operations/ak_merge_union_of_records.py b/src/awkward/operations/ak_merge_union_of_records.py index d523c0b5f8..0094203947 100644 --- a/src/awkward/operations/ak_merge_union_of_records.py +++ b/src/awkward/operations/ak_merge_union_of_records.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import ArrayLike, NumpyMetadata from awkward._regularize import regularize_axis from awkward.errors import AxisError @@ -59,10 +63,15 @@ def merge_union_of_records( def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + def invert_record_union( tags: ArrayLike, index: ArrayLike, contents ) -> ak.contents.RecordArray: diff --git a/src/awkward/operations/ak_min.py b/src/awkward/operations/ak_min.py index 05e583d430..1b9189f740 100644 --- a/src/awkward/operations/ak_min.py +++ b/src/awkward/operations/ak_min.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -142,9 +148,26 @@ def nanmin( def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Min(initial) out = ak._do.reduce( @@ -155,7 +178,21 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("amin") diff --git a/src/awkward/operations/ak_moment.py b/src/awkward/operations/ak_moment.py index 7cac2498ee..2c8e29adb1 100644 --- a/src/awkward/operations/ak_moment.py +++ b/src/awkward/operations/ak_moment.py @@ -3,14 +3,19 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import ( + AxisName, + _get_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis +from awkward._typing import Mapping __all__ = ("moment",) @@ -22,13 +27,13 @@ def moment( x, n, weight=None, - axis=None, + axis: AxisName = None, *, - keepdims=False, - mask_identity=False, - highlevel=True, - behavior=None, - attrs=None, + keepdims: bool = False, + mask_identity: bool = False, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, ): """ Args: @@ -86,9 +91,17 @@ def moment( ) -def _impl(x, n, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - +def _impl( + x, + n, + weight, + axis: AxisName, + keepdims: bool, + mask_identity: bool, + highlevel: bool, + behavior: Mapping | None, + attrs: Mapping | None, +): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -143,8 +156,20 @@ def _impl(x, n, weight, axis, keepdims, mask_identity, highlevel, behavior, attr behavior=ctx.behavior, attrs=ctx.attrs, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(sumwxn / sumw), + + out = sumwxn / sumw + + # propagate named axis to output + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_nan_to_none.py b/src/awkward/operations/ak_nan_to_none.py index 7dabbfe828..23ef938dbe 100644 --- a/src/awkward/operations/ak_nan_to_none.py +++ b/src/awkward/operations/ak_nan_to_none.py @@ -6,6 +6,7 @@ from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext from awkward._nplikes.numpy_like import NumpyMetadata +from awkward._typing import Mapping __all__ = ("nan_to_none",) @@ -13,7 +14,13 @@ @high_level_function() -def nan_to_none(array, *, highlevel=True, behavior=None, attrs=None): +def nan_to_none( + array, + *, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, +): """ Args: array: Array-like data (anything #ak.to_layout recognizes). @@ -35,7 +42,7 @@ def nan_to_none(array, *, highlevel=True, behavior=None, attrs=None): return _impl(array, highlevel, behavior, attrs) -def _impl(array, highlevel, behavior, attrs): +def _impl(array, highlevel: bool, behavior: Mapping | None, attrs: Mapping | None): def action(layout, continuation, backend, **kwargs): if layout.is_numpy and np.issubdtype(layout.dtype, np.floating): mask = backend.nplike.isnan(layout.data) @@ -55,5 +62,6 @@ def action(layout, continuation, backend, **kwargs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + out = ak._do.recursively_apply(layout, action) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_nan_to_num.py b/src/awkward/operations/ak_nan_to_num.py index 4c7472a06f..69e2617c00 100644 --- a/src/awkward/operations/ak_nan_to_num.py +++ b/src/awkward/operations/ak_nan_to_num.py @@ -2,10 +2,14 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata +from awkward._typing import Mapping __all__ = ("nan_to_num",) @@ -15,14 +19,14 @@ @high_level_function() def nan_to_num( array, - copy=True, + copy: bool = True, nan=0.0, posinf=None, neginf=None, *, - highlevel=True, - behavior=None, - attrs=None, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, ): """ Args: @@ -52,7 +56,16 @@ def nan_to_num( return _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs) -def _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs): +def _impl( + array, + copy: bool, + nan, + posinf, + neginf, + highlevel: bool, + behavior: Mapping | None, + attrs: Mapping | None, +): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout, nan_layout, posinf_layout, neginf_layout = ensure_same_backend( ctx.unwrap(array), @@ -81,15 +94,19 @@ def _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs): broadcasting_ids = {} broadcasting = [layout] + arrays_to_broadcast = [array] if isinstance(nan_layout, ak.contents.Content): broadcasting_ids[id(nan)] = len(broadcasting) broadcasting.append(nan_layout) + arrays_to_broadcast.append(nan) if isinstance(posinf_layout, ak.contents.Content): broadcasting_ids[id(posinf)] = len(broadcasting) broadcasting.append(posinf_layout) + arrays_to_broadcast.append(posinf) if isinstance(neginf_layout, ak.contents.Content): broadcasting_ids[id(neginf)] = len(broadcasting) broadcasting.append(neginf_layout) + arrays_to_broadcast.append(neginf) if len(broadcasting) == 1: @@ -138,9 +155,29 @@ def action(inputs, backend, **kwargs): else: return None - out = ak._broadcasting.broadcast_and_apply(broadcasting, action) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + arrays_to_broadcast + ) + out = ak._broadcasting.broadcast_and_apply( + broadcasting, + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) assert isinstance(out, tuple) and len(out) == 1 - out = out[0] + + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_num.py b/src/awkward/operations/ak_num.py index ad9b4e746c..705a1e1c63 100644 --- a/src/awkward/operations/ak_num.py +++ b/src/awkward/operations/ak_num.py @@ -5,8 +5,14 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis +from awkward._typing import Mapping from awkward.errors import AxisError __all__ = ("num",) @@ -15,7 +21,14 @@ @high_level_function() -def num(array, axis=1, *, highlevel=True, behavior=None, attrs=None): +def num( + array, + axis=1, + *, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping | None = None, +): """ Args: array: Array-like data (anything #ak.to_layout recognizes). @@ -83,13 +96,25 @@ def num(array, axis=1, *, highlevel=True, behavior=None, attrs=None): return _impl(array, axis, highlevel, behavior, attrs) -def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) +def _impl( + array, + axis, + highlevel: bool, + behavior: Mapping | None, + attrs: Mapping | None, +): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # use strategy "keep one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) if maybe_posaxis(layout, axis, 1) == 0: index_nplike = layout.backend.index_nplike @@ -109,4 +134,16 @@ def action(layout, depth, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_pad_none.py b/src/awkward/operations/ak_pad_none.py index 34355a8546..17bb3035ac 100644 --- a/src/awkward/operations/ak_pad_none.py +++ b/src/awkward/operations/ak_pad_none.py @@ -5,6 +5,10 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -113,9 +117,15 @@ def pad_none( def _impl(array, target, axis, clip, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + out = ak._do.pad_none(layout, target, axis, clip=clip) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_prod.py b/src/awkward/operations/ak_prod.py index cde898f174..d3d1a050c3 100644 --- a/src/awkward/operations/ak_prod.py +++ b/src/awkward/operations/ak_prod.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -119,9 +125,26 @@ def nanprod( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Prod() out = ak._do.reduce( @@ -132,7 +155,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("prod") diff --git a/src/awkward/operations/ak_ptp.py b/src/awkward/operations/ak_ptp.py index 56daaa6980..6d4beafbd5 100644 --- a/src/awkward/operations/ak_ptp.py +++ b/src/awkward/operations/ak_ptp.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -10,6 +11,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -83,10 +88,16 @@ def ptp( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): maxi = ak.operations.ak_max._impl( layout, @@ -126,8 +137,18 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_ravel.py b/src/awkward/operations/ak_ravel.py index 66a3e3a55d..062601eff4 100644 --- a/src/awkward/operations/ak_ravel.py +++ b/src/awkward/operations/ak_ravel.py @@ -75,7 +75,16 @@ def _impl(array, highlevel, behavior, attrs): result = ak._do.mergemany(out) - return ctx.wrap(result, highlevel=highlevel) + wrapped_out = ctx.wrap(result, highlevel=highlevel) + + # propagate named axis to output + # use strategy "remove all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("ravel") diff --git a/src/awkward/operations/ak_real.py b/src/awkward/operations/ak_real.py index 6d52971dab..655e4e8007 100644 --- a/src/awkward/operations/ak_real.py +++ b/src/awkward/operations/ak_real.py @@ -14,10 +14,10 @@ @ak._connect.numpy.implements("real") @high_level_function() -def real(val, highlevel=True, behavior=None, attrs=None): +def real(array, highlevel=True, behavior=None, attrs=None): """ Args: - val : array_like + array : array_like Input array. highlevel (bool, default is True): If True, return an #ak.Array; otherwise, return a low-level #ak.contents.Content subclass. @@ -30,15 +30,15 @@ def real(val, highlevel=True, behavior=None, attrs=None): If the arrays have complex elements, the returned arrays are floats. """ # Dispatch - yield (val,) + yield (array,) # Implementation - return _impl_real(val, highlevel, behavior, attrs) + return _impl(array, highlevel, behavior, attrs) -def _impl_real(val, highlevel, behavior, attrs): +def _impl(array, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: - layout = ctx.unwrap(val, allow_record=False, primitive_policy="error") + layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") out = ak._do.recursively_apply(layout, _action_real) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_singletons.py b/src/awkward/operations/ak_singletons.py index 35f60d5c97..4de6a59151 100644 --- a/src/awkward/operations/ak_singletons.py +++ b/src/awkward/operations/ak_singletons.py @@ -5,8 +5,13 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis +from awkward._namedaxis import ( + _add_named_axis, + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("singletons",) @@ -56,12 +61,21 @@ def singletons(array, axis=0, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer, not {axis!r}") + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + + # Step 2: propagate named axis from input to output, + # use strategy "add one" (see: awkward._namedaxis) + out_named_axis = _add_named_axis( + named_axis, (axis + 1) if axis >= 0 else axis, layout.minmax_depth[1] + ) def action(layout, depth, backend, **kwargs): posaxis = maybe_posaxis(layout, axis, depth) @@ -90,4 +104,16 @@ def action(layout, depth, backend, **kwargs): out = ak._do.recursively_apply(layout, action, numpy_to_regular=True) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_softmax.py b/src/awkward/operations/ak_softmax.py index e86cbe9cf0..b2cb11bff0 100644 --- a/src/awkward/operations/ak_softmax.py +++ b/src/awkward/operations/ak_softmax.py @@ -3,12 +3,17 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -75,10 +80,16 @@ def softmax( def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs): original_axis = axis - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout = ctx.unwrap(x, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) + x = ctx.wrap(x_layout) if maybe_posaxis(x_layout, axis, 1) != maybe_posaxis(x_layout, -1, 1): @@ -97,8 +108,19 @@ def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs): behavior=ctx.behavior, attrs=ctx.attrs, ) - return ctx.wrap( - maybe_highlevel_to_lowlevel(expx / denom), + + out = expx / denom + + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, + ) diff --git a/src/awkward/operations/ak_sort.py b/src/awkward/operations/ak_sort.py index 5e82e91604..0864fc5d98 100644 --- a/src/awkward/operations/ak_sort.py +++ b/src/awkward/operations/ak_sort.py @@ -6,6 +6,10 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -59,11 +63,22 @@ def sort( def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + out = ak._do.sort(layout, axis, ascending, stable) - return ctx.wrap(out, highlevel=highlevel) + + return ctx.wrap( + out, + highlevel=highlevel, + ) @ak._connect.numpy.implements("sort") diff --git a/src/awkward/operations/ak_std.py b/src/awkward/operations/ak_std.py index 0385032440..7926b341fe 100644 --- a/src/awkward/operations/ak_std.py +++ b/src/awkward/operations/ak_std.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -11,6 +12,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -165,8 +170,6 @@ def nanstd( def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -182,6 +185,13 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): out = ufuncs.sqrt( ak.operations.ak_var._impl( @@ -215,8 +225,18 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_strings_astype.py b/src/awkward/operations/ak_strings_astype.py index b0834db3a6..479232cf01 100644 --- a/src/awkward/operations/ak_strings_astype.py +++ b/src/awkward/operations/ak_strings_astype.py @@ -82,5 +82,6 @@ def action(layout, **kwargs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + out = ak._do.recursively_apply(layout, action) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_sum.py b/src/awkward/operations/ak_sum.py index f00434083e..ae6a40aef8 100644 --- a/src/awkward/operations/ak_sum.py +++ b/src/awkward/operations/ak_sum.py @@ -6,6 +6,12 @@ from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _get_named_axis, + _keep_named_axis, + _named_axis_to_positional_axis, + _remove_named_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -269,9 +275,26 @@ def nansum( def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + # Step 2: propagate named axis from input to output, + # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) + # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) + out_named_axis = _keep_named_axis(named_axis, None) + if not keepdims: + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.minmax_depth[1], + ) + + axis = regularize_axis(axis, none_allowed=True) + reducer = ak._reducers.Sum() out = ak._do.reduce( @@ -282,7 +305,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): keepdims=keepdims, behavior=ctx.behavior, ) - return ctx.wrap(out, highlevel=highlevel, allow_other=True) + + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) @ak._connect.numpy.implements("sum") diff --git a/src/awkward/operations/ak_to_backend.py b/src/awkward/operations/ak_to_backend.py index f65a2c0a81..8d93e2de94 100644 --- a/src/awkward/operations/ak_to_backend.py +++ b/src/awkward/operations/ak_to_backend.py @@ -17,7 +17,7 @@ def to_backend(array, backend, *, highlevel=True, behavior=None, attrs=None): """ Args: array: Array-like data (anything #ak.to_layout recognizes). - backend (`"cpu"`, `"cuda"`, or `"jax"`): If `"cpu"`, the array structure is + backend (`"cpu"`, `"cuda"`, `"jax"`, or `"typetracer"`): If `"cpu"`, the array structure is recursively copied (if need be) to main memory for use with the default Numpy backend; if `"cuda"`, the structure is copied to the GPU(s) for use with CuPy. If `"jax"`, the structure is diff --git a/src/awkward/operations/ak_to_regular.py b/src/awkward/operations/ak_to_regular.py index b72e48d7c5..ae9f9cc3da 100644 --- a/src/awkward/operations/ak_to_regular.py +++ b/src/awkward/operations/ak_to_regular.py @@ -66,7 +66,7 @@ def to_regular(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") diff --git a/src/awkward/operations/ak_transform.py b/src/awkward/operations/ak_transform.py index 23b4dbfd4e..93a4911914 100644 --- a/src/awkward/operations/ak_transform.py +++ b/src/awkward/operations/ak_transform.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +from functools import reduce import awkward as ak from awkward._backends.numpy import NumpyBackend @@ -15,6 +16,7 @@ ) from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis __all__ = ("transform",) @@ -580,6 +582,17 @@ def action(inputs, **kwargs): f"transformation must return a Content, tuple of Contents, or None, not {type(out)}\n\n{out!r}" ) + if depth_context is None: + depth_context = {} + if lateral_context is None: + lateral_context = {} + assert NAMED_AXIS_KEY not in depth_context + assert NAMED_AXIS_KEY not in lateral_context + _depth_context, _lateral_context = NamedAxesWithDims.prepare_contexts( + [array, *more_arrays] + ) + depth_context.update(_depth_context) + lateral_context.update(_lateral_context) backend = next((layout.backend for layout in layouts), cpu) isscalar = [] out = apply_broadcasting_step( @@ -594,6 +607,11 @@ def action(inputs, **kwargs): assert isinstance(out, tuple) out = [broadcast_unpack(x, isscalar) for x in out] + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + if return_value == "none": return elif expect_return_value and not transformer_did_terminate: @@ -602,6 +620,25 @@ def action(inputs, **kwargs): "or tuple of Contents, but instead only returned None." ) elif len(out) == 1: - return ctx.wrap(out[0], highlevel=highlevel) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) else: - return tuple(ctx.wrap(x, highlevel=highlevel) for x in out) + wrapped_out = [] + for x in out: + wrapped = ctx.wrap(x, highlevel=highlevel) + wrapped_out.append( + ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) + ) + return tuple(wrapped_out) diff --git a/src/awkward/operations/ak_unflatten.py b/src/awkward/operations/ak_unflatten.py index 78c2631e31..83a3b8f2b4 100644 --- a/src/awkward/operations/ak_unflatten.py +++ b/src/awkward/operations/ak_unflatten.py @@ -6,6 +6,10 @@ from awkward._backends.numpy import NumpyBackend from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length from awkward._nplikes.typetracer import is_unknown_scalar @@ -91,8 +95,6 @@ def unflatten(array, counts, axis=0, *, highlevel=True, behavior=None, attrs=Non def _impl(array, counts, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout, maybe_counts_layout = ensure_same_backend( ctx.unwrap(array, allow_record=False, primitive_policy="error"), @@ -105,6 +107,13 @@ def _impl(array, counts, axis, highlevel, behavior, attrs): ), ) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + if is_integer_like(maybe_counts_layout): # Regularize unknown values to unknown lengths if ( @@ -292,4 +301,16 @@ def apply(layout, depth, backend, **kwargs): f"at axis={axis}" ) - return ctx.wrap(out, highlevel=highlevel) + wrapped_out = ctx.wrap( + out, + highlevel=highlevel, + ) + + # Step 2: propagate named axis from input to output, + # use strategy "remove all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( + wrapped_out, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_unzip.py b/src/awkward/operations/ak_unzip.py index 8d0bfc229a..8c19380133 100644 --- a/src/awkward/operations/ak_unzip.py +++ b/src/awkward/operations/ak_unzip.py @@ -51,6 +51,7 @@ def unzip(array, *, highlevel=True, behavior=None, attrs=None): def _impl(array, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=True, primitive_policy="error") + fields = ak.operations.fields(layout) def check_for_union(layout, **kwargs): @@ -70,5 +71,10 @@ def check_for_union(layout, **kwargs): return (ctx.wrap(layout, highlevel=highlevel, allow_other=True),) else: return tuple( - ctx.wrap(layout[n], highlevel=highlevel, allow_other=True) for n in fields + ctx.wrap( + layout[n], + highlevel=highlevel, + allow_other=True, + ) + for n in fields ) diff --git a/src/awkward/operations/ak_values_astype.py b/src/awkward/operations/ak_values_astype.py index 714a4320d9..fa25ca5a35 100644 --- a/src/awkward/operations/ak_values_astype.py +++ b/src/awkward/operations/ak_values_astype.py @@ -72,6 +72,7 @@ def values_astype( def _impl(array, to, including_unknown, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + to_str = ak.types.numpytype.dtype_to_primitive(np.dtype(to)) out = ak._do.numbers_to_type(layout, to_str, including_unknown) return ctx.wrap(out, highlevel=highlevel) diff --git a/src/awkward/operations/ak_var.py b/src/awkward/operations/ak_var.py index d1139d8b4c..759f5edf1c 100644 --- a/src/awkward/operations/ak_var.py +++ b/src/awkward/operations/ak_var.py @@ -3,6 +3,7 @@ from __future__ import annotations import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import ( @@ -11,6 +12,10 @@ maybe_highlevel_to_lowlevel, maybe_posaxis, ) +from awkward._namedaxis import ( + _get_named_axis, + _named_axis_to_positional_axis, +) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -170,8 +175,6 @@ def nanvar( def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), @@ -187,6 +190,12 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) + # Handle named axis + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) + with np.errstate(invalid="ignore", divide="ignore"): if weight is None: sumw = ak.operations.ak_count._impl( @@ -267,8 +276,19 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a posaxis = maybe_posaxis(out.layout, axis, 1) out = out[(slice(None, None),) * posaxis + (0,)] - return ctx.wrap( - maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True + wrapped = ctx.wrap( + maybe_highlevel_to_lowlevel(out), + highlevel=highlevel, + allow_other=True, + ) + + # propagate named axis to output + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=_get_named_axis(attrs_of_obj(out)), + highlevel=highlevel, + behavior=None, + attrs=None, ) diff --git a/src/awkward/operations/ak_where.py b/src/awkward/operations/ak_where.py index dda7d99f42..07f5f2f7bd 100644 --- a/src/awkward/operations/ak_where.py +++ b/src/awkward/operations/ak_where.py @@ -2,9 +2,12 @@ from __future__ import annotations +from functools import reduce + import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("where",) @@ -121,8 +124,26 @@ def action(inputs, backend, **kwargs): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [x, y, condition] + ) out = ak._broadcasting.broadcast_and_apply( - layouts, action, numpy_to_regular=True, function_name="ak.where" + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + numpy_to_regular=True, + function_name="ak.where", + ) + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, ) - - return ctx.wrap(out[0], highlevel=highlevel) diff --git a/src/awkward/operations/ak_with_field.py b/src/awkward/operations/ak_with_field.py index 671a061978..3adb5c33a1 100644 --- a/src/awkward/operations/ak_with_field.py +++ b/src/awkward/operations/ak_with_field.py @@ -3,10 +3,12 @@ from __future__ import annotations import copy +from functools import reduce import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_non_string_like_sequence @@ -76,6 +78,11 @@ def _impl(base, what, where, highlevel, behavior, attrs): if is_non_string_like_sequence(where): where = where[0] + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [base, what], + unwrap_kwargs={"none_policy": "promote"}, + ) + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: base, what = ensure_same_backend( ctx.unwrap(base, allow_record=True, primitive_policy="error"), @@ -156,9 +163,24 @@ def action(inputs, **kwargs): return None out = ak._broadcasting.broadcast_and_apply( - [base, what], action, right_broadcast=False + [base, what], + action, + depth_context=depth_context, + lateral_context=lateral_context, + right_broadcast=False, ) assert isinstance(out, tuple) and len(out) == 1 - return ctx.wrap(out[0], highlevel=highlevel) + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out[0], highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/ak_with_named_axis.py b/src/awkward/operations/ak_with_named_axis.py new file mode 100644 index 0000000000..507acc485c --- /dev/null +++ b/src/awkward/operations/ak_with_named_axis.py @@ -0,0 +1,72 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +from awkward._dispatch import high_level_function +from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + NAMED_AXIS_KEY, + AxisMapping, + AxisTuple, + _prepare_named_axis_for_attrs, +) +from awkward._nplikes.numpy_like import NumpyMetadata + +__all__ = ("with_named_axis",) + +np = NumpyMetadata.instance() + + +@high_level_function() +def with_named_axis( + array, + named_axis: AxisTuple | AxisMapping, + *, + highlevel=True, + behavior=None, + attrs=None, +): + """ + Args: + array: Array-like data (anything #ak.to_layout recognizes). + named_axis: AxisTuple | AxisMapping: Names to give to the array axis; this assigns + the `"__named_axis__"` attr. If None, any existing name is unset. + highlevel (bool): If True, return an #ak.Array; otherwise, return + a low-level #ak.contents.Content subclass. + behavior (None or dict): Custom #ak.behavior for the output array, if + high-level. + attrs (None or dict): Custom attributes for the output array, if + high-level. + + Returns an #ak.Array or #ak.Record (or low-level equivalent, if + `highlevel=False`) with a new name. This function does not change the + array in-place. If the new name is None, then the array is returned as it is. + """ + # Dispatch + yield (array,) + + # Implementation + return _impl(array, named_axis, highlevel, behavior, attrs) + + +def _impl(array, named_axis, highlevel, behavior, attrs): + # Named axis handling + if not named_axis: # no-op, e.g. named_axis is None, (), {} + return array + + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=True) + + _named_axis = _prepare_named_axis_for_attrs( + named_axis=named_axis, + ndim=layout.minmax_depth[1], + ) + # now we're good, set the named axis + return ctx.with_attr( + key=NAMED_AXIS_KEY, + value=_named_axis, + ).wrap( + layout, + highlevel=highlevel, + allow_other=True, + ) diff --git a/src/awkward/operations/ak_without_named_axis.py b/src/awkward/operations/ak_without_named_axis.py new file mode 100644 index 0000000000..3697344a4b --- /dev/null +++ b/src/awkward/operations/ak_without_named_axis.py @@ -0,0 +1,54 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +from awkward._dispatch import high_level_function +from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + NAMED_AXIS_KEY, +) +from awkward._nplikes.numpy_like import NumpyMetadata + +__all__ = ("without_named_axis",) + +np = NumpyMetadata.instance() + + +@high_level_function() +def without_named_axis( + array, + *, + highlevel=True, + behavior=None, + attrs=None, +): + """ + Args: + array: Array-like data (anything #ak.to_layout recognizes). + highlevel (bool): If True, return an #ak.Array; otherwise, return + a low-level #ak.contents.Content subclass. + behavior (None or dict): Custom #ak.behavior for the output array, if + high-level. + attrs (None or dict): Custom attributes for the output array, if + high-level. + + Returns an #ak.Array or #ak.Record (or low-level equivalent, if + `highlevel=False`) without named axes. This function does not change the + array in-place. + """ + # Dispatch + yield (array,) + + # Implementation + return _impl(array, highlevel, behavior, attrs) + + +def _impl(array, highlevel, behavior, attrs): + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=True) + + return ctx.without_attr(key=NAMED_AXIS_KEY).wrap( + layout, + highlevel=highlevel, + allow_other=True, + ) diff --git a/src/awkward/operations/ak_zip.py b/src/awkward/operations/ak_zip.py index bed5c233e5..5ce58f8b1a 100644 --- a/src/awkward/operations/ak_zip.py +++ b/src/awkward/operations/ak_zip.py @@ -3,10 +3,12 @@ from __future__ import annotations from collections.abc import Mapping +from functools import reduce import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("zip",) @@ -174,6 +176,7 @@ def _impl( ): if depth_limit is not None and depth_limit <= 0: raise ValueError("depth_limit must be None or at least 1") + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: if isinstance(arrays, Mapping): layouts = ensure_same_backend( @@ -238,8 +241,15 @@ def action(inputs, depth, backend, **ignore): else: return None + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + list(arrays.values()) if isinstance(arrays, Mapping) else list(arrays) + ) out = ak._broadcasting.broadcast_and_apply( - layouts, action, right_broadcast=right_broadcast + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + right_broadcast=right_broadcast, ) assert isinstance(out, tuple) and len(out) == 1 out = out[0] @@ -248,4 +258,15 @@ def action(inputs, depth, backend, **ignore): out = out[0] assert isinstance(out, ak.record.Record) - return ctx.wrap(out, highlevel=highlevel) + # Unify named axes propagated through the broadcast + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped_out = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped_out, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=ctx.behavior, + attrs=ctx.attrs, + ) diff --git a/src/awkward/operations/str/akstr_join.py b/src/awkward/operations/str/akstr_join.py index a5ab638ba5..d18dc0174e 100644 --- a/src/awkward/operations/str/akstr_join.py +++ b/src/awkward/operations/str/akstr_join.py @@ -2,10 +2,15 @@ from __future__ import annotations +from functools import reduce + import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._backends.typetracer import TypeTracerBackend +from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis __all__ = ("join",) @@ -95,6 +100,7 @@ def apply_unary(layout, **kwargs): ) out = ak._do.recursively_apply(layout, apply_unary) + return ctx.wrap(out, highlevel=highlevel) else: def apply_binary(layouts, **kwargs): @@ -123,8 +129,24 @@ def apply_binary(layouts, **kwargs): ), ) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [array, separator] + ) (out,) = ak._broadcasting.broadcast_and_apply( - (layout, maybe_separator_layout), apply_binary + (layout, maybe_separator_layout), + apply_binary, + depth_context=depth_context, + lateral_context=lateral_context, ) - return ctx.wrap(out, highlevel=highlevel) + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) diff --git a/src/awkward/operations/str/akstr_join_element_wise.py b/src/awkward/operations/str/akstr_join_element_wise.py index 98f4e42f91..cd2ed0184f 100644 --- a/src/awkward/operations/str/akstr_join_element_wise.py +++ b/src/awkward/operations/str/akstr_join_element_wise.py @@ -2,10 +2,15 @@ from __future__ import annotations +from functools import reduce + import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._backends.typetracer import TypeTracerBackend +from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis __all__ = ("join_element_wise",) @@ -66,6 +71,22 @@ def action(layouts, **kwargs): ): return (_apply_through_arrow(pc.binary_join_element_wise, *layouts),) - (out,) = ak._broadcasting.broadcast_and_apply(layouts, action) - - return ctx.wrap(out, highlevel=highlevel) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arrays) + (out,) = ak._broadcasting.broadcast_and_apply( + layouts, + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) + + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), + ) diff --git a/src/awkward/operations/str/akstr_repeat.py b/src/awkward/operations/str/akstr_repeat.py index de929c57b7..49bec96569 100644 --- a/src/awkward/operations/str/akstr_repeat.py +++ b/src/awkward/operations/str/akstr_repeat.py @@ -3,11 +3,15 @@ from __future__ import annotations import numbers +from functools import reduce import awkward as ak +from awkward._attrs import attrs_of_obj from awkward._backends.typetracer import TypeTracerBackend +from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, ensure_same_backend +from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("repeat",) @@ -79,8 +83,26 @@ def action(inputs, **kwargs): return (_apply_through_arrow(pc.binary_repeat, *inputs),) + depth_context, lateral_context = NamedAxesWithDims.prepare_contexts( + [array, num_repeats] + ) (out,) = ak._broadcasting.broadcast_and_apply( - (layout, num_repeats_layout), action + (layout, num_repeats_layout), + action, + depth_context=depth_context, + lateral_context=lateral_context, + ) + + out_named_axis = reduce( + _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis + ) + wrapped = ctx.wrap(out, highlevel=highlevel) + return ak.operations.ak_with_named_axis._impl( + wrapped, + named_axis=out_named_axis, + highlevel=highlevel, + behavior=behavior_of_obj(wrapped), + attrs=attrs_of_obj(wrapped), ) else: @@ -98,4 +120,4 @@ def action(layout, **kwargs): out = ak._do.recursively_apply(layout, action) - return ctx.wrap(out, highlevel=highlevel) + return ctx.wrap(out, highlevel=highlevel) diff --git a/tests/test_2596_named_axis.py b/tests/test_2596_named_axis.py new file mode 100644 index 0000000000..acfa1e9e34 --- /dev/null +++ b/tests/test_2596_named_axis.py @@ -0,0 +1,2243 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import sys + +import numpy as np +import pytest + +import awkward as ak +from awkward._namedaxis import _get_named_axis + + +def test_constructor(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis={"x": 0, "y": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + +def test_with_named_axis(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + assert not _get_named_axis(array) + assert array.named_axis == {} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis=("x", "y")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis=("x", None)) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis=(None, "x")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis={"x": 0, "y": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis={"x": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 1} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis={"y": -1}) + assert _get_named_axis(array) + assert array.named_axis == {"y": -1} + assert array.positional_axis == (0, 1) + + # This is possible in a future version of named axis, but currently only strings are supported + # from dataclasses import dataclass + + # @dataclass(frozen=True) + # class exotic_axis: + # attr: str + + # ax1 = exotic_axis(attr="I'm not the type of axis that you're used to") + # ax2 = exotic_axis(attr="...me neither!") + + # array = ak.with_named_axis(array, named_axis=(ax1, ax2)) + # assert array.named_axis == (ax1, ax2) + # assert array.positional_axis == (0, 1) + + +def test_named_axis_indexing(): + array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + + # test indexing + assert ak.all(array[...] == named_array[...]) + assert ak.all(array[()] == named_array[()]) + + assert ak.all(array[None, :, :, :] == named_array[None, :, :, :]) + assert ak.all(array[:, None, :, :] == named_array[:, None, :, :]) + assert ak.all(array[:, :, None, :] == named_array[:, :, None, :]) + assert ak.all(array[:, :, :, None] == named_array[:, :, :, None]) + + assert ak.all(array[0, :, :] == named_array[{"x": 0}]) + assert ak.all(array[:, 0, :] == named_array[{"y": 0}]) + assert ak.all(array[:, :, 0] == named_array[{"z": 0}]) + + assert ak.all(array[0, :, :] == named_array[{0: 0}]) + assert ak.all(array[:, 0, :] == named_array[{1: 0}]) + assert ak.all(array[:, :, 0] == named_array[{2: 0}]) + + assert ak.all(array[0, :, :] == named_array[{-3: 0}]) + assert ak.all(array[:, 0, :] == named_array[{-2: 0}]) + assert ak.all(array[:, :, 0] == named_array[{-1: 0}]) + + assert ak.all(array[0, 0, :] == named_array[{"x": 0, "y": 0}]) + assert ak.all(array[0, :, 0] == named_array[{"x": 0, "z": 0}]) + assert ak.all(array[:, 0, 0] == named_array[{"y": 0, "z": 0}]) + assert array[0, 0, 0] == named_array[{"x": 0, "y": 0, "z": 0}] + + assert ak.all(array[slice(0, 1), :, :] == named_array[{"x": slice(0, 1)}]) + assert ak.all(array[:, slice(0, 1), :] == named_array[{"y": slice(0, 1)}]) + assert ak.all(array[:, :, slice(0, 1)] == named_array[{"z": slice(0, 1)}]) + + assert ak.all(array[0, :, slice(0, 1)] == named_array[{"x": 0, "z": slice(0, 1)}]) + assert ak.all(array[:, 0, slice(0, 1)] == named_array[{"y": 0, "z": slice(0, 1)}]) + assert ak.all(array[slice(0, 1), 0, :] == named_array[{"x": slice(0, 1), "y": 0}]) + + assert ak.all(array[array > 3] == named_array[named_array > 3]) + + # test naming propagation + assert ( + named_array[...].named_axis + == named_array.named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[()].named_axis == named_array.named_axis == {"x": 0, "y": 1, "z": 2} + ) + + assert named_array[None, :, :, :].named_axis == {"x": 1, "y": 2, "z": 3} + assert named_array[:, None, :, :].named_axis == {"x": 0, "y": 2, "z": 3} + assert named_array[:, :, None, :].named_axis == {"x": 0, "y": 1, "z": 3} + assert named_array[:, :, :, None].named_axis == {"x": 0, "y": 1, "z": 2} + + assert named_array[None, ...].named_axis == {"x": 1, "y": 2, "z": 3} + assert named_array[:, None, ...].named_axis == {"x": 0, "y": 2, "z": 3} + assert named_array[..., None, :].named_axis == {"x": 0, "y": 1, "z": 3} + assert named_array[..., None].named_axis == {"x": 0, "y": 1, "z": 2} + + assert ( + named_array[0, :, :].named_axis + == named_array[{"x": 0}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[:, :, 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": 0, "y": 1} + ) + + assert ( + named_array[0, ...].named_axis + == named_array[{"x": 0}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[..., 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": 0, "y": 1} + ) + + assert named_array[{0: 0}].named_axis == {"y": 0, "z": 1} + assert named_array[{1: 0}].named_axis == {"x": 0, "z": 1} + assert named_array[{2: 0}].named_axis == {"x": 0, "y": 1} + + assert named_array[{-3: 0}].named_axis == {"y": 0, "z": 1} + assert named_array[{-2: 0}].named_axis == {"x": 0, "z": 1} + assert named_array[{-1: 0}].named_axis == {"x": 0, "y": 1} + + assert ( + named_array[0, 0, :].named_axis + == named_array[{"x": 0, "y": 0}].named_axis + == {"z": 0} + ) + assert ( + named_array[0, :, 0].named_axis + == named_array[{"x": 0, "z": 0}].named_axis + == {"y": 0} + ) + assert ( + named_array[:, 0, 0].named_axis + == named_array[{"y": 0, "z": 0}].named_axis + == {"x": 0} + ) + assert not _get_named_axis(named_array[0, 0, 0]) + assert not _get_named_axis(named_array[{"x": 0, "y": 0, "z": 0}]) + + assert ( + named_array[slice(0, 1), :, :].named_axis + == named_array[{"x": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[:, slice(0, 1), :].named_axis + == named_array[{"y": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[:, :, slice(0, 1)].named_axis + == named_array[{"z": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + + assert ( + named_array[0, :, slice(0, 1)].named_axis + == named_array[{"x": 0, "z": slice(0, 1)}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, slice(0, 1)].named_axis + == named_array[{"y": 0, "z": slice(0, 1)}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[slice(0, 1), 0, :].named_axis + == named_array[{"x": slice(0, 1), "y": 0}].named_axis + == {"x": 0, "z": 1} + ) + + +def test_negative_named_axis_indexing(): + array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1}) + + # test indexing + assert ak.all(array[...] == named_array[...]) + assert ak.all(array[()] == named_array[()]) + + assert ak.all(array[None, :, :, :] == named_array[None, :, :, :]) + assert ak.all(array[:, None, :, :] == named_array[:, None, :, :]) + assert ak.all(array[:, :, None, :] == named_array[:, :, None, :]) + assert ak.all(array[:, :, :, None] == named_array[:, :, :, None]) + + assert ak.all(array[0, :, :] == named_array[{"x": 0}]) + assert ak.all(array[:, 0, :] == named_array[{"y": 0}]) + assert ak.all(array[:, :, 0] == named_array[{"z": 0}]) + + assert ak.all(array[0, :, :] == named_array[{0: 0}]) + assert ak.all(array[:, 0, :] == named_array[{1: 0}]) + assert ak.all(array[:, :, 0] == named_array[{2: 0}]) + + assert ak.all(array[0, :, :] == named_array[{-3: 0}]) + assert ak.all(array[:, 0, :] == named_array[{-2: 0}]) + assert ak.all(array[:, :, 0] == named_array[{-1: 0}]) + + assert ak.all(array[0, 0, :] == named_array[{"x": 0, "y": 0}]) + assert ak.all(array[0, :, 0] == named_array[{"x": 0, "z": 0}]) + assert ak.all(array[:, 0, 0] == named_array[{"y": 0, "z": 0}]) + assert array[0, 0, 0] == named_array[{"x": 0, "y": 0, "z": 0}] + + assert ak.all(array[slice(0, 1), :, :] == named_array[{"x": slice(0, 1)}]) + assert ak.all(array[:, slice(0, 1), :] == named_array[{"y": slice(0, 1)}]) + assert ak.all(array[:, :, slice(0, 1)] == named_array[{"z": slice(0, 1)}]) + + assert ak.all(array[0, :, slice(0, 1)] == named_array[{"x": 0, "z": slice(0, 1)}]) + assert ak.all(array[:, 0, slice(0, 1)] == named_array[{"y": 0, "z": slice(0, 1)}]) + assert ak.all(array[slice(0, 1), 0, :] == named_array[{"x": slice(0, 1), "y": 0}]) + + assert ak.all(array[array > 3] == named_array[named_array > 3]) + + # test naming propagation + assert ( + named_array[...].named_axis + == named_array.named_axis + == {"x": -3, "y": -2, "z": -1} + ) + assert ( + named_array[()].named_axis + == named_array.named_axis + == {"x": -3, "y": -2, "z": -1} + ) + + assert named_array[None, :, :, :].named_axis == {"x": -3, "y": -2, "z": -1} + assert named_array[:, None, :, :].named_axis == {"x": -4, "y": -2, "z": -1} + assert named_array[:, :, None, :].named_axis == {"x": -4, "y": -3, "z": -1} + assert named_array[:, :, :, None].named_axis == {"x": -4, "y": -3, "z": -2} + + assert named_array[None, ...].named_axis == {"x": -3, "y": -2, "z": -1} + assert named_array[:, None, ...].named_axis == {"x": -4, "y": -2, "z": -1} + assert named_array[..., None, :].named_axis == {"x": -4, "y": -3, "z": -1} + assert named_array[..., None].named_axis == {"x": -4, "y": -3, "z": -2} + + assert ( + named_array[0, :, :].named_axis + == named_array[{"x": 0}].named_axis + == {"y": -2, "z": -1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": -2, "z": -1} + ) + assert ( + named_array[:, :, 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": -2, "y": -1} + ) + + assert ( + named_array[0, ...].named_axis + == named_array[{"x": 0}].named_axis + == {"y": -2, "z": -1} + ) + assert ( + named_array[..., 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": -2, "y": -1} + ) + + assert named_array[{0: 0}].named_axis == {"y": -2, "z": -1} + assert named_array[{1: 0}].named_axis == {"x": -2, "z": -1} + assert named_array[{2: 0}].named_axis == {"x": -2, "y": -1} + + assert named_array[{-3: 0}].named_axis == {"y": -2, "z": -1} + assert named_array[{-2: 0}].named_axis == {"x": -2, "z": -1} + assert named_array[{-1: 0}].named_axis == {"x": -2, "y": -1} + + assert ( + named_array[0, 0, :].named_axis + == named_array[{"x": 0, "y": 0}].named_axis + == {"z": -1} + ) + assert ( + named_array[0, :, 0].named_axis + == named_array[{"x": 0, "z": 0}].named_axis + == {"y": -1} + ) + assert ( + named_array[:, 0, 0].named_axis + == named_array[{"y": 0, "z": 0}].named_axis + == {"x": -1} + ) + assert not _get_named_axis(named_array[0, 0, 0]) + assert not _get_named_axis(named_array[{"x": 0, "y": 0, "z": 0}]) + + assert ( + named_array[slice(0, 1), :, :].named_axis + == named_array[{"x": slice(0, 1)}].named_axis + == {"x": -3, "y": -2, "z": -1} + ) + assert ( + named_array[:, slice(0, 1), :].named_axis + == named_array[{"y": slice(0, 1)}].named_axis + == {"x": -3, "y": -2, "z": -1} + ) + assert ( + named_array[:, :, slice(0, 1)].named_axis + == named_array[{"z": slice(0, 1)}].named_axis + == {"x": -3, "y": -2, "z": -1} + ) + + assert ( + named_array[0, :, slice(0, 1)].named_axis + == named_array[{"x": 0, "z": slice(0, 1)}].named_axis + == {"y": -2, "z": -1} + ) + assert ( + named_array[:, 0, slice(0, 1)].named_axis + == named_array[{"y": 0, "z": slice(0, 1)}].named_axis + == {"x": -2, "z": -1} + ) + assert ( + named_array[slice(0, 1), 0, :].named_axis + == named_array[{"x": slice(0, 1), "y": 0}].named_axis + == {"x": -2, "z": -1} + ) + + +@pytest.mark.xfail( + sys.platform == "win32", + reason="right-broadcasting (NumPy-style) behaves differently for 32-bit windows", + strict=False, +) +def test_named_axis_right_broadcasting(): + # [NumPy-style] rightbroadcasting: (n, m) -> (1, n, m) + a = ak.Array([1]) # (1,) + b = ak.Array([[10, 20], [30, 40], [50, 60]]) # (3, 2) + + na = ak.with_named_axis(a, named_axis={"y": 0}) + nb = ak.with_named_axis(b, named_axis={"x": 0, "y": 1}) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"x": 0, "y": 1} + + na = ak.with_named_axis(a, named_axis={"y": -1}) + nb = ak.with_named_axis(b, named_axis={"y": -2, "x": -1}) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"y": -2, "x": -1} + + +def test_named_axis_left_broadcasting(): + # [Awkward-style] leftbroadcasting: (n, m) -> (n, m, 1) + a = ak.Array([[[0, 1, 2], [], [3, 4]], [], [[5], [6, 7, 8, 9]]]) # (3, var, var) + b = ak.Array([[10, 20, 30], [], [40, 50]]) # (3, var) + + na = ak.with_named_axis(a, named_axis=("x", "y", "z")) + nb = ak.with_named_axis(b, named_axis=("x", "y")) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"x": 0, "y": 1, "z": 2} + + na = ak.with_named_axis(a, named_axis={"x": -3, "y": -2, "z": -1}) + nb = ak.with_named_axis(b, named_axis={"x": -2, "y": -1}) + + naa, nbb = ak.broadcast_arrays(na, nb) + + assert naa.named_axis == nbb.named_axis == {"x": -3, "y": -2, "z": -1} + + # this is not allowed! + a = ak.with_named_axis(ak.Array([[1, 2], [3, 4]]), ("x", "y")) # {"x": 0, "y": 1} + asum = ak.sum(a, axis="x") # {"y": 0} + + with pytest.raises(ValueError): + _ = a + asum + + # this is allowed! + a = ak.with_named_axis(ak.Array([[1, 2], [3, 4]]), ("x", "y")) # {"x": 0, "y": 1} + asum = ak.sum(a, axis="y") # {"x": 0} + + assert (a + asum).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_unary_ufuncs(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert (-named_array).named_axis == named_array.named_axis + assert (+named_array).named_axis == named_array.named_axis + assert (~named_array).named_axis == named_array.named_axis + assert abs(named_array).named_axis == named_array.named_axis + + +def test_named_axis_binary_ufuncs(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = ak.with_named_axis(array, named_axis=(None, "y")) + named_array2 = ak.with_named_axis(array, named_axis=("x", None)) + named_array3 = ak.with_named_axis(array, named_axis=("x", "y")) + + # just for addition, the rest is the same + # __add__ + assert (array + array).named_axis == {} + assert (named_array1 + array).named_axis == {"y": 1} + assert (named_array2 + array).named_axis == {"x": 0} + assert (named_array3 + array).named_axis == {"x": 0, "y": 1} + + assert (named_array1 + named_array2).named_axis == {"x": 0, "y": 1} + assert (named_array3 + named_array3).named_axis == {"x": 0, "y": 1} + + # __radd__ + assert (array + named_array1).named_axis == {"y": 1} + assert (array + named_array2).named_axis == {"x": 0} + assert (array + named_array3).named_axis == {"x": 0, "y": 1} + + a = ak.with_named_axis(array, named_axis=("x", None)) + b = ak.with_named_axis(array, named_axis=("y", None)) + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 0", + ): + _ = a + b + + a = ak.with_named_axis(array, named_axis=(None, "x")) + b = ak.with_named_axis(array, named_axis=(None, "y")) + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 1", + ): + _ = a + b + + +def test_named_axis_ak_all(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.all(array < 4, axis=0) == ak.all(named_array < 4, axis="x")) + assert ak.all(ak.all(array < 4, axis=1) == ak.all(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.all(named_array < 4, axis=0).named_axis + == ak.all(named_array < 4, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.all(named_array < 4, axis=1).named_axis + == ak.all(named_array < 4, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.all(named_array < 4, axis=0, keepdims=True).named_axis + == ak.all(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.all(named_array < 4, axis=1, keepdims=True).named_axis + == ak.all(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_negative_named_axis_ak_all(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.all(array < 4, axis=-2) == ak.all(named_array < 4, axis="x")) + assert ak.all(ak.all(array < 4, axis=-1) == ak.all(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.all(named_array < 4, axis=-2).named_axis + == ak.all(named_array < 4, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.all(named_array < 4, axis=-1).named_axis + == ak.all(named_array < 4, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.all(named_array < 4, axis=-2, keepdims=True).named_axis + == ak.all(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.all(named_array < 4, axis=-1, keepdims=True).named_axis + == ak.all(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_named_axis_ak_almost_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y")) + + assert ak.almost_equal(array1, array2, check_named_axis=False) == ak.almost_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.almost_equal(array1, array2, check_named_axis=True) == ak.almost_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.almost_equal(named_array1, array1, check_named_axis=False) + assert ak.almost_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis=("x", "muons")) + assert ak.almost_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.almost_equal(named_array1, named_array3, check_named_axis=True) + + +def test_negative_named_axis_ak_almost_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis( + array1, named_axis={"x": -2, "y": -1} + ) + + assert ak.almost_equal(array1, array2, check_named_axis=False) == ak.almost_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.almost_equal(array1, array2, check_named_axis=True) == ak.almost_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.almost_equal(named_array1, array1, check_named_axis=False) + assert ak.almost_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis={"x": -2, "z": -1}) + assert ak.almost_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.almost_equal(named_array1, named_array3, check_named_axis=True) + + +def test_named_axis_ak_angle(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.angle(array) == ak.angle(named_array)) + + # check that result axis names are correctly propagated + assert ak.angle(named_array).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_angle(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.angle(array) == ak.angle(named_array)) + + # check that result axis names are correctly propagated + assert ak.angle(named_array).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_any(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.any(array < 4, axis=0) == ak.any(named_array < 4, axis="x")) + assert ak.all(ak.any(array < 4, axis=1) == ak.any(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.any(named_array < 4, axis=0).named_axis + == ak.any(named_array < 4, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.any(named_array < 4, axis=1).named_axis + == ak.any(named_array < 4, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.any(named_array < 4, axis=0, keepdims=True).named_axis + == ak.any(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.any(named_array < 4, axis=1, keepdims=True).named_axis + == ak.any(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_negative_named_axis_ak_any(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.any(array < 4, axis=-2) == ak.any(named_array < 4, axis="x")) + assert ak.all(ak.any(array < 4, axis=-1) == ak.any(named_array < 4, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.any(named_array < 4, axis=-2).named_axis + == ak.any(named_array < 4, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.any(named_array < 4, axis=-1).named_axis + == ak.any(named_array < 4, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.any(named_array < 4, axis=-2, keepdims=True).named_axis + == ak.any(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.any(named_array < 4, axis=-1, keepdims=True).named_axis + == ak.any(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) + + +def test_named_axis_ak_argcartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis=("x", "y")) + named_two = ak.with_named_axis(two, named_axis=("x", "y")) + named_three = ak.with_named_axis(three, named_axis=("x", "y")) + + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=False + ).named_axis == {"x": 0, "y": 1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=True + ).named_axis == {"x": 1, "y": 2} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=[0] + ).named_axis == {"x": 1, "y": 2} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=[1] + ).named_axis == {"x": 0, "y": 2} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="x", nested=[0, 1] + ).named_axis == {"x": 2, "y": 3} + + +def test_negative_named_axis_ak_argcartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis={"x": -2, "y": -1}) + named_two = ak.with_named_axis(two, named_axis={"x": -2, "y": -1}) + named_three = ak.with_named_axis(three, named_axis={"x": -2, "y": -1}) + + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=False + ).named_axis == {"x": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=True + ).named_axis == {"x": -2, "y": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=[0] + ).named_axis == {"x": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=[1] + ).named_axis == {"y": -1} + assert ak.argcartesian( + [named_one, named_two, named_three], axis="y", nested=[0, 1] + ).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_argcombinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ( + ak.argcombinations(named_array, 2, axis=0).named_axis == named_array.named_axis + ) + assert ( + ak.argcombinations(named_array, 2, axis=1).named_axis == named_array.named_axis + ) + + +def test_negative_named_axis_ak_argcombinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + assert ( + ak.argcombinations(named_array, 2, axis=0).named_axis == named_array.named_axis + ) + assert ( + ak.argcombinations(named_array, 2, axis=1).named_axis == named_array.named_axis + ) + + +def test_named_axis_ak_argmax(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.argmax(array, axis=0) == ak.argmax(named_array, axis="x")) + assert ak.all(ak.argmax(array, axis=1) == ak.argmax(named_array, axis="y")) + assert ak.all( + ak.argmax(array, axis=0, keepdims=True) + == ak.argmax(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmax(array, axis=1, keepdims=True) + == ak.argmax(named_array, axis="y", keepdims=True) + ) + assert ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmax(named_array, axis=0).named_axis + == ak.argmax(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.argmax(named_array, axis=1).named_axis + == ak.argmax(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.argmax(named_array, axis=0, keepdims=True).named_axis + == ak.argmax(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.argmax(named_array, axis=1, keepdims=True).named_axis + == ak.argmax(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.argmax(named_array, axis=None)) + + +def test_negative_named_axis_ak_argmax(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.argmax(array, axis=-2) == ak.argmax(named_array, axis="x")) + assert ak.all(ak.argmax(array, axis=-1) == ak.argmax(named_array, axis="y")) + assert ak.all( + ak.argmax(array, axis=-2, keepdims=True) + == ak.argmax(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmax(array, axis=-1, keepdims=True) + == ak.argmax(named_array, axis="y", keepdims=True) + ) + assert ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmax(named_array, axis=-2).named_axis + == ak.argmax(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.argmax(named_array, axis=-1).named_axis + == ak.argmax(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.argmax(named_array, axis=-2, keepdims=True).named_axis + == ak.argmax(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.argmax(named_array, axis=-1, keepdims=True).named_axis + == ak.argmax(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.argmax(named_array, axis=None)) + + +def test_named_axis_ak_argmin(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.argmin(array, axis=0) == ak.argmin(named_array, axis="x")) + assert ak.all(ak.argmin(array, axis=1) == ak.argmin(named_array, axis="y")) + assert ak.all( + ak.argmin(array, axis=0, keepdims=True) + == ak.argmin(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmin(array, axis=1, keepdims=True) + == ak.argmin(named_array, axis="y", keepdims=True) + ) + assert ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmin(named_array, axis=0).named_axis + == ak.argmin(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.argmin(named_array, axis=1).named_axis + == ak.argmin(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.argmin(named_array, axis=0, keepdims=True).named_axis + == ak.argmin(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.argmin(named_array, axis=1, keepdims=True).named_axis + == ak.argmin(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.argmin(named_array, axis=None)) + + +def test_negative_named_axis_ak_argmin(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.argmin(array, axis=-2) == ak.argmin(named_array, axis="x")) + assert ak.all(ak.argmin(array, axis=-1) == ak.argmin(named_array, axis="y")) + assert ak.all( + ak.argmin(array, axis=-2, keepdims=True) + == ak.argmin(named_array, axis="x", keepdims=True) + ) + assert ak.all( + ak.argmin(array, axis=-1, keepdims=True) + == ak.argmin(named_array, axis="y", keepdims=True) + ) + assert ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None) + + # check that result axis names are correctly propagated + assert ( + ak.argmin(named_array, axis=-2).named_axis + == ak.argmin(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.argmin(named_array, axis=-1).named_axis + == ak.argmin(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.argmin(named_array, axis=-2, keepdims=True).named_axis + == ak.argmin(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.argmin(named_array, axis=-1, keepdims=True).named_axis + == ak.argmin(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.argmin(named_array, axis=None)) + + +def test_named_axis_ak_argsort(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.argsort(array, axis=0) == ak.argsort(named_array, axis="x")) + assert ak.all(ak.argsort(array, axis=1) == ak.argsort(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.argsort(named_array, axis=0).named_axis + == ak.argsort(named_array, axis="x").named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.argsort(named_array, axis=1).named_axis + == ak.argsort(named_array, axis="y").named_axis + == {"x": 0, "y": 1} + ) + + +def test_negative_named_axis_ak_argsort(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.argsort(array, axis=-2) == ak.argsort(named_array, axis="x")) + assert ak.all(ak.argsort(array, axis=-1) == ak.argsort(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.argsort(named_array, axis=-2).named_axis + == ak.argsort(named_array, axis="x").named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.argsort(named_array, axis=-1).named_axis + == ak.argsort(named_array, axis="y").named_axis + == {"x": -2, "y": -1} + ) + + +def test_named_axis_ak_array_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y")) + + assert ak.array_equal(array1, array2, check_named_axis=False) == ak.array_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.array_equal(array1, array2, check_named_axis=True) == ak.array_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.array_equal(named_array1, array1, check_named_axis=False) + assert ak.array_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis=("x", "z")) + assert ak.array_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.array_equal(named_array1, named_array3, check_named_axis=True) + + +def test_negative_named_axis_ak_array_equal(): + array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array1 = named_array2 = ak.with_named_axis( + array1, named_axis={"x": -2, "y": -1} + ) + + assert ak.array_equal(array1, array2, check_named_axis=False) == ak.array_equal( + named_array1, named_array2, check_named_axis=False + ) + assert ak.array_equal(array1, array2, check_named_axis=True) == ak.array_equal( + named_array1, named_array2, check_named_axis=True + ) + + assert ak.array_equal(named_array1, array1, check_named_axis=False) + assert ak.array_equal(named_array1, array1, check_named_axis=True) + + named_array3 = ak.with_named_axis(array1, named_axis={"x": -2, "z": -1}) + assert ak.array_equal(named_array1, named_array3, check_named_axis=False) + assert not ak.array_equal(named_array1, named_array3, check_named_axis=True) + + +def test_named_axis_ak_backend(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ak.backend(array) == ak.backend(named_array) + + +def test_named_axis_ak_broadcast_fields(): + x = ak.Array([{"x": {"y": 1, "z": 2, "w": [1]}}]) + y = ak.Array([{"x": [{"y": 1}]}]) + + nx = ak.with_named_axis(x, named_axis=("x", "y")) + ny = ak.with_named_axis(y, named_axis=("a", "b")) + + na, nb = ak.broadcast_fields(nx, ny) + assert na.named_axis == {"x": 0, "y": 1} + assert nb.named_axis == {"a": 0, "b": 1} + + +def test_named_axis_ak_cartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis=("x", "y")) + named_two = ak.with_named_axis(two, named_axis=("x", "y")) + named_three = ak.with_named_axis(three, named_axis=("x", "y")) + + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=False + ).named_axis == {"x": 0, "y": 1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=True + ).named_axis == {"x": 1, "y": 2} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=[0] + ).named_axis == {"x": 1, "y": 2} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=[1] + ).named_axis == {"x": 0, "y": 2} + assert ak.cartesian( + [named_one, named_two, named_three], axis="x", nested=[0, 1] + ).named_axis == {"x": 2, "y": 3} + + +def test_negative_named_axis_ak_cartesian(): + one = ak.Array([[1], [2], [3]]) + two = ak.Array([[4, 5]]) + three = ak.Array([[6, 7]]) + + named_one = ak.with_named_axis(one, named_axis={"x": -2, "y": -1}) + named_two = ak.with_named_axis(two, named_axis={"x": -2, "y": -1}) + named_three = ak.with_named_axis(three, named_axis={"x": -2, "y": -1}) + + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=False + ).named_axis == {"x": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=True + ).named_axis == {"x": -2, "y": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=[0] + ).named_axis == {"x": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=[1] + ).named_axis == {"y": -1} + assert ak.cartesian( + [named_one, named_two, named_three], axis="y", nested=[0, 1] + ).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_categories(): + pyarrow = pytest.importorskip("pyarrow") # noqa: F841 + + array = ak.str.to_categorical([["one", "two"], ["one", "three"], ["one", "four"]]) + + named_array = ak.with_named_axis(array, named_axis=("a", "b")) + + assert ak.all(ak.categories(array) == ak.categories(named_array)) # FIX: ufuncs + assert ( + ak.categories(array).named_axis == ak.categories(named_array).named_axis == {} + ) + + +def test_named_axis_ak_combinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ak.combinations(named_array, 2, axis=0).named_axis == named_array.named_axis + assert ak.combinations(named_array, 2, axis=1).named_axis == named_array.named_axis + + +def test_negative_named_axis_ak_combinations(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + assert ak.combinations(named_array, 2, axis=-2).named_axis == named_array.named_axis + assert ak.combinations(named_array, 2, axis=-1).named_axis == named_array.named_axis + + +def test_named_axis_ak_concatenate(): + array1 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array3 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array4 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + all_arrays = [array1, array2, array3, array4] + + named_array1 = ak.with_named_axis(array1, named_axis=(None, None)) + named_array2 = ak.with_named_axis(array1, named_axis=(None, "y")) + named_array3 = ak.with_named_axis(array1, named_axis=("x", None)) + named_array4 = ak.with_named_axis(array1, named_axis=("x", "y")) + + all_named_arrays = [named_array1, named_array2, named_array3, named_array4] + + assert ak.all( + ak.concatenate(all_arrays, axis=0) == ak.concatenate(all_named_arrays, axis="x") + ) + assert ak.all( + ak.concatenate(all_arrays, axis=1) == ak.concatenate(all_named_arrays, axis="y") + ) + + assert ak.concatenate(all_named_arrays, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.concatenate(all_named_arrays, axis="y").named_axis == {"x": 0, "y": 1} + + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 0", + ): + ak.concatenate( + [ + ak.with_named_axis(array1, named_axis=("x", None)), + ak.with_named_axis(array2, named_axis=("y", None)), + ], + axis=0, + ) + + with pytest.raises( + ValueError, + match="The named axes are incompatible. Got: x and y for positional axis 1", + ): + ak.concatenate( + [ + ak.with_named_axis(array1, named_axis=(None, "x")), + ak.with_named_axis(array2, named_axis=(None, "y")), + ], + axis=1, + ) + + +def test_negative_named_axis_ak_concatenate(): + array1 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array3 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + array4 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + all_arrays = [array1, array2, array3, array4] + + named_array1 = ak.with_named_axis(array1, named_axis={}) + named_array2 = ak.with_named_axis(array1, named_axis={"y": -1}) + named_array3 = ak.with_named_axis(array1, named_axis={"x": -2}) + named_array4 = ak.with_named_axis(array1, named_axis={"x": -2, "y": -1}) + + all_named_arrays = [named_array1, named_array2, named_array3, named_array4] + + assert ak.all( + ak.concatenate(all_arrays, axis=-2) + == ak.concatenate(all_named_arrays, axis="x") + ) + assert ak.all( + ak.concatenate(all_arrays, axis=-1) + == ak.concatenate(all_named_arrays, axis="y") + ) + + assert ak.concatenate(all_named_arrays, axis="x").named_axis == {"x": -2, "y": -1} + assert ak.concatenate(all_named_arrays, axis="y").named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_copy(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + assert ak.copy(named_array).named_axis == {"x": 0, "y": 1} + + +# def test_named_axis_ak_corr(): +# array_x = ak.Array([[0, 1.1], [3.3, 4.4]]) +# array_y = ak.Array([[0, 1], [3, 4]]) + +# named_array_x = ak.with_named_axis(array_x, ("x", "y")) +# named_array_y = ak.with_named_axis(array_y, ("x", "y")) + +# assert ak.all( +# ak.corr(array_x, array_y, axis=0) +# == ak.corr(named_array_x, named_array_y, axis="x") +# ) +# assert ak.all( +# ak.corr(array_x, array_y, axis=1) +# == ak.corr(named_array_x, named_array_y, axis="y") +# ) +# assert ak.corr(array_x, array_y, axis=None) == ak.corr( +# named_array_x, named_array_y, axis=None +# ) + +# assert ak.corr(named_array_x, named_array_y, axis="x").named_axis == {"y": 0} +# assert ak.corr(named_array_x, named_array_y, axis="y").named_axis == {"x": 0} +# assert not _get_named_axis(ak.corr(named_array_x, named_array_y, axis=None)) + + +def test_named_axis_ak_count(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.count(array, axis=0) == ak.count(named_array, axis="x")) + assert ak.all(ak.count(array, axis=1) == ak.count(named_array, axis="y")) + assert ak.count(array, axis=None) == ak.count(named_array, axis=None) + + assert ak.count(named_array, axis="x").named_axis == {"y": 0} + assert ak.count(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.count(named_array, axis=None)) + + +def test_negative_named_axis_ak_count(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.count(array, axis=-2) == ak.count(named_array, axis="x")) + assert ak.all(ak.count(array, axis=-1) == ak.count(named_array, axis="y")) + assert ak.count(array, axis=None) == ak.count(named_array, axis=None) + + assert ak.count(named_array, axis="x").named_axis == {"y": -1} + assert ak.count(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.count(named_array, axis=None)) + + +def test_named_axis_ak_count_nonzero(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.count_nonzero(array, axis=0) == ak.count_nonzero(named_array, axis="x") + ) + assert ak.all( + ak.count_nonzero(array, axis=1) == ak.count_nonzero(named_array, axis="y") + ) + assert ak.count_nonzero(array, axis=None) == ak.count_nonzero( + named_array, axis=None + ) + + assert ak.count_nonzero(named_array, axis="x").named_axis == {"y": 0} + assert ak.count_nonzero(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.count_nonzero(named_array, axis=None)) + + +def test_negative_named_axis_ak_count_nonzero(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all( + ak.count_nonzero(array, axis=-2) == ak.count_nonzero(named_array, axis="x") + ) + assert ak.all( + ak.count_nonzero(array, axis=-1) == ak.count_nonzero(named_array, axis="y") + ) + assert ak.count_nonzero(array, axis=None) == ak.count_nonzero( + named_array, axis=None + ) + + assert ak.count_nonzero(named_array, axis="x").named_axis == {"y": -1} + assert ak.count_nonzero(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.count_nonzero(named_array, axis=None)) + + +# def test_named_axis_ak_covar(): +# array_x = ak.Array([[0, 1.1], [3.3, 4.4]]) +# array_y = ak.Array([[0, 1], [3, 4]]) + +# named_array_x = ak.with_named_axis(array_x, ("x", "y")) +# named_array_y = ak.with_named_axis(array_y, ("x", "y")) + +# assert ak.all( +# ak.covar(array_x, array_y, axis=0) +# == ak.covar(named_array_x, named_array_y, axis="x") +# ) +# assert ak.all( +# ak.covar(array_x, array_y, axis=1) +# == ak.covar(named_array_x, named_array_y, axis="y") +# ) +# assert ak.covar(array_x, array_y, axis=None) == ak.covar( +# named_array_x, named_array_y, axis=None +# ) + +# assert ak.covar(named_array_x, named_array_y, axis="x").named_axis == {"y": 0} +# assert ak.covar(named_array_x, named_array_y, axis="y").named_axis == {"x": 0} +# assert not _get_named_axis(ak.covar(named_array_x, named_array_y, axis=None)) + + +def test_named_axis_ak_drop_none(): + array = ak.Array([[1, None], [3], [None], [4, None, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.drop_none(array, axis=0) == ak.drop_none(named_array, axis="x")) + assert ak.all(ak.drop_none(array, axis=1) == ak.drop_none(named_array, axis="y")) + assert ak.all( + ak.drop_none(array, axis=None) == ak.drop_none(named_array, axis=None) + ) + + assert ak.drop_none(named_array, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.drop_none(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.drop_none(named_array, axis=None).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_drop_none(): + array = ak.Array([[1, None], [3], [None], [4, None, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.drop_none(array, axis=-2) == ak.drop_none(named_array, axis="x")) + assert ak.all(ak.drop_none(array, axis=-1) == ak.drop_none(named_array, axis="y")) + assert ak.all( + ak.drop_none(array, axis=None) == ak.drop_none(named_array, axis=None) + ) + + assert ak.drop_none(named_array, axis="x").named_axis == {"x": -2, "y": -1} + assert ak.drop_none(named_array, axis="y").named_axis == {"x": -2, "y": -1} + assert ak.drop_none(named_array, axis=None).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_enforce_type(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.enforce_type(named_array, "var * ?int64").named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_fill_none(): + array = ak.Array([[1.1, None, 2.2], [], [None, 3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.fill_none(array, 0, axis=0) == ak.fill_none(named_array, 0, axis="x") + ) + assert ak.all( + ak.fill_none(array, 0, axis=1) == ak.fill_none(named_array, 0, axis="y") + ) + assert ak.all( + ak.fill_none(array, 0, axis=None) == ak.fill_none(named_array, 0, axis=None) + ) + + assert ak.fill_none(named_array, 0, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.fill_none(named_array, 0, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.fill_none(named_array, 0, axis=None).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_fill_none(): + array = ak.Array([[1.1, None, 2.2], [], [None, 3.3, 4.4]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all( + ak.fill_none(array, 0, axis=-2) == ak.fill_none(named_array, 0, axis="x") + ) + assert ak.all( + ak.fill_none(array, 0, axis=-1) == ak.fill_none(named_array, 0, axis="y") + ) + assert ak.all( + ak.fill_none(array, 0, axis=None) == ak.fill_none(named_array, 0, axis=None) + ) + + assert ak.fill_none(named_array, 0, axis="x").named_axis == {"x": -2, "y": -1} + assert ak.fill_none(named_array, 0, axis="y").named_axis == {"x": -2, "y": -1} + assert ak.fill_none(named_array, 0, axis=None).named_axis == {"x": -2, "y": -1} + + +def test_named_axis_ak_firsts(): + array = ak.Array([[1.1], [2.2], [], [3.3], [], [], [4.4], [5.5]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.firsts(array, axis=0) == ak.firsts(named_array, axis="x")) + assert ak.all(ak.firsts(array, axis=1) == ak.firsts(named_array, axis="y")) + + assert ak.firsts(named_array, axis="x").named_axis == {"y": 0} + assert ak.firsts(named_array, axis="y").named_axis == {"x": 0} + + +def test_negative_named_axis_ak_firsts(): + array = ak.Array([[1.1], [2.2], [], [3.3], [], [], [4.4], [5.5]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.firsts(array, axis=-2) == ak.firsts(named_array, axis="x")) + assert ak.all(ak.firsts(array, axis=-1) == ak.firsts(named_array, axis="y")) + + assert ak.firsts(named_array, axis="x").named_axis == {"y": -1} + assert ak.firsts(named_array, axis="y").named_axis == {"x": -1} + + +def test_named_axis_ak_flatten(): + array = ak.Array([[[1.1, 2.2]], [[]], [[3.3]], [[]], [[]], [[4.4, 5.5]]]) + + named_array = ak.with_named_axis(array, ("x", "y", "z")) + + assert ak.all(ak.flatten(array, axis=0) == ak.flatten(named_array, axis="x")) + assert ak.all(ak.flatten(array, axis=1) == ak.flatten(named_array, axis="y")) + assert ak.all(ak.flatten(array, axis=2) == ak.flatten(named_array, axis="z")) + assert ak.all(ak.flatten(array, axis=None) == ak.flatten(named_array, axis=None)) + + assert ak.flatten(named_array, axis="x").named_axis == {"x": 0, "y": 1, "z": 2} + assert ak.flatten(named_array, axis="y").named_axis == {"x": 0, "z": 1} + assert ak.flatten(named_array, axis="z").named_axis == {"x": 0, "y": 1} + assert not _get_named_axis(ak.flatten(named_array, axis=None)) + + +def test_negative_named_axis_ak_flatten(): + array = ak.Array([[[1.1, 2.2]], [[]], [[3.3]], [[]], [[]], [[4.4, 5.5]]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1}) + + assert ak.all(ak.flatten(array, axis=-3) == ak.flatten(named_array, axis="x")) + assert ak.all(ak.flatten(array, axis=-2) == ak.flatten(named_array, axis="y")) + assert ak.all(ak.flatten(array, axis=-1) == ak.flatten(named_array, axis="z")) + assert ak.all(ak.flatten(array, axis=None) == ak.flatten(named_array, axis=None)) + + assert ak.flatten(named_array, axis="x").named_axis == {"x": -3, "y": -2, "z": -1} + assert ak.flatten(named_array, axis="y").named_axis == {"x": -2, "z": -1} + assert ak.flatten(named_array, axis="z").named_axis == {"x": -2, "y": -1} + assert not _get_named_axis(ak.flatten(named_array, axis=None)) + + +def test_named_axis_ak_imag(): + array = ak.Array([[1 + 2j], [2 + 1j], []]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.imag(array) == ak.imag(named_array)) + assert ak.imag(named_array).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_is_none(): + array = ak.Array([[[1, None]], [[3]], [[None]], [[4, None, 6]]]) + + named_array = ak.with_named_axis(array, ("x", "y", "z")) + + assert ak.all(ak.is_none(array, axis=0) == ak.is_none(named_array, axis="x")) + assert ak.all(ak.is_none(array, axis=1) == ak.is_none(named_array, axis="y")) + assert ak.all(ak.is_none(array, axis=2) == ak.is_none(named_array, axis="z")) + + assert ak.is_none(named_array, axis="x").named_axis == {"x": 0} + assert ak.is_none(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.is_none(named_array, axis="z").named_axis == {"x": 0, "y": 1, "z": 2} + + +def test_negative_named_axis_ak_is_none(): + array = ak.Array([[[1, None]], [[3]], [[None]], [[4, None, 6]]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1}) + + assert ak.all(ak.is_none(array, axis=-3) == ak.is_none(named_array, axis="x")) + assert ak.all(ak.is_none(array, axis=-2) == ak.is_none(named_array, axis="y")) + assert ak.all(ak.is_none(array, axis=-1) == ak.is_none(named_array, axis="z")) + + assert ak.is_none(named_array, axis="x").named_axis == {"z": -1} + assert ak.is_none(named_array, axis="y").named_axis == {"y": -2, "z": -1} + assert ak.is_none(named_array, axis="z").named_axis == {"x": -3, "y": -2, "z": -1} + + +def test_named_axis_ak_isclose(): + a = b = ak.Array( + [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]] + ) + + na = ak.with_named_axis(a, ("x", "y", "z")) + nb = ak.with_named_axis(b, ("x", "y", "z")) + assert ak.all(ak.isclose(a, b) == ak.isclose(na, nb)) + + na = ak.with_named_axis(a, (None, "y", "z")) + nb = ak.with_named_axis(b, ("x", "y", None)) + assert ak.isclose(na, nb).named_axis == {"x": 0, "y": 1, "z": 2} + + +def test_named_axis_ak_local_index(): + array = ak.Array( + [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]] + ) + + named_array = ak.with_named_axis(array, ("x", "y", "z")) + + assert ak.all( + ak.local_index(array, axis=0) == ak.local_index(named_array, axis="x") + ) + assert ak.all( + ak.local_index(array, axis=1) == ak.local_index(named_array, axis="y") + ) + assert ak.all( + ak.local_index(array, axis=2) == ak.local_index(named_array, axis="z") + ) + + assert ak.local_index(named_array, axis="x").named_axis == {"x": 0} + assert ak.local_index(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.local_index(named_array, axis="z").named_axis == {"x": 0, "y": 1, "z": 2} + + +def test_negative_named_axis_ak_local_index(): + array = ak.Array( + [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]] + ) + named_array = ak.with_named_axis(array, {"x": -3, "y": -2, "z": -1}) + + assert ak.local_index(named_array, axis="x").named_axis == {"z": -1} + assert ak.local_index(named_array, axis="y").named_axis == {"y": -2, "z": -1} + assert ak.local_index(named_array, axis="z").named_axis == { + "x": -3, + "y": -2, + "z": -1, + } + + +def test_named_axis_ak_mask(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + mask = array > 3 + + named_array = ak.with_named_axis(array, ("x", "y")) + named_mask = named_array > 3 + + assert ak.all(ak.mask(array, mask) == ak.mask(named_array, mask)) + assert ak.all(ak.mask(array, mask) == ak.mask(named_array, named_mask)) + + assert ak.mask(named_array, mask).named_axis == named_array.named_axis + assert ak.mask(named_array, named_mask).named_axis == named_array.named_axis + + +def test_named_axis_ak_max(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.max(array, axis=0) == ak.max(named_array, axis="x")) + assert ak.all(ak.max(array, axis=1) == ak.max(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.max(named_array, axis=0).named_axis + == ak.max(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.max(named_array, axis=1).named_axis + == ak.max(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.max(named_array, axis=0, keepdims=True).named_axis + == ak.max(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.max(named_array, axis=1, keepdims=True).named_axis + == ak.max(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.max(named_array, axis=None)) + + +def test_negative_named_axis_ak_max(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.max(array, axis=-2) == ak.max(named_array, axis="x")) + assert ak.all(ak.max(array, axis=-1) == ak.max(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.max(named_array, axis=-2).named_axis + == ak.max(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.max(named_array, axis=-1).named_axis + == ak.max(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.max(named_array, axis=-2, keepdims=True).named_axis + == ak.max(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.max(named_array, axis=-1, keepdims=True).named_axis + == ak.max(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.max(named_array, axis=None)) + + +def test_named_axis_ak_mean(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.mean(array, axis=0) == ak.mean(named_array, axis="x")) + assert ak.all(ak.mean(array, axis=1) == ak.mean(named_array, axis="y")) + assert ak.mean(array, axis=None) == ak.mean(named_array, axis=None) + + assert ak.mean(named_array, axis="x").named_axis == {"y": 0} + assert ak.mean(named_array, axis="y").named_axis == {"x": 0} + assert ak.mean(named_array, axis="x", keepdims=True).named_axis == {"x": 0, "y": 1} + assert ak.mean(named_array, axis="y", keepdims=True).named_axis == {"x": 0, "y": 1} + assert not _get_named_axis(ak.mean(named_array, axis=None)) + + +def test_negative_named_axis_ak_mean(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + assert ak.all(ak.mean(array, axis=-2) == ak.mean(named_array, axis="x")) + assert ak.all(ak.mean(array, axis=-1) == ak.mean(named_array, axis="y")) + assert ak.mean(array, axis=None) == ak.mean(named_array, axis=None) + + assert ak.mean(named_array, axis="x").named_axis == {"y": -1} + assert ak.mean(named_array, axis="y").named_axis == {"x": -1} + assert ak.mean(named_array, axis="x", keepdims=True).named_axis == { + "x": -2, + "y": -1, + } + assert ak.mean(named_array, axis="y", keepdims=True).named_axis == { + "x": -2, + "y": -1, + } + assert not _get_named_axis(ak.mean(named_array, axis=None)) + + +def test_named_axis_ak_merge_option_of_records(): + array = ak.Array([None, {"a": 1}, {"a": 2}]) + + named_array = ak.with_named_axis(array, named_axis=("x",)) + + assert ( + ak.merge_option_of_records(named_array, axis="x").named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_merge_union_of_records(): + array = ak.concatenate(([{"a": 1}], [{"b": 2}])) + + named_array = ak.with_named_axis(array, named_axis=("x",)) + + assert ( + ak.merge_union_of_records(named_array, axis="x").named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_min(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.min(array, axis=0) == ak.min(named_array, axis="x")) + assert ak.all(ak.min(array, axis=1) == ak.min(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.min(named_array, axis=0).named_axis + == ak.min(named_array, axis="x").named_axis + == {"y": 0} + ) + assert ( + ak.min(named_array, axis=1).named_axis + == ak.min(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.min(named_array, axis=0, keepdims=True).named_axis + == ak.min(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.min(named_array, axis=1, keepdims=True).named_axis + == ak.min(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.min(named_array, axis=None)) + + +def test_negative_named_axis_ak_min(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1}) + + # first check that they work the same + assert ak.all(ak.min(array, axis=-2) == ak.min(named_array, axis="x")) + assert ak.all(ak.min(array, axis=-1) == ak.min(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.min(named_array, axis=-2).named_axis + == ak.min(named_array, axis="x").named_axis + == {"y": -1} + ) + assert ( + ak.min(named_array, axis=-1).named_axis + == ak.min(named_array, axis="y").named_axis + == {"x": -1} + ) + assert ( + ak.min(named_array, axis=-2, keepdims=True).named_axis + == ak.min(named_array, axis="x", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert ( + ak.min(named_array, axis=-1, keepdims=True).named_axis + == ak.min(named_array, axis="y", keepdims=True).named_axis + == {"x": -2, "y": -1} + ) + assert not _get_named_axis(ak.min(named_array, axis=None)) + + +def test_named_axis_ak_moment(): + array = ak.Array([[0, 1.1], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.moment(array, 0, axis=0) == ak.moment(named_array, 0, axis="x")) + assert ak.all(ak.moment(array, 0, axis=1) == ak.moment(named_array, 0, axis="y")) + assert ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None) + + assert ak.moment(named_array, 0, axis="x").named_axis == {"y": 0} + assert ak.moment(named_array, 0, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.moment(named_array, 0, axis=None)) + + +def test_negative_named_axis_ak_moment(): + array = ak.Array([[0, 1.1], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.moment(array, 0, axis=-2) == ak.moment(named_array, 0, axis="x")) + assert ak.all(ak.moment(array, 0, axis=-1) == ak.moment(named_array, 0, axis="y")) + assert ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None) + + assert ak.moment(named_array, 0, axis="x").named_axis == {"y": -1} + assert ak.moment(named_array, 0, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.moment(named_array, 0, axis=None)) + + +def test_named_axis_ak_nan_to_none(): + array = ak.Array([[0, np.nan], [np.nan], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.nan_to_none(array) == ak.nan_to_none(named_array)) + assert ak.nan_to_none(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_nan_to_num(): + array = ak.Array([[0, np.nan], [np.nan], [3.3, 4.4]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.nan_to_num(array, nan=0.0) == ak.nan_to_num(named_array, nan=0.0)) + assert ak.nan_to_num(named_array, nan=0.0).named_axis == named_array.named_axis + + +def test_named_axis_ak_num(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.num(array, axis=0) == ak.num(named_array, axis="x") + assert ak.all(ak.num(array, axis=1) == ak.num(named_array, axis="y")) + + assert ak.num(named_array, axis="y").named_axis == {"y": 0} + assert not _get_named_axis(ak.num(named_array, axis="x")) + + +def test_negative_named_axis_ak_num(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.num(array, axis=-2) == ak.num(named_array, axis="x") + assert ak.all(ak.num(array, axis=-1) == ak.num(named_array, axis="y")) + + assert ak.num(named_array, axis="y").named_axis == {"y": 0} + assert not _get_named_axis(ak.num(named_array, axis="x")) + + +def test_named_axis_ak_ones_like(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.ones_like(array) == ak.ones_like(named_array)) + + assert ak.ones_like(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_pad_none(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.pad_none(array, 3, axis=0) == ak.pad_none(named_array, 3, axis=0)) + assert ak.all(ak.pad_none(array, 3, axis=1) == ak.pad_none(named_array, 3, axis=1)) + + assert ak.pad_none(named_array, 3, axis=0).named_axis == named_array.named_axis + assert ak.pad_none(named_array, 3, axis=1).named_axis == named_array.named_axis + + +def test_named_axis_ak_prod(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.prod(array, axis=0) == ak.prod(named_array, axis="x")) + assert ak.all(ak.prod(array, axis=1) == ak.prod(named_array, axis="y")) + assert ak.prod(array, axis=None) == ak.prod(named_array, axis=None) + + assert ak.prod(named_array, axis="x").named_axis == {"y": 0} + assert ak.prod(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.prod(named_array, axis=None)) + + +def test_negative_named_axis_ak_prod(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.prod(array, axis=-2) == ak.prod(named_array, axis="x")) + assert ak.all(ak.prod(array, axis=-1) == ak.prod(named_array, axis="y")) + assert ak.prod(array, axis=None) == ak.prod(named_array, axis=None) + + assert ak.prod(named_array, axis="x").named_axis == {"y": -1} + assert ak.prod(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.prod(named_array, axis=None)) + + +def test_named_axis_ak_ptp(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.ptp(array, axis=0) == ak.ptp(named_array, axis="x")) + assert ak.all(ak.ptp(array, axis=1) == ak.ptp(named_array, axis="y")) + assert ak.ptp(array, axis=None) == ak.ptp(named_array, axis=None) + + assert ak.ptp(named_array, axis="x").named_axis == {"y": 0} + assert ak.ptp(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.ptp(named_array, axis=None)) + + +def test_negative_named_axis_ak_ptp(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.ptp(array, axis=-2) == ak.ptp(named_array, axis="x")) + assert ak.all(ak.ptp(array, axis=-1) == ak.ptp(named_array, axis="y")) + assert ak.ptp(array, axis=None) == ak.ptp(named_array, axis=None) + + assert ak.ptp(named_array, axis="x").named_axis == {"y": -1} + assert ak.ptp(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.ptp(named_array, axis=None)) + + +def test_named_axis_ak_ravel(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.ravel(array) == ak.ravel(named_array)) + + assert not _get_named_axis(ak.ravel(named_array)) + + +def test_named_axis_ak_real(): + array = ak.Array([[1 + 2j], [2 + 1j], []]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.real(array) == ak.real(named_array)) + assert ak.real(named_array).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_round(): + array = ak.Array([[1.234], [2.345, 3.456], []]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.round(array) == ak.round(named_array)) + assert ak.round(named_array).named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_run_lengths(): + array = ak.Array([[1.1, 1.1, 1.1, 2.2, 3.3], [3.3, 4.4], [4.4, 5.5]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.run_lengths(array) == ak.run_lengths(named_array)) + + assert ak.run_lengths(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_singletons(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.singletons(array, axis=0) == ak.singletons(named_array, axis="x")) + assert ak.all(ak.singletons(array, axis=1) == ak.singletons(named_array, axis="y")) + + assert ak.singletons(named_array, axis=0).named_axis == {"x": 0, "y": 2} + assert ak.singletons(named_array, axis=1).named_axis == {"x": 0, "y": 1} + + +def test_negative_named_axis_ak_singletons(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.singletons(array, axis=-2) == ak.singletons(named_array, axis="x")) + assert ak.all(ak.singletons(array, axis=-1) == ak.singletons(named_array, axis="y")) + + assert ak.singletons(named_array, axis=-2).named_axis == {"x": -3, "y": -1} + assert ak.singletons(named_array, axis=-1).named_axis == {"x": -3, "y": -2} + + +def test_named_axis_ak_softmax(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.softmax(array, axis=-1) == ak.softmax(named_array, axis="y")) + + assert ak.softmax(named_array, axis="y").named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_sort(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y")) + + # first check that they work the same + assert ak.all(ak.sort(array, axis=0) == ak.sort(named_array, axis="x")) + assert ak.all(ak.sort(array, axis=1) == ak.sort(named_array, axis="y")) + + # check that result axis names are correctly propagated + assert ( + ak.sort(named_array, axis=0).named_axis + == ak.sort(named_array, axis="x").named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.sort(named_array, axis=1).named_axis + == ak.sort(named_array, axis="y").named_axis + == {"x": 0, "y": 1} + ) + + +def test_named_axis_ak_std(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.std(array, axis=0) == ak.std(named_array, axis="x")) + assert ak.all(ak.std(array, axis=1) == ak.std(named_array, axis="y")) + assert ak.std(array, axis=None) == ak.std(named_array, axis=None) + + assert ak.std(named_array, axis="x").named_axis == {"y": 0} + assert ak.std(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.std(named_array, axis=None)) + + +def test_negative_named_axis_ak_std(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.std(array, axis=-2) == ak.std(named_array, axis="x")) + assert ak.all(ak.std(array, axis=-1) == ak.std(named_array, axis="y")) + assert ak.std(array, axis=None) == ak.std(named_array, axis=None) + + assert ak.std(named_array, axis="x").named_axis == {"y": -1} + assert ak.std(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.std(named_array, axis=None)) + + +def test_named_axis_ak_strings_astype(): + array = ak.Array([["1", "2"], ["3"], ["4", "5", "6"]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.strings_astype(array, np.int32) == ak.strings_astype(named_array, np.int32) + ) + + assert ak.strings_astype(named_array, np.int32).named_axis == named_array.named_axis + + +def test_named_axis_ak_sum(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.sum(array, axis=0) == ak.sum(named_array, axis="x")) + assert ak.all(ak.sum(array, axis=1) == ak.sum(named_array, axis="y")) + assert ak.sum(array, axis=None) == ak.sum(named_array, axis=None) + + assert ak.sum(named_array, axis="x").named_axis == {"y": 0} + assert ak.sum(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.sum(named_array, axis=None)) + + +def test_negative_named_axis_ak_sum(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.sum(array, axis=-2) == ak.sum(named_array, axis="x")) + assert ak.all(ak.sum(array, axis=-1) == ak.sum(named_array, axis="y")) + assert ak.sum(array, axis=None) == ak.sum(named_array, axis=None) + + assert ak.sum(named_array, axis="x").named_axis == {"y": -1} + assert ak.sum(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.sum(named_array, axis=None)) + + +def test_named_axis_ak_to_backend(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.to_backend(named_array, "typetracer").named_axis == named_array.named_axis + + +def test_named_axis_ak_to_packed(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.to_packed(array) == ak.to_packed(named_array)) + + assert ak.to_packed(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_unflatten(): + array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + counts = ak.Array([2, 2, 1, 2, 1, 1]) + + assert ak.all( + ak.unflatten(array, counts, axis=1) + == ak.unflatten(named_array, counts, axis="y") + ) + assert not _get_named_axis(ak.unflatten(named_array, counts, axis="y")) + + +def test_named_axis_ak_unzip(): + array = ak.Array( + [ + {"x": 1.1, "y": [1]}, + {"x": 2.2, "y": [2, 2]}, + {"x": 3.3, "y": [3, 3, 3]}, + ] + ) + named_array = ak.with_named_axis(array, ("x", "y")) + x, y = ak.unzip(named_array) + assert x.named_axis == y.named_axis == {"x": 0, "y": 1} + + +def test_named_axis_ak_values_astype(): + array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all( + ak.values_astype(array, np.float32) == ak.values_astype(named_array, np.float32) + ) + + assert ( + ak.values_astype(named_array, np.float32).named_axis == named_array.named_axis + ) + + +def test_named_axis_ak_var(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.var(array, axis=0) == ak.var(named_array, axis="x")) + assert ak.all(ak.var(array, axis=1) == ak.var(named_array, axis="y")) + assert ak.var(array, axis=None) == ak.var(named_array, axis=None) + + assert ak.var(named_array, axis="x").named_axis == {"y": 0} + assert ak.var(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.var(named_array, axis=None)) + + +def test_negative_named_axis_ak_var(): + array = ak.Array([[1, 2], [3], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, {"x": -2, "y": -1}) + + assert ak.all(ak.var(array, axis=-2) == ak.var(named_array, axis="x")) + assert ak.all(ak.var(array, axis=-1) == ak.var(named_array, axis="y")) + assert ak.var(array, axis=None) == ak.var(named_array, axis=None) + + assert ak.var(named_array, axis="x").named_axis == {"y": -1} + assert ak.var(named_array, axis="y").named_axis == {"x": -1} + assert not _get_named_axis(ak.var(named_array, axis=None)) + + +def test_named_axis_ak_where(): + a = ak.Array([[1, 2], [3, 4]]) + na = ak.with_named_axis(a, ("x", "y")) + + assert ak.all(ak.where(a > 2, 0, 1) == ak.where(na > 2, 0, 1)) + assert ak.where(na > 2, 0, 1).named_axis == {"x": 0, "y": 1} + assert ak.where(na > 2, na, 1).named_axis == {"x": 0, "y": 1} + + nb = ak.with_named_axis(a, ("a", "b")) + with pytest.raises(ValueError): + _ = ak.where(na > 2, nb, 1) + + +def test_named_axis_ak_with_field(): + array = ak.Array( + [ + {"x": 1.1, "y": [1]}, + {"x": 2.2, "y": [2, 2]}, + {"x": 3.3, "y": [3, 3, 3]}, + ] + ) + named_array = ak.with_named_axis(array, ("x", "y")) + xyz = ak.with_field(named_array, ak.Array([[1], [2], [3]]), "z") + x, y, z = ak.unzip(xyz) + assert x.named_axis == y.named_axis == z.named_axis == {"x": 0, "y": 1} + + named_z = ak.with_named_axis(ak.Array([[1], [2], [3]]), ("a", "b")) + with pytest.raises(ValueError): + ak.with_field(named_array, named_z, "z") + + +def test_named_axis_ak_with_name(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.with_name(named_array, "new_name").named_axis == named_array.named_axis + + +def test_named_axis_ak_with_named_axis(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + # tuple + named_array = ak.with_named_axis(array, ("x", "y")) + assert named_array.named_axis == {"x": 0, "y": 1} + + # dict + named_array = ak.with_named_axis(array, {"x": 0, "y": -1}) + assert named_array.named_axis == {"x": 0, "y": -1} + + +def test_named_axis_ak_with_parameter(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ( + ak.with_parameter(named_array, "param", 1.0).named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_without_parameters(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + named_array_with_parameteter = ak.with_parameter(named_array, "param", 1.0) + + assert ( + ak.without_parameters(named_array_with_parameteter).named_axis + == named_array.named_axis + ) + + +def test_named_axis_ak_zeros_like(): + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.zeros_like(array) == ak.zeros_like(named_array)) + + assert ak.zeros_like(named_array).named_axis == named_array.named_axis + + +def test_named_axis_ak_zip(): + named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("x",)) + named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y")) + + assert ak.zip({"x": named_array1, "y": named_array2}).named_axis == {"x": 0, "y": 1} + + named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("a",)) + named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y")) + + with pytest.raises(ValueError): + _ = ak.zip({"x": named_array1, "y": named_array2})