Skip to content

Commit

Permalink
Merge pull request #146 from ggebbie/ggebbie/invert-model
Browse files Browse the repository at this point in the history
Julia 1.10 plus initial code for circulation inversion
  • Loading branch information
ggebbie authored Jan 4, 2024
2 parents 9a37ef9 + 7a269cc commit e1ecd99
Show file tree
Hide file tree
Showing 12 changed files with 524 additions and 194 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMI"
uuid = "582500f6-28c8-4d8f-aabe-b197735ec1d4"
authors = ["G Jake Gebbie <ggebbie@whoi.edu>"]
version = "0.2.5"
version = "0.2.7"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand Down Expand Up @@ -39,7 +39,7 @@ LineSearches = "7"
MAT = "0.10"
Optim = "1"
StatsBase = "0.33"
julia = "1.9"
julia = "1.9,1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
52 changes: 30 additions & 22 deletions scripts/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "166334166cd783b55c1d11fd1ac6684291bfb997"
project_hash = "6f3e7d525e4b2eff10d9cf99ba23a3e3224a2ad2"

[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
Expand Down Expand Up @@ -149,7 +149,7 @@ version = "3.46.2"
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+0"
version = "1.0.5+1"

[[deps.ConcurrentUtilities]]
deps = ["Serialization", "Sockets"]
Expand Down Expand Up @@ -451,9 +451,14 @@ uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0"

[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0"

[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
Expand Down Expand Up @@ -540,7 +545,7 @@ version = "1.1.7"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+0"
version = "2.28.2+1"

[[deps.MicroMamba]]
deps = ["Pkg", "Scratch", "micromamba_jll"]
Expand All @@ -559,13 +564,13 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2022.10.11"
version = "2023.1.10"

[[deps.NCDatasets]]
deps = ["CFTime", "CommonDataModel", "DataStructures", "Dates", "NetCDF_jll", "NetworkOptions", "Printf"]
git-tree-sha1 = "afd015e81e60cfbdba04ef59bcdc80e18bd613cd"
git-tree-sha1 = "e201ed836f4486d0a5f593e68b7621b2e24237c5"
uuid = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
version = "0.12.14"
version = "0.12.16"

[[deps.NLSolversBase]]
deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"]
Expand Down Expand Up @@ -604,12 +609,12 @@ version = "1.12.9"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.21+4"
version = "0.3.23+2"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+0"
version = "0.8.1+2"

[[deps.OpenSSL]]
deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
Expand All @@ -636,9 +641,9 @@ uuid = "429524aa-4258-5aef-a3af-852621145aeb"
version = "1.7.5"

[[deps.OrderedCollections]]
git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.0"
version = "1.6.3"

[[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
Expand Down Expand Up @@ -667,7 +672,7 @@ version = "1.3.0"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.2"
version = "1.10.0"

[[deps.PooledArrays]]
deps = ["DataAPI", "Future"]
Expand Down Expand Up @@ -724,7 +729,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[deps.Random]]
deps = ["SHA", "Serialization"]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.Ratios]]
Expand Down Expand Up @@ -818,6 +823,7 @@ version = "1.1.0"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"

[[deps.SpecialFunctions]]
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
Expand All @@ -843,7 +849,7 @@ version = "1.4.0"
[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
version = "1.9.0"
version = "1.10.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -882,15 +888,15 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "5.10.1+6"
version = "7.2.1+1"

[[deps.TMI]]
deps = ["CSV", "Distances", "Distributions", "Downloads", "GibbsSeaWater", "GoogleDrive", "Interpolations", "LineSearches", "LinearAlgebra", "MAT", "NCDatasets", "NetCDF", "Optim", "OrderedCollections", "SparseArrays", "Statistics", "StatsBase", "UnicodePlots"]
path = ".."
uuid = "582500f6-28c8-4d8f-aabe-b197735ec1d4"
version = "0.2.4"
version = "0.2.6"
weakdeps = ["GeoPythonPlot"]

[deps.TMI.extensions]
Expand Down Expand Up @@ -953,21 +959,23 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[deps.UnicodePlots]]
deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "Requires", "SparseArrays", "StaticArrays", "StatsBase"]
git-tree-sha1 = "9fbe3fb6c4bbe4cafb5ce4d15bbec82f0077e1d5"
git-tree-sha1 = "b96de03092fe4b18ac7e4786bee55578d4b75ae8"
uuid = "b8865327-cd53-5732-bb35-84acbb429228"
version = "3.5.2"
version = "3.6.0"

[deps.UnicodePlots.extensions]
FreeTypeExt = ["FileIO", "FreeType"]
ImageInTerminalExt = "ImageInTerminal"
IntervalSetsExt = "IntervalSets"
TermExt = "Term"
UnitfulExt = "Unitful"

[deps.UnicodePlots.weakdeps]
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FreeType = "b38be410-82b0-50bf-ab77-7b57e271db43"
ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[[deps.UnsafePointers]]
Expand Down Expand Up @@ -1006,12 +1014,12 @@ version = "2.10.3+0"
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+0"
version = "1.2.13+1"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+0"
version = "5.8.0+1"

[[deps.micromamba_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl"]
Expand All @@ -1027,4 +1035,4 @@ version = "1.52.0+1"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+0"
version = "17.4.0+2"
3 changes: 3 additions & 0 deletions scripts/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ GeoPythonPlot = "1e05a8e1-7dec-4f9e-9d3d-7df52321841b"
GoogleDrive = "91feb7a0-3508-11ea-1e8e-afea2c1c9a19"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
TMI = "582500f6-28c8-4d8f-aabe-b197735ec1d4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
144 changes: 144 additions & 0 deletions scripts/invert_model_TS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#=%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Under development: the goal is to invert model output
% and get the transport matrix
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% =#

import Pkg; Pkg.activate(".")
using Revise
using TMI
using Test
using GGplot
using LinearAlgebra
using SparseArrays
using Statistics
#, Distributions, LinearAlgebra, Zygote, ForwardDiff, Optim

TMIversion = "modern_90x45x33_GH10_GH12"
A, Alu, γ, TMIfile, L, B = config_from_nc(TMIversion)



pkgdir() = dirname(dirname(pathof(TMI)))
pkgdir(args...) = joinpath(pkgdir(), args...)

pkgdatadir() = joinpath(pkgdir(),"data")
pkgdatadir(args...) = joinpath(pkgdatadir(), args...)

TMIfile = pkgdatadir("TMI_"*TMIversion*".nc")
θtrue = readfield(TMIfile,"θ",γ)

ctrue = vectrue)

#An = A./sum(A;dims=1)

q = A * ctrue

# The first guess for the tracer concentration should be close to the actual tracer concentration
# take first guess as θtrue+0.01
cvec=vectrue).+ 0.1
θguess = unvectrue,cvec)

