Skip to content

Commit

Permalink
Reland "Support broadcasting over structured block matrices #53909" (#…
Browse files Browse the repository at this point in the history
…54460)

This was reverted in #54332. This
needs #54459 for the tests to
pass. Opening this now to not forget about it.
  • Loading branch information
jishnub committed May 17, 2024
1 parent 1afb580 commit 81cb537
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
6 changes: 4 additions & 2 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()

const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
for ST in Base.uniontypes(StructuredMatrix)
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
@eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}())
end

Expand Down Expand Up @@ -133,6 +133,7 @@ fails as `zero(::Tuple{Int})` is not defined. However,
iszerodefined(::Type) = false
iszerodefined(::Type{<:Number}) = true
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)
iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T)

count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0)

Expand Down Expand Up @@ -160,6 +161,7 @@ fzero(::Type{T}) where T = Some(T)
fzero(r::Ref) = Some(r[])
fzero(t::Tuple{Any}) = Some(only(t))
fzero(S::StructuredMatrix) = Some(zero(eltype(S)))
fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = Some(haszero(T) ? zero(T)*I : nothing)
fzero(x) = nothing
function fzero(bc::Broadcast.Broadcasted)
args = map(fzero, bc.args)
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Random.seed!(1)
struct TypeWithZero end
Base.promote_rule(::Type{TypeWithoutZero}, ::Type{TypeWithZero}) = TypeWithZero
Base.convert(::Type{TypeWithZero}, ::TypeWithoutZero) = TypeWithZero()
Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x))
Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero()
LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero()
LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero
Expand Down
58 changes: 58 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,62 @@ end
@test select_first.(missing, diag) isa Matrix{Missing}
end

@testset "broadcast over structured matrices with matrix elements" begin
function standardbroadcastingtests(D, T)
M = [x for x in D]
Dsum = D .+ D
@test Dsum isa T
@test Dsum == M .+ M
Dcopy = copy.(D)
@test Dcopy isa T
@test Dcopy == D
Df = float.(D)
@test Df isa T
@test Df == D
@test eltype(eltype(Df)) <: AbstractFloat
@test (x -> (x,)).(D) == (x -> (x,)).(M)
@test (x -> 1).(D) == ones(Int,size(D))
@test all(==(2), ndims.(D))
@test_throws MethodError size.(D)
end
@testset "Diagonal" begin
@testset "square" begin
A = [1 3; 2 4]
D = Diagonal([A, A])
standardbroadcastingtests(D, Diagonal)
@test sincos.(D) == sincos.(Matrix{eltype(D)}(D))
M = [x for x in D]
@test cos.(D) == cos.(M)
end

@testset "different-sized square blocks" begin
D = Diagonal([ones(3,3), fill(3.0,2,2)])
standardbroadcastingtests(D, Diagonal)
end

@testset "rectangular blocks" begin
D = Diagonal([ones(Bool,3,4), ones(Bool,2,3)])
standardbroadcastingtests(D, Diagonal)
end

@testset "incompatible sizes" begin
A = reshape(1:12, 4, 3)
B = reshape(1:12, 3, 4)
D1 = Diagonal(fill(A, 2))
D2 = Diagonal(fill(B, 2))
@test_throws DimensionMismatch D1 .+ D2
end
end
@testset "Bidiagonal" begin
A = [1 3; 2 4]
B = Bidiagonal(fill(A,3), fill(A,2), :U)
standardbroadcastingtests(B, Bidiagonal)
end
@testset "UpperTriangular" begin
A = [1 3; 2 4]
U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3])
standardbroadcastingtests(U, UpperTriangular)
end
end

end

0 comments on commit 81cb537

Please sign in to comment.