From 6010819f457e26c93f8cf347907013aed2884120 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 10 Nov 2023 13:47:45 -0800 Subject: [PATCH] fix sorts --- jax/experimental/array_api/_sorting_functions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/jax/experimental/array_api/_sorting_functions.py b/jax/experimental/array_api/_sorting_functions.py index d5c9e4be2506..139593f203cf 100644 --- a/jax/experimental/array_api/_sorting_functions.py +++ b/jax/experimental/array_api/_sorting_functions.py @@ -20,16 +20,17 @@ def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: """Returns the indices that sort an array x along a specified axis.""" del stable # unused - result = jax.numpy.argsort(x, axis=axis) if descending: - return jax.lax.rev(result, dimensions=[axis + x.size if axis < 0 else axis]) - return result + return jax.numpy.argsort(-x, axis=axis) + else: + return jax.numpy.argsort(x, axis=axis) def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: """Returns a sorted copy of an input array x.""" + del stable # unused result = jax.numpy.sort(x, axis=axis) if descending: - return jax.lax.rev(result, dimensions=[axis + x.size if axis < 0 else axis]) + return jax.lax.rev(result, dimensions=[axis + x.ndim if axis < 0 else axis]) return result