Skip to content

Commit

Permalink
Reduce matmul latency by splitting small matmul (#54421)
Browse files Browse the repository at this point in the history
This splits the `matmul2x2` and `matmul3x3` into components that depend
on `MulAddMul` and those that don't depend on it. This improves
compilation time, as the `MulAddMul`-independent methods won't need to
be recompiled in the `@stable_muladdmul` branches.

TTFX (each call timed in a separate session):
```julia
julia> using LinearAlgebra

julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2);

julia> @time mul!(C, A, B);
  1.927468 seconds (5.67 M allocations: 282.523 MiB, 12.09% gc time, 100.00% compilation time) # nightly v"1.12.0-DEV.492"
  1.282717 seconds (4.46 M allocations: 228.816 MiB, 4.58% gc time, 100.00% compilation time) # This PR

julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2);

julia> @time mul!(C, A, B);
  1.653368 seconds (5.75 M allocations: 291.586 MiB, 13.94% gc time, 100.00% compilation time) # nightly
  1.148330 seconds (4.46 M allocations: 230.714 MiB, 4.47% gc time, 100.00% compilation time) # This PR
```

Edit: Not inlining the function seems to incur a runtime perfomance
cost.
```julia
julia> using LinearAlgebra

julia> A = rand(3,3); B = rand(size(A)...); C = zeros(size(A));

julia> @Btime mul!($C, $A, $B);
  23.923 ns (0 allocations: 0 bytes) # nightly
  31.732 ns (0 allocations: 0 bytes) # This PR
```
Adding `@inline` annotations resolves this difference, but this
reintroduces the compilation latency. The tradeoff is perhaps ok, as
users may use `StaticArrays` for performance-critical matrix
multiplications.
  • Loading branch information
jishnub authored May 12, 2024
1 parent 25c8128 commit 5006312
Showing 1 changed file with 89 additions and 110 deletions.
199 changes: 89 additions & 110 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -930,164 +930,143 @@ end


# multiply 2x2 matrices
function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
Base.@constprop :aggressive function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
function __matmul_checks(C, A, B, sz)
require_one_based_indexing(C, A, B)
if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
if !(size(A) == size(B) == size(C) == (2,2))
if !(size(A) == size(B) == size(C) == sz)
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
end
return nothing
end

# separate function with the core of matmul2x2! that doesn't depend on a MulAddMul
Base.@constprop :aggressive function _matmul2x2_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
__matmul_checks(C, A, B, (2,2))
__matmul2x2_elements(tA, tB, A, B)
end
Base.@constprop :aggressive function __matmul2x2_elements(tA, A::AbstractMatrix)
@inbounds begin
if tA == 'N'
tA_uc = uppercase(tA) # possibly unwrap a WrapperChar
if tA_uc == 'N'
A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
elseif tA == 'T'
elseif tA_uc == 'T'
# TODO making these lazy could improve perf
A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1]))
A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2]))
elseif tA == 'C'
elseif tA_uc == 'C'
# TODO making these lazy could improve perf
A11 = copy(A[1,1]'); A12 = copy(A[2,1]')
A21 = copy(A[1,2]'); A22 = copy(A[2,2]')
elseif tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U)
elseif tA == 's'
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L)
elseif tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
end
if tB == 'N'
B11 = B[1,1]; B12 = B[1,2];
B21 = B[2,1]; B22 = B[2,2]
elseif tB == 'T'
# TODO making these lazy could improve perf
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1]))
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2]))
elseif tB == 'C'
# TODO making these lazy could improve perf
B11 = copy(B[1,1]'); B12 = copy(B[2,1]')
B21 = copy(B[1,2]'); B22 = copy(B[2,2]')
elseif tB == 'S'
B11 = symmetric(B[1,1], :U); B12 = B[1,2]
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U)
elseif tB == 's'
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1]))
B21 = B[2,1]; B22 = symmetric(B[2,2], :L)
elseif tB == 'H'
B11 = hermitian(B[1,1], :U); B12 = B[1,2]
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U)
else # if tB == 'h'
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1]))
B21 = B[2,1]; B22 = hermitian(B[2,2], :L)
elseif tA_uc == 'S'
if isuppercase(tA) # tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U)
else
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L)
end
elseif tA_uc == 'H'
if isuppercase(tA) # tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
end
end
end # inbounds
A11, A12, A21, A22
end
Base.@constprop :aggressive __matmul2x2_elements(tA, tB, A, B) = __matmul2x2_elements(tA, A), __matmul2x2_elements(tB, B)

Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
(A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements(C, tA, tB, A, B)
@inbounds begin
_modify!(_add, A11*B11 + A12*B21, C, (1,1))
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
_modify!(_add, A21*B11 + A22*B21, C, (2,1))
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
_modify!(_add, A21*B12 + A22*B22, C, (2,2))
end # inbounds
C
end

# Multiply 3x3 matrices
function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
Base.@constprop :aggressive function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
if !(size(A) == size(B) == size(C) == (3,3))
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
end
# separate function with the core of matmul3x3! that doesn't depend on a MulAddMul
Base.@constprop :aggressive function _matmul3x3_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
__matmul_checks(C, A, B, (3,3))
__matmul3x3_elements(tA, tB, A, B)
end
Base.@constprop :aggressive function __matmul3x3_elements(tA, A::AbstractMatrix)
@inbounds begin
if tA == 'N'
tA_uc = uppercase(tA) # possibly unwrap a WrapperChar
if tA_uc == 'N'
A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3]
A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3]
A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3]
elseif tA == 'T'
elseif tA_uc == 'T'
# TODO making these lazy could improve perf
A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])); A23 = copy(transpose(A[3,2]))
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = copy(transpose(A[3,3]))
elseif tA == 'C'
elseif tA_uc == 'C'
# TODO making these lazy could improve perf
A11 = copy(A[1,1]'); A12 = copy(A[2,1]'); A13 = copy(A[3,1]')
A21 = copy(A[1,2]'); A22 = copy(A[2,2]'); A23 = copy(A[3,2]')
A31 = copy(A[1,3]'); A32 = copy(A[2,3]'); A33 = copy(A[3,3]')
elseif tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3]
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U)
elseif tA == 's'
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L)
elseif tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3]
A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
elseif tA_uc == 'S'
if isuppercase(tA) # tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3]
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U)
else
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L)
end
elseif tA_uc == 'H'
if isuppercase(tA) # tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3]
A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
end
end
end # inbounds
A11, A12, A13, A21, A22, A23, A31, A32, A33
end
Base.@constprop :aggressive __matmul3x3_elements(tA, tB, A, B) = __matmul3x3_elements(tA, A), __matmul3x3_elements(tB, B)

if tB == 'N'
B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3]
B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3]
B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3]
elseif tB == 'T'
# TODO making these lazy could improve perf
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2]))
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = copy(transpose(B[3,3]))
elseif tB == 'C'
# TODO making these lazy could improve perf
B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]')
B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]')
B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]')
elseif tB == 'S'
B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3]
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U)
elseif tB == 's'
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2]))
B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L)
elseif tB == 'H'
B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3]
B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U)
else # if tB == 'h'
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1]))
B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2]))
B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L)
end
Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())

_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
(A11, A12, A13, A21, A22, A23, A31, A32, A33),
(B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements(C, tA, tB, A, B)

@inbounds begin
_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
_modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1))
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))

_modify!(_add, A31*B11 + A32*B21 + A33*B31, C, (3,1))

_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
_modify!(_add, A31*B12 + A32*B22 + A33*B32, C, (3,2))

_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))
_modify!(_add, A31*B13 + A32*B23 + A33*B33, C, (3,3))
end # inbounds
C
Expand Down

0 comments on commit 5006312

Please sign in to comment.