Skip to content

Commit

Permalink
fix sorts
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 10, 2023
1 parent 331ce22 commit 6010819
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/experimental/array_api/_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ def argsort(x: Array, /, *, axis: int = -1, descending: bool = False,
stable: bool = True) -> Array:
"""Returns the indices that sort an array x along a specified axis."""
del stable # unused
result = jax.numpy.argsort(x, axis=axis)
if descending:
return jax.lax.rev(result, dimensions=[axis + x.size if axis < 0 else axis])
return result
return jax.numpy.argsort(-x, axis=axis)
else:
return jax.numpy.argsort(x, axis=axis)


def sort(x: Array, /, *, axis: int = -1, descending: bool = False,
stable: bool = True) -> Array:
"""Returns a sorted copy of an input array x."""
del stable # unused
result = jax.numpy.sort(x, axis=axis)
if descending:
return jax.lax.rev(result, dimensions=[axis + x.size if axis < 0 else axis])
return jax.lax.rev(result, dimensions=[axis + x.ndim if axis < 0 else axis])
return result

0 comments on commit 6010819

Please sign in to comment.