Skip to content

sigmoidal c0 mode breaks create_graph=True (double backward fails) #2

@seungwoo-ji-03

Description

@seungwoo-ji-03

Bug

sigmoidal(x, mode="c0") with create_graph=True produces a gradient tensor with grad_fn=None, making any further backward pass impossible.

x = torch.randn(5, dtype=torch.float64, requires_grad=True)
y = softtorch.sigmoidal(x, softness=0.1, mode="c0")
g1 = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
g2 = torch.autograd.grad(g1.sum(), x)[0]
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

mode="c1", "c2", "smooth" all work correctly.

Root cause

The c0 polynomial in functions.py:1641-1642 is affine:

y = _polyval(torch.tensor([0.5, 0.5], ...), x)   # 0.5*x + 0.5

Since _polyval computes this via Horner's method (y = 0.5; y = y*x + 0.5), the gradient w.r.t. x is a constant scalar 0.5. PyTorch returns this as a leaf tensor with grad_fn=None, which severs the computation graph.

c1/c2 use higher-order polynomials whose gradients depend on x, so grad_fn is preserved.

Affected functions

Direct failure (8): sigmoidal, sign, heaviside, round, greater, less, greater_equal, less_equal — all in c0 mode.

Functions like abs, softrelu, eq, not_equal happen to survive because they combine sigmoidal with multiplication (product rule keeps grad_fn alive), but this is incidental.

Impact

Any use of create_graph=True with c0 mode crashes — this includes Hessian computation, gradient penalties (WGAN-GP), Neural ODE adjoints, and meta-learning (MAML).

Note: torch.autograd.gradgradcheck does not catch this because the true second derivative is mathematically zero, so the numerical check passes. The bug is purely a graph-connectivity issue.

Suggested fix

Add a zero-valued higher-order term to keep the computation graph alive:

# before
y = _polyval(torch.tensor([0.5, 0.5], ...), x)        # 0.5*x + 0.5

# after
y = _polyval(torch.tensor([0.0, 0.5, 0.5], ...), x)   # 0.0*x² + 0.5*x + 0.5

_polyval evaluates 0.0 * x² via Horner's method, which is enough to make PyTorch track x in the graph. The numerical result is identical.

Verified on

  • macOS (Apple Silicon), PyTorch 2.9.1, Python 3.13.7
  • Ubuntu (RTX 3090), PyTorch 2.9.1+cu128, Python 3.12.11

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions