Skip to content

Commit

Permalink
alpha,beta instead of MulAddMul in _generic_matmatmul!
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 23, 2024
1 parent cbdea0a commit 05452f4
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.HEMM})
end
end
Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.NONE})
@stable_muladdmul _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(alpha, beta))
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end

# legacy method
Expand All @@ -540,8 +540,8 @@ function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::S
gemm_wrapper!(C, tA, tB, A, B, α, β)
end
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{false}) where {T<:BlasReal}
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal}
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end
# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
Expand Down Expand Up @@ -743,7 +743,7 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B, true, false)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul())
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), true, false)
end
end

Expand All @@ -770,7 +770,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
_fullstride2(A) && _fullstride2(B) && _fullstride2(C))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
end
# legacy method
gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
Expand Down Expand Up @@ -805,7 +805,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
end
# legacy method
gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
Expand Down Expand Up @@ -975,12 +975,16 @@ end
# aggressive const prop makes mixed eltype mul!(C, A, B) invoke _generic_matmatmul! directly
# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul()) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha::Number, beta::Number) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)

@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
_add::MulAddMul{ais1}) where {T,S,R,ais1}
# legacy method
_generic_matmatmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
_generic_matmatmul!(C, A, B, _add.alpha, _add.beta)

@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha::Number, beta::Number) where {R}
AxM = axes(A, 1)
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
BxK = axes(B, 1)
Expand All @@ -996,34 +1000,33 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
if BxN != CxN
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
end
_rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false)
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
_rmul_or_fill!(C, beta)
(iszero(alpha) || isempty(A) || isempty(B)) && return C
@inbounds for n in BxN, k in BxK
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
Balpha = _rmul_alpha(B[k,n])
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
@simd for m in AxM
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
end
end
elseif isbitstype(R) && sizeof(R) 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
_rmul_or_fill!(C, beta)
(iszero(alpha) || isempty(A) || isempty(B)) && return C
t = wrapperop(A)
pB = parent(B)
pA = parent(A)
tmp = similar(C, CxN)
ci = first(CxM)
ta = t(_add.alpha)
ta = t(alpha)
for i in AxM
mul!(tmp, pB, view(pA, :, i))
@views C[ci,:] .+= t.(ta .* tmp)
ci += 1
end
else
if iszero(_add.alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, _add.beta)
if iszero(alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, beta)
end
a1 = first(AxK)
b1 = first(BxK)
Expand All @@ -1033,7 +1036,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
@simd for k in AxK
Ctmp = muladd(A[i, k], B[k, j], Ctmp)
end
_modify!(_add, Ctmp, C, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), Ctmp, C, (i,j))
end
end
return C
Expand Down

0 comments on commit 05452f4

Please sign in to comment.