From 85cf2285ac56230cf2c79c30ffc8a4727a057dc6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 19 Oct 2025 19:48:09 +0200 Subject: [PATCH 1/4] FIX: Wrap torch.argsort to set stable=True by default --- array_api_compat/torch/_aliases.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..7810e057 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -241,6 +241,21 @@ def sort( ) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values + +# Wrap torch.argsort to set stable=True by default +def argsort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: + + return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs) + + def _normalize_axes(axis, ndim): axes = [] if ndim == 0 and axis: From 31e65041f2605040a91f9d56ff6e7ef8fb671daf Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Sun, 19 Oct 2025 19:52:16 +0200 Subject: [PATCH 2/4] Apply suggestion from @Copilot Remove the empty line with trailing whitespace inside the function body. This line serves no purpose and should be deleted. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- array_api_compat/torch/_aliases.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7810e057..715182a1 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -252,7 +252,6 @@ def argsort( stable: bool = True, **kwargs: object, ) -> Array: - return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs) From 1233b7bf65f382fa5153e8c9190692e67cdd89cc Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 19 Oct 2025 19:58:18 +0200 Subject: [PATCH 3/4] fix linting --- array_api_compat/torch/fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 76342980..f11b3eb5 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Literal -import torch +import torch # noqa: F401 import torch.fft from ._typing import Array From 1fafddae1633f1141483c21877beff8f9e9729b5 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 28 Oct 2025 13:47:24 +0100 Subject: [PATCH 4/4] added to aliases --- array_api_compat/torch/_aliases.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 715182a1..23dafde9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -851,9 +851,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', - 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum', - 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', - 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', + 'argsort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', + 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',