From 83c0aa6787fd338376cad8171f8443dad2c5c3a2 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 9 Oct 2024 10:37:15 -0400 Subject: [PATCH] regularize axis makes sure now that its type is either int or None --- src/awkward/_regularize.py | 20 +++++++++++++------ src/awkward/operations/ak_all.py | 7 ++----- src/awkward/operations/ak_any.py | 7 ++----- src/awkward/operations/ak_argcartesian.py | 3 --- src/awkward/operations/ak_argcombinations.py | 10 ++++++++-- src/awkward/operations/ak_argmax.py | 7 ++----- src/awkward/operations/ak_argmin.py | 7 ++----- src/awkward/operations/ak_argsort.py | 7 ++----- src/awkward/operations/ak_cartesian.py | 3 +-- src/awkward/operations/ak_combinations.py | 4 ++-- src/awkward/operations/ak_concatenate.py | 3 +-- src/awkward/operations/ak_corr.py | 7 +++++-- src/awkward/operations/ak_count.py | 7 ++----- src/awkward/operations/ak_count_nonzero.py | 7 ++----- src/awkward/operations/ak_covar.py | 6 ++++-- src/awkward/operations/ak_drop_none.py | 7 ++----- src/awkward/operations/ak_fill_none.py | 7 ++----- src/awkward/operations/ak_firsts.py | 7 ++----- src/awkward/operations/ak_flatten.py | 3 +-- src/awkward/operations/ak_from_regular.py | 2 +- src/awkward/operations/ak_is_none.py | 7 ++----- src/awkward/operations/ak_linear_fit.py | 3 --- src/awkward/operations/ak_local_index.py | 7 ++----- src/awkward/operations/ak_max.py | 7 ++----- src/awkward/operations/ak_mean.py | 15 +++++--------- .../operations/ak_merge_option_of_records.py | 4 ++-- .../operations/ak_merge_union_of_records.py | 4 ++-- src/awkward/operations/ak_min.py | 7 ++----- src/awkward/operations/ak_moment.py | 3 --- src/awkward/operations/ak_num.py | 9 +++------ src/awkward/operations/ak_pad_none.py | 7 ++----- src/awkward/operations/ak_prod.py | 7 ++----- src/awkward/operations/ak_ptp.py | 7 ++----- src/awkward/operations/ak_singletons.py | 7 ++----- src/awkward/operations/ak_softmax.py | 3 +-- src/awkward/operations/ak_sort.py | 7 ++----- src/awkward/operations/ak_std.py | 7 ++----- src/awkward/operations/ak_sum.py | 7 ++----- src/awkward/operations/ak_to_regular.py | 2 +- src/awkward/operations/ak_unflatten.py | 7 ++----- src/awkward/operations/ak_var.py | 3 +-- 41 files changed, 96 insertions(+), 165 deletions(-) diff --git a/src/awkward/_regularize.py b/src/awkward/_regularize.py index 0dca836cb6..6f78a18409 100644 --- a/src/awkward/_regularize.py +++ b/src/awkward/_regularize.py @@ -51,11 +51,19 @@ def is_non_string_like_sequence(obj) -> bool: return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence) -def regularize_axis(axis: Any) -> int | Any: +def regularize_axis(axis: Any, none_allowed: bool = True) -> int | None: """ - This function's purpose is to convert "0" to 0, "1" to 1, etc., but leave any other value as it is. + This function's main purpose is to convert [np,cp,...].array(0) to 0. """ - try: - return int(axis) - except (TypeError, ValueError): - return axis + if is_integer_like(axis): + regularized_axis = int(axis) + else: + regularized_axis = axis + cond = is_integer(regularized_axis) + msg = f"'axis' must be an integer, not {axis!r}" + if none_allowed: + cond = cond or regularized_axis is None + msg = f"'axis' must be an integer or None, not {axis!r}" + if not cond: + raise TypeError(msg) + return regularized_axis diff --git a/src/awkward/operations/ak_all.py b/src/awkward/operations/ak_all.py index 985795c858..98a22520ba 100644 --- a/src/awkward/operations/ak_all.py +++ b/src/awkward/operations/ak_all.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("all",) @@ -76,8 +76,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -93,8 +91,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.All() diff --git a/src/awkward/operations/ak_any.py b/src/awkward/operations/ak_any.py index ac71938b54..e99065d97c 100644 --- a/src/awkward/operations/ak_any.py +++ b/src/awkward/operations/ak_any.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("any",) @@ -76,8 +76,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -93,8 +91,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.Any() diff --git a/src/awkward/operations/ak_argcartesian.py b/src/awkward/operations/ak_argcartesian.py index 12deed5749..f012290cbe 100644 --- a/src/awkward/operations/ak_argcartesian.py +++ b/src/awkward/operations/ak_argcartesian.py @@ -7,7 +7,6 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis __all__ = ("argcartesian",) @@ -107,8 +106,6 @@ def argcartesian( def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs): - axis = regularize_axis(axis) - if isinstance(arrays, Mapping): index_arrays = {n: ak.local_index(x, axis) for n, x in arrays.items()} else: diff --git a/src/awkward/operations/ak_argcombinations.py b/src/awkward/operations/ak_argcombinations.py index c2b4793c95..337f77cec1 100644 --- a/src/awkward/operations/ak_argcombinations.py +++ b/src/awkward/operations/ak_argcombinations.py @@ -5,6 +5,7 @@ import awkward as ak from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext +from awkward._namedaxis import _get_named_axis, _named_axis_to_positional_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -93,8 +94,6 @@ def _impl( behavior, attrs, ): - axis = regularize_axis(axis) - if parameters is None: parameters = {} else: @@ -102,6 +101,13 @@ def _impl( if with_name is not None: parameters["__record__"] = with_name + # Handle named axis + named_axis = _get_named_axis(array) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) + + axis = regularize_axis(axis, none_allowed=False) + if axis < 0: raise ValueError("the 'axis' for argcombinations must be non-negative") else: diff --git a/src/awkward/operations/ak_argmax.py b/src/awkward/operations/ak_argmax.py index 397c7bef4f..ef9b37e57c 100644 --- a/src/awkward/operations/ak_argmax.py +++ b/src/awkward/operations/ak_argmax.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("argmax", "nanargmax") @@ -141,8 +141,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -158,8 +156,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.ArgMax() diff --git a/src/awkward/operations/ak_argmin.py b/src/awkward/operations/ak_argmin.py index f1e46fe415..6982a4d407 100644 --- a/src/awkward/operations/ak_argmin.py +++ b/src/awkward/operations/ak_argmin.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("argmin", "nanargmin") @@ -138,8 +138,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -155,8 +153,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.ArgMin() diff --git a/src/awkward/operations/ak_argsort.py b/src/awkward/operations/ak_argsort.py index 4ecfb1d2af..7c92d6a645 100644 --- a/src/awkward/operations/ak_argsort.py +++ b/src/awkward/operations/ak_argsort.py @@ -11,7 +11,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("argsort",) @@ -77,15 +77,12 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) out = ak._do.argsort(layout, axis, ascending, stable) diff --git a/src/awkward/operations/ak_cartesian.py b/src/awkward/operations/ak_cartesian.py index 6a82b59d1d..0f46f449c9 100644 --- a/src/awkward/operations/ak_cartesian.py +++ b/src/awkward/operations/ak_cartesian.py @@ -251,11 +251,10 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr # use strategy "unify" (see: awkward._namedaxis) out_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays)) - axis = regularize_axis(axis) - # Handle named axis # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(out_named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) max_ndim = max(layout.minmax_depth[1] for layout in layouts) if with_name is not None: diff --git a/src/awkward/operations/ak_combinations.py b/src/awkward/operations/ak_combinations.py index 900a3fd31c..284023f2cd 100644 --- a/src/awkward/operations/ak_combinations.py +++ b/src/awkward/operations/ak_combinations.py @@ -221,13 +221,13 @@ def _impl( with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) + if with_name is None: pass elif parameters is None: diff --git a/src/awkward/operations/ak_concatenate.py b/src/awkward/operations/ak_concatenate.py index a4155d834b..3e086f7e8c 100644 --- a/src/awkward/operations/ak_concatenate.py +++ b/src/awkward/operations/ak_concatenate.py @@ -128,12 +128,11 @@ def _impl(arrays, axis, mergebool, highlevel, behavior, attrs): ) ) - axis = regularize_axis(axis) - # Handle named axis merged_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays)) # Step 1: normalize named axis to positional axis axis = _named_axis_to_positional_axis(merged_named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) # Step 2: propagate named axis from input to output, # use strategy "unify" (see: awkward._namedaxis) out_named_axis = merged_named_axis diff --git a/src/awkward/operations/ak_corr.py b/src/awkward/operations/ak_corr.py index c468d72648..e646a43b0f 100644 --- a/src/awkward/operations/ak_corr.py +++ b/src/awkward/operations/ak_corr.py @@ -10,7 +10,7 @@ ensure_same_backend, maybe_highlevel_to_lowlevel, ) -from awkward._namedaxis import _get_named_axis +from awkward._namedaxis import _get_named_axis, _is_valid_named_axis from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -88,7 +88,10 @@ def corr( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + if _is_valid_named_axis(axis): + raise NotImplementedError("named axis not yet supported for ak.corr") + + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( diff --git a/src/awkward/operations/ak_count.py b/src/awkward/operations/ak_count.py index 359ff3b739..f9b8c48481 100644 --- a/src/awkward/operations/ak_count.py +++ b/src/awkward/operations/ak_count.py @@ -12,7 +12,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("count",) @@ -118,8 +118,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -135,8 +133,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.Count() diff --git a/src/awkward/operations/ak_count_nonzero.py b/src/awkward/operations/ak_count_nonzero.py index ebe44241e3..74a8b23033 100644 --- a/src/awkward/operations/ak_count_nonzero.py +++ b/src/awkward/operations/ak_count_nonzero.py @@ -12,7 +12,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("count_nonzero",) @@ -77,8 +77,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -94,8 +92,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") diff --git a/src/awkward/operations/ak_covar.py b/src/awkward/operations/ak_covar.py index 06029c2405..7c8fe930fe 100644 --- a/src/awkward/operations/ak_covar.py +++ b/src/awkward/operations/ak_covar.py @@ -10,7 +10,7 @@ ensure_same_backend, maybe_highlevel_to_lowlevel, ) -from awkward._namedaxis import _get_named_axis +from awkward._namedaxis import _get_named_axis, _is_valid_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import regularize_axis @@ -85,7 +85,9 @@ def covar( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) + if _is_valid_named_axis(axis): + raise NotImplementedError("named axis not yet supported for ak.covar") + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( diff --git a/src/awkward/operations/ak_drop_none.py b/src/awkward/operations/ak_drop_none.py index c83ea36db5..d81770f78f 100644 --- a/src/awkward/operations/ak_drop_none.py +++ b/src/awkward/operations/ak_drop_none.py @@ -10,7 +10,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("drop_none",) @@ -72,15 +72,12 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) if axis is None: # if the outer layout is_option, drop_nones without affecting offsets diff --git a/src/awkward/operations/ak_fill_none.py b/src/awkward/operations/ak_fill_none.py index 3a008909e3..fb3dbfd019 100644 --- a/src/awkward/operations/ak_fill_none.py +++ b/src/awkward/operations/ak_fill_none.py @@ -10,7 +10,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("fill_none",) @@ -86,15 +86,12 @@ def _impl(array, value, axis, highlevel, behavior, attrs): ), ) - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) if isinstance(value_layout, ak.record.Record): value_layout = value_layout.array[value_layout.at : value_layout.at + 1] diff --git a/src/awkward/operations/ak_firsts.py b/src/awkward/operations/ak_firsts.py index 24670afbbb..79fba6eb51 100644 --- a/src/awkward/operations/ak_firsts.py +++ b/src/awkward/operations/ak_firsts.py @@ -11,7 +11,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("firsts",) @@ -64,8 +64,6 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False) - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -78,8 +76,7 @@ def _impl(array, axis, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) if maybe_posaxis(layout, axis, 1) == 0: # specialized logic; it's tested in test_0582-propagate-context-in-broadcast_and_apply.py diff --git a/src/awkward/operations/ak_flatten.py b/src/awkward/operations/ak_flatten.py index 4699cfcead..3805d28e71 100644 --- a/src/awkward/operations/ak_flatten.py +++ b/src/awkward/operations/ak_flatten.py @@ -182,12 +182,11 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) # Step 2: propagate named axis from input to output, # if axis == None: use strategy "remove all" (see: awkward._namedaxis) # if axis == 0: use strategy "keep all" (see: awkward._namedaxis) diff --git a/src/awkward/operations/ak_from_regular.py b/src/awkward/operations/ak_from_regular.py index 624297928a..9fe2800a2b 100644 --- a/src/awkward/operations/ak_from_regular.py +++ b/src/awkward/operations/ak_from_regular.py @@ -56,7 +56,7 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False) - axis = regularize_axis(axis) + axis = regularize_axis(axis, none_allowed=True) if axis is None: diff --git a/src/awkward/operations/ak_is_none.py b/src/awkward/operations/ak_is_none.py index 64acad2c65..d9a58a5478 100644 --- a/src/awkward/operations/ak_is_none.py +++ b/src/awkward/operations/ak_is_none.py @@ -11,7 +11,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("is_none",) @@ -49,15 +49,12 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) # Step 2: propagate named axis from input to output, # use strategy "keep up to" (see: awkward._namedaxis) diff --git a/src/awkward/operations/ak_linear_fit.py b/src/awkward/operations/ak_linear_fit.py index d751907427..01ac0f3297 100644 --- a/src/awkward/operations/ak_linear_fit.py +++ b/src/awkward/operations/ak_linear_fit.py @@ -7,7 +7,6 @@ from awkward._layout import HighLevelContext, ensure_same_backend from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis __all__ = ("linear_fit",) @@ -95,8 +94,6 @@ def linear_fit( def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, y_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), diff --git a/src/awkward/operations/ak_local_index.py b/src/awkward/operations/ak_local_index.py index e98d789f78..d5e7089dbc 100644 --- a/src/awkward/operations/ak_local_index.py +++ b/src/awkward/operations/ak_local_index.py @@ -11,7 +11,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("local_index",) @@ -96,15 +96,12 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) # Step 2: propagate named axis from input to output, # use strategy "keep up to" (see: awkward._namedaxis) diff --git a/src/awkward/operations/ak_max.py b/src/awkward/operations/ak_max.py index b22d981759..319b2c7bed 100644 --- a/src/awkward/operations/ak_max.py +++ b/src/awkward/operations/ak_max.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("max", "nanmax") @@ -151,8 +151,6 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -168,8 +166,7 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.Max(initial) diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py index 354a6db4e3..a9b38ce1f0 100644 --- a/src/awkward/operations/ak_mean.py +++ b/src/awkward/operations/ak_mean.py @@ -14,11 +14,10 @@ ) from awkward._namedaxis import ( _get_named_axis, - _is_valid_named_axis, _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("mean", "nanmean") @@ -195,16 +194,12 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) - axis = regularize_axis(axis) - # Handle named axis - if named_axis := _get_named_axis(ctx): - if _is_valid_named_axis(axis): - # Step 1: Normalize named axis to positional axis - axis = _named_axis_to_positional_axis(named_axis, axis) + named_axis = _get_named_axis(ctx) + # Step 1: Normalize named axis to positional axis + axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) with np.errstate(invalid="ignore", divide="ignore"): if weight is None: diff --git a/src/awkward/operations/ak_merge_option_of_records.py b/src/awkward/operations/ak_merge_option_of_records.py index 6648e07e64..17402e77a6 100644 --- a/src/awkward/operations/ak_merge_option_of_records.py +++ b/src/awkward/operations/ak_merge_option_of_records.py @@ -56,12 +56,12 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) + # First, normalise type-invsible "index-of-records" to "record-of-index" def apply_displace_index(layout, backend, **kwargs): if (layout.is_indexed and not layout.is_option) and layout.content.is_record: diff --git a/src/awkward/operations/ak_merge_union_of_records.py b/src/awkward/operations/ak_merge_union_of_records.py index 33b5007b32..0094203947 100644 --- a/src/awkward/operations/ak_merge_union_of_records.py +++ b/src/awkward/operations/ak_merge_union_of_records.py @@ -66,12 +66,12 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=False) + def invert_record_union( tags: ArrayLike, index: ArrayLike, contents ) -> ak.contents.RecordArray: diff --git a/src/awkward/operations/ak_min.py b/src/awkward/operations/ak_min.py index 405f757e7e..1b9189f740 100644 --- a/src/awkward/operations/ak_min.py +++ b/src/awkward/operations/ak_min.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("min", "nanmin") @@ -151,8 +151,6 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -168,8 +166,7 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.Min(initial) diff --git a/src/awkward/operations/ak_moment.py b/src/awkward/operations/ak_moment.py index 46a9003c4b..2c8e29adb1 100644 --- a/src/awkward/operations/ak_moment.py +++ b/src/awkward/operations/ak_moment.py @@ -15,7 +15,6 @@ _get_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import regularize_axis from awkward._typing import Mapping __all__ = ("moment",) @@ -103,8 +102,6 @@ def _impl( behavior: Mapping | None, attrs: Mapping | None, ): - axis = regularize_axis(axis) - with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout, weight_layout = ensure_same_backend( ctx.unwrap(x, allow_record=False, primitive_policy="error"), diff --git a/src/awkward/operations/ak_num.py b/src/awkward/operations/ak_num.py index 45991f0d53..705a1e1c63 100644 --- a/src/awkward/operations/ak_num.py +++ b/src/awkward/operations/ak_num.py @@ -11,7 +11,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward._typing import Mapping from awkward.errors import AxisError @@ -32,7 +32,7 @@ def num( """ Args: array: Array-like data (anything #ak.to_layout recognizes). - axis (AxisName): The dimension at which this operation is applied. The + axis (int): The dimension at which this operation is applied. The outermost dimension is `0`, followed by `1`, etc., and negative values count backward from the innermost: `-1` is the innermost dimension, `-2` is the next level up, etc. @@ -106,8 +106,6 @@ def _impl( with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -116,8 +114,7 @@ def _impl( # use strategy "keep one" (see: awkward._namedaxis) out_named_axis = _keep_named_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) if maybe_posaxis(layout, axis, 1) == 0: index_nplike = layout.backend.index_nplike diff --git a/src/awkward/operations/ak_pad_none.py b/src/awkward/operations/ak_pad_none.py index f3fcec07c7..17bb3035ac 100644 --- a/src/awkward/operations/ak_pad_none.py +++ b/src/awkward/operations/ak_pad_none.py @@ -10,7 +10,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("pad_none",) @@ -120,14 +120,11 @@ def _impl(array, target, axis, clip, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) out = ak._do.pad_none(layout, target, axis, clip=clip) diff --git a/src/awkward/operations/ak_prod.py b/src/awkward/operations/ak_prod.py index 5f46aaac4b..d3d1a050c3 100644 --- a/src/awkward/operations/ak_prod.py +++ b/src/awkward/operations/ak_prod.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("prod", "nanprod") @@ -128,8 +128,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -145,8 +143,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.Prod() diff --git a/src/awkward/operations/ak_ptp.py b/src/awkward/operations/ak_ptp.py index 8389d5d0e2..6d4beafbd5 100644 --- a/src/awkward/operations/ak_ptp.py +++ b/src/awkward/operations/ak_ptp.py @@ -16,7 +16,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("ptp",) @@ -91,15 +91,12 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) with np.errstate(invalid="ignore", divide="ignore"): maxi = ak.operations.ak_max._impl( diff --git a/src/awkward/operations/ak_singletons.py b/src/awkward/operations/ak_singletons.py index a98fac633a..4de6a59151 100644 --- a/src/awkward/operations/ak_singletons.py +++ b/src/awkward/operations/ak_singletons.py @@ -11,7 +11,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis from awkward.errors import AxisError __all__ = ("singletons",) @@ -64,15 +64,12 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) # Step 2: propagate named axis from input to output, # use strategy "add one" (see: awkward._namedaxis) diff --git a/src/awkward/operations/ak_softmax.py b/src/awkward/operations/ak_softmax.py index 4b9223f7c5..b2cb11bff0 100644 --- a/src/awkward/operations/ak_softmax.py +++ b/src/awkward/operations/ak_softmax.py @@ -84,12 +84,11 @@ def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout = ctx.unwrap(x, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) x = ctx.wrap(x_layout) diff --git a/src/awkward/operations/ak_sort.py b/src/awkward/operations/ak_sort.py index 0b3c812c06..0864fc5d98 100644 --- a/src/awkward/operations/ak_sort.py +++ b/src/awkward/operations/ak_sort.py @@ -11,7 +11,7 @@ _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("sort",) @@ -66,15 +66,12 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) out = ak._do.sort(layout, axis, ascending, stable) diff --git a/src/awkward/operations/ak_std.py b/src/awkward/operations/ak_std.py index a40b3c05ec..7926b341fe 100644 --- a/src/awkward/operations/ak_std.py +++ b/src/awkward/operations/ak_std.py @@ -18,7 +18,7 @@ ) from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("std", "nanstd") @@ -185,15 +185,12 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) with np.errstate(invalid="ignore", divide="ignore"): out = ufuncs.sqrt( diff --git a/src/awkward/operations/ak_sum.py b/src/awkward/operations/ak_sum.py index 1f2e82d6f9..ae6a40aef8 100644 --- a/src/awkward/operations/ak_sum.py +++ b/src/awkward/operations/ak_sum.py @@ -13,7 +13,7 @@ _remove_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._regularize import is_integer, regularize_axis +from awkward._regularize import regularize_axis __all__ = ("sum", "nansum") @@ -278,8 +278,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis @@ -295,8 +293,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): total=layout.minmax_depth[1], ) - if not is_integer(axis) and axis is not None: - raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=True) reducer = ak._reducers.Sum() diff --git a/src/awkward/operations/ak_to_regular.py b/src/awkward/operations/ak_to_regular.py index b72e48d7c5..ae9f9cc3da 100644 --- a/src/awkward/operations/ak_to_regular.py +++ b/src/awkward/operations/ak_to_regular.py @@ -66,7 +66,7 @@ def to_regular(array, axis=1, *, highlevel=True, behavior=None, attrs=None): def _impl(array, axis, highlevel, behavior, attrs): - axis = regularize_axis(axis) + axis = regularize_axis(axis, none_allowed=True) with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") diff --git a/src/awkward/operations/ak_unflatten.py b/src/awkward/operations/ak_unflatten.py index 1f5e201312..83a3b8f2b4 100644 --- a/src/awkward/operations/ak_unflatten.py +++ b/src/awkward/operations/ak_unflatten.py @@ -13,7 +13,7 @@ from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length from awkward._nplikes.typetracer import is_unknown_scalar -from awkward._regularize import is_integer, is_integer_like, regularize_axis +from awkward._regularize import is_integer_like, regularize_axis __all__ = ("unflatten",) @@ -107,15 +107,12 @@ def _impl(array, counts, axis, highlevel, behavior, attrs): ), ) - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) - if not is_integer(axis): - raise TypeError(f"'axis' must be an integer by now, not {axis!r}") + axis = regularize_axis(axis, none_allowed=False) if is_integer_like(maybe_counts_layout): # Regularize unknown values to unknown lengths diff --git a/src/awkward/operations/ak_var.py b/src/awkward/operations/ak_var.py index 4774977d31..759f5edf1c 100644 --- a/src/awkward/operations/ak_var.py +++ b/src/awkward/operations/ak_var.py @@ -190,12 +190,11 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) - axis = regularize_axis(axis) - # Handle named axis named_axis = _get_named_axis(ctx) # Step 1: Normalize named axis to positional axis axis = _named_axis_to_positional_axis(named_axis, axis) + axis = regularize_axis(axis, none_allowed=True) with np.errstate(invalid="ignore", divide="ignore"): if weight is None: