Looking for JAX? See SoftJAX.
SoftTorch provides soft differentiable drop-in replacements for traditionally non-differentiable functions in PyTorch, including
- elementwise operators:
abs,relu,clamp,sign,roundandheaviside; - tensor-valued operators:
(arg)max,(arg)min,(arg)quantile,(arg)median,(arg)sort,(arg)topkandrank; - comparison operators such as:
greater,eqorisclose; - logical operators such as:
logical_and,allorany; - functions for selection with indices such as:
where,take_along_dimorindex_select.
All operators offer multiple modes (controlling smoothness or boundedness of the relaxation) and adjustable softening strength.
All operators also support straight-through estimation, using the non-differentiable function in the forward pass and the soft relaxation in the backward pass.
SoftTorch functions are drop-in replacements for their non-differentiable PyTorch counterparts. Special care is needed for functions operating on indices, as we relax discrete indices into distributions over indices, which modifies the shape of returned/accepted values.
Requires Python 3.12+.
pip install softtorch
Available at https://a-paulus.github.io/softtorch/.
import torch
import softtorch as st
x = torch.tensor([-0.2, -1.0, 0.3, 1.0])
# Elementwise functions
print("\nTorch absolute:", torch.abs(x))
print("SoftTorch absolute (hard mode):", st.abs(x, mode="hard"))
print("SoftTorch absolute (soft mode):", st.abs(x))
print("\nTorch clamp:", torch.clamp(x, -0.5, 0.5))
print("SoftTorch clamp (hard mode):", st.clamp(x, -0.5, 0.5, mode="hard"))
print("SoftTorch clamp (soft mode):", st.clamp(x, -0.5, 0.5))
print("\nTorch heaviside:", torch.heaviside(x, torch.tensor(0.5)))
print("SoftTorch heaviside (hard mode):", st.heaviside(x, mode="hard"))
print("SoftTorch heaviside (soft mode):", st.heaviside(x))
print("\nTorch ReLU:", torch.nn.functional.relu(x))
print("SoftTorch ReLU (hard mode):", st.relu(x, mode="hard"))
print("SoftTorch ReLU (soft mode):", st.relu(x))
print("\nTorch round:", torch.round(x))
print("SoftTorch round (hard mode):", st.round(x, mode="hard"))
print("SoftTorch round (soft mode):", st.round(x))
print("\nTorch sign:", torch.sign(x))
print("SoftTorch sign (hard mode):", st.sign(x, mode="hard"))
print("SoftTorch sign (soft mode):", st.sign(x))Torch absolute: tensor([0.2000, 1.0000, 0.3000, 1.0000])
SoftTorch absolute (hard mode): tensor([0.2000, 1.0000, 0.3000, 1.0000])
SoftTorch absolute (soft mode): tensor([0.1523, 0.9999, 0.2715, 0.9999])
Torch clamp: tensor([-0.2000, -0.5000, 0.3000, 0.5000])
SoftTorch clamp (hard mode): tensor([-0.2000, -0.5000, 0.3000, 0.5000])
SoftTorch clamp (soft mode): tensor([-0.1952, -0.4993, 0.2873, 0.4993])
Torch heaviside: tensor([0., 0., 1., 1.])
SoftTorch heaviside (hard mode): tensor([0., 0., 1., 1.])
SoftTorch heaviside (soft mode): tensor([0.1192, 0.0000, 0.9526, 1.0000])
Torch ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000])
SoftTorch ReLU (hard mode): tensor([0.0000, 0.0000, 0.3000, 1.0000])
SoftTorch ReLU (soft mode): tensor([0.0127, 0.0000, 0.3049, 1.0000])
Torch round: tensor([-0., -1., 0., 1.])
SoftTorch round (hard mode): tensor([-0., -1., 0., 1.])
SoftTorch round (soft mode): tensor([-0.0465, -1.0000, 0.1189, 1.0000])
Torch sign: tensor([-1., -1., 1., 1.])
SoftTorch sign (hard mode): tensor([-1., -1., 1., 1.])
SoftTorch sign (soft mode): tensor([-0.7616, -0.9999, 0.9051, 0.9999])
# Tensor-valued operators
print("\nTorch max:", torch.max(x))
print("SoftTorch max (hard mode):", st.max(x, mode="hard"))
print("SoftTorch max (soft mode):", st.max(x))
print("\nTorch min:", torch.min(x))
print("SoftTorch min (hard mode):", st.min(x, mode="hard"))
print("SoftTorch min (soft mode):", st.min(x))
print("\nTorch sort:", torch.sort(x).values)
print("SoftTorch sort (hard mode):", st.sort(x, mode="hard").values)
print("SoftTorch sort (soft mode):", st.sort(x).values)
print("\nTorch quantile:", torch.quantile(x, q=0.2))
print("SoftTorch quantile (hard mode):", st.quantile(x, q=0.2, mode="hard"))
print("SoftTorch quantile (soft mode):", st.quantile(x, q=0.2))
print("\nTorch median:", torch.median(x))
print("SoftTorch median (hard mode):", st.median(x, mode="hard"))
print("SoftTorch median (soft mode):", st.median(x))
print("\nTorch topk:", torch.topk(x, k=3).values)
print("SoftTorch topk (hard mode):", st.topk(x, k=3, mode="hard").values)
print("SoftTorch topk (soft mode):", st.topk(x, k=3).values)
print("\nTorch rank:", torch.argsort(torch.argsort(x)))
print("SoftTorch rank (hard mode):", st.rank(x, mode="hard", descending=False))
print("SoftTorch rank (soft mode):", st.rank(x, descending=False))Torch max: tensor(1.)
SoftTorch max (hard mode): tensor(1.)
SoftTorch max (soft mode): tensor(0.8874)
Torch min: tensor(-1.)
SoftTorch min (hard mode): tensor(-1.)
SoftTorch min (soft mode): tensor(-0.8996)
Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
SoftTorch sort (hard mode): tensor([-1.0000, -0.2000, 0.3000, 1.0000])
SoftTorch sort (soft mode): tensor([-0.8792, -0.1641, 0.2767, 0.8738])
Torch quantile: tensor(-0.5200)
SoftTorch quantile (hard mode): tensor(-0.5200)
SoftTorch quantile (soft mode): tensor(-0.4501)
Torch median: tensor(-0.2000)
SoftTorch median (hard mode): tensor(-0.2000)
SoftTorch median (soft mode): tensor(-0.1641)
Torch topk: tensor([ 1.0000, 0.3000, -0.2000])
SoftTorch topk (hard mode): tensor([ 1.0000, 0.3000, -0.2000])
SoftTorch topk (soft mode): tensor([ 0.8738, 0.2767, -0.1641])
Torch rank: tensor([1, 0, 2, 3])
SoftTorch rank (hard mode): tensor([2., 1., 3., 4.])
SoftTorch rank (soft mode): tensor([1.9950, 1.0548, 3.0239, 3.9228])
# Sort: sweep over methods
print("\nTorch sort:", torch.sort(x).values)
print("SoftTorch sort (softsort):", st.sort(x, method="softsort", softness=0.1).values)
print("SoftTorch sort (neuralsort):", st.sort(x, method="neuralsort", softness=0.1).values)
print("SoftTorch sort (fast_soft_sort):", st.sort(x, method="fast_soft_sort", softness=2.0).values)
print("SoftTorch sort (ot):", st.sort(x, method="ot", softness=0.1).values)
print("SoftTorch sort (sorting_network):", st.sort(x, method="sorting_network", softness=0.1).values)
# Sort: sweep over modes
print("\nTorch sort:", torch.sort(x).values)
for mode in ["hard", "smooth", "c0", "c1", "c2"]:
print(f"SoftTorch sort ({mode}):", st.sort(x, softness=0.5, mode=mode).values)Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
SoftTorch sort (softsort): tensor([-0.8996, -0.1705, 0.2847, 0.8874])
SoftTorch sort (neuralsort): tensor([-0.8792, -0.1641, 0.2767, 0.8738])
SoftTorch sort (fast_soft_sort): tensor([-0.7462, -0.1971, 0.2938, 0.8569])
SoftTorch sort (ot): tensor([-0.7324, -0.2396, 0.3286, 0.7434])
SoftTorch sort (sorting_network): tensor([-0.7999, -0.2672, 0.3847, 0.7863])
Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
SoftTorch sort (hard): tensor([-1.0000, -0.2000, 0.3000, 1.0000])
SoftTorch sort (smooth): tensor([-0.6057, -0.1997, 0.2729, 0.6281])
SoftTorch sort (c0): tensor([-1.0000, -0.6313, 0.6525, 0.9824])
SoftTorch sort (c1): tensor([-0.9982, -0.5432, 0.5814, 0.9837])
SoftTorch sort (c2): tensor([-0.9978, -0.4905, 0.5425, 0.9903])
# Operators returning indices
print("\nTorch argmax:", torch.argmax(x))
print("SoftTorch argmax (hard mode):", st.argmax(x, mode="hard"))
print("SoftTorch argmax (soft mode):", st.argmax(x))
print("\nTorch argmin:", torch.argmin(x))
print("SoftTorch argmin (hard mode):", st.argmin(x, mode="hard"))
print("SoftTorch argmin (soft mode):", st.argmin(x))
print("\nTorch argquantile:", "Not implemented in standard PyTorch")
print("SoftTorch argquantile (hard mode):", st.argquantile(x, q=0.2, mode="hard"))
print("SoftTorch argquantile (soft mode):", st.argquantile(x, q=0.2))
print("\nTorch argmedian:", torch.median(x, dim=0).indices)
print("SoftTorch argmedian (hard mode):", st.median(x, mode="hard", dim=0).indices)
print("SoftTorch argmedian (soft mode):", st.median(x, dim=0).indices)
print("\nTorch argsort:", torch.argsort(x))
print("SoftTorch argsort (hard mode):", st.argsort(x, mode="hard"))
print("SoftTorch argsort (soft mode):", st.argsort(x))
print("\nTorch argtopk:", torch.topk(x, k=3).indices)
print("SoftTorch argtopk (hard mode):", st.topk(x, k=3, mode="hard").indices)
print("SoftTorch argtopk (soft mode):", st.topk(x, k=3).indices)Torch argmax: tensor(3)
SoftTorch argmax (hard mode): tensor([0., 0., 0., 1.])
SoftTorch argmax (soft mode): tensor([0.0215, 0.0022, 0.1176, 0.8586])
Torch argmin: tensor(1)
SoftTorch argmin (hard mode): tensor([0., 1., 0., 0.])
SoftTorch argmin (soft mode): tensor([0.0922, 0.8885, 0.0169, 0.0023])
Torch argquantile: Not implemented in standard PyTorch
SoftTorch argquantile (hard mode): tensor([0.6000, 0.4000, 0.0000, 0.0000])
SoftTorch argquantile (soft mode): tensor([0.5403, 0.3693, 0.0902, 0.0001])
Torch argmedian: tensor(0)
SoftTorch argmedian (hard mode): tensor([1., 0., 0., 0.])
SoftTorch argmedian (soft mode): tensor([0.8009, 0.0491, 0.1498, 0.0002])
Torch argsort: tensor([1, 0, 2, 3])
SoftTorch argsort (hard mode): tensor([[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
SoftTorch argsort (soft mode): tensor([[0.1494, 0.8496, 0.0009, 0.0000],
[0.8009, 0.0491, 0.1498, 0.0002],
[0.1418, 0.0001, 0.7899, 0.0681],
[0.0011, 0.0000, 0.1784, 0.8205]])
Torch argtopk: tensor([3, 2, 0])
SoftTorch argtopk (hard mode): tensor([[0., 0., 0., 1.],
[0., 0., 1., 0.],
[1., 0., 0., 0.]])
SoftTorch argtopk (soft mode): tensor([[0.0011, 0.0000, 0.1784, 0.8205],
[0.1418, 0.0001, 0.7899, 0.0681],
[0.8009, 0.0491, 0.1498, 0.0002]])
y = torch.tensor([0.2, -0.5, 0.5, -1.0])
# Comparison operators
print("\nTorch greater:", torch.greater(x, y))
print("SoftTorch greater (hard mode):", st.greater(x, y, mode="hard"))
print("SoftTorch greater (soft mode):", st.greater(x, y))
print("\nTorch greater equal:", torch.greater_equal(x, y))
print("SoftTorch greater equal (hard mode):", st.greater_equal(x, y, mode="hard"))
print("SoftTorch greater equal (soft mode):", st.greater_equal(x, y))
print("\nTorch less:", torch.less(x, y))
print("SoftTorch less (hard mode):", st.less(x, y, mode="hard"))
print("SoftTorch less (soft mode):", st.less(x, y))
print("\nTorch less equal:", torch.less_equal(x, y))
print("SoftTorch less equal (hard mode):", st.less_equal(x, y, mode="hard"))
print("SoftTorch less equal (soft mode):", st.less_equal(x, y))
print("\nTorch eq:", torch.eq(x, y))
print("SoftTorch eq (hard mode):", st.eq(x, y, mode="hard"))
print("SoftTorch eq (soft mode):", st.eq(x, y))
print("\nTorch not equal:", torch.not_equal(x, y))
print("SoftTorch not equal (hard mode):", st.not_equal(x, y, mode="hard"))
print("SoftTorch not equal (soft mode):", st.not_equal(x, y))
print("\nTorch isclose:", torch.isclose(x, y))
print("SoftTorch isclose (hard mode):", st.isclose(x, y, mode="hard"))
print("SoftTorch isclose (soft mode):", st.isclose(x, y))Torch greater: tensor([False, False, False, True])
SoftTorch greater (hard mode): tensor([0., 0., 0., 1.])
SoftTorch greater (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000])
Torch greater equal: tensor([False, False, False, True])
SoftTorch greater equal (hard mode): tensor([0., 0., 0., 1.])
SoftTorch greater equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000])
Torch less: tensor([ True, True, True, False])
SoftTorch less (hard mode): tensor([1., 1., 1., 0.])
SoftTorch less (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000])
Torch less equal: tensor([ True, True, True, False])
SoftTorch less equal (hard mode): tensor([1., 1., 1., 0.])
SoftTorch less equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000])
Torch eq: tensor([False, False, False, False])
SoftTorch eq (hard mode): tensor([0., 0., 0., 0.])
SoftTorch eq (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000])
Torch not equal: tensor([True, True, True, True])
SoftTorch not equal (hard mode): tensor([1., 1., 1., 1.])
SoftTorch not equal (soft mode): tensor([0.9586, 0.9857, 0.6420, 1.0000])
Torch isclose: tensor([False, False, False, False])
SoftTorch isclose (hard mode): tensor([0., 0., 0., 0.])
SoftTorch isclose (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000])
# Logical operators
fuzzy_a = torch.tensor([0.1, 0.2, 0.8, 1.0])
fuzzy_b = torch.tensor([0.7, 0.3, 0.1, 0.9])
bool_a = fuzzy_a >= 0.5
bool_b = fuzzy_b >= 0.5
print("\nTorch AND:", torch.logical_and(bool_a, bool_b))
print("SoftTorch AND:", st.logical_and(fuzzy_a, fuzzy_b))
print("\nTorch OR:", torch.logical_or(bool_a, bool_b))
print("SoftTorch OR:", st.logical_or(fuzzy_a, fuzzy_b))
print("\nTorch NOT:", torch.logical_not(bool_a))
print("SoftTorch NOT:", st.logical_not(fuzzy_a))
print("\nTorch XOR:", torch.logical_xor(bool_a, bool_b))
print("SoftTorch XOR:", st.logical_xor(fuzzy_a, fuzzy_b))
print("\nTorch ALL:", torch.all(bool_a))
print("SoftTorch ALL:", st.all(fuzzy_a))
print("\nTorch ANY:", torch.any(bool_a))
print("SoftTorch ANY:", st.any(fuzzy_a))
# Selection operators
print("\nTorch Where:", torch.where(bool_a, x, y))
print("SoftTorch Where:", st.where(fuzzy_a, x, y))Torch AND: tensor([False, False, False, True])
SoftTorch AND: tensor([0.0700, 0.0600, 0.0800, 0.9000])
Torch OR: tensor([ True, False, True, True])
SoftTorch OR: tensor([0.7300, 0.4400, 0.8200, 1.0000])
Torch NOT: tensor([ True, True, False, False])
SoftTorch NOT: tensor([0.9000, 0.8000, 0.2000, 0.0000])
Torch XOR: tensor([ True, False, True, False])
SoftTorch XOR: tensor([0.6411, 0.3464, 0.7256, 0.1000])
Torch ALL: tensor(False)
SoftTorch ALL: tensor(0.0160)
Torch ANY: tensor(True)
SoftTorch ANY: tensor(1.)
Torch Where: tensor([ 0.2000, -0.5000, 0.3000, 1.0000])
SoftTorch Where: tensor([ 0.1600, -0.6000, 0.3400, 1.0000])
# Straight-through operators: Use hard function on forward and soft on backward
print("Straight-through ReLU:", st.relu_st(x))
print("Straight-through sort:", st.sort_st(x).values)
print("Straight-through argtopk:", st.topk_st(x, k=3).indices)
print("Straight-through greater:", st.greater_st(x, y))
# And many more...Straight-through ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000])
Straight-through sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
Straight-through argtopk: tensor([[0., 0., 0., 1.],
[0., 0., 1., 0.],
[1., 0., 0., 0.]])
Straight-through greater: tensor([0., 0., 0., 1.])
If this library helped your academic work, please consider citing:
@article{paulus2026softjax,
title={{SoftJAX} \& {SoftTorch}: Empowering Automatic Differentiation Libraries with Informative Gradients},
author={Paulus, Anselm and Geist, A.\ Ren\'e and Musil, V\'it and Hoffmann, Sebastian and Beker, Onur and Martius, Georg},
journal={arXiv preprint},
year={2026},
eprint={2603.08824}
}Also consider starring the project on GitHub!
Special thanks and credit go to Patrick Kidger for the awesome JAX repositories that served as the basis for the documentation of this project.
This project is still relatively young, if you have any suggestions for improvement or other feedback, please reach out or raise a GitHub issue!
Differentiable sorting, top-k and rank
DiffSort: Differentiable sorting networks in PyTorch.
DiffTopK: Differentiable top-k in PyTorch.
FastSoftSort: Fast differentiable sorting and ranking in JAX.
Differentiable Top-k with Optimal Transport in JAX.
SoftSort: Differentiable argsort in PyTorch and TensorFlow.
Other
DiffLogic: Differentiable logic gate networks in PyTorch.
SmoothOT: Smooth and Sparse Optimal Transport.
JaxOpt: Differentiable optimization in JAX.
SoftTorch builds on / implements various different algorithms for e.g. differentiable topk, sorting and rank, including:
Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application
Differentiable Ranks and Sorting using Optimal Transport
Differentiable Top-k with Optimal Transport
SoftSort: A Continuous Relaxation for the argsort Operator
Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances
Smooth and Sparse Optimal Transport
Smooth Approximations of the Rounding Function
Fast Differentiable Sorting and Ranking
Differentiable Sorting Networks for Scalable Sorting and Ranking Supervision
Please check the API Documentation for implementation details.