-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Bug
sigmoidal(x, mode="c0") with create_graph=True produces a gradient tensor with grad_fn=None, making any further backward pass impossible.
x = torch.randn(5, dtype=torch.float64, requires_grad=True)
y = softtorch.sigmoidal(x, softness=0.1, mode="c0")
g1 = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
g2 = torch.autograd.grad(g1.sum(), x)[0]
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fnmode="c1", "c2", "smooth" all work correctly.
Root cause
The c0 polynomial in functions.py:1641-1642 is affine:
y = _polyval(torch.tensor([0.5, 0.5], ...), x) # 0.5*x + 0.5Since _polyval computes this via Horner's method (y = 0.5; y = y*x + 0.5), the gradient w.r.t. x is a constant scalar 0.5. PyTorch returns this as a leaf tensor with grad_fn=None, which severs the computation graph.
c1/c2 use higher-order polynomials whose gradients depend on x, so grad_fn is preserved.
Affected functions
Direct failure (8): sigmoidal, sign, heaviside, round, greater, less, greater_equal, less_equal — all in c0 mode.
Functions like abs, softrelu, eq, not_equal happen to survive because they combine sigmoidal with multiplication (product rule keeps grad_fn alive), but this is incidental.
Impact
Any use of create_graph=True with c0 mode crashes — this includes Hessian computation, gradient penalties (WGAN-GP), Neural ODE adjoints, and meta-learning (MAML).
Note: torch.autograd.gradgradcheck does not catch this because the true second derivative is mathematically zero, so the numerical check passes. The bug is purely a graph-connectivity issue.
Suggested fix
Add a zero-valued higher-order term to keep the computation graph alive:
# before
y = _polyval(torch.tensor([0.5, 0.5], ...), x) # 0.5*x + 0.5
# after
y = _polyval(torch.tensor([0.0, 0.5, 0.5], ...), x) # 0.0*x² + 0.5*x + 0.5_polyval evaluates 0.0 * x² via Horner's method, which is enough to make PyTorch track x in the graph. The numerical result is identical.
Verified on
- macOS (Apple Silicon), PyTorch 2.9.1, Python 3.13.7
- Ubuntu (RTX 3090), PyTorch 2.9.1+cu128, Python 3.12.11