Skip to content

Commit dc5006c

Browse files
committed
Fix type inference failure in norm on structural matrix
1 parent e055009 commit dc5006c

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
4040
end,
4141
# out-of-place versions
4242
@thunk(if isempty(x) || p == 0
43-
zero.(x) .* (zero(y) * zero(real(Δy)))
43+
# Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0)
44+
# only infers down to Union(Diagonal{Float64}, Matrix{Float64})
45+
# rather than Diagonal{Float64}, so we cast it back.
46+
maybe_withsomezeros_rewrap(x, zero.(x) .* (zero(y) * zero(real(Δy))))
4447
elseif p == 2
4548
_norm2_back(x, y, Δy)
4649
elseif p == 1
@@ -72,7 +75,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number})
7275
end
7376
,
7477
@thunk(if isempty(x)
75-
zero.(x) .* (zero(y) * zero(real(Δy)))
78+
# Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0)
79+
# only infers down to Union(Diagonal{Float64}, Matrix{Float64})
80+
# rather than Diagonal{Float64}, so we cast it back.
81+
maybe_withsomezeros_rewrap(x, zero.(x) .* (zero(y) * zero(real(Δy))))
7682
else
7783
_norm2_back(x, y, Δy)
7884
end)
@@ -99,7 +105,7 @@ function rrule(::typeof(norm), x::Number, p::Real)
99105
function norm_pullback(ȳ)
100106
Δy = unthunk(ȳ)
101107
∂x = if iszero(Δy) || iszero(p)
102-
zero(x) * zero(real(Δy))
108+
zero(x) * zero(real(Δy))
103109
else
104110
signx = x isa Real ? sign(x) : x * pinv(y)
105111
signx * real(Δy)

src/rulesets/LinearAlgebra/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ for S in [
5858
:UnitLowerTriangular,
5959
]
6060
@eval withsomezeros_rewrap(::$S, x) = $S(x)
61+
@eval maybe_withsomezeros_rewrap(::$S, x) = $S(x)
6162
end
63+
maybe_withsomezeros_rewrap(::AbstractArray, x) = x
6264

6365
# Bidiagonal, Tridiagonal have more complicated storage.
6466
# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))

0 commit comments

Comments
 (0)