Skip to content

Commit

Permalink
regularize axis makes sure now that its type is either int or None
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Oct 9, 2024
1 parent 595b240 commit 83c0aa6
Show file tree
Hide file tree
Showing 41 changed files with 96 additions and 165 deletions.
20 changes: 14 additions & 6 deletions src/awkward/_regularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ def is_non_string_like_sequence(obj) -> bool:
return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence)


def regularize_axis(axis: Any) -> int | Any:
def regularize_axis(axis: Any, none_allowed: bool = True) -> int | None:
"""
This function's purpose is to convert "0" to 0, "1" to 1, etc., but leave any other value as it is.
This function's main purpose is to convert [np,cp,...].array(0) to 0.
"""
try:
return int(axis)
except (TypeError, ValueError):
return axis
if is_integer_like(axis):
regularized_axis = int(axis)
else:
regularized_axis = axis
cond = is_integer(regularized_axis)
msg = f"'axis' must be an integer, not {axis!r}"
if none_allowed:
cond = cond or regularized_axis is None
msg = f"'axis' must be an integer or None, not {axis!r}"
if not cond:
raise TypeError(msg)
return regularized_axis
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_remove_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("all",)

Expand Down Expand Up @@ -76,8 +76,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
Expand All @@ -93,8 +91,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
total=layout.minmax_depth[1],
)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=True)

reducer = ak._reducers.All()

Expand Down
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_remove_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("any",)

Expand Down Expand Up @@ -76,8 +76,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
Expand All @@ -93,8 +91,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
total=layout.minmax_depth[1],
)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=True)

reducer = ak._reducers.Any()

Expand Down
3 changes: 0 additions & 3 deletions src/awkward/operations/ak_argcartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis

__all__ = ("argcartesian",)

Expand Down Expand Up @@ -107,8 +106,6 @@ def argcartesian(


def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs):
axis = regularize_axis(axis)

if isinstance(arrays, Mapping):
index_arrays = {n: ak.local_index(x, axis) for n, x in arrays.items()}
else:
Expand Down
10 changes: 8 additions & 2 deletions src/awkward/operations/ak_argcombinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _get_named_axis, _named_axis_to_positional_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis

Expand Down Expand Up @@ -93,15 +94,20 @@ def _impl(
behavior,
attrs,
):
axis = regularize_axis(axis)

if parameters is None:
parameters = {}
else:
parameters = dict(parameters)
if with_name is not None:
parameters["__record__"] = with_name

# Handle named axis
named_axis = _get_named_axis(array)
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(named_axis, axis)

axis = regularize_axis(axis, none_allowed=False)

if axis < 0:
raise ValueError("the 'axis' for argcombinations must be non-negative")
else:
Expand Down
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_remove_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("argmax", "nanargmax")

Expand Down Expand Up @@ -141,8 +141,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
Expand All @@ -158,8 +156,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
total=layout.minmax_depth[1],
)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=True)

reducer = ak._reducers.ArgMax()

Expand Down
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_remove_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("argmin", "nanargmin")

Expand Down Expand Up @@ -138,8 +138,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
Expand All @@ -155,8 +153,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
total=layout.minmax_depth[1],
)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=True)

reducer = ak._reducers.ArgMin()

Expand Down
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
_named_axis_to_positional_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("argsort",)

Expand Down Expand Up @@ -77,15 +77,12 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(named_axis, axis)

if not is_integer(axis):
raise TypeError(f"'axis' must be an integer by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=False)

out = ak._do.argsort(layout, axis, ascending, stable)

Expand Down
3 changes: 1 addition & 2 deletions src/awkward/operations/ak_cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,10 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
# use strategy "unify" (see: awkward._namedaxis)
out_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays))

axis = regularize_axis(axis)

# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(out_named_axis, axis)
axis = regularize_axis(axis, none_allowed=False)
max_ndim = max(layout.minmax_depth[1] for layout in layouts)

if with_name is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/operations/ak_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def _impl(
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(named_axis, axis)

axis = regularize_axis(axis, none_allowed=False)

if with_name is None:
pass
elif parameters is None:
Expand Down
3 changes: 1 addition & 2 deletions src/awkward/operations/ak_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,11 @@ def _impl(arrays, axis, mergebool, highlevel, behavior, attrs):
)
)

axis = regularize_axis(axis)

# Handle named axis
merged_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays))
# Step 1: normalize named axis to positional axis
axis = _named_axis_to_positional_axis(merged_named_axis, axis)
axis = regularize_axis(axis, none_allowed=False)
# Step 2: propagate named axis from input to output,
# use strategy "unify" (see: awkward._namedaxis)
out_named_axis = merged_named_axis
Expand Down
7 changes: 5 additions & 2 deletions src/awkward/operations/ak_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ensure_same_backend,
maybe_highlevel_to_lowlevel,
)
from awkward._namedaxis import _get_named_axis
from awkward._namedaxis import _get_named_axis, _is_valid_named_axis
from awkward._nplikes import ufuncs
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
Expand Down Expand Up @@ -88,7 +88,10 @@ def corr(


def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
axis = regularize_axis(axis)
if _is_valid_named_axis(axis):
raise NotImplementedError("named axis not yet supported for ak.corr")

axis = regularize_axis(axis, none_allowed=True)

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, y_layout, weight_layout = ensure_same_backend(
Expand Down
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_remove_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("count",)

Expand Down Expand Up @@ -118,8 +118,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
Expand All @@ -135,8 +133,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
total=layout.minmax_depth[1],
)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=True)

reducer = ak._reducers.Count()

Expand Down
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_count_nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_remove_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("count_nonzero",)

Expand Down Expand Up @@ -77,8 +77,6 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
Expand All @@ -94,8 +92,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
total=layout.minmax_depth[1],
)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=True)

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
Expand Down
6 changes: 4 additions & 2 deletions src/awkward/operations/ak_covar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ensure_same_backend,
maybe_highlevel_to_lowlevel,
)
from awkward._namedaxis import _get_named_axis
from awkward._namedaxis import _get_named_axis, _is_valid_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis

Expand Down Expand Up @@ -85,7 +85,9 @@ def covar(


def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
axis = regularize_axis(axis)
if _is_valid_named_axis(axis):
raise NotImplementedError("named axis not yet supported for ak.covar")
axis = regularize_axis(axis, none_allowed=True)

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, y_layout, weight_layout = ensure_same_backend(
Expand Down
7 changes: 2 additions & 5 deletions src/awkward/operations/ak_drop_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_named_axis_to_positional_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis
from awkward.errors import AxisError

__all__ = ("drop_none",)
Expand Down Expand Up @@ -72,15 +72,12 @@ def _impl(array, axis, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

axis = regularize_axis(axis)

# Handle named axis
named_axis = _get_named_axis(ctx)
# Step 1: Normalize named axis to positional axis
axis = _named_axis_to_positional_axis(named_axis, axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis, none_allowed=True)

if axis is None:
# if the outer layout is_option, drop_nones without affecting offsets
Expand Down
Loading

0 comments on commit 83c0aa6

Please sign in to comment.