diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 76f776a..c483152 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: matrix: version: - '1' - - '1.6' + - '1.10' os: - ubuntu-latest - macOS-latest diff --git a/.gitignore b/.gitignore index 788274b..899cd07 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ demos/.ipynb_checkpoints/ docs/build/ docs/site/ docs/Manifest.toml + +Manifest.toml diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..cb77f89 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,984 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.12.1" +manifest_format = "2.0" +project_hash = "c8f5f45579604b7204fcaa029c0a41ea02d98e72" + +[[deps.ADTypes]] +git-tree-sha1 = "27cecae79e5cc9935255f90c53bb831cc3c870d7" +uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +version = "1.18.0" + + [deps.ADTypes.extensions] + ADTypesChainRulesCoreExt = "ChainRulesCore" + ADTypesConstructionBaseExt = "ConstructionBase" + ADTypesEnzymeCoreExt = "EnzymeCore" + + [deps.ADTypes.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + + [deps.AbstractFFTs.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.AbstractOperators]] +deps = ["FastBroadcast", "LinearAlgebra", "OperatorCore", "Polyester", "Random", "RecursiveArrayTools"] +path = "../AbstractOperators" +uuid = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +version = "0.4.0" + + [deps.AbstractOperators.extensions] + GpuExt = "GPUArrays" + LinearMapsExt = "LinearMaps" + + [deps.AbstractOperators.weakdeps] + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "MacroTools"] +git-tree-sha1 = "3b86719127f50670efe356bc11073d84b4ed7a5d" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.42" + + [deps.Accessors.extensions] + AxisKeysExt = "AxisKeys" + IntervalSetsExt = "IntervalSets" + LinearAlgebraExt = "LinearAlgebra" + StaticArraysExt = "StaticArrays" + StructArraysExt = "StructArrays" + TestExt = "Test" + UnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "7e35fca2bdfba44d797c53dfe63a51fabf39bfc0" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.4.0" + + [deps.Adapt.extensions] + AdaptSparseArraysExt = "SparseArrays" + AdaptStaticArraysExt = "StaticArrays" + + [deps.Adapt.weakdeps] + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.2" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra"] +git-tree-sha1 = "d81ae5489e13bc03567d4fbbb06c546a5e53c857" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.22.0" + + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceCUDSSExt = ["CUDSS", "CUDA"] + ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" + ArrayInterfaceChainRulesExt = "ChainRules" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceMetalExt = "Metal" + ArrayInterfaceReverseDiffExt = "ReverseDiff" + ArrayInterfaceSparseArraysExt = "SparseArrays" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" + +[[deps.BenchmarkTools]] +deps = ["Compat", "JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] +git-tree-sha1 = "7fecfb1123b8d0232218e2da0c213004ff15358d" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.6.3" + +[[deps.Bessels]] +git-tree-sha1 = "4435559dc39793d53a9e3d278e185e920b4619ef" +uuid = "0e736298-9ec6-45e8-9647-e4fc86a2fe38" +version = "0.2.8" + +[[deps.BitTwiddlingConvenienceFunctions]] +deps = ["Static"] +git-tree-sha1 = "f21cfd4950cb9f0587d5067e69405ad2acd27b87" +uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" +version = "0.1.6" + +[[deps.Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1b96ea4a01afe0ea4090c5c8039690672dd13f2e" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.9+0" + +[[deps.CPUSummary]] +deps = ["CpuId", "IfElse", "PrecompileTools", "Preferences", "Static"] +git-tree-sha1 = "f3a21d7fc84ba618a779d1ed2fcca2e682865bab" +uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +version = "0.2.7" + +[[deps.CloseOpenIntervals]] +deps = ["Static", "StaticArrayInterface"] +git-tree-sha1 = "05ba0d07cd4fd8b7a39541e31a7b0254704ea581" +uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" +version = "0.1.13" + +[[deps.CodecBzip2]] +deps = ["Bzip2_jll", "TranscodingStreams"] +git-tree-sha1 = "84990fa864b7f2b4901901ca12736e45ee79068c" +uuid = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" +version = "0.8.5" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "962834c22b66e32aa10f7611c08c8ca4e20749a9" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.8" + +[[deps.Combinatorics]] +git-tree-sha1 = "8010b6bb3388abe68d95743dcbea77650bb2eddf" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.3" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools"] +git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.1" + +[[deps.CommonWorldInvalidations]] +git-tree-sha1 = "ae52d1c52048455e85a387fbee9be553ec2b68d0" +uuid = "f70d9fcc-98c5-4d4a-abd7-e4cdeebd8ca8" +version = "1.0.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "9d8a54ce4b17aa5bdce0ea5c34bc5e7c340d16ad" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.18.1" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.3.0+1" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConstructionBase]] +git-tree-sha1 = "b4b092499347b18a015186eae3042f72267106cb" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.6.0" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.CpuId]] +deps = ["Markdown"] +git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" +uuid = "adafc99b-e345-5852-983c-f28acb93d879" +version = "0.3.1" + +[[deps.DSP]] +deps = ["Bessels", "FFTW", "IterTools", "LinearAlgebra", "Polynomials", "Random", "Reexport", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "5989debfc3b38f736e69724818210c67ffee4352" +uuid = "717857b8-e6f2-59f4-9121-6e50c889abd2" +version = "0.8.4" + + [deps.DSP.extensions] + OffsetArraysExt = "OffsetArrays" + + [deps.DSP.weakdeps] + OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" + +[[deps.DSPOperators]] +deps = ["AbstractOperators", "DSP", "FFTW", "LinearAlgebra"] +path = "../AbstractOperators/DSPOperators" +uuid = "d5a72628-6e2f-430e-82f5-561df0bb8116" +version = "0.1.0" + +[[deps.DataStructures]] +deps = ["OrderedCollections"] +git-tree-sha1 = "e357641bb3e0638d353c4b29ea0e40ea644066a6" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.19.3" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.DifferentiationInterface]] +deps = ["ADTypes", "LinearAlgebra"] +git-tree-sha1 = "c8d85ecfcbaef899308706bebdd8b00107f3fb43" +uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +version = "0.6.54" + + [deps.DifferentiationInterface.extensions] + DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" + DifferentiationInterfaceDiffractorExt = "Diffractor" + DifferentiationInterfaceEnzymeExt = ["EnzymeCore", "Enzyme"] + DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation" + DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" + DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" + DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] + DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" + DifferentiationInterfaceGTPSAExt = "GTPSA" + DifferentiationInterfaceMooncakeExt = "Mooncake" + DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] + DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] + DifferentiationInterfaceSparseArraysExt = "SparseArrays" + DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" + DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" + DifferentiationInterfaceStaticArraysExt = "StaticArrays" + DifferentiationInterfaceSymbolicsExt = "Symbolics" + DifferentiationInterfaceTrackerExt = "Tracker" + DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] + + [deps.DifferentiationInterface.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" + Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" + FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" + FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" + Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" + PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" + SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.DocStringExtensions]] +git-tree-sha1 = "7442a5dfe1ebb773c29cc2962a8980f47221d76c" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.5" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FFTW]] +deps = ["AbstractFFTs", "FFTW_jll", "Libdl", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] +git-tree-sha1 = "97f08406df914023af55ade2f843c39e99c5d969" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "1.10.0" + +[[deps.FFTWOperators]] +deps = ["AbstractOperators", "FFTW", "LinearAlgebra", "Polyester"] +path = "../AbstractOperators/FFTWOperators" +uuid = "c59a084b-ba08-4f3f-af9e-f4298d6caa94" +version = "0.1.0" + +[[deps.FFTW_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6d6219a004b8cf1e0b4dbe27a2860b8e04eba0be" +uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" +version = "3.3.11+0" + +[[deps.FastBroadcast]] +deps = ["ArrayInterface", "LinearAlgebra", "Polyester", "Static", "StaticArrayInterface", "StrideArraysCore"] +git-tree-sha1 = "ab1b34570bcdf272899062e1a56285a53ecaae08" +uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +version = "0.3.5" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "ba6ce081425d0afb2bedd00d9884464f764a9225" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "1.2.2" + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + + [deps.ForwardDiff.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "83cf05ab16a73219e5f6bd1bdfa9848fa24ac627" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.2.0" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.IntelOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl"] +git-tree-sha1 = "ec1debd61c300961f98064cfb21287613ad7f303" +uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" +version = "2025.2.0+0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" + +[[deps.InverseFunctions]] +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.17" +weakdeps = ["Dates", "Test"] + + [deps.InverseFunctions.extensions] + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "b2d91fe939cae05960e760110b328288867b5758" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.6" + +[[deps.IterTools]] +git-tree-sha1 = "42d5f897009e7ff2cf88db414a389e5ed1bdd023" +uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +version = "1.10.0" + +[[deps.IterativeSolvers]] +deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] +git-tree-sha1 = "59545b0a2b27208b0650df0a46b8e3019f85055b" +uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" +version = "0.9.4" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "0533e564aae234aff59ab625543145446d8b6ec2" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.7.1" + +[[deps.JSON]] +deps = ["Dates", "Logging", "Parsers", "PrecompileTools", "StructUtils", "UUIDs", "Unicode"] +git-tree-sha1 = "eb04df293213df64ddd720c86de3c431f5f8ccf1" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "1.2.1" + + [deps.JSON.extensions] + JSONArrowExt = ["ArrowTypes"] + + [deps.JSON.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "411eccfe8aba0814ffa0fdf4860913ed09c34975" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.3" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JuliaSyntaxHighlighting]] +deps = ["StyledStrings"] +uuid = "ac6e5ff7-fb65-4e79-a425-ec3bc9c03011" +version = "1.12.0" + +[[deps.LayoutPointers]] +deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] +git-tree-sha1 = "a9eaadb366f5493a5654e843864c13d8b107548c" +uuid = "10f19ff3-798f-405d-979b-55457f8fc047" +version = "0.1.17" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.11.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.11.1+1" + +[[deps.LibGit2]] +deps = ["LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.9.0+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "OpenSSL_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.3+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.12.0" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.29" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" + +[[deps.MKL_jll]] +deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "oneTBB_jll"] +git-tree-sha1 = "282cadc186e7b2ae0eeadbd7a4dffed4196ae2aa" +uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +version = "2025.2.0+0" + +[[deps.MacroTools]] +git-tree-sha1 = "1e0228a030642014fe5cfe68c2c0a818f9e3f522" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.16" + +[[deps.ManualMemory]] +git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" +uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" +version = "0.1.8" + +[[deps.Markdown]] +deps = ["Base64", "JuliaSyntaxHighlighting", "StyledStrings"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" + +[[deps.MathOptInterface]] +deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON3", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "PrecompileTools", "Printf", "SparseArrays", "SpecialFunctions", "Test"] +git-tree-sha1 = "a2cbab4256690aee457d136752c404e001f27768" +uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +version = "1.46.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2025.5.20" + +[[deps.MutableArithmetics]] +deps = ["LinearAlgebra", "SparseArrays", "Test"] +git-tree-sha1 = "22df8573f8e7c593ac205455ca088989d0a2c7a0" +uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" +version = "1.6.7" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "9b8215b1ee9e78a293f99797cd31375471b2bcae" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.1.3" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.3.0" + +[[deps.OSQP]] +deps = ["Libdl", "LinearAlgebra", "MathOptInterface", "OSQP_jll", "SparseArrays"] +git-tree-sha1 = "50faf456a64ac1ca097b78bcdf288d94708adcdd" +uuid = "ab2f91bb-94b4-55e3-9ba0-7f65df51de79" +version = "0.8.1" + +[[deps.OSQP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "d0f73698c33e04e557980a06d75c2d82e3f0eb49" +uuid = "9c4f68bf-6205-5545-a508-2878b064d984" +version = "0.600.200+0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.29+0" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.7+0" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.5.1+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1346c9208249809840c91b26703912dff463d335" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.6+0" + +[[deps.OperatorCore]] +path = "../OperatorCore" +uuid = "3945cd23-d97e-4db0-9df2-35342dbd287d" +version = "0.1.1" + +[[deps.OrderedCollections]] +git-tree-sha1 = "05868e21324cede2207c6f0f466b4bfef6d5e7ee" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.8.1" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "7d2f8f21da5db6a806faf7b9b292296da42b2810" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.3" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.12.0" + + [deps.Pkg.extensions] + REPLExt = "REPL" + + [deps.Pkg.weakdeps] + REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Polyester]] +deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Static", "StaticArrayInterface", "StrideArraysCore", "ThreadingUtilities"] +git-tree-sha1 = "6f7cd22a802094d239824c57d94c8e2d0f7cfc7d" +uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" +version = "0.7.18" + +[[deps.PolyesterWeave]] +deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] +git-tree-sha1 = "645bed98cd47f72f67316fd42fc47dee771aefcd" +uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" +version = "0.2.2" + +[[deps.Polynomials]] +deps = ["LinearAlgebra", "OrderedCollections", "RecipesBase", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "972089912ba299fba87671b025cd0da74f5f54f7" +uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" +version = "4.1.0" + + [deps.Polynomials.extensions] + PolynomialsChainRulesCoreExt = "ChainRulesCore" + PolynomialsFFTWExt = "FFTW" + PolynomialsMakieExt = "Makie" + PolynomialsMutableArithmeticsExt = "MutableArithmetics" + + [deps.Polynomials.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" + MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "07a921781cab75691315adc645096ed5e370cb77" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.3.3" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "0f27480397253da18fe2c12a4ba4eb9eb208bf3d" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.5.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" + +[[deps.Profile]] +deps = ["StyledStrings"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +version = "1.11.0" + +[[deps.ProximalAlgorithms]] +deps = ["ADTypes", "DifferentiationInterface", "LinearAlgebra", "OperatorCore", "Printf", "ProximalCore"] +path = "../ProximalAlgorithms.jl" +uuid = "140ffc9f-1907-541a-a177-7475e0a401e9" +version = "0.8.0" + +[[deps.ProximalCore]] +deps = ["LinearAlgebra"] +path = "../ProximalCore.jl" +uuid = "dc4f5ac2-75d1-4f31-931e-60435d74994b" +version = "0.2.0" + +[[deps.ProximalOperators]] +deps = ["IterativeSolvers", "LinearAlgebra", "OSQP", "ProximalCore", "SparseArrays", "SuiteSparse", "TSVD"] +path = "../ProximalOperators.jl" +uuid = "a725b495-10eb-56fe-b38b-717eba820537" +version = "0.17.0" +weakdeps = ["RecursiveArrayTools"] + + [deps.ProximalOperators.extensions] + RecursiveArrayToolsExt = "RecursiveArrayTools" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecursiveArrayTools]] +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] +git-tree-sha1 = "51bdb23afaaa551f923a0e990f7c44a4451a26f1" +uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" +version = "3.39.0" + + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" + RecursiveArrayToolsForwardDiffExt = "ForwardDiff" + RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions" + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" + RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] + RecursiveArrayToolsSparseArraysExt = ["SparseArrays"] + RecursiveArrayToolsStructArraysExt = "StructArrays" + RecursiveArrayToolsTablesExt = ["Tables"] + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.1" + +[[deps.RuntimeGeneratedFunctions]] +deps = ["ExprTools", "SHA", "Serialization"] +git-tree-sha1 = "2f609ec2295c452685d3142bc4df202686e555d2" +uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +version = "0.5.16" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.SIMDTypes]] +git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" +uuid = "94e857df-77ce-4151-89e5-788b33177be4" +version = "0.1.0" + +[[deps.SciMLPublic]] +git-tree-sha1 = "ed647f161e8b3f2973f24979ec074e8d084f1bee" +uuid = "431bcebd-1456-4ced-9d72-93c2757fff0b" +version = "1.0.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "c5391c6ace3bc430ca630251d02ea9687169ca68" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.2" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.12.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "f2685b435df2613e25fc10ad8c26dddb8640f547" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.6.1" + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + + [deps.SpecialFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[[deps.Static]] +deps = ["CommonWorldInvalidations", "IfElse", "PrecompileTools", "SciMLPublic"] +git-tree-sha1 = "49440414711eddc7227724ae6e570c7d5559a086" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "1.3.1" + +[[deps.StaticArrayInterface]] +deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Static"] +git-tree-sha1 = "96381d50f1ce85f2663584c8e886a6ca97e60554" +uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" +version = "1.8.0" + + [deps.StaticArrayInterface.extensions] + StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" + StaticArrayInterfaceStaticArraysExt = "StaticArrays" + + [deps.StaticArrayInterface.weakdeps] + OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "6ab403037779dae8c514bad259f32a447262455a" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.4" + +[[deps.Statistics]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] + +[[deps.StrideArraysCore]] +deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"] +git-tree-sha1 = "83151ba8065a73f53ca2ae98bc7274d817aa30f2" +uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" +version = "0.5.8" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.11.0" + +[[deps.StructUtils]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "79529b493a44927dd5b13dde1c7ce957c2d049e4" +uuid = "ec057cc2-7a8d-4b58-b3b3-92acb9f63b42" +version = "2.6.0" + + [deps.StructUtils.extensions] + StructUtilsMeasurementsExt = ["Measurements"] + StructUtilsTablesExt = ["Tables"] + + [deps.StructUtils.weakdeps] + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + +[[deps.StructuredOptimization]] +deps = ["AbstractOperators", "Combinatorics", "DSP", "DSPOperators", "DifferentiationInterface", "FFTW", "FFTWOperators", "LinearAlgebra", "ProximalAlgorithms", "ProximalCore", "ProximalOperators", "RecursiveArrayTools"] +path = "." +uuid = "46cd3e9d-64ff-517d-a929-236bc1a1fc9d" +version = "0.5.0" + +[[deps.StyledStrings]] +uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +version = "1.11.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.8.3+2" + +[[deps.SymbolicIndexingInterface]] +deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"] +git-tree-sha1 = "94c58884e013efff548002e8dc2fdd1cb74dfce5" +uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +version = "0.3.46" + + [deps.SymbolicIndexingInterface.extensions] + SymbolicIndexingInterfacePrettyTablesExt = "PrettyTables" + + [deps.SymbolicIndexingInterface.weakdeps] + PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TSVD]] +deps = ["Adapt", "LinearAlgebra"] +git-tree-sha1 = "c39caef6bae501e5607a6caf68dd9ac6e8addbcb" +uuid = "9449cd9e-2762-5aa3-a617-5413e99d722e" +version = "0.4.4" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" + +[[deps.ThreadingUtilities]] +deps = ["ManualMemory"] +git-tree-sha1 = "d969183d3d244b6c33796b5ed01ab97328f2db85" +uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" +version = "0.5.5" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.3" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.3.1+2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.15.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.64.0+1" + +[[deps.oneTBB_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl"] +git-tree-sha1 = "1350188a69a6e46f799d3945beef36435ed7262f" +uuid = "1317d2d5-d96f-522e-a858-c73665f53c3e" +version = "2022.0.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.5.0+2" diff --git a/Project.toml b/Project.toml index 73a52e2..746b42a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,29 +1,32 @@ name = "StructuredOptimization" uuid = "46cd3e9d-64ff-517d-a929-236bc1a1fc9d" -version = "0.4.0" +version = "0.5.0" [deps] AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" +DSPOperators = "d5a72628-6e2f-430e-82f5-561df0bb8116" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FFTWOperators = "c59a084b-ba08-4f3f-af9e-f4298d6caa94" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" [compat] -AbstractOperators = "0.3" -DSP = "0.5.1 - 0.7" +AbstractOperators = "0.4" +Combinatorics = "1.0.2" +DSP = "0.5.1 - 0.8" +DSPOperators = "0.1" +DifferentiationInterface = "0.6" FFTW = "1" -ProximalAlgorithms = "0.5" -ProximalOperators = "0.15" -RecursiveArrayTools = "1 - 2" -julia = "1.4" - -[extras] -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["LinearAlgebra", "Test", "Random"] +FFTWOperators = "0.1" +LinearAlgebra = "1" +ProximalAlgorithms = "0.8" +ProximalCore = "0.2" +ProximalOperators = "0.17" +RecursiveArrayTools = "1 - 3" +julia = "1.10" diff --git a/README.md b/README.md index 6ec97e5..f69ea04 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![Build status](https://github.com/JuliaFirstOrder/StructuredOptimization.jl/workflows/CI/badge.svg)](https://github.com/JuliaFirstOrder/StructuredOptimization.jl/actions?query=workflow%3ACI) [![codecov](https://codecov.io/gh/JuliaFirstOrder/StructuredOptimization.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaFirstOrder/StructuredOptimization.jl) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliafirstorder.github.io/StructuredOptimization.jl/stable) [![](https://img.shields.io/badge/docs-latest-blue.svg)](https://juliafirstorder.github.io/StructuredOptimization.jl/latest) diff --git a/src/StructuredOptimization.jl b/src/StructuredOptimization.jl index e989dad..7a3f5bb 100644 --- a/src/StructuredOptimization.jl +++ b/src/StructuredOptimization.jl @@ -2,24 +2,39 @@ module StructuredOptimization using LinearAlgebra using RecursiveArrayTools -using AbstractOperators +using ProximalCore +using AbstractOperators, DSPOperators, FFTWOperators using ProximalOperators using ProximalAlgorithms +using Combinatorics: permutations, powerset +using ProximalAlgorithms: IterativeAlgorithm, override_parameters -import ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus -export ZeroFPR, PANOC, PANOCplus +ProximalAlgorithms.value_and_gradient(f, x) = begin + y, fy = gradient(f, x) + return fy, y +end +ProximalAlgorithms.value_and_gradient!(grad_f_x, f, x) = begin + fy = gradient!(grad_f_x, f, x) + return fy +end + +abstract type AbstractExpression end + +include("syntax/variable.jl") +include("syntax/expressions/expression.jl") +include("syntax/terms/term.jl") + +const TermOrExpr = Union{Term,AbstractExpression} -include("syntax/syntax.jl") include("calculus/precomposeNonlinear.jl") # TODO move to ProximalOperators? -include("arraypartition.jl") # TODO move to ProximalOperators? +include("calculus/sqrNormL2WithNormalOp.jl") # problem parsing include("solvers/terms_extract.jl") include("solvers/terms_properties.jl") -include("solvers/terms_splitting.jl") +include("solvers/parse.jl") # solver calls -include("solvers/solvers_options.jl") include("solvers/build_solve.jl") include("solvers/minimize.jl") diff --git a/src/arraypartition.jl b/src/arraypartition.jl deleted file mode 100644 index 06eff5e..0000000 --- a/src/arraypartition.jl +++ /dev/null @@ -1,36 +0,0 @@ -import ProximalOperators -import RecursiveArrayTools - -@inline function ProximalOperators.prox( - h, - x::RecursiveArrayTools.ArrayPartition, - gamma... -) - # unwrap - y, fy = ProximalOperators.prox(h, x.x, gamma...) - # wrap - return RecursiveArrayTools.ArrayPartition(y), fy -end - -@inline function ProximalOperators.gradient( - h, - x::RecursiveArrayTools.ArrayPartition -) - # unwrap - grad, fx = ProximalOperators.gradient(h, x.x) - # wrap - return RecursiveArrayTools.ArrayPartition(grad), fx -end - -@inline ProximalOperators.prox!( - y::RecursiveArrayTools.ArrayPartition, - h, - x::RecursiveArrayTools.ArrayPartition, - gamma... -) = ProximalOperators.prox!(y.x, h, x.x, gamma...) - -@inline ProximalOperators.gradient!( - y::RecursiveArrayTools.ArrayPartition, - h, - x::RecursiveArrayTools.ArrayPartition -) = ProximalOperators.gradient!(y.x, h, x.x) diff --git a/src/calculus/precomposeNonlinear.jl b/src/calculus/precomposeNonlinear.jl index 19dec7c..110ad3a 100644 --- a/src/calculus/precomposeNonlinear.jl +++ b/src/calculus/precomposeNonlinear.jl @@ -15,9 +15,9 @@ struct PrecomposeNonlinear{P, end function PrecomposeNonlinear(g::P, G::T) where {P, T} - t, s = domainType(G), size(G,2) + t, s = domain_type(G), size(G,2) bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)) - t, s = codomainType(G), size(G,1) + t, s = codomain_type(G), size(G,1) bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)) bufC2 = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)) PrecomposeNonlinear{P, T, typeof(bufD), typeof(bufC)}(g, G, bufD, bufC, bufC2) diff --git a/src/calculus/sqrNormL2WithNormalOp.jl b/src/calculus/sqrNormL2WithNormalOp.jl new file mode 100644 index 0000000..50276aa --- /dev/null +++ b/src/calculus/sqrNormL2WithNormalOp.jl @@ -0,0 +1,88 @@ +# squared L2 norm (times a constant, or weighted) precomposed with an operator + +""" + SqrNormL2WithNormalOp(λ=1, L::LinearOperator) + +With a nonnegative scalar `λ`, return the squared Euclidean norm +```math +f(x) = \\tfrac{λ}{2}\\|L * x\\|^2. +``` +With a nonnegative array `λ`, return the weighted squared Euclidean norm +```math +f(x) = \\tfrac{1}{2}∑_i λ_i y_i^2 where y = L * x. +``` + +This is a special case of the more general `Precompose(SqrNormL2(), L, 1, 0)` operator, +where `L` is a linear operator, and only the gradient is needed, not the proximal operator. +The gradient of the precomposed squared norm is +```math +\nabla f(x) = Lᴴ * L * x, +``` +and in many cases, there is an optimized implementation of the normal operator `Lᴴ * L` +that makes the compution of the gradient much faster than the naive implementation. + +A notable drawback of this method is that gradient! does not return the +squared norm of `L * x`, but rather the squared norm of `Lᴴ * L * x` (i.e. the +squared norm of the gradient). Most algorithms, however, tolerate this +difference, and it is much faster to compute. +""" +struct SqrNormL2WithNormalOp{T,SC,L<:AbstractOperator,L2<:AbstractOperator} + A::L + AᴴA::L2 + lambda::T + function SqrNormL2WithNormalOp(A, lambda) + @assert A isa AbstractOperator + @assert is_linear(A) + if any(lambda .< 0) + error("coefficients in λ must be nonnegative") + else + AᴴA = A' * A + new{typeof(lambda),all(lambda .> 0),typeof(A),typeof(AᴴA)}(A, AᴴA, lambda) + end + end +end + +is_convex(::Type{<:SqrNormL2WithNormalOp}) = true +is_smooth(::Type{<:SqrNormL2WithNormalOp}) = true +is_separable(::Type{<:SqrNormL2WithNormalOp}) = true +is_generalized_quadratic(::Type{<:SqrNormL2WithNormalOp}) = true +is_strongly_convex(::Type{SqrNormL2WithNormalOp{T,SC}}) where {T,SC} = SC + +SqrNormL2WithNormalOp(A) = SqrNormL2WithNormalOp(A, 1) + +function (f::SqrNormL2WithNormalOp{S})(x) where {S <: Real} + y = f.A * x + return f.lambda / real(eltype(y))(2) * norm(y)^2 +end + +function (f::SqrNormL2WithNormalOp{<:AbstractArray})(x) + y = f.A * x + R = real(eltype(y)) + sqnorm = R(0) + for k in eachindex(y) + sqnorm += f.lambda[k] * abs2(y[k]) + end + return sqnorm / R(2) +end + +function gradient!(y, f::SqrNormL2WithNormalOp{<:Real}, x) + R = real(eltype(y)) + mul!(y, f.AᴴA, x) + sqnx = R(0) + for k in eachindex(y) + y[k] *= f.lambda + sqnx += abs2(y[k]) + end + return f.lambda / R(2) * sqnx +end + +function gradient!(y, f::SqrNormL2WithNormalOp{<:AbstractArray}, x) + R = real(eltype(y)) + mul!(y, f.AᴴA, x) + sqnx = R(0) + for k in eachindex(y) + y[k] *= f.lambda[k] + sqnx += f.lambda[k] * abs2(y[k]) + end + return sqnx / R(2) +end diff --git a/src/solvers/build_solve.jl b/src/solvers/build_solve.jl index b360902..f55032a 100644 --- a/src/solvers/build_solve.jl +++ b/src/solvers/build_solve.jl @@ -1,11 +1,11 @@ -export build +export suggest_algorithm """ - parse_problem(terms::Tuple, solver::ForwardBackwardSolver) + parse_problem(terms::TermSet, solver::IterativeAlgorithm) -Takes as input a tuple containing the terms defining the problem and the solver. +Takes as input a TermSet containing the terms defining the problem and the solver. -Returns a tuple containing the optimization variables and the problem terms +Returns a TermSet containing the optimization variables and the problem terms to be fed into the solver. # Example @@ -21,37 +21,93 @@ julia> p = problem( ls(A*x - b ) , norm(x) <= 1 ); julia> StructuredOptimization.parse_problem(p, PANOCplus()); ``` """ -function parse_problem(terms::Tuple, solver::T) where T <: ForwardBackwardSolver - x = extract_variables(terms) - # Separate smooth and nonsmooth - smooth, nonsmooth = split_smooth(terms) - if is_proximable(nonsmooth) - g = extract_proximable(x, nonsmooth) - kwargs = Dict{Symbol, Any}(:g => g) - if !isempty(smooth) - if is_linear(smooth) - f = extract_functions(smooth) - A = extract_operators(x, smooth) - kwargs[:A] = A - else # ?? - f = extract_functions_nodisp(smooth) - A = extract_affines(x, smooth) - f = PrecomposeNonlinear(f, A) - end - kwargs[:f] = f +function parse_problem(terms::Union{Term,TermSet}, algorithm::T, return_partial::Bool = false) where {T <: IterativeAlgorithm} + terms = terms isa TermSet ? terms : TermSet(terms) + assumptions = ProximalAlgorithms.get_assumptions(algorithm) + variables = extract_variables(terms) + remaining_terms = terms + kwargs = Dict{Symbol, Any}() + for assumption in assumptions + for term_selection in reverse(collect(powerset(remaining_terms, 1))) + term_selection = TermSet(term_selection...) + preparation_result = prepare(term_selection, assumption, variables) + if preparation_result !== nothing + term_selection = collect(term_selection) + remaining_terms = setdiff(remaining_terms, term_selection) + push!(kwargs, preparation_result...) + break + end + end + if isempty(remaining_terms) + return algorithm, kwargs, variables + end end - return (x, kwargs) - end - error("Sorry, I cannot parse this problem for solver of type $(T)") + return return_partial ? (kwargs, remaining_terms) : nothing end +function print_diagnostics(terms::Union{Term,TermSet}, algorithm::T) where {T <: IterativeAlgorithm} + terms = terms isa TermSet ? terms : TermSet(terms) + kwargs, remaining_terms = parse_problem(terms, algorithm, true) + print("The algorithm $(typeof(algorithm).name.name) assumes problem of form: ") + show(ProximalAlgorithms.get_assumptions(algorithm)) + println() + if !isempty(kwargs) + println("Successfully prepared the following terms:") + for (key, value) in kwargs + println(" - $key: $(typeof(value))") + end + end + println("The following terms could not be prepared:") + for term in remaining_terms + println(" - $term") + end +end + +function parse_problem(terms::Union{Term,TermSet}) + terms = terms isa TermSet ? terms : TermSet(terms) + for algorithm in ProximalAlgorithms.get_algorithms() + result = parse_problem(terms, algorithm) + if result !== nothing + return result + end + end + return nothing +end + +function suggest_algorithm(terms::Union{Term,TermSet}, algorithms = ProximalAlgorithms.get_algorithms()) + terms = terms isa TermSet ? terms : TermSet(terms) + suitable_algs = [] + for algorithm in algorithms + result = parse_problem(terms, algorithm) + if result !== nothing + push!(suitable_algs, algorithm) + end + end + return suitable_algs +end + +function print_diagnostics(terms::Union{Term,TermSet}) + terms = terms isa TermSet ? terms : TermSet(terms) + best_algorithm, best_algorithm_remaining_terms = nothing, Inf + for algorithm in ProximalAlgorithms.get_algorithms() + _, remaining_terms = parse_problem(terms, algorithm, true) + if length(remaining_terms) < best_algorithm_remaining_terms + best_algorithm_remaining_terms = length(remaining_terms) + best_algorithm = algorithm + end + end + println("The closest algorithm to the problem is $best_algorithm") + print_diagnostics(terms, best_algorithm) +end export solve """ - solve(terms::Tuple, solver::ForwardBackwardSolver) + solve(terms::Union{Term,TermSet}; kwargs...) + solve(terms::Union{Term,TermSet}, solver::IterativeAlgorithm; kwargs...) + solve(terms::Union{Term,TermSet}, solvers::Union{AbstractVector,Tuple}; kwargs...) -Takes as input a tuple containing the terms defining the problem and the solver options. +Takes as input a Term/TermSet containing the terms defining the problem and the solver options. Solves the problem returning a tuple containing the iterations taken and the build solver. @@ -65,14 +121,57 @@ julia> A, b = randn(10,4), randn(10); julia> p = problem(ls(A*x - b ), norm(x) <= 1); -julia> solve(p, PANOCplus()); +julia> solve(p, PANOCplus(); maxiter=10); julia> ~x ``` """ -function solve(terms::Tuple, solver::ForwardBackwardSolver) - x, kwargs = parse_problem(terms, solver) - x_star, it = solver(; x0 = ~x, kwargs...) - ~x .= x_star - return x, it +function solve(terms::Union{Term,TermSet}, solvers::Union{<:AbstractVector{IterativeAlgorithm},<:Tuple{Vararg{IterativeAlgorithm}}}; kwargs...) + terms = terms isa TermSet ? terms : TermSet(terms) + for solver in solvers + result = parse_problem(terms, solver) + if result isa Nothing + continue + end + _, term_kwargs, x = result + solver = override_parameters(solver; kwargs...) + x_star, it = solver(; x0 = ~x, term_kwargs...) + ~x .= x_star isa Tuple ? x_star[1] : x_star + return x, it + end + if length(solvers) == 1 + print_diagnostics(terms, solvers[1]) + error("Sorry, I cannot parse this problem for solver of type $(typeof(solvers[1]).parameters[1])") + else + print_diagnostics(terms) + error("Sorry, I cannot parse this problem for any of the provided solvers") + end +end + +function solve(terms::Union{Term,TermSet}, solver::IterativeAlgorithm; kwargs...) + terms = terms isa TermSet ? terms : TermSet(terms) + result = parse_problem(terms, solver) + if result === nothing + print_diagnostics(terms, solver) + error("Sorry, I cannot parse this problem for solver of type $(typeof(solver).parameters[1])") + end + _, term_kwargs, x = result + solver = override_parameters(solver; kwargs...) + x_star, it = solver(; x0 = ~x, term_kwargs...) + ~x .= x_star isa Tuple ? x_star[1] : x_star + return x, it +end + +function solve(terms::Union{Term,TermSet}; kwargs...) + terms = terms isa TermSet ? terms : TermSet(terms) + result = parse_problem(terms) + if result === nothing + print_diagnostics(terms) + error("Sorry, I cannot find a suitable solver for this problem") + end + solver, term_kwargs, x = result + solver = override_parameters(solver; kwargs...) + x_star, it = solver(; x0 = ~x, term_kwargs...) + ~x .= x_star + return x, it end diff --git a/src/solvers/minimize.jl b/src/solvers/minimize.jl index b22a5e0..dc37c3b 100644 --- a/src/solvers/minimize.jl +++ b/src/solvers/minimize.jl @@ -1,4 +1,59 @@ -export @minimize +export problem, @minimize, @term + +""" + problems(terms...) + +Constructs a problem. + +# Example + +```julia + +julia> x = Variable(4) +Variable(Float64, (4,)) + +julia> A, b = randn(10,4), randn(10); + +julia> p = problem(ls(A*x-b), norm(x) <= 1) + +``` + +""" +problem(terms...) = begin + flattened_terms = Term[] + for t in terms + if t isa TermSet + append!(flattened_terms, t.terms) + elseif t isa Term + push!(flattened_terms, t) + else + error("All arguments must be of type Term or TermSet") + end + end + TermSet(flattened_terms...) +end + +function expand_terms_with_repr(expr) + if expr isa Expr && expr.head == :call && expr.args[1] == :+ + return Tuple(map(t -> :(Term($(esc(t)), $(string(t)))), expr.args[2:end])) + elseif expr isa Symbol + return (esc(expr),) + elseif expr isa Expr && expr.head == :tuple + return Tuple(first.(expand_terms_with_repr.(expr.args))) + else + return (:(Term($(esc(expr)), $(string(expr)))),) + end +end + +""" + @term expr + +Records the code representation of the term. Useful if later we want to print the term, e.g. when debugging. +""" +macro term(expr) + terms = expand_terms_with_repr(expr) + return Expr(:block, terms...) +end """ @minimize cost [st ctr] [with slv_opt] @@ -29,28 +84,34 @@ Returns as output a tuple containing the optimization variables and the number of iterations spent by the solver algorithm. """ macro minimize(cf::Union{Expr, Symbol}) - cost = esc(cf) - return :(solve(problem($(cost)), default_solver())) + cost = expand_terms_with_repr(cf) + problem_expr = Expr(:call, :problem, cost...) + return :(solve($problem_expr)) end macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol}) - cost = esc(cf) - if s == :(st) - constraints = esc(cstr) - return :(solve(problem($(cost), $(constraints)), default_solver())) - elseif s == :(with) + cost = expand_terms_with_repr(cf) + if s == :st + constraints = expand_terms_with_repr(cstr) + terms = (cost..., constraints...) + problem_expr = Expr(:call, :problem, terms...) + return :(solve($problem_expr)) + elseif s == :with solver = esc(cstr) - return :(solve(problem($(cost)), $(solver))) + problem_expr = Expr(:call, :problem, cost...) + return :(solve($problem_expr, $solver)) else error("wrong symbol after cost function! use `st` or `with`") end end macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol}, w::Symbol, slv::Union{Expr, Symbol}) - cost = esc(cf) - s != :(st) && error("wrong symbol after cost function! use `st`") - constraints = esc(cstr) - w != :(with) && error("wrong symbol after constraints! use `with`") + cost = expand_terms_with_repr(cf) + s != :st && error("wrong symbol after cost function! use `st`") + constraints = expand_terms_with_repr(cstr) + w != :with && error("wrong symbol after constraints! use `with`") solver = esc(slv) - return :(solve(problem($(cost), $(constraints)), $(solver))) + terms = (cost..., constraints...) + problem_expr = Expr(:call, :problem, terms...) + return :(solve($problem_expr, $solver)) end diff --git a/src/solvers/parse.jl b/src/solvers/parse.jl new file mode 100644 index 0000000..3a6dc75 --- /dev/null +++ b/src/solvers/parse.jl @@ -0,0 +1,619 @@ +function add_to_incompatibilities(incompatibilities, t1, t2) + if haskey(incompatibilities, t1) + push!(incompatibilities[t1], t2) + else + incompatibilities[t1] = Set([t2]) + end + if haskey(incompatibilities, t2) + push!(incompatibilities[t2], t1) + else + incompatibilities[t2] = Set([t1]) + end +end + +function group_by_variables(terms) + variable_bags = Dict{Variable, Vector{Any}}() + for term in terms + for var in variables(term) + if haskey(variable_bags, var) + push!(variable_bags[var], term) + else + variable_bags[var] = [term] + end + end + end + return variable_bags +end + +function can_be_separable_sum(variable_bags) + for (var, term_list) in variable_bags + if length(term_list) > 1 # more than one term for this variable + # Check if any of the terms are sliced + operators = [get_operators_for_var(term, var) for term in term_list] + slicing_masks = [is_sliced(op) ? get_slicing_mask(op) : nothing for op in operators] + for i in eachindex(operators) + if is_sliced(operators[i]) + # This operator is sliced, check if it is overlapping with any other sliced operator + for j in i+1:length(operators) + if is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) + return false + end + end + else # no slicing -> this term is incompatible with all others + return false + end + end + end + end + return true +end + +function get_unseparable_pairs(variable_bags) + incompatibilities = Dict{Term, Set{Term}}() + for (var, term_list) in variable_bags + if length(term_list) > 1 # more than one term for this variable + # Check if any of the terms are sliced + operators = [get_operators_for_var(term, var) for term in term_list] + slicing_masks = [is_sliced(op) ? get_slicing_mask(op) : nothing for op in operators] + for i in eachindex(operators) + if is_sliced(operators[i]) + # This operator is sliced, check if it is overlapping with any other sliced operator + for j in i+1:length(operators) + if is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) + add_to_incompatibilities(incompatibilities, term_list[i], term_list[j]) + end + end + else # no slicing -> this term is incompatible with all others + for j in i+1:length(operators) + add_to_incompatibilities(incompatibilities, term_list[i], term_list[j]) + end + end + end + end + end + return incompatibilities +end + +function merge_function_with_operator(op, f, disp, λ) + if is_eye(op) + f = disp == 0 ? f : PrecomposeDiagonal(f, 1.0, disp) + if size(op, 1) != size(op, 2) + f = ReshapeInput(f, size(op, 1)) + end + elseif is_diagonal(op) + if f isa SqrNormL2 + f = SqrNormL2(f.lambda .* diag(op) .^ 2) + else + f = PrecomposeDiagonal(f, diag(op), disp) + end + elseif is_AAc_diagonal(op) + f = Precompose(f, op, diag_AAc(op), disp) + elseif is_linear(op) + # we assume that prox will not be called on this term because it will not give a valid result + f = Precompose(f, op, 1, disp) + else + # we assume that prox will not be called on this term because it will not give a valid result + if disp != 0 + op = AbstractOperators.AffineAdd(op, disp) + end + f = PrecomposeNonlinear(f, op) + end + return λ == 1 ? f : Postcompose(f, λ) +end + +unsatisfied_properties(term, assumptions::ProximalAlgorithms.AssumptionItem) = [property_func for property_func in assumptions.second if !property_func(term)] +does_satisfy(term, assumptions::ProximalAlgorithms.AssumptionItem) = all(property_func(term) for property_func in assumptions.second) + +function prepare(term::Term, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{N, Variable}) where N + if does_satisfy(term, assumption.func) && (!(ProximalCore.is_proximable in assumption.func.second) || is_AAc_diagonal(term.A.L)) + op = extract_operators(variables, term) + disp = displacement(term) + return (assumption.func.first => merge_function_with_operator(op, term.f, disp, term.lambda),) + else + return nothing + end +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.SimpleTerm, ::NTuple{N, Variable}) where N + repr = term.repr !== nothing ? term.repr : string(term) + problematic_properties = unsatisfied_properties(term, assumption.func) + if length(problematic_properties) == 0 + println("Term $repr satisfies all required properties, but the following operator is not AAc diagonal: ", term.A.L) + else + println("Term $repr does not satisfy required property: $(join(problematic_properties, ", "))") + end +end + +function prepare_proximable_single_var_per_term(variable_bags, variables::NTuple{N, Variable}) where {N} + fs = () + for var in variables + if haskey(variable_bags, var) + term_list = variable_bags[var] + if length(term_list) > 1 + #multiple terms per variable + #currently this happens only with GetIndex + fxi,idxs = (),() + for ti in term_list + op = operator(ti) + fxi = (fxi..., merge_function_with_operator(op, ti.f, displacement(ti), ti.lambda)) + if AbstractOperators.ndoms(op, 2) > 1 + op = op[findfirst(==(var), variables(ti))] + end + if typeof(op) <: Compose + idx = op.A[1].idx + else + idx = op.idx + end + idxs = (idxs..., get_slicing_mask(op)) + end + fs = (fs..., SlicedSeparableSum(fxi,idxs)) + else + op = operator(term_list[1]) + disp = displacement(term_list[1]) + fs = (fs..., merge_function_with_operator(op, term_list[1].f, disp, term_list[1].lambda)) + end + else + fs = (fs..., IndFree()) + end + end + return SeparableSum(fs) +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + if any(term -> !does_satisfy(term, assumption.func), terms) + return nothing + end + if ProximalCore.is_proximable in assumption.func.second + if any(!is_AAc_diagonal(affine(term)) for term in terms) + return nothing + end + variable_bags = group_by_variables(terms) + if !can_be_separable_sum(variable_bags) + return nothing + end + if all(length.(values(variable_bags)) .== 1) + # all terms references only one variable + return (assumption.func.first => prepare_proximable_single_var_per_term(variable_bags, variables),) + else + op = extract_operators(variables, terms) + idxs = get_slicing_expr(op) + op = remove_slicing(op) + hcat_ops = tuple([op[i] for i in eachindex(op.A)]...) + μs = AbstractOperators.diag_AAc(op) + f = extract_functions(terms) + return (assumption.func.first => PrecomposedSlicedSeparableSum(f.fs, idxs, hcat_ops, μs),) + end + else + fs = () + for term in terms + if is_linear(term) + f = merge_function_with_operator(extract_operators(variables, term), term.f, displacement(term), term.lambda) + else + f = extract_functions(term) + op = extract_affines(variables, term) + f = PrecomposeNonlinear(f, op) + f = term.lambda == 1 ? f : Postcompose(f, term.lambda) + end + fs = (fs..., f) + end + return (assumption.func.first => ProximalOperators.Sum(fs),) + end +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + return + end + problematic_term_index = findfirst(term -> !does_satisfy(term, assumption.func), terms) + if problematic_term_index !== nothing + problematic_term = terms[problematic_term_index] + repr = problematic_term.repr !== nothing ? problematic_term.repr : string(problematic_term) + problematic_properties = unsatisfied_properties(problematic_term, assumption.func) + println("Term $repr does not satisfy required property: $(join(problematic_properties, ", "))") + elseif any(term -> !is_AAc_diagonal(affine(term)), terms) + println("The following terms contains operators that are not AAc diagonal:") + for term in terms + if !is_AAc_diagonal(affine(term)) + repr = term.repr !== nothing ? term.repr : string(term) + println(" - $repr") + end + end + else + variable_bags = group_by_variables(terms) + incompatibilities = get_unseparable_pairs(variable_bags) + println("The following terms are incompatible with each other:") + for (term, incompatible_terms) in incompatibilities + println(" - $term: $(join(incompatible_terms, ", "))") + end + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where N + op = extract_affines(variables, term) + if does_satisfy(op, assumption.operator) && does_satisfy(term.f, assumption.func) + return ( + assumption.func.first => term.lambda == 1 ? term.f : Postcompose(term.f, term.lambda), + assumption.operator.first => op + ) + else # try preparing as a simple term + tup = prepare(term, ProximalAlgorithms.SimpleTerm(assumption.func), variables) + if tup !== nothing && length(variables) > 1 + example_input = ArrayPartition(Tuple(~var for var in variables)) + tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) + end + return tup + end +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where N + op = affine(term) + repr = term.repr !== nothing ? term.repr : string(term) + if is_eye(op) + problematic_properties = unsatisfied_properties(term.f, assumption.func) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + else + println("A possible decomposition of term $repr:") + f = term.lambda == 1 ? term.f : Postcompose(term.f, term.lambda) + print(" - ", assumption.func.first, " = ", f) + if !does_satisfy(f, assumption.func) + problematic_properties = unsatisfied_properties(f, assumption.func) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + println("When trying to prepare the term as a simple term:") + print_diagnostics(term, ProximalAlgorithms.SimpleTerm(assumption.func), variables) +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + op = extract_affines(variables, terms) + f = extract_functions(terms) + if does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func) + return ( + assumption.func.first => f, + assumption.operator.first => op + ) + else # try preparing as a simple term + return prepare(terms, ProximalAlgorithms.SimpleTerm(assumption.func), variables) + end +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where {N} + op = extract_affines(variables, terms) + f = extract_functions(terms) + repr = string(terms) + if is_eye(op) + for term in terms + problematic_properties = unsatisfied_properties(term.f, assumption.func) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + end + else + println("A possible decomposition of terms $repr:") + print(" - ", assumption.func.first, " = ", f) + if !does_satisfy(f, assumption.func) + problematic_properties = unsatisfied_properties(f, assumption.func) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + println("When trying to prepare terms as a simple function:") + print_diagnostics(terms, ProximalAlgorithms.SimpleTerm(assumption.func), variables) +end + +function prepare(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} + op = extract_affines(variables, term) + f = extract_functions(term) + if does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₁) + return ( + assumption.func₁.first => f, + assumption.operator.first => op + ) + elseif does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₂) + return ( + assumption.func₂.first => f, + assumption.operator.first => affine(term) + ) + else + # try preparing as a simple term + tup = prepare(term, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) + if tup !== nothing && length(variables) > 1 + example_input = ArrayPartition(tuple([~var for var in variables]...)) + tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) + end + return tup + end +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} + op = affine(term) + f = extract_functions(term) + repr = term.repr !== nothing ? term.repr : string(term) + if is_eye(op) + problematic_properties = unsatisfied_properties(term.f, assumption.func₁) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + else + println("A possible decomposition of term $repr:") + print(" - ", assumption.func₁.first, " = ", f) + if !does_satisfy(f, assumption.func₁) + problematic_properties = unsatisfied_properties(f, assumption.func₁) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + println("When trying to prepare the term as a simple term:") + print_diagnostics(term, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + op = extract_affines(variables, terms) + f = extract_functions(terms) + if does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₁) + return ( + assumption.func₁.first => f, + assumption.operator.first => op + ) + elseif does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₂) + return ( + assumption.func₂.first => f, + assumption.operator.first => affine(terms[1].A) + ) + else + # try preparing as a simple term + tup = prepare(terms, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) + if tup === nothing + tup = prepare(terms, ProximalAlgorithms.SimpleTerm(assumption.func₂), variables) + end + if tup !== nothing && length(variables) > 1 + example_input = ArrayPartition(tuple([~var for var in variables]...)) + tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) + end + return tup + end +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + return + end + op = affine(terms[1].A) + f = extract_functions(terms) + repr = string(terms) + if is_eye(op) + for term in terms + problematic_properties = unsatisfied_properties(term.f, assumption.func₁) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + end + else + println("A possible decomposition of terms $repr:") + print(" - ", assumption.func₁.first, " = ", f) + if !does_satisfy(f, assumption.func₁) + problematic_properties = unsatisfied_properties(f, assumption.func₁) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + println("Alteratively, one can try to prepare the function part as:") + print(" - ", assumption.func₂.first, " = ", f) + if !does_satisfy(f, assumption.func₂) + problematic_properties = unsatisfied_properties(f, assumption.func₂) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + end + println("When trying to prepare the term as a simple term:") + print_diagnostics(terms, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) +end + +function prepare(term::Term, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where N + f = term.f + f_is_ls = f isa ProximalOperators.LeastSquares || f isa ProximalOperators.SqrNormL2 || f isa SqrNormL2WithNormalOp + if !f_is_ls + return nothing + end + if f isa SqrNormL2WithNormalOp + lambda = term.lambda * f.lambda + op = term.f.A + b = displacement(op) + op = remove_displacement(op) + else + lambda = term.lambda + op = extract_operators(variables, term) + b = displacement(term) + end + if !does_satisfy(op, assumption.operator) + return nothing + end + if lambda != 1 + op = lambda * op + b = lambda * b + end + return ( + assumption.operator.first => op, + assumption.b => b, + ) +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where N + op = extract_operators(variables, term) + b = displacement(term) + f = term.f + repr = term.repr !== nothing ? term.repr : string(term) + if !(f isa ProximalOperators.LeastSquares || f isa ProximalOperators.SqrNormL2) + println("Term $repr does not satisfy required property: it is not a least squares function") + else + println("A possible decomposition of term $repr:") + print(" - ", assumption.operator.first, " = ", op) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + print(" - ", assumption.b.first, " = ", b) + end +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + return nothing +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + else + println("Cannot prepare terms $terms as a least squares term: only a single term can be prepared as such.") + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where N + f = term.f + if displacement(term) != 0 || !(f isa ProximalOperators.SqrNormL2) + return nothing + end + λ = term.lambda * f.lambda + op = extract_affines(variables, term) + if is_eye(op) + return (assumption.λ => λ,) + elseif is_diagonal(op) + return (assumption.λ => λ * diag(op),) + else + return nothing + end +end + +function print_diagnostics(term::Term, ::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where N + repr = term.repr !== nothing ? term.repr : string(term) + if displacement(term) != 0 + println("Term $repr does not satisfy required property: it has non-zero displacement") + elseif !(term.f isa ProximalOperators.SqrNormL2) + println("Term $repr does not satisfy required property: it is not a squared L2 function") + else + println("Term $repr does not satisfy required property: the operator is not an identity or diagonal") + end +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + return nothing +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + else + println("Cannot prepare terms $terms as a squared L2 term: only a single term can be prepared as such.") + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where N + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + return prepare(term, simple_assumption, variables) +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where N + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + print_diagnostics(term, simple_assumption, variables) +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where {N} + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + results = () + for term in terms + result = prepare(term, simple_assumption, variables) + if isnothing(result) + return nothing + end + results = (results..., result[1].second) + end + return (assumption.func.first => results,) +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where {N} + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + for term in terms + if prepare(term, simple_assumption, variables) === nothing + print_diagnostics(term, simple_assumption, variables) + end + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where N + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + return prepare(term, operator_term_assumption, variables) +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where N + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + print_diagnostics(term, operator_term_assumption, variables) +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where {N} + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + function_results = () + operator_results = () + for term in terms + result = prepare(term, operator_term_assumption, variables) + if isnothing(result) + return nothing + end + function_results = (function_results..., result[1].second) + operator_results = (operator_results..., result[2].second) + end + return ( + assumption.func.first => function_results, + assumption.operator.first => operator_results + ) +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where {N} + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + for term in terms + if prepare(term, operator_term_assumption, variables) === nothing + print_diagnostics(term, operator_term_assumption, variables) + end + end +end diff --git a/src/solvers/solvers_options.jl b/src/solvers/solvers_options.jl deleted file mode 100644 index ff6b963..0000000 --- a/src/solvers/solvers_options.jl +++ /dev/null @@ -1,5 +0,0 @@ -using ProximalAlgorithms - -const ForwardBackwardSolver = ProximalAlgorithms.IterativeAlgorithm - -const default_solver = ProximalAlgorithms.PANOC diff --git a/src/solvers/terms_extract.jl b/src/solvers/terms_extract.jl index 389dea6..a7c583b 100644 --- a/src/solvers/terms_extract.jl +++ b/src/solvers/terms_extract.jl @@ -1,47 +1,41 @@ # returns all variables of a cost function, in terms of appearance extract_variables(t::TermOrExpr) = variables(t) -function extract_variables(t::NTuple{N,TermOrExpr}) where {N} - x = variables.(t) - xAll = x[1] - for i = 2:length(x) - for xi in x[i] - if (xi in xAll) == false - xAll = (xAll...,xi) - end - end - end - return xAll +function extract_variables(t::Union{Tuple, TermSet}) + var_tuples = variables.(t) + vars = collect(Base.Iterators.flatten(var_tuples)) + return tuple(unique(vars)...) end # extract functions from terms function extract_functions(t::Term) - f = displacement(t) == 0 ? t.f : PrecomposeDiagonal(t.f, 1.0, displacement(t)) #for now I keep this - f = t.lambda == 1. ? f : Postcompose(f, t.lambda) #for now I keep this + disp = displacement(t) + f = disp == 0 ? t.f : PrecomposeDiagonal(t.f, one(t.lambda), disp) #for now I keep this + f = t.lambda == 1 ? f : Postcompose(f, t.lambda) #for now I keep this #TODO change this return f end -extract_functions(t::NTuple{N,Term}) where {N} = SeparableSum(extract_functions.(t)) -extract_functions(t::Tuple{Term}) = extract_functions(t[1]) +extract_functions(t::TermSet) = SeparableSum(extract_functions.(t)) # extract functions from terms without displacement function extract_functions_nodisp(t::Term) - f = t.lambda == 1. ? t.f : Postcompose(t.f, t.lambda) + f = t.lambda == 1 ? t.f : Postcompose(t.f, t.lambda) return f end -extract_functions_nodisp(t::NTuple{N,Term}) where {N} = SeparableSum(extract_functions_nodisp.(t)) -extract_functions_nodisp(t::Tuple{Term}) = extract_functions_nodisp(t[1]) +extract_functions_nodisp(t::TermSet) = SeparableSum(extract_functions_nodisp.(t)) # extract operators from terms # returns all operators with an order dictated by xAll #single term, single variable -extract_operators(xAll::Tuple{Variable}, t::TermOrExpr) = operator(t) -extract_operators(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} = extract_operators(xAll, (t,)) +extract_operators(::Tuple{Variable}, t::AbstractExpression) = operator(t) +extract_operators(::Tuple{Variable}, t::Term) = operator(t) +extract_operators(xAll::NTuple{N,Variable}, t::AbstractExpression) where {N} = extract_operators(xAll, (t,)) +extract_operators(xAll::NTuple{N,Variable}, t::Term) where {N} = extract_operators(xAll, TermSet(t,)) #multiple terms, multiple variables -function extract_operators(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) where {N,M} +function extract_operators(xAll::NTuple{N,Variable}, t::TermSet) where {N} ops = () for ti in t tex = expand(xAll,ti) @@ -50,7 +44,7 @@ function extract_operators(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) wh return vcat(ops...) end -sort_and_extract_operators(xAll::Tuple{Variable}, t::TermOrExpr) = operator(t) +sort_and_extract_operators(::Tuple{Variable}, t::TermOrExpr) = operator(t) function sort_and_extract_operators(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} p = zeros(Int,N) @@ -66,12 +60,13 @@ end # returns all affines with an order dictated by xAll #single term, single variable -extract_affines(xAll::Tuple{Variable}, t::TermOrExpr) = affine(t) - -extract_affines(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} = extract_affines(xAll, (t,)) +extract_affines(::Tuple{Variable}, t::AbstractExpression) = affine(t) +extract_affines(::Tuple{Variable}, t::Term) = affine(t) +extract_affines(xAll::NTuple{N,Variable}, t::AbstractExpression) where {N} = extract_affines(xAll, (t,)) +extract_affines(xAll::NTuple{N,Variable}, t::Term) where {N} = extract_affines(xAll, TermSet(t,)) #multiple terms, multiple variables -function extract_affines(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) where {N,M} +function extract_affines(xAll::NTuple{N,Variable}, t::TermSet) where {N} ops = () for ti in t tex = expand(xAll,ti) @@ -80,7 +75,7 @@ function extract_affines(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) wher return vcat(ops...) end -sort_and_extract_affines(xAll::Tuple{Variable}, t::TermOrExpr) = affine(t) +sort_and_extract_affines(::Tuple{Variable}, t::TermOrExpr) = affine(t) function sort_and_extract_affines(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} p = zeros(Int,N) @@ -94,7 +89,7 @@ end # expand term domain dimensions function expand(xAll::NTuple{N,Variable}, t::Term) where {N} xt = variables(t) - C = codomainType(operator(t)) + C = codomain_type(operator(t)) size_out = size(operator(t),1) ex = t.A @@ -109,7 +104,7 @@ end function expand(xAll::NTuple{N,Variable}, ex::AbstractExpression) where {N} ex = convert(Expression,ex) xt = variables(ex) - C = codomainType(operator(ex)) + C = codomain_type(operator(ex)) size_out = size(operator(ex),1) for x in xAll @@ -119,62 +114,3 @@ function expand(xAll::NTuple{N,Variable}, ex::AbstractExpression) where {N} end return ex end - -# extract function and merge operator -function extract_merge_functions(t::Term) - if is_sliced(t) - if typeof(operator(t)) <: Compose - op = operator(t).A[2] - else - op = Eye(size(operator(t),1)...) - end - else - op = operator(t) - end - if is_eye(op) - f = displacement(t) == 0 ? t.f : PrecomposeDiagonal(t.f, 1.0, displacement(t)) - elseif is_diagonal(op) - f = PrecomposeDiagonal(t.f, diag(op), displacement(t)) - elseif is_AAc_diagonal(op) - f = Precompose(t.f, op, diag_AAc(op), displacement(t)) - end - f = t.lambda == 1. ? f : Postcompose(f, t.lambda) #for now I keep this - #TODO change this - return f -end - -function extract_proximable(xAll::NTuple{N,Variable}, t::NTuple{M,Term}) where {N,M} - fs = () - for x in xAll - tx = () #terms containing x - for ti in t - if x in variables(ti) - tx = (tx...,ti) #collect terms containing x - end - end - if isempty(tx) - fx = IndFree() - elseif length(tx) == 1 #only one term per variable - fx = extract_proximable(x,tx[1]) - else - #multiple terms per variable - #currently this happens only with GetIndex - fxi,idxs = (),() - for ti in tx - fxi = (fxi..., extract_merge_functions(ti)) - idx = typeof(operator(ti)) <: Compose ? operator(ti).A[1].idx : operator(ti).idx - idxs = (idxs..., idx ) - end - fx = SlicedSeparableSum(fxi,idxs) - end - fs = (fs...,fx) - end - if length(fs) > 1 - return SeparableSum(fs) ##probably change constructor in Prox? - else - return fs[1] - end -end - -extract_proximable(xAll::Variable, t::Term) = extract_merge_functions(t) -extract_proximable(xAll::NTuple{N,Variable}, t::Term) where {N} = extract_proximable(xAll,(t,)) diff --git a/src/solvers/terms_properties.jl b/src/solvers/terms_properties.jl index a95b4f3..45c517c 100644 --- a/src/solvers/terms_properties.jl +++ b/src/solvers/terms_properties.jl @@ -1,25 +1,45 @@ -is_proximable(term::Term) = is_AAc_diagonal(term) +is_proximable(term::Term) = is_proximable(typeof(term.f)) && is_AAc_diagonal(term.A.L) -function is_proximable(terms::Tuple) - # Check that each term is proximable - if any(is_proximable.(terms) .== false) - return false - end +function get_operators_for_var(term, var) + full_operator = affine(term) + if AbstractOperators.ndoms(full_operator, 2) == 1 + return full_operator + else + return full_operator[findfirst(==(var), variables(term))] + end +end + +function is_separable_sum(terms::TermSet) # Construct the set of occurring variables vars = Set() for term in terms union!(vars, variables(term)) end # Check that each variable occurs in only one term - for v in vars - tv = [t for t in terms if v in variables(t)] - if length(tv) != 1 - if all( is_sliced.(tv) ) && all( is_proximable.(tv) ) - return true - else + for var in vars + terms_with_var = [t for t in terms if var in variables(t)] + if length(terms_with_var) != 1 + # All terms must be either or have a single variable + if ! all( length(variables(term)) == 1 || is_separable(term.f) for term in terms_with_var ) return false end + # All terms must be sliced for this variable + operators = [get_operators_for_var(term, var) for term in terms_with_var] + if any(is_sliced(op) for op in operators) + return false + end + # The sliced operators must not overlap + slicing_masks = [is_sliced(op) ? get_slicing_mask(op) : nothing for op in operators] + for i in eachindex(operators), j in i+1:length(operators) + if any(slicing_masks[i] .&& slicing_masks[j]) + return false + end + end end end return true end + +function is_proximable(terms::TermSet) + return all(is_proximable.(terms)) && is_separable_sum(terms) +end diff --git a/src/solvers/terms_splitting.jl b/src/solvers/terms_splitting.jl deleted file mode 100644 index a1dad74..0000000 --- a/src/solvers/terms_splitting.jl +++ /dev/null @@ -1,31 +0,0 @@ -# -# """ -# `split_smooth(cf::Vararg{Term}) -> (smooth, nonsmooth)` -# -# Splits cost function into `SmoothFunction` and `NonSmoothFunction` terms. -# """ -# split_smooth(cf::Vararg{Term}) = cf[findall(is_smooth(cf))],cf[findall((!).(is_smooth(cf)))] -# split_smooth{N}(cf::NTuple{N,Term}) = split_smooth(cf...) -# -# """ -# `split_AAc_diagonal(cf::Vararg{Term}) -> (proximable, non_proximable)` -# -# Splits cost function into terms with L'*L diagonal operator. -# """ -# split_AAc_diagonal(cf::Vararg{Term}) = cf[findall(is_AAc_diagonal(cf))],cf[findall((!).(is_AAc_diagonal(cf)))] -# split_AAc_diagonal{N}(cf::NTuple{N,Term}) = split_AAc_diagonal(cf...) -# -# #""" TODO -# #`split_Quadratic(cf::Vararg{Term}) -> (quadratic, non_quadratic)` -# # -# #Splits cost function into `QuadraticFunction` and non `QuadraticFunction` terms. -# #""" - -split_smooth(terms::Tuple) = - terms[findall(is_smooth.(terms))], terms[findall((!).(is_smooth.(terms)))] - -split_quadratic(terms::Tuple) = - terms[findall(is_quadratic.(terms))], terms[findall((!).(is_quadratic.(terms)))] - -split_AAc_diagonal(terms::Tuple) = - terms[findall(is_AAc_diagonal.(terms))], terms[findall((!).(is_AAc_diagonal.(terms)))] diff --git a/src/syntax/expressions/abstractOperator_bind.jl b/src/syntax/expressions/abstractOperator_bind.jl index c6edbb0..38d8b6d 100644 --- a/src/syntax/expressions/abstractOperator_bind.jl +++ b/src/syntax/expressions/abstractOperator_bind.jl @@ -19,7 +19,7 @@ julia> reshape(A*x-b,2,5) function reshape(a::AbstractExpression, dims...) A = convert(Expression,a) op = Reshape(A.L, dims...) - return Expression{length(A.x)}(A.x,op) + return Expression(A.x,op) end #Reshape @@ -33,7 +33,7 @@ imported = [ ] importedFFTW = [ - :fft :(AbstractOperators.DFT); + :fft :DFT; :rfft :RDFT; :irfft :IRDFT; :ifft :IDFT; @@ -90,7 +90,7 @@ for i = 1:size(fun,1) @eval begin function $f(a::AbstractExpression, args...) A = convert(Expression,a) - op = $fAbsOp(codomainType(operator(A)),size(operator(A),1), args...) + op = $fAbsOp(codomain_type(operator(A)),size(operator(A),1), args...) return op*A end end diff --git a/src/syntax/expressions/addition.jl b/src/syntax/expressions/addition.jl index bb700a0..9f8b4fd 100644 --- a/src/syntax/expressions/addition.jl +++ b/src/syntax/expressions/addition.jl @@ -1,7 +1,7 @@ import Base: +, - """ - +(ex1::AbstractExpression, ex2::AbstractExpression) + +(ex1::AbstractExpression, ex2::AbstractExpression) Add two expressions. @@ -47,112 +47,111 @@ julia> ex3.+z function (+)(a::AbstractExpression, b::AbstractExpression) A = convert(Expression,a) B = convert(Expression,b) - if variables(A) == variables(B) - return Expression{length(A.x)}(A.x,affine(A)+affine(B)) - else - opA = affine(A) - xA = variables(A) - opB = affine(B) - xB = variables(B) + if variables(A) == variables(B) + return Expression(A.x,affine(A)+affine(B)) + else + opA = affine(A) + xA = variables(A) + opB = affine(B) + xB = variables(B) xNew, opNew = Usum_op(xA,xB,opA,opB,true) - return Expression{length(xNew)}(xNew,opNew) - end + return Expression(xNew,opNew) + end end # sum expressions function (-)(a::AbstractExpression, b::AbstractExpression) A = convert(Expression,a) B = convert(Expression,b) - if variables(A) == variables(B) - return Expression{length(A.x)}(A.x,affine(A)-affine(B)) - else - opA = affine(A) - xA = variables(A) - opB = affine(B) - xB = variables(B) + if variables(A) == variables(B) + return Expression(A.x,affine(A)-affine(B)) + else + opA = affine(A) + xA = variables(A) + opB = affine(B) + xB = variables(B) xNew, opNew = Usum_op(xA,xB,opA,opB,false) - return Expression{length(xNew)}(xNew,opNew) - end + return Expression(xNew,opNew) + end end #unsigned sum affines with single variables -function Usum_op(xA::Tuple{Variable}, - xB::Tuple{Variable}, - A::AbstractOperator, - B::AbstractOperator,sign::Bool) +function Usum_op(xA::Tuple{Variable}, xB::Tuple{Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool) xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - return xNew, opNew + return xNew, opNew end #unsigned sum: HCAT + AbstractOperator -function Usum_op(xA::NTuple{N,Variable}, - xB::Tuple{Variable}, - A::L1, - B::AbstractOperator,sign::Bool) where {N, M, L1<:HCAT{N}} - if xB[1] in xA +function Usum_op(xA::NTuple{N,Variable}, xB::Tuple{Variable}, A::HCAT{N}, B::AbstractOperator, sign::Bool) where {N} + if xB[1] in xA idx = findfirst(xA.==Ref(xB[1])) S = sign ? A[idx]+B : A[idx]-B - xNew = xA + xNew = xA opNew = hcat(A[1:idx-1],S,A[idx+1:N] ) - else + else xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - end - return xNew, opNew + end + return xNew, opNew end #unsigned sum: AbstractOperator+HCAT -function Usum_op(xA::Tuple{Variable}, - xB::NTuple{N,Variable}, - A::AbstractOperator, - B::L2,sign::Bool) where {N, M, L2<:HCAT{N}} - if xA[1] in xB +function Usum_op(xA::Tuple{Variable}, xB::NTuple{N,Variable}, A::AbstractOperator, B::HCAT{N}, sign::Bool) where {N} + if xA[1] in xB idx = findfirst(xA.==Ref(xB[1])) S = sign ? A+B[idx] : B[idx]-A - xNew = xB + xNew = xB opNew = sign ? hcat(B[1:idx-1],S,B[idx+1:N] ) : -hcat(B[1:idx-1],S,B[idx+1:N] ) - else + else xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - end + end - return xNew, opNew + return xNew, opNew end #unsigned sum: HCAT+HCAT -function Usum_op(xA::NTuple{NA,Variable}, - xB::NTuple{NB,Variable}, - A::L1, - B::L2,sign::Bool) where {NA,NB,M, - L1<:HCAT{NB}, - L2<:HCAT{NB} } - xNew = xA - opNew = A - for i in eachindex(xB) - xNew, opNew = Usum_op(xNew, (xB[i],), opNew, B[i], sign) - end +function Usum_op(xA::NTuple{NA,Variable}, xB::NTuple{NB,Variable}, A::HCAT{NA}, B::HCAT{NB}, sign::Bool) where {NA,NB} + xNew = xA + opNew = A + for i in eachindex(xB) + xNew, opNew = Usum_op(xNew, (xB[i],), opNew, B[i], sign) + end return xNew,opNew end #unsigned sum: multivar AbstractOperator + AbstractOperator -function Usum_op(xA::NTuple{N,Variable}, - xB::Tuple{Variable}, - A::AbstractOperator, - B::AbstractOperator,sign::Bool) where {N} - if xB[1] in xA - Z = Zeros(A) #this will be an HCAT +function Usum_op( + xA::NTuple{N,Variable}, xB::Tuple{Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool +) where {N} + if xB[1] in xA + Z = Zeros(A) #this will be an HCAT xNew, opNew = Usum_op(xA,xB,Z,B,sign) - opNew += A - else + opNew += A + else xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - end - return xNew, opNew + end + return xNew, opNew +end + +function Usum_op( + xA::Tuple{Variable}, xB::NTuple{N,Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool +) where {N} + if xA[1] in xB + Z = Zeros(B) #this will be an HCAT + xNew, opNew = Usum_op(xA,xB,A,Z,sign) + opNew += B + else + xNew = (xA...,xB...) + opNew = sign ? hcat(A,B) : hcat(A,-B) + end + return xNew, opNew end """ - +(ex::AbstractExpression, b::Union{AbstractArray,Number}) + +(ex::AbstractExpression, b::Union{AbstractArray,Number}) Add a scalar or an `Array` to an expression: @@ -175,7 +174,7 @@ julia> b = randn(10); julia> size(b), eltype(b) ((10,), Float64) -julia> size(affine(ex),1), codomainType(affine(ex)) +julia> size(affine(ex),1), codomain_type(affine(ex)) ((10,), Float64) julia> ex + b @@ -185,19 +184,19 @@ julia> ex + b """ function (+)(a::AbstractExpression, b::Union{AbstractArray,Number}) A = convert(Expression,a) - return Expression{length(A.x)}(A.x,AffineAdd(affine(A),b)) + return Expression(A.x,AffineAdd(affine(A),b)) end (+)(a::Union{AbstractArray,Number}, b::AbstractExpression) = b+a function (-)(a::AbstractExpression, b::Union{AbstractArray,Number}) A = convert(Expression,a) - return Expression{length(A.x)}(A.x,AffineAdd(affine(A),b,false)) + return Expression(A.x,AffineAdd(affine(A),b,false)) end function (-)(a::Union{AbstractArray,Number}, b::AbstractExpression) B = convert(Expression,b) - return Expression{length(B.x)}(B.x,-AffineAdd(affine(B),a)) + return Expression(B.x,-AffineAdd(affine(B),a)) end # sum with array/scalar @@ -208,14 +207,14 @@ function Broadcast.broadcasted(::typeof(+),a::AbstractExpression, b::AbstractExp B = convert(Expression,b) if size(affine(A),1) != size(affine(B),1) if prod(size(affine(A),1)) > prod(size(affine(B),1)) - B = Expression{length(B.x)}(variables(B), + B = Expression(variables(B), BroadCast(affine(B),size(affine(A),1))) elseif prod(size(affine(B),1)) > prod(size(affine(A),1)) - A = Expression{length(A.x)}(variables(A), + A = Expression(variables(A), BroadCast(affine(A),size(affine(B),1))) - end + end return A+B - end + end return A+B end @@ -224,13 +223,13 @@ function Broadcast.broadcasted(::typeof(-),a::AbstractExpression, b::AbstractExp B = convert(Expression,b) if size(affine(A),1) != size(affine(B),1) if prod(size(affine(A),1)) > prod(size(affine(B),1)) - B = Expression{length(B.x)}(variables(B), + B = Expression(variables(B), BroadCast(affine(B),size(affine(A),1))) elseif prod(size(affine(B),1)) > prod(size(affine(A),1)) - A = Expression{length(A.x)}(variables(A), + A = Expression(variables(A), BroadCast(affine(A),size(affine(B),1))) - end + end return A-B - end + end return A-B end diff --git a/src/syntax/expressions/addition_tricky_part.jl b/src/syntax/expressions/addition_tricky_part.jl new file mode 100644 index 0000000..dc7bbca --- /dev/null +++ b/src/syntax/expressions/addition_tricky_part.jl @@ -0,0 +1,189 @@ +using Base.Iterators: flatten +abstract type OpStructure end + +struct HCatStructure{N} <: OpStructure + op::AbstractOperators.AbstractOperator + structure::NTuple{N,Any} +end + +struct SumStructure{N} <: OpStructure + op::AbstractOperators.AbstractOperator + structure::NTuple{N,Any} +end + +function get_structure(op::AbstractOperators.HCAT, vars) + if length(op.A) == AbstractOperators.ndoms(op, 2) # this is the deepest or only HCAT operator + return HCatStructure(op, vars) + else # there are more nested HCAT operators, let's recurse! + result = () + var_group_counter = 1 + for suboperator in op.A + subvars = vars[var_group_counter:var_group_counter+AbstractOperators.ndoms(suboperator, 2)-1] + if AbstractOperators.ndoms(suboperator, 2) == 1 + returned = subvars + else + returned = get_structure(suboperator, subvars) + @assert returned !== nothing + end + if returned isa Tuple + result = (result..., returned...) + else + result = (result..., returned) + end + var_group_counter += AbstractOperators.ndoms(suboperator, 2) + end + return HCatStructure(op, result) + end +end + +function get_structure(op::AbstractOperators.Sum, vars) + return SumStructure(op, tuple((get_structure(suboperator, vars) for suboperator in op.A)...)) +end + +function get_structure(op, vars) + if op isa AbstractOperators.AbstractOperator && AbstractOperators.ndoms(op, 2) == 1 + return SumStructure(op, vars) + else + for k in 1:fieldcount(typeof(op)) + value = getfield(op, k) + if value isa AbstractOperators.AbstractOperator + return get_structure(value, vars) + elseif value isa Tuple + for v in value + return get_structure(v, vars) + end + end + end + @assert false "This should never happen" + end +end + +function deep_flatten(structure::HCatStructure) + result = () + for item in structure.structure + if isa(item, OpStructure) + sub_flattened = deep_flatten(item) + if sub_flattened === nothing + return nothing + end + result = tuple(result..., sub_flattened...) + else + result = tuple(result..., item) + end + end + return result +end + +function deep_flatten(structure::SumStructure) + nested_structures = tuple((deep_flatten(item) for item in structure.structure)...) + if all(==(nested_structures[1]), nested_structures) + return nested_structures[1] + else + return nothing + end +end + +struct UnregularIndex{N} + max::NTuple{N, Int} + UnregularIndex(max) = any(max .< 1) ? error("max must be >= 1") : new{length(max)}(tuple(max...)) +end + +Base.first(iter::UnregularIndex) = tuple(fill(1, length(iter.max))...) +Base.length(iter::UnregularIndex) = sum(iter.max) + +function Base.iterate(iter::UnregularIndex) + state = first(iter) + return state, state +end + +function Base.iterate(iter::UnregularIndex{N}, state::NTuple{N, Int}) where {N} + if state == iter.max + return nothing + end + currentdim = findfirst(i -> state[i] != iter.max[i], 1:N) + nextstate = tuple((j < currentdim ? 1 : (j == currentdim ? state[j]+1 : state[j]) for j in 1:N)...) + return nextstate, nextstate +end + +get_structure_only(str) = str isa OpStructure ? tuple((get_structure_only(item) for item in str.structure)...) : str + +Base.length(str::OpStructure) = length(str.structure) +Base.getindex(str::OpStructure, i) = str.structure[i] + +permute_structure(str, perm) = tuple((str[i][perm[i]] for i in eachindex(str))...) + +function compute_permutations(st) + result = () + for perm in UnregularIndex(length.(st)) + result = (result..., permute_structure(st, perm)) + end + return result +end + +function get_all_permutations(structure::SumStructure) + product = [get_all_permutations(item) for item in structure.structure] + return tuple((SumStructure(structure.op, st) for st in compute_permutations(product))...) +end + +function get_all_permutations(structure::HCatStructure) + nested_perms = [isa(item, Int) ? (item,) : get_all_permutations(item) for item in structure.structure] + product = compute_permutations(nested_perms) + combinations = flatten(permutations(p) for p in product) + return tuple((HCatStructure(structure.op, tuple(p...)) for p in combinations)...) +end + +function find_feasible_permutation(vars, stA, stB) + stA_perms = get_all_permutations(stA) + stB_perms = get_all_permutations(stB) + stA_pairs = filter(pair -> pair[2] !== nothing, [(s, deep_flatten(s)) for s in stA_perms]) + stB_pairs = filter(pair -> pair[2] !== nothing, [(s, deep_flatten(s)) for s in stB_perms]) + for vars_perm in permutations(vars) + vars_perm = tuple(vars_perm...) + stA_perm = findfirst(pair -> pair[2] == vars_perm, stA_pairs) + if stA_perm === nothing + continue + end + stB_perm = findfirst(pair -> pair[2] == vars_perm, stB_pairs) + if stB_perm === nothing + continue + end + return vars_perm + end + return nothing +end + +function add_missing_vars(old_vars, op, vars) + missing_vars = setdiff(vars, old_vars) + if isempty(missing_vars) + return old_vars, op + end + dummy_ops = [AbstractOperators.Zeros(eltype(~var), size(~var), AbstractOperators.codomain_type(op), size(op, 1)) for var in missing_vars] + new_vars = (old_vars..., missing_vars...) + new_op = AbstractOperators.HCAT(op, dummy_ops...) + return new_vars, new_op +end + +function Usum_op( + xA::NTuple{N,Variable}, xB::NTuple{M,Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool +) where {N,M} + xNew = tuple(unique((xA...,xB...))...) + xA, A = add_missing_vars(xA, A, xNew) + xB, B = add_missing_vars(xB, B, xNew) + vars_index = tuple((i for i in eachindex(xNew))...) + xA_index = tuple((findfirst(==(x), xNew) for x in xA)...) + xB_index = tuple((findfirst(==(x), xNew) for x in xB)...) + structureA = get_structure(A, xA_index) + structureB = get_structure(B, xB_index) + var_perm = find_feasible_permutation(vars_index, structureA, structureB) + if var_perm === nothing + error("No feasible permutation found") + end + if var_perm != xA_index + A = AbstractOperators.permute(A, invperm([xA_index...])) + end + if var_perm != xB_index + B = AbstractOperators.permute(B, invperm([xB_index...])) + end + opNew = sign ? A+B : A-B + return xNew, opNew +end diff --git a/src/syntax/expressions/expression.jl b/src/syntax/expressions/expression.jl index 08d1f53..5d3fad4 100644 --- a/src/syntax/expressions/expression.jl +++ b/src/syntax/expressions/expression.jl @@ -1,7 +1,7 @@ struct Expression{N,A<:AbstractOperator} <: AbstractExpression x::NTuple{N,Variable} L::A - function Expression{N}(x::NTuple{N,Variable}, L::A) where {N,A<:AbstractOperator} + function Expression(x::NTuple{N,Variable}, L::A) where {N,A<:AbstractOperator} # checks on L ndoms(L,1) > 1 && throw(ArgumentError( "Cannot create expression with LinearOperator with `ndoms(L,1) > 1`" @@ -13,11 +13,11 @@ struct Expression{N,A<:AbstractOperator} <: AbstractExpression check_sz && throw(ArgumentError( "Size of the operator domain $(size(L, 2)) must match size of the variable $(size.(x))" )) - dmL = domainType(L) + dmL = domain_type(L) dmx = eltype.(x) check_dm = length(dmx) == 1 ? dmx[1] != dmL : dmx != dmL check_dm && throw(ArgumentError( - "Type of the operator domain $(domainType(L)) must match type of the variable $(eltype.(x))" + "Type of the operator domain $(domain_type(L)) must match type of the variable $(eltype.(x))" )) new{N,A}(x,L) end @@ -27,12 +27,21 @@ struct AdjointExpression{E <: AbstractExpression} <: AbstractExpression ex::E end -import Base: adjoint +import Base: adjoint, show adjoint(ex::AbstractExpression) = AdjointExpression(convert(Expression,ex)) adjoint(ex::AdjointExpression) = ex.ex +function show(io::IO, ex::Expression) + if length(ex.x) == 1 + print(io, AbstractOperators.fun_name(ex.L), " * ", ex.x[1]) + else + print(io, AbstractOperators.fun_name(ex.L), " * (", join(ex.x, ", "), ")") + end +end + include("utils.jl") include("multiplication.jl") include("addition.jl") +include("addition_tricky_part.jl") include("abstractOperator_bind.jl") diff --git a/src/syntax/expressions/multiplication.jl b/src/syntax/expressions/multiplication.jl index a99f84f..5658422 100644 --- a/src/syntax/expressions/multiplication.jl +++ b/src/syntax/expressions/multiplication.jl @@ -27,7 +27,7 @@ julia> affine(ex2) """ function (*)(L::AbstractOperator, a::AbstractExpression) A = convert(Expression,a) - Expression{length(A.x)}(A.x,L*affine(A)) + Expression(A.x,L*affine(A)) end """ @@ -71,21 +71,21 @@ julia> randn(10,5).*X """ function (*)(m::T, a::Union{AbstractVector,AbstractMatrix}) where {T<:AbstractExpression} M = convert(Expression,m) - op = LMatrixOp(codomainType(affine(M)),size(affine(M),1),a) + op = LMatrixOp(codomain_type(affine(M)),size(affine(M),1),a) return op*M end #LMatrixOp function (*)(M::AbstractMatrix, a::T) where {T<:AbstractExpression} A = convert(Expression,a) - op = MatrixOp(codomainType(affine(A)),size(affine(A),1),M) + op = MatrixOp(codomain_type(affine(A)),size(affine(A),1),M) return op*A end #MatrixOp function Broadcast.broadcasted(::typeof(*), d::D, a::T) where {D <: Union{Number,AbstractArray}, T<:AbstractExpression} A = convert(Expression,a) - op = DiagOp(codomainType(affine(A)),size(affine(A),1),d) + op = DiagOp(codomain_type(affine(A)),size(affine(A),1),d) return op*A end Broadcast.broadcasted(::typeof(*), a::T, d::D) where {D <: Union{Number,AbstractArray}, T<:AbstractExpression} = @@ -94,7 +94,7 @@ d.*a function (*)(coeff::T1, a::T) where {T1<:Number, T<:AbstractExpression} A = convert(Expression,a) - return Expression{length(A.x)}(A.x,coeff*affine(A)) + return Expression(A.x,coeff*affine(A)) end (*)(a::T, coeff::T1) where {T1<:Number, T<:AbstractExpression} = coeff*a ##Scale @@ -132,7 +132,7 @@ function (*)(ex1::AbstractExpression, ex2::AbstractExpression) A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = Ax_mul_Bx(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Ax_mul_Bx @@ -144,7 +144,7 @@ function (*)(ex1::AdjointExpression, ex2::AbstractExpression) A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = Axt_mul_Bx(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Axt_mul_Bx @@ -156,7 +156,7 @@ function (*)(ex1::AbstractExpression, ex2::AdjointExpression) A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = Ax_mul_Bxt(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Ax_mul_Bxt @@ -168,7 +168,7 @@ function Broadcast.broadcasted(::typeof(*), ex1::AbstractExpression, ex2::Abstra A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = HadamardProd(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Hadamard diff --git a/src/syntax/expressions/utils.jl b/src/syntax/expressions/utils.jl index 7c0af76..69f11b2 100644 --- a/src/syntax/expressions/utils.jl +++ b/src/syntax/expressions/utils.jl @@ -4,7 +4,7 @@ import Base: convert import AbstractOperators: displacement convert(::Type{Expression},x::Variable{T,N,A}) where {T,N,A} = -Expression{1}((x,),Eye(T,size(x))) +Expression((x,),Eye(T,size(x))) """ variables(ex::Expression) diff --git a/src/syntax/problem.jl b/src/syntax/problem.jl deleted file mode 100644 index 4387ddd..0000000 --- a/src/syntax/problem.jl +++ /dev/null @@ -1,28 +0,0 @@ -export problem - -""" - problems(terms...) - -Constructs a problem. - -# Example - -```julia - -julia> x = Variable(4) -Variable(Float64, (4,)) - -julia> A, b = randn(10,4), randn(10); - -julia> p = problem(ls(A*x-b), norm(x) <= 1) - -``` - -""" -function problem(terms::Vararg) - cf = () - for i = 1:length(terms) - cf = (cf...,terms[i]...) - end - return cf -end diff --git a/src/syntax/syntax.jl b/src/syntax/syntax.jl deleted file mode 100644 index 514514b..0000000 --- a/src/syntax/syntax.jl +++ /dev/null @@ -1,8 +0,0 @@ -abstract type AbstractExpression end - -include("variable.jl") -include("expressions/expression.jl") -include("terms/term.jl") -include("problem.jl") - -const TermOrExpr = Union{Term,AbstractExpression} diff --git a/src/syntax/terms/proximalOperators_bind.jl b/src/syntax/terms/proximalOperators_bind.jl index c507638..fb5aa28 100644 --- a/src/syntax/terms/proximalOperators_bind.jl +++ b/src/syntax/terms/proximalOperators_bind.jl @@ -4,7 +4,7 @@ import LinearAlgebra: norm export norm """ - norm(x::AbstractExpression, p=2, [q,] [dim=1]) + norm(x::AbstractExpression, p=2, [q]; [dim=1]) Returns the norm of `x`. @@ -48,9 +48,9 @@ function norm(ex::AbstractExpression, ::typeof(*)) end # Mixed Norm -function norm(ex::AbstractExpression, p1::Int, p2::Int, dim::Int = 1 ) +function norm(ex::AbstractExpression, p1::Int, p2::Int; dim::Int = 1) if p1 == 2 && p2 == 1 - f = NormL21(1.0,dim) + f = NormL21(1.0, dim) else error("function not implemented") end @@ -59,21 +59,44 @@ end # Least square terms -export ls +export ls, normalop_ls """ ls(x::AbstractExpression) Returns the squared norm (least squares) of `x`: - ```math f (\\mathbf{x}) = \\frac{1}{2} \\| \\mathbf{x} \\|^2 ``` - (shorthand of `1/2*norm(x)^2`). """ ls(ex) = Term(SqrNormL2(), ex) +""" + normalop_ls(x::AbstractExpression) + +Returns the squared norm (least squares) of `L*x`: +```math +f (\\mathbf{L} * \\mathbf{x}) = \\frac{1}{2} \\| \\mathbf{L} * \\mathbf{x} \\|^2 +``` +(shorthand of `1/2*norm(x)^2`). + +The only difference with `ls` comes when gradient! is called. In this case, the +gradient is computed as usual, but the squared norm of the gradient (i.e. the +squared norm of `Lᴴ * L * x`) is returned instead of the squared norm of `L * x`. +This is much faster to compute, if `Lᴴ * L` has a fast implementation. +""" + +normalop_ls(::Variable) = error("normalop_ls does not work with Variables alone. Use ls instead.") +function normalop_ls(ex::Expression) + eye_op = if length(ex.x) == 1 + Eye(domain_type(ex.L), size(ex.L, 2)) + else + HCAT([Eye(domain_type(L), size(L, 2)) for L in ex.L]...) + end + return Term(SqrNormL2WithNormalOp(ex.L), Expression(ex.x, eye_op)) +end + import Base: ^ function (^)(t::Term{T1,T2,T3}, exp::Integer) where {T1, T2 <: NormL2, T3} @@ -138,13 +161,12 @@ Term(CrossEntropy(b), ex) export logisticloss """ - logbarrier(x::AbstractExpression, y::AbstractArray) + logisticloss(x::AbstractExpression, y::Array) Applies the logistic loss function: ```math -f(\\mathbf{x}) = \\sum_{i} \\log(1+ \\exp(-y_i x_i)), +f(\\mathbf{x}) = \\sum_i \\log(1 + \\exp(-y_i x_i)). ``` -where `y` is an array containing ``y_i``. """ logisticloss(ex::AbstractExpression, y::AbstractArray) = Term(LogisticLoss(y, 1.0), ex) diff --git a/src/syntax/terms/term.jl b/src/syntax/terms/term.jl index 2e42973..d72f4b6 100644 --- a/src/syntax/terms/term.jl +++ b/src/syntax/terms/term.jl @@ -1,13 +1,87 @@ -struct Term{T1 <: Real, T2, T3 <: AbstractExpression} - lambda::T1 - f::T2 - A::T3 - Term(lambda::T1, f::T2, ex::T3) where {T1,T2,T3} = new{T1,T2,T3}(lambda,f,ex) +struct Term{T1<:Real,T2,T3<:AbstractExpression} + lambda::T1 + f::T2 + A::T3 + repr::Union{String,Nothing} + function Term(lambda::T1, f::T2, A::T3, repr::Union{String,Nothing}) where {T1<:Real,T2,T3<:AbstractExpression} + T1_ = real(codomain_type(affine(A))) + lambda = convert(T1_, lambda) + return new{T1_,T2,T3}(lambda, f, A, repr) + end +end + +function Term(lambda, f, ex::AbstractExpression) + return Term(lambda, f, ex, nothing) end function Term(f, ex::AbstractExpression) - A = convert(Expression,ex) - Term(1,f, A) + A = convert(Expression, ex) + Term(1, f, A) +end + +function Term(f, ex::AbstractExpression, repr::String) + A = convert(Expression, ex) + Term(1, f, A, repr) +end + +function Term(t::Term, repr::String) + Term(t.lambda, t.f, t.A, repr) +end + +struct TermSet{N,T} + terms::T + function TermSet(terms...) + @assert all(t -> t isa Term, terms) "All elements must be of type Term" + new{length(terms), typeof(terms)}(terms) + end +end + +function Base.iterate(t::TermSet{N}, state=1) where {N} + if state > N + return nothing + else + return (t.terms[state], state + 1) + end +end + +Base.length(::TermSet{N}) where {N} = N +Base.getindex(t::TermSet{N}, i::Int) where {N} = t.terms[i] + +Term(t::TermSet, ::String) = t + +import Base: ==, show + +# Ignore the repr when comparing terms +==(t1::Term, t2::Term) = t1.lambda == t2.lambda && t1.f == t2.f && t1.A == t2.A + +function show(io::IO, t::Term) + if t.repr !== nothing + print(io, t.repr) + else + print(io, t.lambda, " * ", t.f, "(", t.A, ")") + end +end + +function show(io::IO, t::TermSet) + non_indicator_terms = filter(x -> !is_set_indicator(x), t.terms) + indicator_terms = filter(is_set_indicator, t.terms) + for i in 1:length(non_indicator_terms) + show(io, non_indicator_terms[i]) + if i < length(non_indicator_terms) + print(io, " + ") + end + end + if !isempty(indicator_terms) + if !isempty(non_indicator_terms) + print(io, " s.t. ") + end + for i in 1:length(indicator_terms) + show(io, indicator_terms[i]) + if i < length(indicator_terms) + print(io, ", ") + end + end + end end # Operations @@ -16,24 +90,22 @@ end import Base: + -(+)(a::Term,b::Term) = (a,b) -(+)(a::NTuple{N,Term},b::Term) where {N} = (a...,b) -(+)(a::Term,b::NTuple{N,Term}) where {N} = (a,b...) -(+)(a::NTuple{N,Term},b::Tuple{}) where {N} = a -(+)(a::Tuple{},b::NTuple{N,Term}) where {N} = b -(+)(a::NTuple{N,Term},b::NTuple{M,Term}) where {N,M} = (a...,b...) +(+)(a::Term, b::Term) = TermSet(a, b) +(+)(a::TermSet, b::Term) = TermSet(a..., b) +(+)(a::Term, b::TermSet) = TermSet(a, b...) +(+)(a::TermSet, b::TermSet) = TermSet(a..., b...) # Define multiplication by constant import Base: * -function (*)(a::T1, t::Term{T,T2,T3}) where {T1<:Real, T, T2, T3} - coeff = *(promote(a,t.lambda)...) - Term(coeff, t.f, t.A) +function (*)(a::T1, t::Term{T,T2,T3}) where {T1<:Real,T,T2,T3} + coeff = *(promote(a, t.lambda)...) + Term(coeff, t.f, t.A) end -function (*)(a::T1, t::T2) where {T1<:Real, N, T2 <: Tuple{Vararg{<:Term,N}} } - return a.*t +function (*)(a::T1, t::TermSet) where {T1<:Real} + return a .* t end # Properties @@ -44,49 +116,64 @@ affine(t::Term) = affine(t.A) displacement(t::Term) = displacement(t.A) #importing properties from ProximalOperators -import ProximalOperators: - is_affine, - is_cone, - is_convex, - is_generalized_quadratic, - is_prox_accurate, - is_quadratic, - is_separable, - is_set, - is_singleton, - is_smooth, - is_strongly_convex +import ProximalCore: + is_affine_indicator, + is_cone_indicator, + is_convex, + is_generalized_quadratic, + is_proximable, + is_quadratic, + is_separable, + is_set_indicator, + is_singleton_indicator, + is_smooth, + is_locally_smooth, + is_strongly_convex + +is_func_f = [:is_set_indicator, :is_singleton_indicator, :is_smooth, :is_locally_smooth] + +for f in is_func_f + @eval begin + import ProximalCore: $f + $f(t::Term) = $f(t.f) + $f(t::TermSet) = all($f.(t.terms)) + end +end #importing properties from AbstractOperators -is_f = [:is_linear, - :is_eye, - :is_null, - :is_diagonal, - :is_AcA_diagonal, - :is_AAc_diagonal, - :is_orthogonal, - :is_invertible, - :is_full_row_rank, - :is_full_column_rank, - :is_sliced - ] - -for f in is_f - @eval begin - import AbstractOperators: $f - $f(t::Term) = $f(operator(t)) - $f(t::NTuple{N,Term}) where {N} = all($f.(t)) - end +is_op_f = [ + :is_linear, + :is_eye, + :is_null, + :is_diagonal, + :is_AcA_diagonal, + :is_AAc_diagonal, + :is_orthogonal, + :is_invertible, + :is_full_row_rank, + :is_full_column_rank, + :is_sliced, +] + +for f in is_op_f + @eval begin + import AbstractOperators: $f + $f(t::Term) = $f(operator(t)) + $f(t::TermSet) = all($f.(t)) + end end -is_smooth(t::Term) = is_smooth(t.f) -is_convex(t::Term) = is_convex(t.f) && is_linear(t) +is_affine_indicator(t::Term) = is_affine_indicator(t.f) && is_linear(t) +is_cone_indicator(t::Term) = is_cone_indicator(t.f) && is_linear(t) +is_convex(t::Term) = is_convex(t.f) && is_linear(t) is_quadratic(t::Term) = is_quadratic(t.f) && is_linear(t) +is_generalized_quadratic(t::Term) = is_generalized_quadratic(t.f) && is_linear(t) is_strongly_convex(t::Term) = is_strongly_convex(t.f) && is_full_column_rank(operator(t.A)) +is_separable(t::Term) = is_separable(t.f) && is_diagonal(operator(t.A)) include("proximalOperators_bind.jl") # other stuff, to make Term work with iterators import Base: iterate, isempty -iterate(t::Term, state = true) = state ? (t, false) : nothing -isempty(t::Term) = false +iterate(t::Term, state=true) = state ? (t, false) : nothing +isempty(t::Term) = false diff --git a/src/syntax/variable.jl b/src/syntax/variable.jl index d5ede3f..159c698 100644 --- a/src/syntax/variable.jl +++ b/src/syntax/variable.jl @@ -1,36 +1,38 @@ import Base: convert, size, eltype, ~ -export Variable +export Variable, get_name struct Variable{T, N, A <: AbstractArray{T,N}} <: AbstractExpression x::A + name::String + function Variable(x::AbstractArray{T,N}; name::String="x") where {T,N} + A = typeof(x) + new{T,N,A}(x, name) + end end # constructors """ - Variable([T::Type,] dims...) + Variable([T::Type,] dims...; name::String="x") + Variable(x::AbstractArray; name::String="x") -Returns a `Variable` of dimension `dims` initialized with an array of all zeros. - -`Variable(x::AbstractArray)` - -Returns a `Variable` of dimension `size(x)` initialized with `x` +Creates an optimization variable of type `T` and dimensions `dims...`, or from the provided array `x`. +The optional `name` argument allows to specify a name for the variable, which is useful for display purposes. """ -function Variable(T::Type, args::Vararg{I,N}) where {I <: Integer,N} - Variable{T,N,Array{T,N}}(zeros(T, args...)) +function Variable(T::Type, args::Int...; name::String="x") + Variable(zeros(T, args...); name) end -function Variable(args::Vararg{I}) where {I <: Integer} - Variable(zeros(args...)) +function Variable(args::Int...; name::String="x") + Variable(zeros(args...); name) end # Utils function Base.show(io::IO, x::Variable) - print(io, "Variable($(eltype(x.x)), $(size(x.x)))") + print(io, "Variable($(eltype(x.x)), $(size(x.x)), \"$(x.name)\")") end - """ ~(x::Variable) @@ -46,7 +48,7 @@ size(x::Variable, [dim...]) Like `size(A::AbstractArray, [dims...])` returns the tuple containing the dimensions of the variable `x`. """ size(x::Variable) = size(x.x) -size(x::Variable, dim::I) where { I <: Integer} = size(x.x, dim) +size(x::Variable, dim::Integer) = size(x.x, dim) """ eltype(x::Variable) @@ -54,3 +56,10 @@ eltype(x::Variable) Like `eltype(x::AbstractArray)` returns the type of the elements of `x`. """ eltype(x::Variable) = eltype(x.x) + +""" +get_name(x::Variable) + +Returns the name of the variable `x`. If no name was provided at construction, returns `"x"`. +""" +get_name(x::Variable) = x.name diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..6d0eb10 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,43 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" +DSPOperators = "d5a72628-6e2f-430e-82f5-561df0bb8116" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FFTWOperators = "c59a084b-ba08-4f3f-af9e-f4298d6caa94" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WaveletOperators = "f3582904-6f60-4bbd-985d-55eab799bc9d" +AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" +ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +StructuredOptimization = "46cd3e9d-64ff-517d-a929-236bc1a1fc9d" + +[compat] +Aqua = "0.8" +DSP = "0.5.1 - 0.8" +DSPOperators = "0.1" +FFTW = "1" +FFTWOperators = "0.1" +LinearAlgebra = "1" +Random = "1" +Test = "1" +WaveletOperators = "0.1" +AbstractOperators = "0.4" +ProximalAlgorithms = "0.8" +ProximalCore = "0.2" +ProximalOperators = "0.17" +RecursiveArrayTools = "1 - 3" + +[sources] +StructuredOptimization = { path = "../" } +AbstractOperators = { path = "../../AbstractOperators" } +ProximalAlgorithms = { path = "../../ProximalAlgorithms.jl" } +ProximalCore = { path = "../../ProximalCore.jl" } +ProximalOperators = { path = "../../ProximalOperators.jl" } +DSPOperators = { path = "../../AbstractOperators/DSPOperators" } +WaveletOperators = { path = "../../AbstractOperators/WaveletOperators" } +FFTWOperators = { path = "../../AbstractOperators/FFTWOperators" } + diff --git a/test/runtests.jl b/test/runtests.jl index b6731bd..5e5a0d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,35 +1,52 @@ using StructuredOptimization -using AbstractOperators +using AbstractOperators, DSPOperators, FFTWOperators using ProximalOperators using ProximalAlgorithms using RecursiveArrayTools using LinearAlgebra, Random using DSP, FFTW using Test +using Aqua Random.seed!(0) @testset "StructuredOptimization" begin + @testset "Calculus" begin + include("test_proxstuff.jl") + end -@testset "Calculus" begin - include("test_proxstuff.jl") -end + @testset "Syntax" begin + include("test_variables.jl") + include("test_expressions.jl") + include("test_AbstractOp_binding.jl") + include("test_terms.jl") + end -@testset "Syntax" begin - include("test_variables.jl") - include("test_expressions.jl") - include("test_AbstractOp_binding.jl") - include("test_terms.jl") -end + @testset "Problem construction" begin + include("test_problem.jl") + include("test_build_minimize.jl") + end -@testset "Problem construction" begin - include("test_problem.jl") - include("test_build_minimize.jl") -end - -@testset "End-to-end tests" begin - include("test_usage_small.jl") - include("test_usage.jl") -end + @testset "End-to-end tests" begin + include("test_usage_small.jl") + include("test_usage.jl") + end + @testset "Aqua" begin + Aqua.test_all(StructuredOptimization; ambiguities=false, piracies=false) + Aqua.test_ambiguities( + StructuredOptimization; exclude=[Base.:(+), Base.:<=, Base.:>=], broken=true + ) + Aqua.test_piracies( + StructuredOptimization; + treat_as_own=[ + ProximalAlgorithms.value_and_gradient, + ProximalAlgorithms.value_and_gradient!, + ProximalOperators.prox, + ProximalOperators.prox!, + ProximalOperators.gradient, + ProximalOperators.gradient!, + ], + ) + end end diff --git a/test/test_AbstractOp_binding.jl b/test/test_AbstractOp_binding.jl index 0ebb192..9f5ab96 100644 --- a/test/test_AbstractOp_binding.jl +++ b/test/test_AbstractOp_binding.jl @@ -52,15 +52,15 @@ ex = x[1:2] # DFT n = 5 -op = AbstractOperators.DFT(Float64,(n,)) +op = DFT(Float64,(n,)) x = Variable(randn(n)) ex = fft(x) @test norm(operator(ex)*(~x)-op*(~x)) <1e-12 # IDFT n = 5 -op = IDFT(Float64,(n,)) -x = Variable(randn(n)) +op = IDFT(ComplexF64,(n,)) +x = Variable(randn(ComplexF64, n)) ex = ifft(x) @test norm(operator(ex)*(~x)-op*(~x)) <1e-12 diff --git a/test/test_build_minimize.jl b/test/test_build_minimize.jl index 828e221..510f7f7 100644 --- a/test/test_build_minimize.jl +++ b/test/test_build_minimize.jl @@ -1,4 +1,4 @@ -using ProximalAlgorithms +using ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus x = Variable(10) A = randn(5, 10) @@ -42,15 +42,18 @@ xp = copy(~x) @test norm(xp-xpg) <= 1e-4 # test nonconvex Rosenbrock function with known minimum -solvers = [ZeroFPR(tol = 1e-6), PANOC(tol = 1e-6)] -for solver in solvers - x = Variable(1) - y = Variable(1) - a,b = 2.0, 100.0 +function test_solver(solver) + x = Variable(1) + y = Variable(1) + a, b = 2.0, 100.0 - cf = norm(x-a)^2+b*norm(pow(x,2)-y)^2 - @minimize cf+1e-10*norm(x,1)+1e-10*norm(y,1) with solver + cf = norm(x - a)^2 + b * norm(pow(x, 2) - y)^2 + @minimize cf + 1e-10 * norm(x, 1) + 1e-10 * norm(y, 1) with solver - @test norm(~x-[a]) < 1e-4 - @test norm(~y-[a^2]) < 1e-4 + @test norm(~x - [a]) < 1e-4 + @test norm(~y - [a^2]) < 1e-4 +end +solvers = [ZeroFPR(; tol=1e-6), PANOC(; tol=1e-6), PANOCplus(; tol=1e-6)] +for solver in solvers + test_solver(solver) end diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 890786d..b56b18d 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -191,43 +191,26 @@ ex3 = ex1+ex2 n = 3 b = randn(n) -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.+x2 @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.+(x2+2) @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 @test displacement(ex1) == 2 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = (x1+2).+(x2+b) @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 @test displacement(ex1) == (b.+2) -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.+x2 -@test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.+(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 -@test displacement(ex1) == 2 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = (x1+b).+(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 -@test displacement(ex1) == (b.+2) - n,m =2,4 x1 = Variable(randn(n,m)) -x2 = Variable(randn(1,m)) +x2 = Variable(randn(n,m)) ex1 = x1.+x2+6 @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 @test displacement(ex1) == 6 @@ -237,43 +220,26 @@ ex1 = x1.+x2+6 n = 3 b = randn(n) -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.-x2 @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.-(x2+2) @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 @test displacement(ex1) == -2 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = (x1+2).-(x2+b) @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 @test displacement(ex1) == (2 .-b) -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.-x2 -@test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.-(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 -@test displacement(ex1) == -2 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = (x1+b).-(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 -@test displacement(ex1) == (b.-2) - n,m =2,4 x1 = Variable(randn(n,m)) -x2 = Variable(randn(1,m)) +x2 = Variable(randn(n,m)) ex1 = x1.-x2+6 @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 @test displacement(ex1) == 6 @@ -316,3 +282,13 @@ ex3 = ex1-ex2 @test_throws DimensionMismatch MatrixOp(randn(10,20))*Variable(20)+randn(11) @test_throws ErrorException MatrixOp(randn(10,20))*Variable(20)+(3+im) +# Advanced (+) sum +x, y, z, w = Variable(rand(10)), Variable(rand(20)), Variable(rand(30)), Variable(rand(40)) +A = randn(10,10) +exA = (z[1:10]+x)+3*(x+z[1:10])+A*(w[1:10]+z[1:10])+(z[1:10]+w[1:10]) +exB = 5*w[1:10]+z[1:10]+z[1:10]+3*y[1:10]+z[1:10] +exC = exA+exB +op = operator(exC) +output = op*ArrayPartition(~z,~x,~w,~y) +expected_output = 4*(~x)+3*(~y)[1:10]+8*(~z)[1:10]+6*(~w)[1:10]+A*((~w)[1:10]+(~z)[1:10]) +@test norm(output-expected_output) < 1e-12 diff --git a/test/test_problem.jl b/test/test_problem.jl index 757e0a0..677071a 100644 --- a/test/test_problem.jl +++ b/test/test_problem.jl @@ -105,163 +105,3 @@ V = StructuredOptimization.extract_operators(xAll,cf) @test typeof(V[6][3]) <: Zeros @test typeof(V[6][4]) <: Zeros @test typeof(V[6][5]) <: Eye - -println("\nTesting splitting Terms\n") - -x = Variable(5) -y = Variable(5) -cf = ls(x)+10*norm(x,2)+ls(x+y) - -f, g = StructuredOptimization.split_smooth(cf) -@test f[1] == cf[1] -@test f[2] == cf[3] -@test g[1] == cf[2] - -cf = ls(x) -f, g = StructuredOptimization.split_smooth((cf,)) -@test f == (cf,) -@test g == () - -cf = norm(x,1)+norm(y,2)+norm(randn(5,5)*x+y,Inf) -xAll = StructuredOptimization.extract_variables(cf) -AAc, nonAAc = StructuredOptimization.split_AAc_diagonal(cf) -@test AAc[1] == cf[1] -@test AAc[2] == cf[2] -@test nonAAc[1] == cf[3] - -cf = ls(sigmoid(x)) + ls(x) -fq, fs = StructuredOptimization.split_quadratic(cf) -@test fs[1] == cf[1] -@test fq[1] == cf[2] - -println("\nTesting extracting Proximable functions\n") -# testing is_proximable -@test StructuredOptimization.is_proximable(AAc) == true -@test StructuredOptimization.is_proximable(nonAAc) == false - -cf = norm(x[1:2],1)+norm(x[3:5]) -xAll = StructuredOptimization.extract_variables(cf) - -@test all(StructuredOptimization.is_AAc_diagonal.(cf)) == true -@test StructuredOptimization.is_proximable(cf) == true - -cf = norm(x[1:2],1)+norm(x[3:5])+norm(x,Inf) -xAll = StructuredOptimization.extract_variables(cf) - -@test all(StructuredOptimization.is_AAc_diagonal.(cf)) == true -@test StructuredOptimization.is_proximable(cf) == false - -# testing extract_proximable -# single variable, single term -x = Variable(randn(5)) -b = randn(5) -cf = 10*norm(x-b,1) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true - -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - 10*norm(~x-b,1)) < 1e-12 - -# single variable, single term, diagonal term -x = Variable(randn(5)) -b = randn(5) -d = randn(5) -cf = 10*norm(d.*x-b,1) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true - -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - 10*norm(d.*~x-b,1)) < 1e-12 - -# single variable, single term, tight frame term -x = Variable(randn(5)) -b = randn(5) -d = randn(5) -cf = 10*norm(dct(x)-b,1) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true - -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - 10*norm(dct(~x)-b,1)) < 1e-12 - -# single variable, single term, tight frame term, fft -# TODO this not working (probably fix needed in ProxOp) -#x = Variable(randn(5)) -#b = randn(5) -#d = randn(5) -#cf = 10*norm(fft(x)-b,1) -#xAll = StructuredOptimization.extract_variables(cf) -#@test StructuredOptimization.is_proximable(cf) == true -# -#f = StructuredOptimization.extract_proximable(xAll,cf) -#@test norm(f(~x) - 10*norm(fft(~x)-b,1)) < 1e-12 - -# single variable, multiple terms with GetIndex -x = Variable(randn(5)) -b = randn(2) -cf = 10*norm(x[1:2]-b,1)+norm(x[3:5],2) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - sum([10*norm((~x)[1:2]-b,1);norm((~x)[3:5],2)])) < 1e-12 - -# single variable, multiple terms with GetIndex composed with dct -x = Variable(randn(5)) -b = randn(2) -cf = 10*norm(x[1:2]-b,1)+norm(dct(x[3:5]),2) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - sum([10*norm((~x)[1:2]-b,1);norm(dct((~x)[3:5]),2)])) < 1e-12 - -# multiple variables, multiple terms -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(3)) -b2 = randn(3) - -cf = 10*norm(x2-b2,1)+norm(x1+b1,2) -xAll = (x1,x2) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f.fs[1](~x1)-norm(~x1+b1,2) ) < 1e-12 -@test norm(f.fs[2](~x2)-10*norm(~x2-b2,1) ) < 1e-12 - -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(5)) -b2 = randn(5) - -# TODO fix this? -#cf = 10*norm(x2+x1+b2,1) -#xAll = (x1,x2) -#@test StructuredOptimization.is_proximable(cf) == true -#f = StructuredOptimization.extract_proximable(xAll,cf) -# TODO fix this! in ProxOp? -# @test norm(f((~x1,~x2))-10*norm(~x2+~x1+b2,1) ) < 1e-12 - -# multiple variables, missing terms -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(3)) -b2 = randn(3) - -cf = 10*norm(x2-b2,1) -xAll = (x1,x2) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test f.fs[1](~x1) == 0. -@test norm(f.fs[2](~x2)-10*norm(~x2-b2,1) ) < 1e-12 - -# multiple variables, multiple terms, with GetIndex -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(3)) -b2 = randn(3) - -cf = norm(x1[3:5]+b1[3:5],1)+10*norm(x2-b2,1)+norm(x1[1:2]+b1[1:2],2) -xAll = (x1,x2) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f.fs[1](~x1)-norm((~x1)[1:2]+b1[1:2],2)-norm((~x1)[3:5]+b1[3:5],1) ) < 1e-12 -@test norm(f.fs[2](~x2)-10*norm(~x2-b2,1) ) < 1e-12 diff --git a/test/test_proxstuff.jl b/test/test_proxstuff.jl index c4ce361..d1744d7 100644 --- a/test/test_proxstuff.jl +++ b/test/test_proxstuff.jl @@ -26,8 +26,8 @@ r = randn(l,n2) b = randn(l,n2) G = AffineAdd(Ax_mul_Bx( - HCAT(A,Zeros(codomainType(B), size(B,2), size(A,1) )), - HCAT(Zeros(codomainType(A), size(A,2), size(B,1) ),B) + HCAT(A,Zeros(codomain_type(B), size(B,2), size(A,1) )), + HCAT(Zeros(codomain_type(A), size(A,2), size(B,1) ),B) ), b,false) diff --git a/test/test_terms.jl b/test/test_terms.jl index 6988194..8295ef8 100644 --- a/test/test_terms.jl +++ b/test/test_terms.jl @@ -43,7 +43,7 @@ cf = 3*norm(X,2,1) @test cf.lambda - 3 == 0 @test cf.f(~X) == sum( sqrt.(sum((~X).^2, dims=1 )) ) -cf = 4*norm(X,2,1,2) +cf = 4*norm(X,2,1; dim=2) @test cf.lambda - 4 == 0 @test cf.f(~X) == sum( sqrt.(sum((~X).^2, dims=2 )) ) @@ -175,7 +175,7 @@ end cf = 2*norm(x,1) ccf = conj(cf) @test ccf.A == cf.A -@test ccf.f == Conjugate(Postcompose(NormL1(),2)) +@test ccf.f == Conjugate(Postcompose(NormL1(),2.0)) @test_throws ErrorException conj(norm(randn(2,10)*x,1)) cf = 2*norm(x,1) @@ -192,21 +192,6 @@ cf = ls(x) + 10*norm(x, 1) @test cf[2].lambda == 10 @test cf[2].f(~x) == norm(~x,1) -x = Variable(10) -cf = () #empty cost function -cf += 10*norm(x, 1) -@test length(cf) == 1 -@test cf[1].lambda == 10 -@test cf[1].f(~x) == 10*norm(~x,1) - -x = Variable(10) -cf = () #empty cost function -cf += ls(x) + 10*norm(x, 1) -@test cf[1].lambda == 1 -@test cf[1].f(~x) == 0.5*norm(~x)^2 -@test cf[2].lambda == 10 -@test cf[2].f(~x) == norm(~x,1) - # More complex situations x = Variable(10) @@ -261,5 +246,7 @@ cf = norm(w + z)^2 @test StructuredOptimization.is_AcA_diagonal(cf) == false cf = norm(x, 1) + norm(y, 2) -@test StructuredOptimization.is_smooth.(cf) == (false,false) -@test StructuredOptimization.is_AcA_diagonal.(cf) == (true,true) +@test StructuredOptimization.is_smooth.(cf.terms) == (false,false) +@test StructuredOptimization.is_smooth(cf) == false +@test StructuredOptimization.is_AcA_diagonal.(cf.terms) == (true,true) +@test StructuredOptimization.is_AcA_diagonal(cf) == true diff --git a/test/test_usage.jl b/test/test_usage.jl index 8d5f2b8..1ca344b 100644 --- a/test/test_usage.jl +++ b/test/test_usage.jl @@ -1,3 +1,5 @@ +using ProximalAlgorithms: PANOCplus, FastForwardBackward, ZeroFPR, PANOC + Random.seed!(0) ################################################################################ @@ -5,7 +7,6 @@ Random.seed!(0) ################################################################################ println("Testing: regularized least squares, with two variable blocks to make things weird") - m, n1, n2 = 30, 50, 100 A1 = randn(m, n1) @@ -17,55 +18,37 @@ lam2 = 1.0 # Solve with PANOC+ -x1_fpg = Variable(n1) -x2_fpg = Variable(n2) -expr = ls(A1*x1_fpg + A2*x2_fpg - b) + lam1*norm(x1_fpg, 1) + lam2*norm(x2_fpg, 2) +x1_panocplus = Variable(n1) +x2_panocplus = Variable(n2) +expr = ls(A1*x1_panocplus + A2*x2_panocplus - b) + lam1*norm(x1_panocplus, 1) + lam2*norm(x2_panocplus, 2) prob = problem(expr) -@time sol = solve(prob, PANOCplus(tol=1e-10, verbose=false,maxit=20000)) - -# Solve with ZeroFPR +@time sol = solve(prob, PANOCplus()) -x1_zerofpr = Variable(n1) -x2_zerofpr = Variable(n2) -expr = ls(A1*x1_zerofpr + A2*x2_zerofpr - b) + lam1*norm(x1_zerofpr, 1) + lam2*norm(x2_zerofpr, 2) -prob = problem(expr) -@time sol = solve(prob, ZeroFPR(tol=1e-10, verbose=false)) - -# Solve with PANOC - -x1_panoc = Variable(n1) -x2_panoc = Variable(n2) -expr = ls(A1*x1_panoc + A2*x2_panoc - b) + lam1*norm(x1_panoc, 1) + lam2*norm(x2_panoc, 2) -prob = problem(expr) -@time sol = solve(prob, PANOC(tol=1e-10, verbose=false)) - -# Solve with minimize, use default solver/options - -x1 = Variable(n1) -x2 = Variable(n2) -@time sol = @minimize ls(A1*x1 + A2*x2 - b) + lam1*norm(x1, 1) + lam2*norm(x2, 2) - -@test norm(~x1_fpg - ~x1_zerofpr, Inf)/(1+norm(~x1_zerofpr, Inf)) <= 1e-6 -@test norm(~x2_fpg - ~x2_zerofpr, Inf)/(1+norm(~x2_zerofpr, Inf)) <= 1e-6 -@test norm(~x1_fpg - ~x1_panoc, Inf)/(1+norm(~x1_panoc, Inf)) <= 1e-6 -@test norm(~x2_fpg - ~x2_panoc, Inf)/(1+norm(~x2_panoc, Inf)) <= 1e-6 -@test norm(~x1 - ~x1_zerofpr, Inf)/(1+norm(~x1_zerofpr, Inf)) <= 1e-3 -@test norm(~x2 - ~x2_zerofpr, Inf)/(1+norm(~x2_zerofpr, Inf)) <= 1e-3 - -res = A1*~x1_fpg + A2*~x2_fpg - b +res = A1*~x1_panocplus + A2*~x2_panocplus - b grad1 = A1'*res grad2 = A2'*res -ind1_zero = (~x1_fpg .== 0) -subgr1 = lam1*sign.(~x1_fpg) +ind1_zero = (~x1_panocplus .== 0) +subgr1 = lam1*sign.(~x1_panocplus) subdiff1_low, subdiff1_upp = copy(subgr1), copy(subgr1) subdiff1_low[ind1_zero] .= -lam1 subdiff1_upp[ind1_zero] .= +lam1 -subgr2 = lam2*(~x2_fpg/norm(~x2_fpg, 2)) +subgr2 = lam2*(~x2_panocplus/norm(~x2_panocplus, 2)) @test maximum(subdiff1_low + grad1) <= 1e-6 @test maximum(-subdiff1_upp - grad1) <= 1e-6 @test norm(grad2 + subgr2) <= 1e-6 +# Solve with FastForwardBackward + +x1_ffb = Variable(n1) +x2_ffb = Variable(n2) +expr = ls(A1*x1_ffb + A2*x2_ffb - b) + lam1*norm(x1_ffb, 1) + lam2*norm(x2_ffb, 2) +prob = problem(expr) +@time sol = solve(prob, FastForwardBackward()) + +@test norm(~x1_panocplus - ~x1_ffb, Inf)/(1+norm(~x1_ffb, Inf)) <= 1e-6 +@test norm(~x2_panocplus - ~x2_ffb, Inf)/(1+norm(~x2_ffb, Inf)) <= 1e-6 + ############################################################################### ## Lasso problem with known solution ############################################################################### @@ -164,13 +147,13 @@ prob = problem(expr) # Solve with minimize, default solver/options -x = Variable(n) -@time sol = @minimize smooth(norm(A*x - b, 2)) + lam*norm(x, 1) +#x = Variable(n) +#@time sol = @minimize smooth(norm(A*x - b, 2)) + lam*norm(x, 1) @test norm(~x_pg - ~x_fpg, Inf)/(1+norm(~x_pg, Inf)) <= 1e-4 @test norm(~x_pg - ~x_zerofpr, Inf)/(1+norm(~x_pg, Inf)) <= 1e-4 @test norm(~x_pg - ~x_panoc, Inf)/(1+norm(~x_pg, Inf)) <= 1e-4 -@test norm(~x_pg - ~x, Inf)/(1+norm(~x_pg, Inf)) <= 1e-3 +#@test norm(~x_pg - ~x, Inf)/(1+norm(~x_pg, Inf)) <= 1e-3 ################################################################################ ### Box-constrained least-squares @@ -227,11 +210,11 @@ prob = problem(expr, x_panoc in [lb, ub]) # Solve with minimize, default solver/options -x = Variable(n) -@time sol = @minimize ls(A*x - b) st x in [lb, ub] +#x = Variable(n) +#@time sol = @minimize ls(A*x - b) st x in [lb, ub] -@test norm(~x - max.(lb, min.(ub, ~x)), Inf) <= 1e-12 -@test norm(~x - max.(lb, min.(ub, ~x - A'*(A*~x - b))), Inf)/(1+norm(~x, Inf)) <= 1e-4 +#@test norm(~x - max.(lb, min.(ub, ~x)), Inf) <= 1e-12 +#@test norm(~x - max.(lb, min.(ub, ~x - A'*(A*~x - b))), Inf)/(1+norm(~x, Inf)) <= 1e-4 ################################################################################ ### Non-negative least-squares from a known solution diff --git a/test/test_usage_small.jl b/test/test_usage_small.jl index 8503707..14bbcd4 100644 --- a/test/test_usage_small.jl +++ b/test/test_usage_small.jl @@ -1,14 +1,24 @@ +using ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus, ADMM, CGNR + A = randn(3,5) b = randn(3) x_zfpr = Variable(5) prob_zfpr = problem(ls(A*x_zfpr - b) + 1e-3*norm(x_zfpr, 1)) -sol_zfpr = solve(prob_zfpr, ZeroFPR()) +sol_zfpr = solve(prob_zfpr, ZeroFPR(maxit=10)) x_pnc = Variable(5) prob_pnc = problem(ls(A*x_pnc - b) + 1e-3*norm(x_pnc, 1)) -sol_pnc = solve(prob_pnc, PANOC()) +sol_pnc = solve(prob_pnc, PANOC(maxit=10)) x_pncp = Variable(5) prob_pncp = problem(ls(A*x_pncp - b) + 1e-3*norm(x_pncp, 1)) -sol_pncp = solve(prob_pncp, PANOCplus()) +sol_pncp = solve(prob_pncp, PANOCplus(maxit=10)) + +x_admm = Variable(5) +prob_admm = problem(ls(A*x_admm - b) + 1e-3*norm(x_admm, 1)) +sol_admm = solve(prob_admm, ADMM(maxit=10)) + +x_cg = Variable(5) +prob_cg = problem(ls(A*x_cg - b) + 1e-3*norm(x_cg, 2)^2) +sol_cg = solve(prob_cg, CGNR(maxit=10))