#first guess tracer control vector is near zero, and we want this to remain relatively small
u = Field(-0.01.*ones(size.wet)),γ,θtrue.name,θtrue.longname,θtrue.units)
uvec = vec(u)



#We need an error covariance matrix
W⁻ = Diagonal(1 ./( ones(sum.wet))).^2)#(1/sum(γ.wet))

#I want to allow a bunch of error in the surface part of the tracer conservation please
Qerror = ones(size.wet))
Qerror[:,:,1].=0
Qfield = Field(Qerror,γ,θtrue.name,θtrue.longname,θtrue.units)
Qvec = vec(Qfield)


Q⁻ = Diagonal(1 ./( ones(sum.wet))).^2)
A0= A .* 0.2
non_zero_indices1, non_zero_indices2, non_zero_values = findnz(A0)

non_zero_indices = hcat(non_zero_indices1, non_zero_indices2)


convec = [uvec; non_zero_values]
ulength=length(uvec)

# get sample J value
F = costfunction_gridded_model(convec,non_zero_indices,u,A0,ctrue,cvec,q,W⁻,Q⁻,γ)
fg!(F,G,x) = costfunction_gridded_model!(F,G,x,non_zero_indices,u,A0,ctrue,cvec,q,W⁻,Q⁻,γ)
fg(x) = costfunction_gridded_model(x,non_zero_indices,u,A0,ctrue,cvec,q,W⁻,Q⁻,γ)
f(x) = fg(x)[1]
J₀,gJ₀ = fg(convec)

#### gradient check ###################
# check with forward differences
ϵ = 1e-3
#ii = rand(1:sum(γ.wet[:,:,1]))
println(size(length(convec)))
ii = rand(1:length(convec))
println("Location for test =",ii)
δu = copy(convec); δu[ii] += ϵ
∇f_finite = (f(δu) - f(convec))/ϵ
println(∇f_finite)

fg!(J₀,gJ₀,(convec+δu)./2) # J̃₀ is not overwritten
∇f = gJ₀[ii]
println(∇f)

# error less than 10 percent?
println("Percent error ",100*abs(∇f - ∇f_finite)/abs(∇f + ∇f_finite))
#### end gradient check #################


#print(length(convec))
# filter the data with an Optim.jl method
iterations = 5
out = steadyclimatology(convec,fg!,iterations)

# reconstruct by hand to double-check.
= unvec((W⁻ * u),out.minimizer[begin:ulength])


# reconstruct tracer map
c₀ = θguess
= θguess+

Δc̃ =- θtrue
Δc₀ = θguess - θtrue

Anew = A0 + sparse(non_zero_indices[:, 1], non_zero_indices[:, 2], out.minimizer[ulength+1:end])
onesvec = ones(size(q))

Adiff1 = sum((A.-A0).^2)
Adiff2 = sum((A.-Anew).^2)
oldf = sum((non_zero_values).^2)
newf = sum((out.minimizer[ulength+1:end]).^2)
tracer_cons1 = sum((A0*cvec-q).^2)
tracer_cons2 = sum((Anew*(cvec+out.minimizer[begin:ulength])-q).^2)
mass_cons1 = sum((A0*onesvec-onesvec).^2)
mass_cons2 = sum((Anew*onesvec-onesvec).^2)


println("A difference before: $Adiff1")
println("A difference after: $Adiff2")
println("old tracer cons:$tracer_cons1")
println("new tracer cons:$tracer_cons2")
println("old mass cons:$mass_cons1")
println("new mass cons:$mass_cons2")


# plot the difference
level = 15 # your choice 1-33
depth = γ.depth[level]

cntrs = 0:0.5:15
label = "True θ"
planviewplottrue, depth, cntrs, titlelabel=label)
readline()

cntrs = 0:0.5:15
label = "Optimized θ"
planviewplot(c̃, depth, cntrs, titlelabel=label)


Loading

0 comments on commit e1ecd99

Please sign in to comment.