Skip to content

Commit

Permalink
Use all(map(...)) instead of all_in
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed May 2, 2024
1 parent 9e66d1a commit 541a76d
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,23 +360,15 @@ julia> lmul!(F.Q, B)
"""
lmul!(A, B)

# unroll the in(a, b) computation to enable constant propagation
# This is a 2-valued in implementation that doesn't account for missing values
_in(t::AbstractChar, ::Tuple{}) = false
function _in(t::AbstractChar, chars::Tuple{Vararg{AbstractChar}})
return t == first(chars) || _in(t, Base.tail(chars))
end
all_in(chars, (tA, tB)) = _in(tA, chars) && _in(tB, chars)

# THE one big BLAS dispatch
# aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
# if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
if all_in(('N', 'T', 'C'), map(uppercase, (tA_uc, tB_uc)))
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
if tA_uc == 'T' && tB_uc == 'N' && A === B
return syrk_wrapper!(C, 'T', A, _add)
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
Expand Down Expand Up @@ -407,10 +399,11 @@ end
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
# if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB)))
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B, _add)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
Expand Down Expand Up @@ -453,15 +446,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
!iszero(stride(x, 1)) && # We only check input's stride here.
if _in(tA_uc, ('N', 'T', 'C'))
if tA_uc in ('N', 'T', 'C')
return BLAS.gemv!(tA, alpha, A, x, beta, y)
elseif tA_uc == 'S'
return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y)
elseif tA_uc == 'H'
return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y)
end
end
if _in(tA_uc, ('S', 'H'))
if tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
Expand All @@ -488,7 +481,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
Anew, ta = _in(tA_uc, ('S', 'H')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β))
end
end
Expand All @@ -507,13 +500,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
!iszero(stride(x, 1)) && _in(tA_uc, ('N', 'T', 'C'))
!iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C')
xfl = reinterpret(reshape, T, x) # Use reshape here.
yfl = reinterpret(reshape, T, y)
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
return y
elseif _in(tA_uc, ('S', 'H'))
elseif tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
Expand Down Expand Up @@ -613,10 +606,11 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, T, mA, nB)
# if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB)))
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
Expand Down Expand Up @@ -789,7 +783,7 @@ end
@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
Anew, ta = _in(tA_uc, ('S', 'H')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return _generic_matvecmul!(C, ta, Anew, B, _add)
end

Expand Down

0 comments on commit 541a76d

Please sign in to comment.