From fb25c1f512433522ff3f4034a61052c80f562d3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hakkel=20Tam=C3=A1s?= Date: Tue, 19 Nov 2019 14:39:23 +0100 Subject: [PATCH] multiplication by scalar --- src/BuildCompTree.jl | 38 +++++++++++++++++++++----------------- test/functionality.jl | 8 ++++++-- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/BuildCompTree.jl b/src/BuildCompTree.jl index 8f88387..ad61385 100644 --- a/src/BuildCompTree.jl +++ b/src/BuildCompTree.jl @@ -14,14 +14,14 @@ Base.:+(FO1::FunOp, FO2::FunOp) = begin FunctionOperatorComposite(FO1, FO2, :+) end -Base.:+(FO1::FunOp, S::LinearAlgebra.UniformScaling) = begin - assertAddDimScaling(FO1, S) - FunctionOperatorComposite(FO1, createScalingForAddSub(FO1, S), :+) +Base.:+(FO::FunOp, S::LinearAlgebra.UniformScaling) = begin + assertAddDimScaling(FO, S) + FunctionOperatorComposite(FO, createScalingForAddSub(FO, S), :+) end -Base.:+(S::LinearAlgebra.UniformScaling, FO2::FunOp) = begin - assertAddDimScaling(FO2, S) - FunctionOperatorComposite(createScalingForAddSub(FO2, S), FO2, :+) +Base.:+(S::LinearAlgebra.UniformScaling, FO::FunOp) = begin + assertAddDimScaling(FO, S) + FunctionOperatorComposite(createScalingForAddSub(FO, S), FO, :+) end Base.:-(FO1::FunOp, FO2::FunOp) = begin @@ -30,14 +30,14 @@ Base.:-(FO1::FunOp, FO2::FunOp) = begin FunctionOperatorComposite(FO1, FO2, :-) end -Base.:-(FO1::FunOp, S::LinearAlgebra.UniformScaling) = begin - assertAddDimScaling(FO1, S) - FunctionOperatorComposite(FO1, createScalingForAddSub(FO1, S), :-) +Base.:-(FO::FunOp, S::LinearAlgebra.UniformScaling) = begin + assertAddDimScaling(FO, S) + FunctionOperatorComposite(FO, createScalingForAddSub(FO, S), :-) end -Base.:-(S::LinearAlgebra.UniformScaling, FO2::FunOp) = begin - assertAddDimScaling(FO2, S) - FunctionOperatorComposite(createScalingForAddSub(FO2, S), FO2, :-) +Base.:-(S::LinearAlgebra.UniformScaling, FO::FunOp) = begin + assertAddDimScaling(FO, S) + FunctionOperatorComposite(createScalingForAddSub(FO, S), FO, :-) end Base.:*(FO1::FunOp, FO2::FunOp) = begin @@ -52,17 +52,21 @@ Base.:*(FO::FunctionOperator, S::LinearAlgebra.UniformScaling{Bool}) = Base.:*(FO::FunctionOperatorComposite, S::LinearAlgebra.UniformScaling{Bool}) = FunctionOperatorComposite(FO, name = getName(FO) * " * I") -Base.:*(FO1::FunOp, S::LinearAlgebra.UniformScaling) = - FunctionOperatorComposite(FO1, createScalingForMult(FO1, S, FO1.inDims), :*) - Base.:*(S::LinearAlgebra.UniformScaling{Bool}, FO::FunctionOperator) = FunctionOperator(FO, name = "I * " * getName(FO)) Base.:*(S::LinearAlgebra.UniformScaling{Bool}, FO::FunctionOperatorComposite) = FunctionOperatorComposite(FO, name = "I * " * getName(FO)) -Base.:*(S::LinearAlgebra.UniformScaling, FO2::FunOp) = - FunctionOperatorComposite(createScalingForMult(FO2, S, FO2.outDims), FO2, :*) +Base.:*(FO::FunOp, S::LinearAlgebra.UniformScaling) = + FunctionOperatorComposite(FO, createScalingForMult(FO, S, FO.inDims), :*) + +Base.:*(S::LinearAlgebra.UniformScaling, FO::FunOp) = + FunctionOperatorComposite(createScalingForMult(FO, S, FO.outDims), FO, :*) + +Base.:*(FO::FunOp, λ::Number) = FO * (λ*I) + +Base.:*(λ::Number, FO::FunOp) = (λ*I) * FO # Adjoint operator creates a new FunctionOperatorComposite object, toggles the adjoint field and # switches the input and output dimension constraints (and also voids plan for FunctionOperatorComposite) diff --git a/test/functionality.jl b/test/functionality.jl index 4165d53..88eeaa0 100644 --- a/test/functionality.jl +++ b/test/functionality.jl @@ -35,6 +35,10 @@ using FunctionOperators, LinearAlgebra, Test @test Op₁' * I * (ones(10,10)*8) == ones(10,10)*2 @test I * Op₁ * (ones(10,10)*2) == ones(10,10)*8 @test I * Op₁' * (ones(10,10)*8) == ones(10,10)*2 + @test Op₁ * 2 * (ones(10,10)*2) == ones(10,10)*64 + @test Op₁' * 2 * (ones(10,10)*4) == ones(10,10)*2 + @test 2 * Op₁ * (ones(10,10)*2) == ones(10,10)*16 + @test 2 * Op₁' * (ones(10,10)*8) == ones(10,10)*4 @test Op₃ * Op₁ * (ones(10,10)*2) == ones(10,10)*8 .* w @test (Op₃ * Op₁)' * (ones(10,10)*8 .* w) == Op₁' * Op₃' * (ones(10,10)*8 .* w) @test (Op₃ * Op₁)' * (ones(10,10)*8 .* w) == ones(10,10)*2 @@ -62,9 +66,9 @@ using FunctionOperators, LinearAlgebra, Test manual_tests() end @testset "With automatic reshape" begin - FO_settings.auto_reshape = true + FunctionOperators_global_settings.auto_reshape = true manual_tests() - FO_settings.auto_reshape = false + FunctionOperators_global_settings.auto_reshape = false end end @testset "Adjoint of addition/substraction" begin