Skip to content

Commit

Permalink
Reduce generic matrix*vector latency (#56289)
Browse files Browse the repository at this point in the history
```julia
julia> using LinearAlgebra

julia> A = rand(Int,4,4); x = rand(Int,4); y = similar(x);

julia> @time mul!(y, A, x, 2, 2);
  0.330489 seconds (792.22 k allocations: 41.519 MiB, 8.75% gc time, 99.99% compilation time) # master
  0.134212 seconds (339.89 k allocations: 17.103 MiB, 15.23% gc time, 99.98% compilation time) # This PR
```
Main changes:
- `generic_matvecmul!` and `_generic_matvecmul!` now accept `alpha` and
`beta` arguments instead of `MulAddMul(alpha, beta)`. The methods that
accept a `MulAddMul(alpha, beta)` are also retained for backward
compatibility, but these now forward `alpha` and `beta`, instead of the
other way around.
- Narrow the scope of the `@stable_muladdmul` applications. We now
construct the `MulAddMul(alpha, beta)` object only where it is needed in
a function call, and we annotate the call site with `@stable_muladdmul`.
This leads to smaller branches.
- Create a new internal function with methods for the `'N'`, `'T'` and
`'C'` cases, so that firstly, there's less code duplication, and
secondly, the `_generic_matvecmul!` method is now simple enough to
enable constant propagation. This eliminates the unnecessary branches,
and only the one that is taken is compiled.

Together, this reduces the TTFX substantially.
  • Loading branch information
jishnub authored Oct 23, 2024
1 parent 005608a commit b9b4dfa
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 67 deletions.
126 changes: 60 additions & 66 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,20 @@ _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, tA, A, x, alpha, beta)
generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
gemv!(y, tA, A, x, _add.alpha, _add.beta)

# Real (possibly transposed) matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, tA, A, x, alpha, beta)
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, A, x, _add.alpha, _add.beta)

# Complex matrix times real vector.
# Reinterpret the matrix as a real matrix and do real matvec computation.
# works only in cooperation with BLAS when A is untransposed (tA == 'N')
# but that check is included in gemv! anyway
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, tA, A, x, alpha, beta)
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, A, x, _add.alpha, _add.beta)

# Vector-Matrix multiplication
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
Expand Down Expand Up @@ -539,9 +532,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
if tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β)
else
return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, α, β)
end
end

Expand All @@ -564,7 +557,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
return y
else
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return @stable_muladdmul _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β))
return _generic_matvecmul!(y, ta, Anew, x, α, β)
end
end

Expand All @@ -591,9 +584,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
elseif tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β)
else
return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, α, β)
end
end

Expand Down Expand Up @@ -825,82 +818,83 @@ end
# NOTE: the generic version is also called as fallback for
# strides != 1 cases

Base.@constprop :aggressive generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, alpha::Number, beta::Number) =
@stable_muladdmul generic_matvecmul!(C, tA, A, B, MulAddMul(alpha, beta))
# legacy method, retained for backward compatibility
generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) =
generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
alpha::Number, beta::Number)
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return _generic_matvecmul!(C, ta, Anew, B, _add)
return _generic_matvecmul!(C, ta, Anew, B, alpha, beta)
end

function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
@assert tA in ('N', 'T', 'C')
mB = length(B)
mA, nA = lapack_size(tA, A)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
end
if mA != length(C)
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
end

# legacy method, retained for backward compatibility
_generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) =
_generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
function __generic_matvecmul!(f::F, C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number) where {F}
Astride = size(A, 1)

@inbounds begin
if tA == 'T' # fastest case
if nA == 0
for k = 1:mA
_modify!(_add, false, C, k)
end
else
for k = 1:mA
aoffs = (k-1)*Astride
firstterm = transpose(A[aoffs + 1])*B[1]
s = zero(firstterm + firstterm)
for i = 1:nA
s += transpose(A[aoffs+i]) * B[i]
end
_modify!(_add, s, C, k)
end
end
elseif tA == 'C'
if nA == 0
for k = 1:mA
_modify!(_add, false, C, k)
if length(B) == 0
for k = eachindex(C)
@stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k)
end
else
for k = 1:mA
for k = eachindex(C)
aoffs = (k-1)*Astride
firstterm = A[aoffs + 1]'B[1]
firstterm = f(A[aoffs + 1]) * B[1]
s = zero(firstterm + firstterm)
for i = 1:nA
s += A[aoffs + i]'B[i]
for i = eachindex(B)
s += f(A[aoffs+i]) * B[i]
end
_modify!(_add, s, C, k)
@stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k)
end
end
else # tA == 'N'
for i = 1:mA
if !iszero(_add.beta)
C[i] *= _add.beta
elseif mB == 0
end
end
function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number)
Astride = size(A, 1)
@inbounds begin
for i = eachindex(C)
if !iszero(beta)
C[i] *= beta
elseif length(B) == 0
C[i] = false
else
C[i] = zero(A[i]*B[1] + A[i]*B[1])
end
end
for k = 1:mB
for k = eachindex(B)
aoffs = (k-1)*Astride
b = _add(B[k])
for i = 1:mA
b = @stable_muladdmul MulAddMul(alpha,beta)(B[k])
for i = eachindex(C)
C[i] += A[aoffs + i] * b
end
end
end
end # @inbounds
return C
end
function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number)
require_one_based_indexing(C, A, B)
@assert tA in ('N', 'T', 'C')
mB = length(B)
mA, nA = lapack_size(tA, A)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
end
if mA != length(C)
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
end

if tA == 'T' # fastest case
__generic_matvecmul!(transpose, C, A, B, alpha, beta)
elseif tA == 'C'
__generic_matvecmul!(adjoint, C, A, B, alpha, beta)
else # tA == 'N'
__generic_matvecmul!(identity, C, A, B, alpha, beta)
end
C
end

Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ for TC in (:AbstractVector, :AbstractMatrix)
if isone(alpha) && iszero(beta)
return _trimul!(C, A, B)
else
return @stable_muladdmul generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta))
return _generic_matvecmul!(C, 'N', A, B, alpha, beta)
end
end
end
Expand Down

0 comments on commit b9b4dfa

Please sign in to comment.