@@ -40,7 +40,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
40
40
end ,
41
41
# out-of-place versions
42
42
@thunk (if isempty (x) || p == 0
43
- zero .(x) .* (zero (y) * zero (real (Δy)))
43
+ # Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0)
44
+ # only infers down to Union(Diagonal{Float64}, Matrix{Float64})
45
+ # rather than Diagonal{Float64}, so we cast it back.
46
+ maybe_withsomezeros_rewrap (x, zero .(x) .* (zero (y) * zero (real (Δy))))
44
47
elseif p == 2
45
48
_norm2_back (x, y, Δy)
46
49
elseif p == 1
@@ -72,7 +75,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number})
72
75
end
73
76
,
74
77
@thunk (if isempty (x)
75
- zero .(x) .* (zero (y) * zero (real (Δy)))
78
+ # Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0)
79
+ # only infers down to Union(Diagonal{Float64}, Matrix{Float64})
80
+ # rather than Diagonal{Float64}, so we cast it back.
81
+ maybe_withsomezeros_rewrap (x, zero .(x) .* (zero (y) * zero (real (Δy))))
76
82
else
77
83
_norm2_back (x, y, Δy)
78
84
end )
@@ -99,7 +105,7 @@ function rrule(::typeof(norm), x::Number, p::Real)
99
105
function norm_pullback (ȳ)
100
106
Δy = unthunk (ȳ)
101
107
∂x = if iszero (Δy) || iszero (p)
102
- zero (x) * zero (real (Δy))
108
+ zero (x) * zero (real (Δy))
103
109
else
104
110
signx = x isa Real ? sign (x) : x * pinv (y)
105
111
signx * real (Δy)
0 commit comments