From 753182f3cb05ecba865e6b158d83e21babd78056 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Wed, 4 Mar 2026 23:21:49 -0500 Subject: [PATCH 01/14] Delete train_output.txt --- examples/train_output.txt | 71 --------------------------------------- 1 file changed, 71 deletions(-) delete mode 100644 examples/train_output.txt diff --git a/examples/train_output.txt b/examples/train_output.txt deleted file mode 100644 index c76430e..0000000 --- a/examples/train_output.txt +++ /dev/null @@ -1,71 +0,0 @@ -Precompiling packages... -Info Given FlowMarginals was explicitly requested, output will be shown live  -ERROR: LoadError: UndefVarError: `linked_vec_length` not defined in `Bijectors` -Suggestion: check for spelling errors or missing imports. -Stacktrace: - [1] getproperty(x::Module, f::Symbol) - @ Base ./Base_compiler.jl:47 - [2] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 - [3] include(mapexpr::Function, mod::Module, _path::String) - @ Base ./Base.jl:307 - [4] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:12 - [5] include(mod::Module, _path::String) - @ Base ./Base.jl:306 - [6] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing) - @ Base ./loading.jl:2996 - [7] top-level scope - @ stdin:5 - [8] eval(m::Module, e::Any) - @ Core ./boot.jl:489 - [9] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) - @ Base ./loading.jl:2842 - [10] include_string - @ ./loading.jl:2852 [inlined] - [11] exec_options(opts::Base.JLOptions) - @ Base ./client.jl:315 - [12] _start() - @ Base ./client.jl:550 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:1 -in expression starting at stdin:5 - ✗ FlowMarginals - 0 dependencies successfully precompiled in 5 seconds. 254 already precompiled. - -ERROR: LoadError: The following 1 direct dependency failed to precompile: - -FlowMarginals - -Failed to precompile FlowMarginals [d4f3a2b1-9e8c-4d7f-b5a6-2c1e0f9d8e7a] to "/home/marcobonici/.julia/compiled/v1.12/FlowMarginals/jl_lWeZRJ". -ERROR: LoadError: UndefVarError: `linked_vec_length` not defined in `Bijectors` -Suggestion: check for spelling errors or missing imports. -Stacktrace: - [1] getproperty(x::Module, f::Symbol) - @ Base ./Base_compiler.jl:47 - [2] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 - [3] include(mapexpr::Function, mod::Module, _path::String) - @ Base ./Base.jl:307 - [4] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:12 - [5] include(mod::Module, _path::String) - @ Base ./Base.jl:306 - [6] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing) - @ Base ./loading.jl:2996 - [7] top-level scope - @ stdin:5 - [8] eval(m::Module, e::Any) - @ Core ./boot.jl:489 - [9] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) - @ Base ./loading.jl:2842 - [10] include_string - @ ./loading.jl:2852 [inlined] - [11] exec_options(opts::Base.JLOptions) - @ Base ./client.jl:315 - [12] _start() - @ Base ./client.jl:550 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:1 -in expression starting at stdin: -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/examples/train_multinormal.jl:11 From b1c733cc845a723944a739841df80bd26f11c0ba Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 5 Mar 2026 01:30:11 -0500 Subject: [PATCH 02/14] Adding NSF --- Project.toml | 6 + examples/Manifest.toml | 6 +- examples/train_multinormal_nsf.jl | 130 ++++++++++++++++++ src/SimpleFlows.jl | 5 +- src/distribution.jl | 30 +++-- src/io.jl | 31 +++-- src/nsf.jl | 211 ++++++++++++++++++++++++++++++ src/splines.jl | 154 ++++++++++++++++++++++ test/runtests.jl | 2 + test/test_nsf.jl | 52 ++++++++ test/test_splines.jl | 93 +++++++++++++ 11 files changed, 695 insertions(+), 25 deletions(-) create mode 100644 examples/train_multinormal_nsf.jl create mode 100644 src/nsf.jl create mode 100644 src/splines.jl create mode 100644 test/test_nsf.jl create mode 100644 test/test_splines.jl diff --git a/Project.toml b/Project.toml index efcacd0..6904989 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,15 @@ authors = ["marcobonici "] [deps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -20,12 +23,15 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Bijectors = "0.15.17" +ChainRulesCore = "1.26.0" ConcreteStructs = "0.2.3" Distributions = "0.25.123" +ForwardDiff = "1.3.2" JSON = "1.4.0" LinearAlgebra = "1.12.0" Lux = "1.31.3" MLUtils = "0.4.8" +NNlib = "0.9.33" NPZ = "0.4.3" Optimisers = "0.4.7" Random = "1.11.0" diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 0aabf77..5e4edc0 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -1753,10 +1753,8 @@ uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" version = "0.1.0" [[deps.SimpleFlows]] -deps = ["Bijectors", "ConcreteStructs", "Distributions", "JSON", "LinearAlgebra", "Lux", "MLUtils", "NPZ", "Optimisers", "Random", "Statistics", "Test", "Zygote"] -git-tree-sha1 = "9ee8f18c24029fe3ad3362abade4e005e283e4d1" -repo-rev = "main" -repo-url = "https://github.com/CosmologicalEmulators/SimpleFlows.jl" +deps = ["Bijectors", "ChainRulesCore", "ConcreteStructs", "Distributions", "ForwardDiff", "JSON", "LinearAlgebra", "Lux", "MLUtils", "NNlib", "NPZ", "Optimisers", "Random", "Statistics", "Test", "Zygote"] +path = ".." uuid = "7aff1418-a6e2-48c2-ba03-ae8e32e98757" version = "0.1.0" diff --git a/examples/train_multinormal_nsf.jl b/examples/train_multinormal_nsf.jl new file mode 100644 index 0000000..6c31a90 --- /dev/null +++ b/examples/train_multinormal_nsf.jl @@ -0,0 +1,130 @@ +""" +Train a Neural Spline Flow (NSF) normalizing flow to approximate a 4-dimensional correlated Gaussian, +then use Turing.jl to run NUTS inference under both the exact prior and the NF prior. + +Usage: + julia --project=. -t 4 examples/train_multinormal_nsf.jl + +Note: Turing.jl must be installed in your global environment or activated environment. +""" + +using SimpleFlows +using Distributions, LinearAlgebra, Random, Statistics, Printf +using Turing + +rng = Random.MersenneTwister(42) + +# ── 1. Define a non-trivial 4D target distribution ─────────────────────────── + +μ = [1.0, -0.5, 2.0, 0.3] +Σ = [1.00 0.50 0.10 0.00; + 0.50 1.00 0.30 0.20; + 0.10 0.30 1.00 0.40; + 0.00 0.20 0.40 1.00] +target = MvNormal(μ, Σ) + +# ── 2. Draw training data ───────────────────────────────────────────────────── + +N_train = 10_000 +data = Float64.(rand(rng, target, N_train)) # 4 × 10_000 + +println("Target mean: ", round.(mean(data; dims=2)[:]; digits=3)) +println("Target std: ", round.(std(data; dims=2)[:]; digits=3)) + +# ── 3. Build and train the flow ─────────────────────────────────────────────── + +# Note: NSF typically requires fewer transforms but slightly larger hidden layers +# or just more expressive bins (K). +flow = FlowDistribution(Float64; + architecture = :NSF, + n_transforms = 4, + dist_dims = 4, + hidden_layer_sizes = [32, 32], + K = 8, + tail_bound = 3.0, + rng, +) + +println("\nTraining NSF (4 transforms, K=8, [32, 32] hidden units, 500 epochs)…") +train_flow!(flow, data; + n_epochs = 500, + lr = 1e-3, + batch_size = 512, + verbose = true, +) + +# ── 4. Evaluate fit ─────────────────────────────────────────────────────────── + +N_test = 5_000 +x_test = Float64.(rand(rng, target, N_test)) + +lp_true = mean(logpdf(target, Float64.(x_test))) +lp_flow = mean(logpdf(flow, x_test)) + +println("\n── Density Fit ──────────────────────────────────────────") +println("Mean log-pdf (true distribution): ", round(lp_true; digits=4)) +println("Mean log-pdf (trained NSF flow): ", round(Float64(lp_flow); digits=4)) +println("Difference: ", round(abs(lp_true - lp_flow); digits=4)) + +samples = Distributions.rand(rng, flow, N_test) # 4 × N_test +println("\nFlow sample mean: ", round.(mean(samples; dims=2)[:]; digits=3)) +println("Flow sample std: ", round.(std(samples; dims=2)[:]; digits=3)) + +# ── 5. Save the trained flow ───────────────────────────────────────────────── + +save_dir = joinpath(@__DIR__, "..", "trained_flows", "nsf_mvn_4d") +save_trained_flow(save_dir, flow) +println("\nFlow saved to $save_dir") + +# ── 6. Reload and verify round-trip ────────────────────────────────────────── + +flow2 = load_trained_flow(save_dir; rng) +lp_reloaded = mean(logpdf(flow2, x_test)) +println("Mean log-pdf after reload: ", round(Float64(lp_reloaded); digits=4)) +@assert lp_reloaded ≈ lp_flow atol=1f-3 "Round-trip failed!" +println("Round-trip OK ✓") + +# ── 7. Turing.jl inference demo ─────────────────────────────────────────────── + +println("\n── Turing Inference ─────────────────────────────────────────") + +θ_true = [1.0, -0.5, 2.0, 0.3] +y_obs = θ_true[1] + 0.5 * randn(rng) +println("Observed y: ", round(y_obs; digits=4), " (true θ[1] = $(θ_true[1]))") + +@model function inference_model(y_obs, prior) + θ ~ prior + y_obs ~ Normal(θ[1], 0.5) +end + +n_samples = 1000 +n_chains = 4 + +# ── 7a. Exact MvNormal prior ───────────────────────────────────────────────── +println("\nSampling with exact MvNormal prior ($n_chains chains × $n_samples samples)…") +chain_exact = sample( + inference_model(y_obs, target), + NUTS(), MCMCThreads(), n_samples, n_chains; + progress = false, +) + +# ── 7b. Trained NSF prior ──────────────────────────────────────────────────── +println("Sampling with trained NSF prior ($n_chains chains × $n_samples samples)…") +chain_nf = sample( + inference_model(y_obs, flow2), + NUTS(), MCMCThreads(), n_samples, n_chains; + progress = false, +) + +# ── 8. Compare posteriors ──────────────────────────────────────────────────── + +println("\n── Posterior comparison (θ[1]) ──────────────────────────────") +exact_θ1 = vec(chain_exact[:, "θ[1]", :]) +nf_θ1 = vec(chain_nf[:, "θ[1]", :]) + +@printf " True θ[1]: %.4f\n" θ_true[1] +@printf " Observed y: %.4f\n" y_obs +@printf " Posterior mean (exact): %.4f ± %.4f\n" mean(exact_θ1) std(exact_θ1) +@printf " Posterior mean (NSF): %.4f ± %.4f\n" mean(nf_θ1) std(nf_θ1) + +println("\n✨ End of script: Turing MCMC sampling with NSF finished successfully!") diff --git a/src/SimpleFlows.jl b/src/SimpleFlows.jl index 034d494..2d06472 100644 --- a/src/SimpleFlows.jl +++ b/src/SimpleFlows.jl @@ -12,9 +12,12 @@ include("normalizer.jl") include("distribution.jl") include("training.jl") include("io.jl") +include("splines.jl") +include("nsf.jl") -export RealNVP, FlowDistribution +export RealNVP, NeuralSplineFlow, FlowDistribution, NSFCouplingLayer export MinMaxNormalizer export train_flow!, save_trained_flow, load_trained_flow +export unconstrained_rational_quadratic_spline end diff --git a/src/distribution.jl b/src/distribution.jl index 5c1d648..f09b02f 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -11,8 +11,8 @@ A trained normalizing flow wrapped as a `Distributions.jl` - `n_dims`, `hidden_dims`, `n_layers`: architecture metadata for serialization - `normalizer`: fitted `MinMaxNormalizer` (always present after training) """ -mutable struct FlowDistribution{T<:Real} <: ContinuousMultivariateDistribution - model :: RealNVP +mutable struct FlowDistribution{T<:Real, M<:AbstractLuxLayer} <: ContinuousMultivariateDistribution + model :: M ps st n_dims :: Int @@ -21,31 +21,37 @@ mutable struct FlowDistribution{T<:Real} <: ContinuousMultivariateDistribution end """ - FlowDistribution([Type=Float32]; n_transforms, dist_dims, hidden_layer_sizes, - hidden_dims=64, n_layers=3, - activation=gelu, rng=Random.default_rng()) + FlowDistribution([Type=Float32]; architecture=:RealNVP, n_transforms, dist_dims, + hidden_layer_sizes, hidden_dims=64, n_layers=3, + activation=gelu, rng=Random.default_rng(), K=8, tail_bound=3.0) Construct and randomly initialise a `FlowDistribution`. - -`hidden_layer_sizes` sets the width of each hidden layer independently. -For convenience, you may pass `hidden_dims` and `n_layers` instead, which will -expand to `fill(hidden_dims, n_layers)`. -The `normalizer` field is `nothing` until `train_flow!` is called. +`architecture` can be `:RealNVP` or `:NSF`. """ function FlowDistribution(::Type{T}=Float32; + architecture=:RealNVP, n_transforms::Int, dist_dims::Int, hidden_layer_sizes::Vector{Int}=Int[], hidden_dims::Int=64, n_layers::Int=3, activation=gelu, + K=8, tail_bound=3.0, rng::AbstractRNG=Random.default_rng()) where {T<:Real} # If no vector given, fall back to the scalar convenience args if isempty(hidden_layer_sizes) hidden_layer_sizes = fill(hidden_dims, n_layers) end - model = RealNVP(; n_transforms, dist_dims, hidden_layer_sizes, activation) + + model = if architecture == :RealNVP + RealNVP(; n_transforms, dist_dims, hidden_layer_sizes, activation) + elseif architecture == :NSF + NeuralSplineFlow(; n_transforms, dist_dims, hidden_layer_sizes, K, tail_bound, activation) + else + error("Unknown architecture: $architecture. Supported: :RealNVP, :NSF") + end + ps, st = Lux.setup(rng, model) ps = Lux.fmap(x -> x isa AbstractArray ? T.(x) : x, ps) - return FlowDistribution{T}(model, ps, st, dist_dims, hidden_layer_sizes, nothing) + return FlowDistribution{T, typeof(model)}(model, ps, st, dist_dims, hidden_layer_sizes, nothing) end # ── Distributions.jl interface ──────────────────────────────────────────────── diff --git a/src/io.jl b/src/io.jl index 830930e..8280860 100644 --- a/src/io.jl +++ b/src/io.jl @@ -34,30 +34,45 @@ end # ── Architecture dict ───────────────────────────────────────────────────────── function _flow_to_dict(flow::FlowDistribution) - return Dict( - "architecture" => "RealNVP", + d = Dict( + "architecture" => (flow.model isa RealNVP ? "RealNVP" : "NSF"), "n_transforms" => flow.model.n_transforms, "dist_dims" => flow.n_dims, "hidden_layer_sizes" => flow.hidden_layer_sizes, "activation" => "gelu", ) + if flow.model isa NeuralSplineFlow + d["K"] = flow.model.K + d["tail_bound"] = flow.model.tail_bound + end + return d end function _build_flow_from_dict(d::AbstractDict, ::Type{T}=Float32, rng::AbstractRNG=Random.default_rng()) where {T<:Real} arch = d["architecture"] - arch == "RealNVP" || error("Unknown architecture: $arch") + arch_sym = (arch == "RealNVP" ? :RealNVP : :NSF) + # Support both new (hidden_layer_sizes) and legacy (hidden_dims + n_layers) formats hidden_layer_sizes = if haskey(d, "hidden_layer_sizes") Int.(d["hidden_layer_sizes"]) else fill(Int(d["hidden_dims"]), Int(d["n_layers"])) end - return FlowDistribution(T; - n_transforms = Int(d["n_transforms"]), - dist_dims = Int(d["dist_dims"]), - hidden_layer_sizes = hidden_layer_sizes, - rng, + + kwargs = Dict{Symbol, Any}( + :architecture => arch_sym, + :n_transforms => Int(d["n_transforms"]), + :dist_dims => Int(d["dist_dims"]), + :hidden_layer_sizes => hidden_layer_sizes, + :rng => rng, ) + + if arch == "NSF" + kwargs[:K] = Int(d["K"]) + kwargs[:tail_bound] = Float64(d["tail_bound"]) + end + + return FlowDistribution(T; kwargs...) end # ── Public API ──────────────────────────────────────────────────────────────── diff --git a/src/nsf.jl b/src/nsf.jl new file mode 100644 index 0000000..8643844 --- /dev/null +++ b/src/nsf.jl @@ -0,0 +1,211 @@ +# src/nsf.jl +using Lux +using Bijectors +using Random +using LinearAlgebra + +""" + NSFCouplingLayer(mask, conditioner; K=8, tail_bound=3.0) + +A Neural Spline Flow (NSF) coupling layer using Rational Quadratic Splines. +`mask` is a binary vector (1 for variables to transform, 0 for variables that condition). +`conditioner` is a Lux network that takes variables with 0 and returns spline parameters for variables with 1. +""" +# Bijectors and Layers for NSF +struct NSFSplineBijector + mask + params + K::Int + tail_bound::Float64 +end + +function forward_and_log_det(b::NSFSplineBijector, x::AbstractArray) + # x is (D, N) + D, N = size(x) + mask = b.mask + K = b.K + tail_bound = b.tail_bound + params = b.params + + # x_tr: variables to be transformed + x_tr = x[mask, :] + D_tr = size(x_tr, 1) + + # Reshape params to (D_tr, 3K-1, N) + params = reshape(params, D_tr, 3*K - 1, N) + + # Partition params + w_unnorm = params[:, 1:K, :] + h_unnorm = params[:, K+1:2*K, :] + dv_unnorm = params[:, 2*K+1:end, :] + + # Flatten everything to call the spline function + x_tr_flat = vec(x_tr) + w_flat = reshape(permutedims(w_unnorm, (1, 3, 2)), D_tr * N, K) + h_flat = reshape(permutedims(h_unnorm, (1, 3, 2)), D_tr * N, K) + dv_flat = reshape(permutedims(dv_unnorm, (1, 3, 2)), D_tr * N, K-1) + + y_tr_flat, lad_flat = unconstrained_rational_quadratic_spline( + x_tr_flat, w_flat, h_flat, dv_flat, eltype(x)(tail_bound) + ) + + y_tr = reshape(y_tr_flat, D_tr, N) + lad_tr = reshape(lad_flat, D_tr, N) + + # We yield the full transformed y (only for masked dims) + # The MaskedCoupling logic will handle the identity part. + # But wait, MaskedCoupling expects the bijector to return an array of same size as x? + # No, MaskedCoupling: y, log_det = transform_fn(params) + # Then y = ifelse.(bj.mask, y, x) + # So y MUST have same size as x. + + # Reconstruction using comprehension (Zygote friendly) + # We need to find the index into y_tr for each masked dimension. + # tr_indices[i] will be the row index in y_tr if mask[i] is true. + tr_indices = cumsum(mask) + + y = vcat([mask[i] ? y_tr[tr_indices[i]:tr_indices[i], :] : x[i:i, :] for i in 1:D]...) + log_det = vcat([mask[i] ? lad_tr[tr_indices[i]:tr_indices[i], :] : fill(zero(eltype(x)), 1, N) for i in 1:D]...) + + return y, log_det + + return y, log_det +end + +function inverse_and_log_det(b::NSFSplineBijector, y::AbstractArray) + # y is (D, N) + D, N = size(y) + mask = b.mask + K = b.K + tail_bound = b.tail_bound + params = b.params + + y_tr = y[mask, :] + D_tr = size(y_tr, 1) + + params = reshape(params, D_tr, 3*K - 1, N) + w_unnorm = params[:, 1:K, :] + h_unnorm = params[:, K+1:2*K, :] + dv_unnorm = params[:, 2*K+1:end, :] + + y_tr_flat = vec(y_tr) + w_flat = reshape(permutedims(w_unnorm, (1, 3, 2)), D_tr * N, K) + h_flat = reshape(permutedims(h_unnorm, (1, 3, 2)), D_tr * N, K) + dv_flat = reshape(permutedims(dv_unnorm, (1, 3, 2)), D_tr * N, K-1) + + x_tr_flat, lad_flat = unconstrained_rational_quadratic_spline( + y_tr_flat, w_flat, h_flat, dv_flat, eltype(y)(tail_bound); + inverse=true + ) + + x_tr = reshape(x_tr_flat, D_tr, N) + lad_tr = reshape(lad_flat, D_tr, N) + + tr_indices = cumsum(mask) + + x = vcat([mask[i] ? x_tr[tr_indices[i]:tr_indices[i], :] : y[i:i, :] for i in 1:D]...) + log_det = vcat([mask[i] ? lad_tr[tr_indices[i]:tr_indices[i], :] : fill(zero(eltype(y)), 1, N) for i in 1:D]...) + + return x, log_det + + return x, log_det +end + +function NSFCouplingBijector_from_flat(params, mask, K, tail_bound) + return NSFSplineBijector(mask, params, K, tail_bound) +end + +""" + NeuralSplineFlow(; n_transforms, dist_dims, hidden_dims, n_layers, K=8, tail_bound=3.0, activation=gelu) + +Neural Spline Flow (NSF) with rational quadratic coupling layers. +""" +@concrete struct NeuralSplineFlow <: Lux.AbstractLuxContainerLayer{(:conditioners,)} + conditioners + dist_dims :: Int + n_transforms :: Int + hidden_layer_sizes :: Vector{Int} + K :: Int + tail_bound :: Float64 +end + +function NeuralSplineFlow(; n_transforms::Int, dist_dims::Int, + hidden_layer_sizes::Vector{Int}, K=8, tail_bound=3.0, activation=gelu) + # Number of transformed dimensions in each layer (mask alternate half) + D = dist_dims + D_tr = D - (D ÷ 2) # Approximately half + # Conditioner output size: D_tr * (3K - 1) + out_dims = D_tr * (3*K - 1) + + mlps = [MLP(D, hidden_layer_sizes, out_dims; activation) + for _ in 1:n_transforms] + keys_ = ntuple(i -> Symbol(:conditioners_, i), n_transforms) + conditioners = NamedTuple{keys_}(Tuple(mlps)) + return NeuralSplineFlow(conditioners, D, n_transforms, hidden_layer_sizes, K, Float64(tail_bound)) +end + +function Lux.initialstates(rng::AbstractRNG, m::NeuralSplineFlow) + mask_list = [Bool.(collect(1:(m.dist_dims)) .% 2 .== i % 2) + for i in 1:(m.n_transforms)] + return (; mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) +end + +# Generic log_prob and draw_samples +function log_prob(model::Union{RealNVP, NeuralSplineFlow}, ps, st, x::AbstractMatrix) + lp = nothing + for i in model.n_transforms:-1:1 + k = keys(model.conditioners)[i] + mask = st.mask_list[i] + cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], + s = st.conditioners[k] + x_cond -> Lux.apply(m, x_cond, p, s)[1] + end + + bj = if model isa RealNVP + MaskedCoupling(mask, cond_fn, AffineBijector) + else + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) + end + + x, ld = inverse_and_log_det(bj, x) + lp = isnothing(lp) ? ld : lp .+ ld + end + base_lp = dsum(gaussian_logpdf.(x); dims=(1,)) + return isnothing(lp) ? base_lp : lp .+ base_lp +end + +function draw_samples(rng::AbstractRNG, ::Type{T}, model::Union{RealNVP, NeuralSplineFlow}, + ps, st, n_samples::Int) where T + x = randn(rng, T, model.dist_dims, n_samples) + for i in 1:(model.n_transforms) + k = keys(model.conditioners)[i] + mask = st.mask_list[i] + cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], + s = st.conditioners[k] + x_cond -> Lux.apply(m, x_cond, p, s)[1] + end + + bj = if model isa RealNVP + MaskedCoupling(mask, cond_fn, AffineBijector) + else + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) + end + + x, _ = forward_and_log_det(bj, x) + end + return x +end + +# Helper to build the bijector from flat parameters +struct NSFSplineConstructor + mask + K::Int + tail_bound::Float64 +end + +function (c::NSFSplineConstructor)(params) + return NSFCouplingLayer(c.mask, params, c.K, c.tail_bound) +end + +# Wait, I need to fix MaskedCoupling to work with NSFCouplingLayer +# or just use the logic in NSFCouplingLayer. diff --git a/src/splines.jl b/src/splines.jl new file mode 100644 index 0000000..b8c12e2 --- /dev/null +++ b/src/splines.jl @@ -0,0 +1,154 @@ +# src/splines.jl +using NNlib +using ChainRulesCore +using ForwardDiff + +# Helper to find bins. Zygote ignores gradients for indices. +function compute_bin_idx(cum_arrays::AbstractMatrix{T}, inputs::AbstractVector{T}, K::Int) where {T<:Real} + M = length(inputs) + bin_idx = zeros(Int, M) + for i in 1:M + # Handle ForwardDiff values properly if needed, but searchsortedlast supports Duals natively + idx = searchsortedlast(@view(cum_arrays[i, :]), inputs[i]) + bin_idx[i] = clamp(idx, 1, K) + end + return bin_idx +end +ChainRulesCore.@non_differentiable compute_bin_idx(Any...) + +""" + unconstrained_rational_quadratic_spline(inputs, widths, heights, derivs; kwargs...) + +Inputs: +- `inputs`: Vector of length M. +- `unnormalized_widths`: Matrix of shape (M, K). +- `unnormalized_heights`: Matrix of shape (M, K). +- `unnormalized_derivatives`: Matrix of shape (M, K-1). + +All inputs are flattened. Returns `(outputs, logabsdet)`. +""" +function unconstrained_rational_quadratic_spline( + inputs::AbstractVector{T_in}, + unnormalized_widths::AbstractMatrix{T_w}, + unnormalized_heights::AbstractMatrix{T_h}, + unnormalized_derivatives::AbstractMatrix{T_dv}, + tail_bound=3.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3; + inverse::Bool=false +) where {T_in<:Real, T_w<:Real, T_h<:Real, T_dv<:Real} + # Operations will be performed in the promoted type + T = promote_type(T_in, T_w, T_h, T_dv) + + # Cast everything to the consistency type T + inputs = T.(inputs) + unnormalized_widths = T.(unnormalized_widths) + unnormalized_heights = T.(unnormalized_heights) + unnormalized_derivatives = T.(unnormalized_derivatives) + + # Convert keyword arguments to T + tail_bound = T(tail_bound) + min_bin_width = T(min_bin_width) + min_bin_height = T(min_bin_height) + min_derivative = T(min_derivative) + + M, K = size(unnormalized_widths) + + # 1. Pad derivatives with constant for linear tails + constant = T(log(exp(1 - min_derivative) - 1)) + pad_c = fill(constant, M) + unnorm_derivs = hcat(pad_c, unnormalized_derivatives, pad_c) # (M, K+1) + + # 2. Extract valid interior regions (using Mask to avoid mutating variables for Zygote) + inside_mask = (inputs .>= -tail_bound) .& (inputs .<= tail_bound) + + # We will compute the spline for ALL points, but only conditionally select the output + # This avoids Zygote mutating array issues, relying on `ifelse`. + # To avoid NaNs or domain errors, we clamp the inputs outside the tail bound. + clamped_inputs = clamp.(inputs, -tail_bound, tail_bound) + + # 3. Compute normalize bin parameters + widths_raw = NNlib.softmax(unnormalized_widths; dims=2) + widths = min_bin_width .+ (1 - min_bin_width * K) .* widths_raw + + cumwidths_raw = cumsum(widths; dims=2) + # Scale to (0, 1) then to (-tail_bound, tail_bound) + # We use hcat to ensure exact boundaries + interior_cumwidths = (2 * tail_bound) .* (@view cumwidths_raw[:, 1:end-1]) .- tail_bound + cumwidths = hcat(fill(-tail_bound, M), interior_cumwidths, fill(tail_bound, M)) + widths = cumwidths[:, 2:end] .- cumwidths[:, 1:end-1] + + derivatives = min_derivative .+ NNlib.softplus.(unnorm_derivs) + + heights_raw = NNlib.softmax(unnormalized_heights; dims=2) + heights = min_bin_height .+ (1 - min_bin_height * K) .* heights_raw + + cumheights_raw = cumsum(heights; dims=2) + interior_cumheights = (2 * tail_bound) .* (@view cumheights_raw[:, 1:end-1]) .- tail_bound + cumheights = hcat(fill(-tail_bound, M), interior_cumheights, fill(tail_bound, M)) + heights = cumheights[:, 2:end] .- cumheights[:, 1:end-1] + + # 4. Find the appropriate bin + if inverse + bin_idx = compute_bin_idx(cumheights, clamped_inputs, K) + else + bin_idx = compute_bin_idx(cumwidths, clamped_inputs, K) + end + + # 5. Gather bin specific parameters + # Zygote friendly gather using linear indexing + # cumwidths is (M, K+1) + # widths is (M, K) + linear_indices_k_plus_1 = (bin_idx .- 1) .* M .+ (1:M) + linear_indices_k = (bin_idx .- 1) .* M .+ (1:M) + + input_cumwidths = cumwidths[linear_indices_k_plus_1] + input_bin_widths = widths[linear_indices_k] + + input_cumheights = cumheights[linear_indices_k_plus_1] + input_heights = heights[linear_indices_k] + + delta = heights ./ widths + input_delta = delta[linear_indices_k] + + input_derivatives = derivatives[linear_indices_k_plus_1] + input_derivatives_plus_one = derivatives[linear_indices_k_plus_1 .+ M] + + # 6. Evaluate spline equations + if inverse + a = (clamped_inputs .- input_cumheights) .* (input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) .+ input_heights .* (input_delta .- input_derivatives) + b = input_heights .* input_derivatives .- (clamped_inputs .- input_cumheights) .* (input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) + c = -input_delta .* (clamped_inputs .- input_cumheights) + + discriminant = b.^2 .- 4 .* a .* c + # Numerical stability: splines are monotonic so discriminant >= 0. + # We use abs and a tiny epsilon for square root gradients. + root = (2 .* c) ./ (-b .- sqrt.(abs.(discriminant) .+ T(1e-12))) + + rq_outputs = root .* input_bin_widths .+ input_cumwidths + + theta_one_minus_theta = root .* (1 .- root) + denominator = input_delta .+ ((input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) .* theta_one_minus_theta) + derivative_numerator = (input_delta.^2) .* (input_derivatives_plus_one .* root.^2 .+ 2 .* input_delta .* theta_one_minus_theta .+ input_derivatives .* (1 .- root).^2) + + rq_logabsdet = log.(abs.(derivative_numerator) .+ T(1e-12)) .- 2 .* log.(abs.(denominator) .+ T(1e-12)) + rq_logabsdet = -rq_logabsdet + else + theta = (clamped_inputs .- input_cumwidths) ./ input_bin_widths + theta_one_minus_theta = theta .* (1 .- theta) + + numerator = input_heights .* (input_delta .* theta.^2 .+ input_derivatives .* theta_one_minus_theta) + denominator = input_delta .+ ((input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) .* theta_one_minus_theta) + rq_outputs = input_cumheights .+ numerator ./ denominator + + derivative_numerator = (input_delta.^2) .* (input_derivatives_plus_one .* theta.^2 .+ 2 .* input_delta .* theta_one_minus_theta .+ input_derivatives .* (1 .- theta).^2) + rq_logabsdet = log.(abs.(derivative_numerator) .+ T(1e-12)) .- 2 .* log.(abs.(denominator) .+ T(1e-12)) + end + + # 7. Apply identity outside bounds + outputs = ifelse.(inside_mask, rq_outputs, inputs) + logabsdet = ifelse.(inside_mask, rq_logabsdet, zero(T)) + + return outputs, logabsdet +end diff --git a/test/runtests.jl b/test/runtests.jl index 675c57d..c70ac25 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,6 @@ rng = Random.MersenneTwister(42) include("test_realnvp.jl") include("test_io.jl") include("test_normalizer.jl") + include("test_splines.jl") + include("test_nsf.jl") end diff --git a/test/test_nsf.jl b/test/test_nsf.jl new file mode 100644 index 0000000..d9f06d4 --- /dev/null +++ b/test/test_nsf.jl @@ -0,0 +1,52 @@ +using SimpleFlows +using Test +using Random +using Statistics +using Distributions + +@testset "NSF Integrated Model" begin + rng = Random.default_rng() + Random.seed!(rng, 123) + + dist_dims = 2 + n_transforms = 2 + K = 4 + + # 1. Initialization + dist = FlowDistribution(Float32; + architecture=:NSF, + n_transforms=n_transforms, + dist_dims=dist_dims, + hidden_dims=16, + n_layers=2, + K=K + ) + + @test dist.model isa NeuralSplineFlow + @test dist.model.K == K + + # 2. Forward pass (logpdf) + x = randn(Float32, dist_dims, 10) + lp = logpdf(dist, x) + @test length(lp) == 10 + @test all(isfinite, lp) + + # 3. Sampling + samples = rand(rng, dist, 100) + @test size(samples) == (dist_dims, 100) + @test all(isfinite, samples) + + # 4. Training (Smoke test) + # Simple data: Gaussian + target_data = randn(Float32, dist_dims, 200) + + # Initialize with normalizer + train_flow!(dist, target_data; n_epochs=1, batch_size=100) + + @test !isnothing(dist.normalizer) + + # Verify logpdf still works after training + lp_after = logpdf(dist, target_data[:, 1:10]) + @test length(lp_after) == 10 + @test all(isfinite, lp_after) +end diff --git a/test/test_splines.jl b/test/test_splines.jl new file mode 100644 index 0000000..7db1c5a --- /dev/null +++ b/test/test_splines.jl @@ -0,0 +1,93 @@ +using SimpleFlows +using Test +using NPZ +using ForwardDiff +using Zygote +using LinearAlgebra + +@testset "Rational Quadratic Splines" begin + # 1. Load reference data (F64) + data = npzread("/tmp/nsf_test_data_f64.npz") + + for T in [Float32, Float64] + @testset "Precision: $T" begin + inputs = T.(data["inputs"]) + unnormalized_widths = T.(data["unnormalized_widths"]) + unnormalized_heights = T.(data["unnormalized_heights"]) + unnormalized_derivatives = T.(data["unnormalized_derivatives"]) + + ref_outputs = T.(data["outputs"]) + ref_logabsdet = T.(data["logabsdet"]) + + D, N = size(inputs) + + @testset "Numerical Parity with PyTorch" begin + for d in 1:D + out, lad = unconstrained_rational_quadratic_spline( + inputs[d, :], + unnormalized_widths[d, :, :], + unnormalized_heights[d, :, :], + unnormalized_derivatives[d, :, :], + T(3.0) # tail_bound + ) + + @test out ≈ ref_outputs[d, :] atol=(T == Float32 ? 1e-4 : 1e-6) + @test lad ≈ ref_logabsdet[d, :] atol=(T == Float32 ? 1e-4 : 1e-6) + end + end + + @testset "Invertibility" begin + for d in 1:D + out, _ = unconstrained_rational_quadratic_spline( + inputs[d, :], + unnormalized_widths[d, :, :], + unnormalized_heights[d, :, :], + unnormalized_derivatives[d, :, :], + T(3.0) + ) + + back, _ = unconstrained_rational_quadratic_spline( + out, + unnormalized_widths[d, :, :], + unnormalized_heights[d, :, :], + unnormalized_derivatives[d, :, :], + T(3.0); + inverse=true + ) + + @test back ≈ inputs[d, :] atol=(T == Float32 ? 1e-4 : 1e-5) + end + end + + @testset "ForwardDiff Compatibility" begin + d = 1 + x = inputs[d, 1:5] + w = unnormalized_widths[d, 1:5, :] + h = unnormalized_heights[d, 1:5, :] + dv = unnormalized_derivatives[d, 1:5, :] + + f(x_in) = sum(unconstrained_rational_quadratic_spline(x_in, w, h, dv, T(3.0))[1]) + + g = ForwardDiff.gradient(f, x) + @test all(isfinite, g) + @test length(g) == 5 + end + + @testset "Zygote Compatibility" begin + d = 1 + x = inputs[d, 1:5] + w = unnormalized_widths[d, 1:5, :] + h = unnormalized_heights[d, 1:5, :] + dv = unnormalized_derivatives[d, 1:5, :] + + f(x_in) = sum(unconstrained_rational_quadratic_spline(x_in, w, h, dv, T(3.0))[1]) + + gz = Zygote.gradient(f, x)[1] + @test all(isfinite, gz) + + gf = ForwardDiff.gradient(f, x) + @test gz ≈ gf atol=T == Float32 ? 1e-4 : 1e-10 + end + end + end +end From 2c870d8e99c525bcefb102e97f41cd6979c2a1e5 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 5 Mar 2026 08:28:19 -0500 Subject: [PATCH 03/14] Adding MAF --- examples/Manifest.toml | 2 +- examples/Project.toml | 2 + examples/train_multinormal_maf.jl | 111 ++++++++++++++++++++++++ src/SimpleFlows.jl | 9 +- src/distribution.jl | 4 +- src/generic_ops.jl | 84 +++++++++++++++++++ src/io.jl | 22 ++++- src/made.jl | 135 ++++++++++++++++++++++++++++++ src/maf.jl | 99 ++++++++++++++++++++++ src/nsf.jl | 45 ---------- src/realnvp.jl | 45 ---------- src/training.jl | 14 +++- test/runtests.jl | 2 + test/test_maf.jl | 106 +++++++++++++++++++++++ 14 files changed, 579 insertions(+), 101 deletions(-) create mode 100644 examples/train_multinormal_maf.jl create mode 100644 src/generic_ops.jl create mode 100644 src/made.jl create mode 100644 src/maf.jl create mode 100644 test/test_maf.jl diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 5e4edc0..46edf98 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.12.1" manifest_format = "2.0" -project_hash = "64c6eed7b4f2bf47e0f4350e25ac6a3f80fb5a0d" +project_hash = "3d2728b3089e56da155cbe07b6f794382965b85e" [[deps.ADTypes]] git-tree-sha1 = "f7304359109c768cf32dc5fa2d371565bb63b68a" diff --git a/examples/Project.toml b/examples/Project.toml index 5225e9e..7618a38 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,5 +2,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" SimpleFlows = "7aff1418-a6e2-48c2-ba03-ae8e32e98757" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/examples/train_multinormal_maf.jl b/examples/train_multinormal_maf.jl new file mode 100644 index 0000000..5e6f85d --- /dev/null +++ b/examples/train_multinormal_maf.jl @@ -0,0 +1,111 @@ +# examples/train_multinormal_maf.jl +using SimpleFlows +using Random +using Statistics +using Distributions, LinearAlgebra, Random, Statistics, Printf +using Turing +using Optimisers + +function run_maf_example() + # ── 1. Setup Data ────────────────────────────────────────────────────────── + rng = Random.default_rng() + Random.seed!(rng, 123) + + dist_dims = 4 + # Create a correlated 4D Gaussian + μ_true = [1.0, -0.5, 2.0, 0.3] + Σ_true = [1.0 0.5 0.2 0.1; + 0.5 1.0 0.4 0.2; + 0.2 0.4 1.0 0.5; + 0.1 0.2 0.5 1.0] + + true_dist = MvNormal(μ_true, Σ_true) + n_samples = 5000 + target_data = rand(rng, true_dist, n_samples) + + println("Target mean: ", round.(mean(target_data, dims=2)[:], digits=3)) + println("Target std: ", round.(std(target_data, dims=2)[:], digits=3)) + println() + + # ── 2. Initialize MAF Flow ─────────────────────────────────────────────── + # MAF is generally more expressive than RealNVP for same number of layers + n_transforms = 4 + hidden_layer_sizes = [32, 32] + + println("Training MAF ($n_transforms transforms, $(hidden_layer_sizes) units, 500 epochs)…") + + dist = FlowDistribution(Float32; + architecture=:MAF, + n_transforms=n_transforms, + dist_dims=dist_dims, + hidden_layer_sizes=hidden_layer_sizes, + rng=rng + ) + + # ── 3. Train ────────────────────────────────────────────────────────────── + # We use a smaller learning rate for MAF as it can be more sensitive + train_flow!(dist, target_data; n_epochs=500, batch_size=200, opt=Optimisers.Adam(1e-3)) + + # ── 4. Evaluate Density ────────────────────────────────────────────────── + test_data = rand(rng, true_dist, 1000) + lp_true = logpdf(true_dist, test_data) + lp_flow = logpdf(dist, test_data) + + println("\n── Density Fit ──────────────────────────────────────────") + println("Mean log-pdf (true distribution): ", round(mean(lp_true), digits=4)) + println("Mean log-pdf (trained MAF flow): ", round(mean(lp_flow), digits=4)) + println("Difference: ", round(abs(mean(lp_true) - mean(lp_flow)), digits=4)) + + samples = rand(rng, dist, 5000) + println("\nFlow sample mean: ", round.(mean(samples, dims=2)[:], digits=3)) + println("Flow sample std: ", round.(std(samples, dims=2)[:], digits=3)) + + # ── 5. Serialization ───────────────────────────────────────────────────── + save_path = joinpath(@__DIR__, "../trained_flows/maf_mvn_4d") + save_trained_flow(save_path, dist) + + # Reload and verify + dist_reloaded = load_trained_flow(save_path) + lp_reloaded = logpdf(dist_reloaded, test_data) + println("Mean log-pdf after reload: ", round(mean(lp_reloaded), digits=4)) + if isapprox(mean(lp_flow), mean(lp_reloaded), atol=1e-5) + println("Round-trip OK ✓") + else + println("Round-trip FAILED ✗") + end + + # ── 6. Turing.jl Integration ─────────────────────────────────────────── + println("\n── Turing Inference ─────────────────────────────────────────") + # Observed data: one sample from θ[1] with noise + θ_true = μ_true + y_obs = θ_true[1] + 0.1 * randn(rng) + println("Observed y: ", round(y_obs, digits=4), " (true θ[1] = ", θ_true[1], ")") + + @model function linear_model(y, prior_dist) + θ ~ prior_dist + y ~ Normal(θ[1], 0.1) + end + + # Sampling with exact prior + println("\nSampling with exact MvNormal prior (4 chains × 1000 samples)…") + chain_exact = sample(linear_model(y_obs, true_dist), HMC(0.1, 10), MCMCThreads(), 1000, 4; progress=false) + + # Sampling with MAF prior + println("Sampling with trained MAF prior (4 chains × 1000 samples)…") + chain_maf = sample(linear_model(y_obs, dist), HMC(0.1, 10), MCMCThreads(), 1000, 4; progress=false) + + println("\n── Posterior comparison (θ[1]) ──────────────────────────────") + exact_θ1 = vec(chain_exact[:, "θ[1]", :]) + maf_θ1 = vec(chain_maf[:, "θ[1]", :]) + + @printf " True θ[1]: %.4f\n" θ_true[1] + @printf " Observed y: %.4f\n" y_obs + @printf " Posterior mean (exact): %.4f ± %.4f\n" mean(exact_θ1) std(exact_θ1) + @printf " Posterior mean (MAF): %.4f ± %.4f\n" mean(maf_θ1) std(maf_θ1) + + println("\n✨ End of script: Turing MCMC sampling with MAF finished successfully!") +end + +if abspath(PROGRAM_FILE) == joinpath(@__DIR__, "train_multinormal_maf.jl") + run_maf_example() +end diff --git a/src/SimpleFlows.jl b/src/SimpleFlows.jl index 2d06472..390a688 100644 --- a/src/SimpleFlows.jl +++ b/src/SimpleFlows.jl @@ -9,13 +9,16 @@ using Optimisers, Zygote include("layers.jl") include("realnvp.jl") include("normalizer.jl") +include("splines.jl") +include("nsf.jl") +include("made.jl") +include("maf.jl") +include("generic_ops.jl") include("distribution.jl") include("training.jl") include("io.jl") -include("splines.jl") -include("nsf.jl") -export RealNVP, NeuralSplineFlow, FlowDistribution, NSFCouplingLayer +export RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow, FlowDistribution, NSFCouplingLayer export MinMaxNormalizer export train_flow!, save_trained_flow, load_trained_flow export unconstrained_rational_quadratic_spline diff --git a/src/distribution.jl b/src/distribution.jl index f09b02f..25b2a2a 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -45,8 +45,10 @@ function FlowDistribution(::Type{T}=Float32; RealNVP(; n_transforms, dist_dims, hidden_layer_sizes, activation) elseif architecture == :NSF NeuralSplineFlow(; n_transforms, dist_dims, hidden_layer_sizes, K, tail_bound, activation) + elseif architecture == :MAF + MaskedAutoregressiveFlow(; n_transforms, dist_dims, hidden_layer_sizes, activation) else - error("Unknown architecture: $architecture. Supported: :RealNVP, :NSF") + error("Unknown architecture: $architecture. Supported: :RealNVP, :NSF, :MAF") end ps, st = Lux.setup(rng, model) diff --git a/src/generic_ops.jl b/src/generic_ops.jl new file mode 100644 index 0000000..2d59b7a --- /dev/null +++ b/src/generic_ops.jl @@ -0,0 +1,84 @@ +# src/generic_ops.jl + +""" + log_prob(model, ps, st, x) -> Vector + +Compute per-sample log-probability of `x` (shape: dist_dims × batch) under the flow. +Pure functional — no mutations, safe for Zygote. +Supports RealNVP, NeuralSplineFlow, and MaskedAutoregressiveFlow. +""" +function log_prob(model::Union{RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow}, ps, st, x::AbstractMatrix) + lp = nothing + for i in model.n_transforms:-1:1 + ks = model isa MaskedAutoregressiveFlow ? keys(model.mades) : keys(model.conditioners) + k = ks[i] + + if model isa MaskedAutoregressiveFlow + bj = MAFBijector(model.mades[k], ps.mades[k], st.mades[k]) + # Inverse is x -> u, which is what we need for log_prob + x, ld = Bijectors.with_logabsdet_jacobian(bj, x) + else + mask = st.mask_list[i] + cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], + s = st.conditioners[k] + x_cond -> Lux.apply(m, x_cond, p, s)[1] + end + + bj = if model isa RealNVP + MaskedCoupling(mask, cond_fn, AffineBijector) + else + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) + end + x, ld = inverse_and_log_det(bj, x) + end + + lp = isnothing(lp) ? ld : lp .+ ld + + # Apply ReversePermute between MAF blocks + if model isa MaskedAutoregressiveFlow && i > 1 + x = x[end:-1:1, :] + end + end + base_lp = dsum(gaussian_logpdf.(x); dims=(1,)) + return isnothing(lp) ? base_lp : lp .+ base_lp +end + +""" + draw_samples(rng, T, model, ps, st, n_samples) -> Matrix + +Sample from the flow by pushing Gaussian noise through the forward transforms. +Supports RealNVP, NeuralSplineFlow, and MaskedAutoregressiveFlow. +""" +function draw_samples(rng::AbstractRNG, ::Type{T}, model::Union{RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow}, + ps, st, n_samples::Int) where T + x = randn(rng, T, model.dist_dims, n_samples) + for i in 1:(model.n_transforms) + ks = model isa MaskedAutoregressiveFlow ? keys(model.mades) : keys(model.conditioners) + k = ks[i] + + if model isa MaskedAutoregressiveFlow + bj = MAFBijector(model.mades[k], ps.mades[k], st.mades[k]) + x, _ = forward_and_log_det(bj, x) + else + mask = st.mask_list[i] + cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], + s = st.conditioners[k] + x_cond -> Lux.apply(m, x_cond, p, s)[1] + end + + bj = if model isa RealNVP + MaskedCoupling(mask, cond_fn, AffineBijector) + else + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) + end + + x, _ = forward_and_log_det(bj, x) + end + + # Apply ReversePermute between MAF blocks + if model isa MaskedAutoregressiveFlow && i < model.n_transforms + x = x[end:-1:1, :] + end + end + return x +end diff --git a/src/io.jl b/src/io.jl index 8280860..6004068 100644 --- a/src/io.jl +++ b/src/io.jl @@ -34,8 +34,18 @@ end # ── Architecture dict ───────────────────────────────────────────────────────── function _flow_to_dict(flow::FlowDistribution) + arch = if flow.model isa RealNVP + "RealNVP" + elseif flow.model isa NeuralSplineFlow + "NSF" + elseif flow.model isa MaskedAutoregressiveFlow + "MAF" + else + error("Unknown architecture type: $(typeof(flow.model))") + end + d = Dict( - "architecture" => (flow.model isa RealNVP ? "RealNVP" : "NSF"), + "architecture" => arch, "n_transforms" => flow.model.n_transforms, "dist_dims" => flow.n_dims, "hidden_layer_sizes" => flow.hidden_layer_sizes, @@ -50,7 +60,15 @@ end function _build_flow_from_dict(d::AbstractDict, ::Type{T}=Float32, rng::AbstractRNG=Random.default_rng()) where {T<:Real} arch = d["architecture"] - arch_sym = (arch == "RealNVP" ? :RealNVP : :NSF) + arch_sym = if arch == "RealNVP" + :RealNVP + elseif arch == "NSF" + :NSF + elseif arch == "MAF" + :MAF + else + error("Unknown architecture: $arch") + end # Support both new (hidden_layer_sizes) and legacy (hidden_dims + n_layers) formats hidden_layer_sizes = if haskey(d, "hidden_layer_sizes") diff --git a/src/made.jl b/src/made.jl new file mode 100644 index 0000000..0e82a54 --- /dev/null +++ b/src/made.jl @@ -0,0 +1,135 @@ +# src/made.jl +using Lux +using Random +using Statistics + +""" + create_degrees(input_dims::Int, hidden_dims::Vector{Int}, out_dims::Int) + +Generate degrees for each unit in a MADE network to ensure autoregressive property. +Returns a list of degree vectors (one for input, each hidden layer, and output). +""" +function create_degrees(input_dims::Int, hidden_dims::Vector{Int}, out_dims::Int; sequential::Bool=true) + # Degrees for input: 1 to D + degrees = [collect(1:input_dims)] + + # Degrees for hidden layers + for h_dim in hidden_dims + if sequential + # Use a deterministic pattern: repeat 1 to D-1 + d_hidden = [((i - 1) % (input_dims - 1)) + 1 for i in 1:h_dim] + push!(degrees, d_hidden) + else + # Sample hidden degrees in [1, D-1] + push!(degrees, rand(1:(input_dims-1), h_dim)) + end + end + + # Degrees for output: same as input (usually out_dims = input_dims or 2*input_dims) + # If out_dims is a multiple of input_dims, we repeat the degrees + if out_dims % input_dims == 0 + n_repeats = out_dims ÷ input_dims + push!(degrees, repeat(collect(1:input_dims), n_repeats)) + else + # Fallback for arbitrary output sizes (though usually it's D or 2D) + push!(degrees, collect(1:out_dims)) + end + + return degrees +end + +""" + create_masks(degrees::Vector{Vector{Int}}) + +Create binary masks from a list of degree vectors. +""" +function create_masks(degrees::Vector{Vector{Int}}) + masks = Matrix{Float32}[] + # Layer l goes from degrees[l] to degrees[l+1] + for i in 1:(length(degrees) - 1) + m_curr = degrees[i] + m_next = degrees[i+1] + + # Matrix shape is (length(m_next), length(m_curr)) + mask = zeros(Float32, length(m_next), length(m_curr)) + + for r in 1:length(m_next) + for c in 1:length(m_curr) + if i < length(degrees)-1 + # Hidden layers: weight exists if degree(prev) <= degree(curr) + if m_next[r] >= m_curr[c] + mask[r, c] = 1.0f0 + end + else + # Output layer: weight exists if degree(prev) < degree(curr) + # This ensures y_i depends only on x_{ m_curr[c] + mask[r, c] = 1.0f0 + end + end + end + end + push!(masks, mask) + end + return masks +end + +# ── MaskedDense Layer ──────────────────────────────────────────────────────── + +struct MaskedDense{F} <: Lux.AbstractLuxLayer + activation::F + in_dims::Int + out_dims::Int + mask::Matrix{Float32} +end + +function MaskedDense(in_dims::Int, out_dims::Int, mask::Matrix{Float32}, activation=identity) + return MaskedDense(activation, in_dims, out_dims, mask) +end + +function Lux.initialparameters(rng::AbstractRNG, l::MaskedDense) + return ( + weight = Lux.glorot_uniform(rng, l.out_dims, l.in_dims), + bias = zeros(Float32, l.out_dims) + ) +end + +function Lux.initialstates(rng::AbstractRNG, l::MaskedDense) + return NamedTuple() +end + +function (l::MaskedDense)(x::AbstractArray, ps, st) + # y = (W .* mask) * x + b + # We use ps.weight .* l.mask for differentiability + y = (ps.weight .* l.mask) * x .+ ps.bias + return l.activation.(y), st +end + +# ── MADE Network ──────────────────────────────────────────────────────────── + +struct MADE <: Lux.AbstractLuxContainerLayer{(:layers,)} + layers::Lux.Chain + input_dims::Int + hidden_dims::Vector{Int} + out_dims::Int +end + +function MADE(input_dims::Int, hidden_dims::Vector{Int}, out_dims::Int; + activation=relu, sequential::Bool=true) + degrees = create_degrees(input_dims, hidden_dims, out_dims; sequential) + masks = create_masks(degrees) + + layers_list = [] + for i in 1:length(masks) + in_d = size(masks[i], 2) + out_d = size(masks[i], 1) + act = (i == length(masks)) ? identity : activation + push!(layers_list, MaskedDense(in_d, out_d, masks[i], act)) + end + + return MADE(Lux.Chain(layers_list...), input_dims, hidden_dims, out_dims) +end + +function (l::MADE)(x::AbstractArray, ps, st) + return Lux.apply(l.layers, x, ps.layers, st.layers) +end diff --git a/src/maf.jl b/src/maf.jl new file mode 100644 index 0000000..e4b67ca --- /dev/null +++ b/src/maf.jl @@ -0,0 +1,99 @@ +# src/maf.jl +using Lux +using Bijectors +using Random +using LinearAlgebra + +# ── MAF Bijector ──────────────────────────────────────────────────────────── + +struct MAFBijector + made + ps + st +end + +""" + inverse_and_log_det(b::MAFBijector, x::AbstractArray) + +Density estimation pass (Inverse): u = (x - m) * exp(-alpha). Fast O(1). +""" +function Bijectors.with_logabsdet_jacobian(b::MAFBijector, x::AbstractMatrix) + # MADE(x) returns (m, log_alpha) concatenated + out, _ = Lux.apply(b.made, x, b.ps, b.st) + D = size(x, 1) + m = out[1:D, :] + log_alpha = out[D+1:end, :] + + # Inverse transform + u = (x .- m) .* exp.(-log_alpha) + + # Log-determinant: sum of -log_alpha across dimensions + lad = -sum(log_alpha; dims=1) + + # Return as (output, logabsdet) + return u, vec(lad) +end + +""" + forward_and_log_det(b::MAFBijector, u::AbstractArray) + +Sampling pass (Forward): x_i = u_i * exp(alpha_i) + m_i. Sequential O(D). +""" +function forward_and_log_det(b::MAFBijector, u::AbstractMatrix) + D, N = size(u) + T = eltype(u) + + # Initialize x with zeros. We will fill it dimension by dimension. + # To be Zygote friendly, we can use a loop and vcat/hcat or just use a copy if allowed. + # Actually, sampling is usually not the target of AD during training, but for completeness: + + x = zeros(T, D, N) + + for i in 1:D + # Compute parameters for the current state of x + out, _ = Lux.apply(b.made, x, b.ps, b.st) + m_i = out[i, :] + log_alpha_i = out[D+i, :] + + # Update row i of x + # x[i, :] = u[i, :] .* exp.(log_alpha_i) .+ m_i + # Non-mutating version for Zygote: + new_row = u[i:i, :] .* exp.(log_alpha_i') .+ m_i' + x = vcat(x[1:i-1, :], new_row, x[i+1:end, :]) + end + + # Compute final log_alpha for the fully constructed x to get log_det + out_final, _ = Lux.apply(b.made, x, b.ps, b.st) + log_alpha_final = out_final[D+1:end, :] + lad = sum(log_alpha_final; dims=1) + + return x, vec(lad) +end + +# ── ReversePermute ────────────────────────────────────────────────────────── + +struct ReversePermute <: Lux.AbstractLuxLayer end + +function (l::ReversePermute)(x::AbstractArray, ps, st) + return x[end:-1:1, :], st +end + +# ── MaskedAutoregressiveFlow ─────────────────────────────────────────────── + +@concrete struct MaskedAutoregressiveFlow <: Lux.AbstractLuxContainerLayer{(:mades,)} + mades + dist_dims::Int + n_transforms::Int + hidden_layer_sizes::Vector{Int} +end + +function MaskedAutoregressiveFlow(; n_transforms::Int, dist_dims::Int, hidden_layer_sizes::Vector{Int}, activation=relu) + mades_list = [MADE(dist_dims, hidden_layer_sizes, 2 * dist_dims; activation) for _ in 1:n_transforms] + keys_ = ntuple(i -> Symbol(:made_, i), n_transforms) + mades = NamedTuple{keys_}(Tuple(mades_list)) + return MaskedAutoregressiveFlow(mades, dist_dims, n_transforms, hidden_layer_sizes) +end + +function Lux.initialstates(rng::AbstractRNG, m::MaskedAutoregressiveFlow) + return (mades = Lux.initialstates(rng, m.mades),) +end diff --git a/src/nsf.jl b/src/nsf.jl index 8643844..b16af8f 100644 --- a/src/nsf.jl +++ b/src/nsf.jl @@ -150,51 +150,6 @@ function Lux.initialstates(rng::AbstractRNG, m::NeuralSplineFlow) return (; mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) end -# Generic log_prob and draw_samples -function log_prob(model::Union{RealNVP, NeuralSplineFlow}, ps, st, x::AbstractMatrix) - lp = nothing - for i in model.n_transforms:-1:1 - k = keys(model.conditioners)[i] - mask = st.mask_list[i] - cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], - s = st.conditioners[k] - x_cond -> Lux.apply(m, x_cond, p, s)[1] - end - - bj = if model isa RealNVP - MaskedCoupling(mask, cond_fn, AffineBijector) - else - MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) - end - - x, ld = inverse_and_log_det(bj, x) - lp = isnothing(lp) ? ld : lp .+ ld - end - base_lp = dsum(gaussian_logpdf.(x); dims=(1,)) - return isnothing(lp) ? base_lp : lp .+ base_lp -end - -function draw_samples(rng::AbstractRNG, ::Type{T}, model::Union{RealNVP, NeuralSplineFlow}, - ps, st, n_samples::Int) where T - x = randn(rng, T, model.dist_dims, n_samples) - for i in 1:(model.n_transforms) - k = keys(model.conditioners)[i] - mask = st.mask_list[i] - cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], - s = st.conditioners[k] - x_cond -> Lux.apply(m, x_cond, p, s)[1] - end - - bj = if model isa RealNVP - MaskedCoupling(mask, cond_fn, AffineBijector) - else - MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) - end - - x, _ = forward_and_log_det(bj, x) - end - return x -end # Helper to build the bijector from flat parameters struct NSFSplineConstructor diff --git a/src/realnvp.jl b/src/realnvp.jl index 1c4fac1..135770e 100644 --- a/src/realnvp.jl +++ b/src/realnvp.jl @@ -40,48 +40,3 @@ function Lux.initialstates(rng::AbstractRNG, m::RealNVP) for i in 1:(m.n_transforms)] return (; mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) end - -# ── Pure-functional log-prob and sample ────────────────────────────────────── - -""" - log_prob(model, ps, st, x) -> Vector - -Compute per-sample log-probability of `x` (shape: dist_dims × batch) under the flow. -Pure functional — no mutations, safe for Zygote. -""" -function log_prob(model::RealNVP, ps, st, x::AbstractMatrix) - lp = nothing - for i in model.n_transforms:-1:1 - k = keys(model.conditioners)[i] - mask = st.mask_list[i] - cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], - s = st.conditioners[k] - x_cond -> Lux.apply(m, x_cond, p, s)[1] - end - bj = MaskedCoupling(mask, cond_fn, AffineBijector) - x, ld = inverse_and_log_det(bj, x) - lp = isnothing(lp) ? ld : lp .+ ld - end - base_lp = dsum(gaussian_logpdf.(x); dims=(1,)) - return isnothing(lp) ? base_lp : lp .+ base_lp -end - -""" - draw_samples(rng, T, model, ps, st, n_samples) -> Matrix - -Sample from the flow by pushing Gaussian noise through the forward transforms. -""" -function draw_samples(rng::AbstractRNG, ::Type{T}, model::RealNVP, - ps, st, n_samples::Int) where T - x = randn(rng, T, model.dist_dims, n_samples) - for i in 1:(model.n_transforms) - k = keys(model.conditioners)[i] - cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], - s = st.conditioners[k] - x_cond -> Lux.apply(m, x_cond, p, s)[1] - end - bj = MaskedCoupling(st.mask_list[i], cond_fn, AffineBijector) - x, _ = forward_and_log_det(bj, x) - end - return x -end diff --git a/src/training.jl b/src/training.jl index c4b55f5..5348bdf 100644 --- a/src/training.jl +++ b/src/training.jl @@ -16,15 +16,21 @@ negative log-likelihood on the normalised data. """ function train_flow!(flow::FlowDistribution{T}, data::AbstractMatrix; n_epochs::Int=1000, - lr::Real=T(1f-3), + lr::Union{Nothing, Real}=nothing, batch_size::Int=256, - verbose::Bool=true) where {T} + verbose::Bool=true, + opt=nothing) where {T} # Always fit and apply a min-max normalizer flow.normalizer = MinMaxNormalizer(T.(data)) data_T = normalize(flow.normalizer, T.(data)) - opt = Optimisers.OptimiserChain(Optimisers.ClipGrad(T(1)), Optimisers.Adam(T(lr))) - opt_state = Optimisers.setup(opt, flow.ps) + actual_opt = if isnothing(opt) + actual_lr = isnothing(lr) ? T(1f-3) : T(lr) + Optimisers.OptimiserChain(Optimisers.ClipGrad(T(1)), Optimisers.Adam(actual_lr)) + else + opt + end + opt_state = Optimisers.setup(actual_opt, flow.ps) loader = DataLoader(data_T; batchsize=batch_size, shuffle=true) diff --git a/test/runtests.jl b/test/runtests.jl index c70ac25..7c8e931 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test using SimpleFlows using Random, Distributions, LinearAlgebra +using Lux, Zygote, Bijectors, ForwardDiff rng = Random.MersenneTwister(42) @@ -11,4 +12,5 @@ rng = Random.MersenneTwister(42) include("test_normalizer.jl") include("test_splines.jl") include("test_nsf.jl") + include("test_maf.jl") end diff --git a/test/test_maf.jl b/test/test_maf.jl new file mode 100644 index 0000000..d43073d --- /dev/null +++ b/test/test_maf.jl @@ -0,0 +1,106 @@ +using SimpleFlows +using Test +using Random +using Lux +using ForwardDiff +using Zygote +using Bijectors +using Statistics +using LinearAlgebra + +@testset "MAF Components" begin + rng = Random.default_rng() + Random.seed!(rng, 123) + + D = 4 + hidden = [16, 16] + + @testset "MADE Autoregressive Property" begin + # Create a MADE network + made = SimpleFlows.MADE(D, hidden, 2*D) + ps, st = Lux.setup(rng, made) + + # We check the Jacobian of the outputs with respect to the inputs. + # For an autoregressive model, the Jacobian of y_i w.r.t x_j must be zero if j >= i. + # Since we have 2 outputs per input (m and log_alpha), we check both blocks. + + function f(x) + out, _ = Lux.apply(made, x, ps, st) + return out + end + + x0 = randn(Float32, D) + J = ForwardDiff.jacobian(f, x0) # (2*D, D) + + # Block 1: m (1:D, 1:D) + J_m = J[1:D, :] + @test isapprox(tril(J_m, -1), J_m, atol=1e-7) + + # Block 2: log_alpha (D+1:2D, 1:D) + J_la = J[D+1:end, :] + @test isapprox(tril(J_la, -1), J_la, atol=1e-7) + end + + @testset "Differentiability" begin + made = SimpleFlows.MADE(D, hidden, 2*D) + ps, st = Lux.setup(rng, made) + x = randn(Float32, D, 5) + + # Zygote test + gs = Zygote.gradient(ps) do p + out, _ = Lux.apply(made, x, p, st) + sum(out) + end + @test gs[1] !== nothing + + # ForwardDiff test + function g(p_vec) + # Reconstruct NamedTuple structure if possible, but easier to test w.r.t x + return nothing + end + + gx = ForwardDiff.gradient(x -> sum(Lux.apply(made, x, ps, st)[1]), x[:, 1]) + @test all(isfinite, gx) + end + + @testset "MAFBijector Invertibility" begin + made = SimpleFlows.MADE(D, hidden, 2*D) + ps, st = Lux.setup(rng, made) + bj = SimpleFlows.MAFBijector(made, ps, st) + + x = randn(Float32, D, 4) + + # Inverse (Density pass) + u, lad_inv = Bijectors.with_logabsdet_jacobian(bj, x) + + # Forward (Sampling pass) + x_rec, lad_fwd = SimpleFlows.forward_and_log_det(bj, u) + + @test x ≈ x_rec atol=1e-4 + @test lad_inv ≈ -lad_fwd atol=1e-4 + end + + @testset "Integrated MAF Model" begin + dist = FlowDistribution(Float32; + architecture=:MAF, + n_transforms=2, + dist_dims=D, + hidden_dims=16, + n_layers=2 + ) + + x = randn(Float32, D, 10) + lp = logpdf(dist, x) + @test length(lp) == 10 + @test all(isfinite, lp) + + samples = rand(rng, dist, 10) + @test size(samples) == (D, 10) + @test all(isfinite, samples) + + # Training smoke test + data = randn(Float32, D, 100) + train_flow!(dist, data; n_epochs=1, batch_size=50) + @test !isnothing(dist.normalizer) + end +end From 6cfaff566f724413dc4d50b5fef2ac3116803e83 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 5 Mar 2026 08:39:01 -0500 Subject: [PATCH 04/14] Update README.md --- README.md | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 771b521..a50a463 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,10 @@ distributions inside [`Turing.jl`](https://turinglang.org/) probabilistic progra ## Features -- **RealNVP** (Real-valued Non-Volume Preserving) architecture with affine coupling layers +- **Architectures**: + - **RealNVP** (Real-valued Non-Volume Preserving) + - **NSF** (Neural Spline Flow with Rational Quadratic Splines) + - **MAF** (Masked Autoregressive Flow with MADE) - Full `Distributions.jl` interface: `logpdf`, `rand`, `length` - `Bijectors.jl` compatible → plug directly into Turing models as a prior - Training via `Optimisers.Adam` + `Zygote.jl` autodiff, mini-batched with `MLUtils` @@ -15,10 +18,10 @@ distributions inside [`Turing.jl`](https://turinglang.org/) probabilistic progra ## Quick Start ```julia -using SimpleFlows, Distributions +using SimpleFlows, Distributions, LinearAlgebra -# 1. Build a 4-dim flow -flow = FlowDistribution(; n_transforms=6, dist_dims=4, hidden_dims=64, n_layers=3) +# 1. Build a 4-dim flow (options: :RealNVP, :NSF, :MAF) +flow = FlowDistribution(Float32; architecture=:RealNVP, n_transforms=6, dist_dims=4, hidden_layer_sizes=[64, 64, 64]) # 2. Sample training data from your target distribution data = Float32.(rand(MvNormal(zeros(4), I), 10_000)) @@ -60,8 +63,8 @@ my_flow/ | Architecture | Status | |---|---| | RealNVP | ✅ Done | -| MAF | 📋 Planned | -| NSF | 📋 Planned | +| MAF | ✅ Done | +| NSF | ✅ Done | ## Running Tests @@ -69,8 +72,10 @@ my_flow/ julia --project=. -e "using Pkg; Pkg.test()" ``` -## Running the Example +## Running the Examples ```bash julia --project=. examples/train_multinormal.jl +julia --project=. examples/train_multinormal_nsf.jl +julia --project=. examples/train_multinormal_maf.jl ``` From 2af05c7436330e856b469e0fd843a39411ed4353 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 5 Mar 2026 11:20:02 -0500 Subject: [PATCH 05/14] Small changes; more tests --- src/SimpleFlows.jl | 4 +- src/distribution.jl | 8 ++-- src/generic_ops.jl | 4 +- src/io.jl | 2 +- src/layers.jl | 31 ++++++++---- src/normalizer.jl | 36 +++++++++----- src/nsf.jl | 104 ++++++++++++++-------------------------- src/realnvp.jl | 22 ++++++--- src/training.jl | 7 +-- test/runtests.jl | 3 +- test/test_io.jl | 76 +++++++++++++++++------------ test/test_layers.jl | 2 +- test/test_normalizer.jl | 19 ++++++++ test/test_training.jl | 54 +++++++++++++++++++++ 14 files changed, 232 insertions(+), 140 deletions(-) create mode 100644 test/test_training.jl diff --git a/src/SimpleFlows.jl b/src/SimpleFlows.jl index 390a688..2135773 100644 --- a/src/SimpleFlows.jl +++ b/src/SimpleFlows.jl @@ -18,9 +18,9 @@ include("distribution.jl") include("training.jl") include("io.jl") -export RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow, FlowDistribution, NSFCouplingLayer +export RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow, FlowDistribution export MinMaxNormalizer -export train_flow!, save_trained_flow, load_trained_flow +export train_flow!, save_trained_flow, load_trained_flow, normalize, denormalize export unconstrained_rational_quadratic_spline end diff --git a/src/distribution.jl b/src/distribution.jl index 25b2a2a..e4cc8a2 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -62,12 +62,12 @@ Distributions.length(d::FlowDistribution) = d.n_dims function _apply_normalizer(d::FlowDistribution{T}, x::AbstractMatrix{<:Real}) where {T} isnothing(d.normalizer) && return x, zero(T) - return normalize(d.normalizer, x), d.normalizer.log_jac + return SimpleFlows.normalize(d.normalizer, x), d.normalizer.log_jac end function _apply_normalizer(d::FlowDistribution{T}, x::AbstractVector{<:Real}) where {T} isnothing(d.normalizer) && return x, zero(T) - return normalize(d.normalizer, x), d.normalizer.log_jac + return SimpleFlows.normalize(d.normalizer, x), d.normalizer.log_jac end function Distributions.logpdf(d::FlowDistribution, x::AbstractVector{<:Real}) @@ -85,14 +85,14 @@ end function Base.rand(rng::AbstractRNG, d::FlowDistribution{T}) where {T} z = draw_samples(rng, T, d.model, d.ps, d.st, 1) - x = isnothing(d.normalizer) ? z : denormalize(d.normalizer, z) + x = isnothing(d.normalizer) ? z : SimpleFlows.denormalize(d.normalizer, z) return T.(vec(x)) end function Distributions.rand(rng::AbstractRNG, d::FlowDistribution{T}, n::Int) where {T} z = draw_samples(rng, T, d.model, d.ps, d.st, n) isnothing(d.normalizer) && return z - return denormalize(d.normalizer, z) + return SimpleFlows.denormalize(d.normalizer, z) end # ── Bijectors.jl interface ─────────────────────────────────────────────────── diff --git a/src/generic_ops.jl b/src/generic_ops.jl index 2d59b7a..097ea63 100644 --- a/src/generic_ops.jl +++ b/src/generic_ops.jl @@ -27,7 +27,7 @@ function log_prob(model::Union{RealNVP, NeuralSplineFlow, MaskedAutoregressiveFl bj = if model isa RealNVP MaskedCoupling(mask, cond_fn, AffineBijector) else - MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, model.K, model.tail_bound)) end x, ld = inverse_and_log_det(bj, x) end @@ -69,7 +69,7 @@ function draw_samples(rng::AbstractRNG, ::Type{T}, model::Union{RealNVP, NeuralS bj = if model isa RealNVP MaskedCoupling(mask, cond_fn, AffineBijector) else - MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, mask, model.K, model.tail_bound)) + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, model.K, model.tail_bound)) end x, _ = forward_and_log_det(bj, x) diff --git a/src/io.jl b/src/io.jl index 6004068..33a0d90 100644 --- a/src/io.jl +++ b/src/io.jl @@ -151,7 +151,7 @@ function load_trained_flow(path::String; rng::AbstractRNG=Random.default_rng()) xmin = flat["normalizer_xmin"] xmax = flat["normalizer_xmax"] T_norm = eltype(xmin) - flow.normalizer = MinMaxNormalizer(xmin, xmax, sum(-log.(xmax .- xmin))) + flow.normalizer = MinMaxNormalizer{T_norm}(xmin, xmax, T_norm(sum(-log.(xmax .- xmin)))) delete!(flat, "normalizer_xmin") delete!(flat, "normalizer_xmax") end diff --git a/src/layers.jl b/src/layers.jl index 605052d..e5265d9 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -49,19 +49,32 @@ masked dimensions are transformed. bijector_constructor end -function _apply_mask(bj::MaskedCoupling, x::AbstractArray, transform_fn) - x_cond = x .* .!bj.mask # pass-through dims → conditioner input - params = bj.conditioner(x_cond) - y, log_det = transform_fn(params) - log_det = log_det .* bj.mask # only masked dims contribute to log|J| - y = ifelse.(bj.mask, y, x) # pass through unmasked dims unchanged - return y, dsum(log_det; dims=Tuple(1:(ndims(x) - 1))) +function _apply_mask(bj::MaskedCoupling, x::AbstractMatrix, transform_fn) + D, N = size(x) + m = bj.mask + + # 1. Conditioning + x_cond = x .* .!m + params = bj.conditioner(x_cond) + + # 2. Transform the active dims only + x_tr = x[m, :] + bj_inner = bj.bijector_constructor(params) + y_tr, ld_tr = transform_fn(bj_inner, x_tr) + + # 3. Reconstruct full y + # We use a Zygote-friendly reconstruction + tr_idx = cumsum(m) + y = vcat([m[i] ? y_tr[tr_idx[i]:tr_idx[i], :] : x[i:i, :] for i in 1:D]...) + + # sum log-dets over the transformed dims + return y, dsum(ld_tr; dims=(1,)) end function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray) - _apply_mask(bj, x, p -> forward_and_log_det(bj.bijector_constructor(p), x)) + _apply_mask(bj, x, forward_and_log_det) end function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray) - _apply_mask(bj, y, p -> inverse_and_log_det(bj.bijector_constructor(p), y)) + _apply_mask(bj, y, inverse_and_log_det) end diff --git a/src/normalizer.jl b/src/normalizer.jl index 43e1ba9..d34c56a 100644 --- a/src/normalizer.jl +++ b/src/normalizer.jl @@ -21,11 +21,17 @@ end Fit a `MinMaxNormalizer` from training data (shape `n_dims × n_samples`). """ -function MinMaxNormalizer(data::AbstractMatrix{T}) where T - xmin = vec(minimum(data; dims=2)) - xmax = vec(maximum(data; dims=2)) - log_jac = sum(-log.(xmax .- xmin)) - return MinMaxNormalizer{T}(xmin, xmax, log_jac) +function MinMaxNormalizer(x::AbstractMatrix{T}) where {T} + x_min = vec(minimum(x, dims=2)) + x_max = vec(maximum(x, dims=2)) + + if any(x_min .≈ x_max) + throw(ArgumentError("Data has zero variance along one or more dimensions. Cannot initialize MinMaxNormalizer.")) + end + + # Base volume change: sum(-log(x_max - x_min)) + logabsdet = sum(-log.(x_max .- x_min)) + return MinMaxNormalizer{T}(x_min, x_max, logabsdet) end """ @@ -33,19 +39,23 @@ end Apply the forward transform: `z = (x - x_min) / (x_max - x_min)`. """ -normalize(n::MinMaxNormalizer, x::AbstractMatrix) = - (x .- n.x_min) ./ (n.x_max .- n.x_min) +function normalize(n::MinMaxNormalizer, x::AbstractMatrix) + return (x .- n.x_min) ./ (n.x_max .- n.x_min) +end -normalize(n::MinMaxNormalizer, x::AbstractVector) = - (x .- n.x_min) ./ (n.x_max .- n.x_min) +function normalize(n::MinMaxNormalizer, x::AbstractVector) + return (x .- n.x_min) ./ (n.x_max .- n.x_min) +end """ denormalize(n::MinMaxNormalizer, z) -> x Apply the inverse transform: `x = z * (x_max - x_min) + x_min`. """ -denormalize(n::MinMaxNormalizer, z::AbstractMatrix) = - z .* (n.x_max .- n.x_min) .+ n.x_min +function denormalize(n::MinMaxNormalizer, z::AbstractMatrix) + return z .* (n.x_max .- n.x_min) .+ n.x_min +end -denormalize(n::MinMaxNormalizer, z::AbstractVector) = - z .* (n.x_max .- n.x_min) .+ n.x_min +function denormalize(n::MinMaxNormalizer, z::AbstractVector) + return z .* (n.x_max .- n.x_min) .+ n.x_min +end diff --git a/src/nsf.jl b/src/nsf.jl index b16af8f..53ae8c2 100644 --- a/src/nsf.jl +++ b/src/nsf.jl @@ -5,32 +5,26 @@ using Random using LinearAlgebra """ - NSFCouplingLayer(mask, conditioner; K=8, tail_bound=3.0) + NSFSplineBijector(mask, params, K, tail_bound) A Neural Spline Flow (NSF) coupling layer using Rational Quadratic Splines. `mask` is a binary vector (1 for variables to transform, 0 for variables that condition). -`conditioner` is a Lux network that takes variables with 0 and returns spline parameters for variables with 1. +`params` is the output of the conditioner for the transformed dimensions. """ # Bijectors and Layers for NSF struct NSFSplineBijector - mask params K::Int tail_bound::Float64 end function forward_and_log_det(b::NSFSplineBijector, x::AbstractArray) - # x is (D, N) - D, N = size(x) - mask = b.mask + # x is (D_tr, N) + D_tr, N = size(x) K = b.K tail_bound = b.tail_bound params = b.params - # x_tr: variables to be transformed - x_tr = x[mask, :] - D_tr = size(x_tr, 1) - # Reshape params to (D_tr, 3K-1, N) params = reshape(params, D_tr, 3*K - 1, N) @@ -40,79 +34,51 @@ function forward_and_log_det(b::NSFSplineBijector, x::AbstractArray) dv_unnorm = params[:, 2*K+1:end, :] # Flatten everything to call the spline function - x_tr_flat = vec(x_tr) + x_flat = vec(x) w_flat = reshape(permutedims(w_unnorm, (1, 3, 2)), D_tr * N, K) h_flat = reshape(permutedims(h_unnorm, (1, 3, 2)), D_tr * N, K) dv_flat = reshape(permutedims(dv_unnorm, (1, 3, 2)), D_tr * N, K-1) - y_tr_flat, lad_flat = unconstrained_rational_quadratic_spline( - x_tr_flat, w_flat, h_flat, dv_flat, eltype(x)(tail_bound) + y_flat, lad_flat = unconstrained_rational_quadratic_spline( + x_flat, w_flat, h_flat, dv_flat, eltype(x)(tail_bound) ) - y_tr = reshape(y_tr_flat, D_tr, N) - lad_tr = reshape(lad_flat, D_tr, N) - - # We yield the full transformed y (only for masked dims) - # The MaskedCoupling logic will handle the identity part. - # But wait, MaskedCoupling expects the bijector to return an array of same size as x? - # No, MaskedCoupling: y, log_det = transform_fn(params) - # Then y = ifelse.(bj.mask, y, x) - # So y MUST have same size as x. - - # Reconstruction using comprehension (Zygote friendly) - # We need to find the index into y_tr for each masked dimension. - # tr_indices[i] will be the row index in y_tr if mask[i] is true. - tr_indices = cumsum(mask) + y = reshape(y_flat, D_tr, N) + lad = reshape(lad_flat, D_tr, N) - y = vcat([mask[i] ? y_tr[tr_indices[i]:tr_indices[i], :] : x[i:i, :] for i in 1:D]...) - log_det = vcat([mask[i] ? lad_tr[tr_indices[i]:tr_indices[i], :] : fill(zero(eltype(x)), 1, N) for i in 1:D]...) - - return y, log_det - - return y, log_det + return y, lad end function inverse_and_log_det(b::NSFSplineBijector, y::AbstractArray) - # y is (D, N) - D, N = size(y) - mask = b.mask + # y is (D_tr, N) + D_tr, N = size(y) K = b.K tail_bound = b.tail_bound params = b.params - y_tr = y[mask, :] - D_tr = size(y_tr, 1) - params = reshape(params, D_tr, 3*K - 1, N) w_unnorm = params[:, 1:K, :] h_unnorm = params[:, K+1:2*K, :] dv_unnorm = params[:, 2*K+1:end, :] - y_tr_flat = vec(y_tr) + y_flat = vec(y) w_flat = reshape(permutedims(w_unnorm, (1, 3, 2)), D_tr * N, K) h_flat = reshape(permutedims(h_unnorm, (1, 3, 2)), D_tr * N, K) dv_flat = reshape(permutedims(dv_unnorm, (1, 3, 2)), D_tr * N, K-1) - x_tr_flat, lad_flat = unconstrained_rational_quadratic_spline( - y_tr_flat, w_flat, h_flat, dv_flat, eltype(y)(tail_bound); + x_flat, lad_flat = unconstrained_rational_quadratic_spline( + y_flat, w_flat, h_flat, dv_flat, eltype(y)(tail_bound); inverse=true ) - x_tr = reshape(x_tr_flat, D_tr, N) - lad_tr = reshape(lad_flat, D_tr, N) + x = reshape(x_flat, D_tr, N) + lad = reshape(lad_flat, D_tr, N) - tr_indices = cumsum(mask) - - x = vcat([mask[i] ? x_tr[tr_indices[i]:tr_indices[i], :] : y[i:i, :] for i in 1:D]...) - log_det = vcat([mask[i] ? lad_tr[tr_indices[i]:tr_indices[i], :] : fill(zero(eltype(y)), 1, N) for i in 1:D]...) - - return x, log_det - - return x, log_det + return x, lad end -function NSFCouplingBijector_from_flat(params, mask, K, tail_bound) - return NSFSplineBijector(mask, params, K, tail_bound) +function NSFCouplingBijector_from_flat(params, K, tail_bound) + return NSFSplineBijector(params, K, tail_bound) end """ @@ -122,6 +88,7 @@ Neural Spline Flow (NSF) with rational quadratic coupling layers. """ @concrete struct NeuralSplineFlow <: Lux.AbstractLuxContainerLayer{(:conditioners,)} conditioners + mask_list :: Vector{BitVector} dist_dims :: Int n_transforms :: Int hidden_layer_sizes :: Vector{Int} @@ -131,23 +98,27 @@ end function NeuralSplineFlow(; n_transforms::Int, dist_dims::Int, hidden_layer_sizes::Vector{Int}, K=8, tail_bound=3.0, activation=gelu) - # Number of transformed dimensions in each layer (mask alternate half) D = dist_dims - D_tr = D - (D ÷ 2) # Approximately half - # Conditioner output size: D_tr * (3K - 1) - out_dims = D_tr * (3*K - 1) - mlps = [MLP(D, hidden_layer_sizes, out_dims; activation) - for _ in 1:n_transforms] + # Pre-generate masks to know out_dims per layer + mask_list = [BitVector(collect(1:D) .% 2 .== i % 2) + for i in 1:n_transforms] + + mlps = [] + for i in 1:n_transforms + m = mask_list[i] + D_tr = sum(m) + out_dims = D_tr * (3*K - 1) + push!(mlps, MLP(D, hidden_layer_sizes, out_dims; activation)) + end + keys_ = ntuple(i -> Symbol(:conditioners_, i), n_transforms) conditioners = NamedTuple{keys_}(Tuple(mlps)) - return NeuralSplineFlow(conditioners, D, n_transforms, hidden_layer_sizes, K, Float64(tail_bound)) + return NeuralSplineFlow(conditioners, mask_list, D, n_transforms, hidden_layer_sizes, K, Float64(tail_bound)) end function Lux.initialstates(rng::AbstractRNG, m::NeuralSplineFlow) - mask_list = [Bool.(collect(1:(m.dist_dims)) .% 2 .== i % 2) - for i in 1:(m.n_transforms)] - return (; mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) + return (; mask_list=m.mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) end @@ -159,8 +130,5 @@ struct NSFSplineConstructor end function (c::NSFSplineConstructor)(params) - return NSFCouplingLayer(c.mask, params, c.K, c.tail_bound) + return NSFSplineBijector(c.mask, params, c.K, c.tail_bound) end - -# Wait, I need to fix MaskedCoupling to work with NSFCouplingLayer -# or just use the logic in NSFCouplingLayer. diff --git a/src/realnvp.jl b/src/realnvp.jl index 135770e..96cca68 100644 --- a/src/realnvp.jl +++ b/src/realnvp.jl @@ -21,6 +21,7 @@ Masks alternate between even/odd dimensions. """ @concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)} conditioners + mask_list :: Vector{BitVector} dist_dims :: Int n_transforms :: Int hidden_layer_sizes :: Vector{Int} @@ -28,15 +29,24 @@ end function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_layer_sizes::Vector{Int}, activation=gelu) - mlps = [MLP(dist_dims, hidden_layer_sizes, 2 * dist_dims; activation) - for _ in 1:n_transforms] + D = dist_dims + + # Pre-generate masks + mask_list = [BitVector(collect(1:D) .% 2 .== i % 2) + for i in 1:n_transforms] + + mlps = [] + for i in 1:n_transforms + m = mask_list[i] + D_tr = sum(m) + push!(mlps, MLP(D, hidden_layer_sizes, 2 * D_tr; activation)) + end + keys_ = ntuple(i -> Symbol(:conditioners_, i), n_transforms) conditioners = NamedTuple{keys_}(Tuple(mlps)) - return RealNVP(conditioners, dist_dims, n_transforms, hidden_layer_sizes) + return RealNVP(conditioners, mask_list, D, n_transforms, hidden_layer_sizes) end function Lux.initialstates(rng::AbstractRNG, m::RealNVP) - mask_list = [Bool.(collect(1:(m.dist_dims)) .% 2 .== i % 2) - for i in 1:(m.n_transforms)] - return (; mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) + return (; mask_list=m.mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) end diff --git a/src/training.jl b/src/training.jl index 5348bdf..f32906d 100644 --- a/src/training.jl +++ b/src/training.jl @@ -20,9 +20,10 @@ function train_flow!(flow::FlowDistribution{T}, data::AbstractMatrix; batch_size::Int=256, verbose::Bool=true, opt=nothing) where {T} - # Always fit and apply a min-max normalizer + # 5. Fit & Attach Normalizer (on the first batch or full data) + # We fit it on the full training data for simplicity flow.normalizer = MinMaxNormalizer(T.(data)) - data_T = normalize(flow.normalizer, T.(data)) + data_norm = SimpleFlows.normalize(flow.normalizer, data) actual_opt = if isnothing(opt) actual_lr = isnothing(lr) ? T(1f-3) : T(lr) @@ -32,7 +33,7 @@ function train_flow!(flow::FlowDistribution{T}, data::AbstractMatrix; end opt_state = Optimisers.setup(actual_opt, flow.ps) - loader = DataLoader(data_T; batchsize=batch_size, shuffle=true) + loader = DataLoader(data_norm; batchsize=batch_size, shuffle=true) for epoch in 1:n_epochs total_loss = zero(T) diff --git a/test/runtests.jl b/test/runtests.jl index 7c8e931..c79623b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,14 @@ using Test using SimpleFlows using Random, Distributions, LinearAlgebra -using Lux, Zygote, Bijectors, ForwardDiff +using Lux, Zygote, Bijectors, ForwardDiff, JSON rng = Random.MersenneTwister(42) @testset "SimpleFlows.jl" begin include("test_layers.jl") include("test_realnvp.jl") + include("test_training.jl") include("test_io.jl") include("test_normalizer.jl") include("test_splines.jl") diff --git a/test/test_io.jl b/test/test_io.jl index e56f140..ab2d2ca 100644 --- a/test/test_io.jl +++ b/test/test_io.jl @@ -2,39 +2,55 @@ rng = Random.MersenneTwister(4) dims = 3 - # Test backward-compat scalar API - flow = FlowDistribution(; n_transforms=3, dist_dims=dims, - hidden_dims=16, n_layers=2, rng) - - # We don't train here. Training stability is tested in test_realnvp.jl. - # The IO test only verifies that model serialization survives round-trip. - - # Save - tmpdir = mktempdir() - save_trained_flow(tmpdir, flow) - - @test isfile(joinpath(tmpdir, "flow_setup.json")) - @test isfile(joinpath(tmpdir, "weights.npz")) - - # Load - flow2 = load_trained_flow(tmpdir; rng) - - # Dimensions must match - @test Distributions.length(flow2) == dims - - # logpdf must be identical at test points - x_test = randn(rng, Float32, dims, 20) - lp1 = logpdf(flow, x_test) - lp2 = logpdf(flow2, x_test) - @test lp1 ≈ lp2 atol=1f-4 - - # Also test the new per-layer vector API round-trips correctly - flow3 = FlowDistribution(; n_transforms=2, dist_dims=dims, - hidden_layer_sizes=[32, 64, 32], rng) + for arch in [:RealNVP, :NSF, :MAF] + flow = if arch == :NSF + FlowDistribution(Float32; architecture=arch, n_transforms=3, dist_dims=dims, + hidden_layer_sizes=[16, 16], K=4, tail_bound=2.0, rng=rng) + else + FlowDistribution(Float32; architecture=arch, n_transforms=3, dist_dims=dims, + hidden_layer_sizes=[16, 16], rng=rng) + end + + # Save + tmpdir = mktempdir() + save_trained_flow(tmpdir, flow) + + @test isfile(joinpath(tmpdir, "flow_setup.json")) + @test isfile(joinpath(tmpdir, "weights.npz")) + + # Load + flow2 = load_trained_flow(tmpdir; rng) + + # Dimensions must match + @test Distributions.length(flow2) == dims + @test typeof(flow2.model) == typeof(flow.model) + + # logpdf must be identical at test points + x_test = randn(rng, Float32, dims, 20) + lp1 = logpdf(flow, x_test) + lp2 = logpdf(flow2, x_test) + @test lp1 ≈ lp2 atol=1f-4 + end + + # Test backward-compat scalar API and legacy JSON load + flow3 = FlowDistribution(; architecture=:RealNVP, n_transforms=2, dist_dims=dims, + hidden_dims=16, n_layers=2, rng=rng) tmpdir2 = mktempdir() save_trained_flow(tmpdir2, flow3) + + # Manually modify JSON to simulate old format + setup_file = joinpath(tmpdir2, "flow_setup.json") + setup_dict = JSON.parsefile(setup_file) + delete!(setup_dict, "hidden_layer_sizes") + setup_dict["hidden_dims"] = 16 + setup_dict["n_layers"] = 2 + open(setup_file, "w") do io + JSON.print(io, setup_dict, 4) + end + flow4 = load_trained_flow(tmpdir2; rng) - @test flow4.hidden_layer_sizes == [32, 64, 32] + @test flow4.hidden_layer_sizes == [16, 16] + x_test = randn(rng, Float32, dims, 20) lp3 = logpdf(flow3, x_test) lp4 = logpdf(flow4, x_test) @test lp3 ≈ lp4 atol=1f-4 diff --git a/test/test_layers.jl b/test/test_layers.jl index 5770e7c..b934017 100644 --- a/test/test_layers.jl +++ b/test/test_layers.jl @@ -28,7 +28,7 @@ end mask = Bool.(collect(1:n) .% 2 .== 0) # [false, true, false, true, false, true] # Trivial conditioner: returns zeros (identity bijector) - conditioner = x_cond -> zeros(Float32, 2n, B) + conditioner = x_cond -> zeros(Float32, 2*sum(mask), B) bj = SimpleFlows.MaskedCoupling(mask, conditioner, SimpleFlows.AffineBijector) x = randn(rng, Float32, n, B) diff --git a/test/test_normalizer.jl b/test/test_normalizer.jl index a491a0a..6c81749 100644 --- a/test/test_normalizer.jl +++ b/test/test_normalizer.jl @@ -1,4 +1,23 @@ @testset "MinMaxNormalizer" begin + # 1. Normal functioning data + x = Float32[1 2 3; 4 5 6] + norm = MinMaxNormalizer(x) + + @test norm.x_min == Float32[1, 4] + @test norm.x_max == Float32[3, 6] + + x_norm = SimpleFlows.normalize(norm, x) + @test size(x_norm) == size(x) + @test all(x_norm .>= 0) + @test all(x_norm .<= 1) + + x_unnorm = SimpleFlows.denormalize(norm, x_norm) + @test x_unnorm ≈ x + + # 2. Zero variance edge case + x_zero_var = Float32[1 1 1; 4 5 6] + @test_throws ArgumentError MinMaxNormalizer(x_zero_var) + rng = Random.MersenneTwister(7) dims = 4 diff --git a/test/test_training.jl b/test/test_training.jl new file mode 100644 index 0000000..978d4b9 --- /dev/null +++ b/test/test_training.jl @@ -0,0 +1,54 @@ +using SimpleFlows +using Test +using Random +using Optimisers + +@testset "Custom Optimizer & Training" begin + rng = Random.default_rng() + Random.seed!(rng, 10) + + dims = 2 + N = 100 + data = randn(Float32, dims, N) + + flow = FlowDistribution(Float32; architecture=:RealNVP, n_transforms=2, dist_dims=dims, + hidden_layer_sizes=[16, 16], rng) + + # Initial NLL + lp_init = -mean(logpdf(flow, data)) + + # Custom optimizer + custom_opt = Optimisers.Adam(1f-4) + train_flow!(flow, data; n_epochs=5, batch_size=50, opt=custom_opt, verbose=false) + + # Check that training occurred without error and actually evaluated + lp_trained = -mean(logpdf(flow, data)) + @test isfinite(lp_trained) + @test !isnothing(flow.normalizer) +end + +@testset "Automatic Type Conversion" begin + rng = Random.default_rng() + Random.seed!(rng, 42) + + dims = 2 + N = 100 + # User passes Float64 data accidentally + data_f64 = randn(Float64, dims, N) + + # Model is strictly Float32 + flow = FlowDistribution(Float32; architecture=:RealNVP, n_transforms=2, dist_dims=dims, + hidden_layer_sizes=[16, 16], rng) + + # Should not throw MethodError + train_flow!(flow, data_f64; n_epochs=2, batch_size=50, verbose=false) + + # Normalizer should now be correctly typed as MinMaxNormalizer{Float32} + @test flow.normalizer isa MinMaxNormalizer{Float32} + + # logpdf on Float64 data should evaluate correctly without errors. + # The output type may promote to Float64 depending on input type. + lp = logpdf(flow, data_f64) + @test all(isfinite, lp) +end + From ef16bd73ec5c627511e5eedd5ae47652c608b82d Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 5 Mar 2026 11:20:47 -0500 Subject: [PATCH 06/14] Renaming file --- examples/{train_multinormal.jl => train_multinormal_nvp.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{train_multinormal.jl => train_multinormal_nvp.jl} (100%) diff --git a/examples/train_multinormal.jl b/examples/train_multinormal_nvp.jl similarity index 100% rename from examples/train_multinormal.jl rename to examples/train_multinormal_nvp.jl From a52a8e0bbb22979d71d71187bad31b2e40f79e53 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 5 Mar 2026 11:37:10 -0500 Subject: [PATCH 07/14] More tests --- test/test_maf.jl | 19 +++++++++++++++++-- test/test_splines.jl | 23 +++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/test/test_maf.jl b/test/test_maf.jl index d43073d..95f24db 100644 --- a/test/test_maf.jl +++ b/test/test_maf.jl @@ -20,9 +20,24 @@ using LinearAlgebra made = SimpleFlows.MADE(D, hidden, 2*D) ps, st = Lux.setup(rng, made) + # Explicit Mask Multiplication (nflows inspired) + # Extract masks from MaskedDense layers to ensure their product is strictly lower triangular. + m1 = made.layers.layer_1.mask + m2 = made.layers.layer_2.mask + m3 = made.layers.layer_3.mask + total_mask = m3 * m2 * m1 + + # Block 1: m (1:D, 1:D) + M_m = total_mask[1:D, :] + @test isapprox(tril(M_m, -1), M_m, atol=1e-7) + @test count(>(0), M_m) > 0 # make sure it's not identically zero + + # Block 2: log_alpha (D+1:2D, 1:D) + M_la = total_mask[D+1:end, :] + @test isapprox(tril(M_la, -1), M_la, atol=1e-7) + @test count(>(0), M_la) > 0 + # We check the Jacobian of the outputs with respect to the inputs. - # For an autoregressive model, the Jacobian of y_i w.r.t x_j must be zero if j >= i. - # Since we have 2 outputs per input (m and log_alpha), we check both blocks. function f(x) out, _ = Lux.apply(made, x, ps, st) diff --git a/test/test_splines.jl b/test/test_splines.jl index 7db1c5a..3a067c6 100644 --- a/test/test_splines.jl +++ b/test/test_splines.jl @@ -88,6 +88,29 @@ using LinearAlgebra gf = ForwardDiff.gradient(f, x) @test gz ≈ gf atol=T == Float32 ? 1e-4 : 1e-10 end + + @testset "Boundary & Tail Consistency" begin + # x explicitly outside the tail_bound + x_out = T.([-4.0, 4.0, -10.0, 10.0]) + N = length(x_out) + K = 5 + w = randn(T, N, K) + h = randn(T, N, K) + dv = randn(T, N, K-1) + + # Forward should be strict linear identity outside bounds + y_out, lad_out = unconstrained_rational_quadratic_spline(x_out, w, h, dv, T(3.0)) + + @test y_out ≈ x_out atol=1e-6 + @test lad_out ≈ zeros(T, N) atol=1e-6 + + # Inverse should also be strict identity + x_rec, lad_inv = unconstrained_rational_quadratic_spline(x_out, w, h, dv, T(3.0); inverse=true) + + @test x_rec ≈ x_out atol=1e-6 + @test lad_inv ≈ zeros(T, N) atol=1e-6 + end end + end end From 7cbd094ec0b92141c9048a377e2d782ac0c8744b Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 5 Mar 2026 12:28:08 -0500 Subject: [PATCH 08/14] Added tests. --- test/test_layers.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ test/test_training.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/test/test_layers.jl b/test/test_layers.jl index b934017..9ae7d8b 100644 --- a/test/test_layers.jl +++ b/test/test_layers.jl @@ -43,3 +43,45 @@ end x_rec, _ = SimpleFlows.inverse_and_log_det(bj, y) @test x_rec ≈ x atol=1f-5 end + +@testset "Explicit Dense Jacobian Inverses" begin + rng = Random.MersenneTwister(3) + n = 4 + + # 1. AffineBijector + params = randn(rng, Float32, 2n, 1) + b_aff = SimpleFlows.AffineBijector(params) + + x_single = randn(rng, Float32, n) + f_aff(x) = SimpleFlows.forward_and_log_det(b_aff, reshape(x, n, 1))[1][:] + finv_aff(y) = SimpleFlows.inverse_and_log_det(b_aff, reshape(y, n, 1))[1][:] + + y_aff = f_aff(x_single) + J_aff = ForwardDiff.jacobian(f_aff, x_single) + Jinv_aff = ForwardDiff.jacobian(finv_aff, y_aff) + + # Check J * J_inv ≈ I + @test J_aff * Jinv_aff ≈ I(n) atol=1f-4 + # Check determinant matches (Affine returns per-dimension ld) + ld_fwd = SimpleFlows.forward_and_log_det(b_aff, reshape(x_single, n, 1))[2][:] + @test log(abs(det(J_aff))) ≈ sum(ld_fwd) atol=1f-4 + + # 2. MaskedCoupling + mask = [false, true, false, true] + # Trivial deterministic conditioner for testing structural zeros. MUST be deterministic for ForwardDiff! + fixed_params = randn(rng, Float32, 2*sum(mask), 1) + conditioner = x_cond -> fixed_params .+ sum(x_cond) * 0.0f0 # Ensure dual numbers propagate if needed + bj_m = SimpleFlows.MaskedCoupling(mask, conditioner, SimpleFlows.AffineBijector) + + f_m(x) = SimpleFlows.forward_and_log_det(bj_m, reshape(x, n, 1))[1][:] + finv_m(y) = SimpleFlows.inverse_and_log_det(bj_m, reshape(y, n, 1))[1][:] + + x_m = randn(rng, Float32, n) + y_m = f_m(x_m) + + J_m = ForwardDiff.jacobian(f_m, x_m) + Jinv_m = ForwardDiff.jacobian(finv_m, y_m) + + @test J_m * Jinv_m ≈ I(n) atol=1f-4 +end + diff --git a/test/test_training.jl b/test/test_training.jl index 978d4b9..2c9ce8f 100644 --- a/test/test_training.jl +++ b/test/test_training.jl @@ -52,3 +52,34 @@ end @test all(isfinite, lp) end +@testset "Reparameterization Trick Zygote Gradients" begin + rng = Random.default_rng() + Random.seed!(rng, 100) + + dims = 2 + + # We test that we can differentiate a loss function evaluated on the output of draw_samples. + # This proves that our flows support the reparameterization trick natively. + for arch in [:RealNVP, :NSF, :MAF] + flow = FlowDistribution(Float32; architecture=arch, n_transforms=1, dist_dims=dims, + hidden_layer_sizes=[16], K=4, tail_bound=2.0, rng) + + # We need a stable test_rng since draw_samples is stochastic. + # Zygote can sometimes struggle with capturing global RNGs during AD, so we pass one explicitly + # and ignore the derivative with respect to it (which is 0). + test_rng = Random.default_rng() + Random.seed!(test_rng, 42) + + function loss_fn(ps) + x_samples = SimpleFlows.draw_samples(test_rng, Float32, flow.model, ps, flow.st, 10) + return sum(x_samples.^2) + end + + loss, (dps,) = Zygote.withgradient(loss_fn, flow.ps) + + @test isfinite(loss) + @test !isnothing(dps) + end +end + + From e57d85f8dbcc4c8e55784ee3ffbfc1677a24f5c4 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Mon, 9 Mar 2026 11:52:36 -0400 Subject: [PATCH 09/14] Adding dependencies --- Project.toml | 32 ++++++++++++++++---------------- src/distribution.jl | 16 +++++++++------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 6904989..9838883 100644 --- a/Project.toml +++ b/Project.toml @@ -22,19 +22,19 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Bijectors = "0.15.17" -ChainRulesCore = "1.26.0" -ConcreteStructs = "0.2.3" -Distributions = "0.25.123" -ForwardDiff = "1.3.2" -JSON = "1.4.0" -LinearAlgebra = "1.12.0" -Lux = "1.31.3" -MLUtils = "0.4.8" -NNlib = "0.9.33" -NPZ = "0.4.3" -Optimisers = "0.4.7" -Random = "1.11.0" -Statistics = "1.11.1" -Test = "1.11.0" -Zygote = "0.7.10" +Bijectors = "0.15" +ChainRulesCore = "1.26" +ConcreteStructs = "0.2" +Distributions = "0.25" +ForwardDiff = "1.3" +JSON = "0.21" +LinearAlgebra = "1.12" +Lux = "1.31" +MLUtils = "0.4" +NNlib = "0.9" +NPZ = "0.4" +Optimisers = "0.4" +Random = "1.11" +Statistics = "1.11" +Test = "1.11" +Zygote = "0.7" diff --git a/src/distribution.jl b/src/distribution.jl index e4cc8a2..5187731 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -101,10 +101,12 @@ end # therefore no parameter transformation is required for HMC/NUTS. Bijectors.bijector(::FlowDistribution) = identity -# Explicitly implement the VectorBijectors interface -Bijectors.VectorBijectors.vec_length(d::FlowDistribution) = d.n_dims -Bijectors.VectorBijectors.linked_vec_length(d::FlowDistribution) = d.n_dims -Bijectors.VectorBijectors.to_vec(d::FlowDistribution) = Base.identity -Bijectors.VectorBijectors.from_vec(d::FlowDistribution) = Base.identity -Bijectors.VectorBijectors.to_linked_vec(d::FlowDistribution) = Base.identity -Bijectors.VectorBijectors.from_linked_vec(d::FlowDistribution) = Base.identity +if isdefined(Bijectors, :VectorBijectors) + # Explicitly implement the VectorBijectors interface + Bijectors.VectorBijectors.vec_length(d::FlowDistribution) = d.n_dims + Bijectors.VectorBijectors.linked_vec_length(d::FlowDistribution) = d.n_dims + Bijectors.VectorBijectors.to_vec(d::FlowDistribution) = Base.identity + Bijectors.VectorBijectors.from_vec(d::FlowDistribution) = Base.identity + Bijectors.VectorBijectors.to_linked_vec(d::FlowDistribution) = Base.identity + Bijectors.VectorBijectors.from_linked_vec(d::FlowDistribution) = Base.identity +end From 9c6072ff48c64a39204085d3778c577ecf5594f6 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Mon, 16 Mar 2026 17:53:27 -0400 Subject: [PATCH 10/14] Adding tests --- Project.toml | 4 ++++ test/runtests.jl | 1 + test/test_ad.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) create mode 100644 test/test_ad.jl diff --git a/Project.toml b/Project.toml index 9838883..1496fb4 100644 --- a/Project.toml +++ b/Project.toml @@ -7,12 +7,14 @@ authors = ["marcobonici "] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -25,12 +27,14 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Bijectors = "0.15" ChainRulesCore = "1.26" ConcreteStructs = "0.2" +DifferentiationInterface = "0.7.16" Distributions = "0.25" ForwardDiff = "1.3" JSON = "0.21" LinearAlgebra = "1.12" Lux = "1.31" MLUtils = "0.4" +Mooncake = "0.5.23" NNlib = "0.9" NPZ = "0.4" Optimisers = "0.4" diff --git a/test/runtests.jl b/test/runtests.jl index c79623b..8887791 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,4 +14,5 @@ rng = Random.MersenneTwister(42) include("test_splines.jl") include("test_nsf.jl") include("test_maf.jl") + include("test_ad.jl") end diff --git a/test/test_ad.jl b/test/test_ad.jl new file mode 100644 index 0000000..76faebe --- /dev/null +++ b/test/test_ad.jl @@ -0,0 +1,47 @@ +using Test +using SimpleFlows +using Distributions +using Random +using DifferentiationInterface +using ForwardDiff +using Zygote +using Mooncake + +println("Starting AD tests...") + +@testset "Automatic Differentiation" begin + rng = Random.MersenneTwister(42) + n_dims = 2 + architectures = [:RealNVP, :NSF, :MAF] + + for arch in architectures + println("Testing architecture: ", arch) + @testset "$arch" begin + d = FlowDistribution(Float64; architecture=arch, n_transforms=2, dist_dims=n_dims, hidden_layer_sizes=[16, 16], n_layers=2) + x = rand(rng, n_dims) + + f(x) = logpdf(d, x) + val = f(x) + @test isfinite(val) + + for (name, backend) in [ + ("ForwardDiff", AutoForwardDiff()), + ("Zygote", AutoZygote()), + ("Mooncake", AutoMooncake(config=nothing)) + ] + println(" Testing backend: ", name) + @testset "$name" begin + g = try + DifferentiationInterface.gradient(f, backend, x) + catch e + @error "Failed to differentiate $arch with $name" exception=(e, catch_backtrace()) + nothing + end + @test g !== nothing + @test length(g) == n_dims + @test all(isfinite, g) + end + end + end + end +end From 07199c4ccf841dbc238ca0cfd4bb0065d08ec05b Mon Sep 17 00:00:00 2001 From: marcobonici Date: Mon, 16 Mar 2026 20:40:49 -0400 Subject: [PATCH 11/14] Add GitHub Actions CI, fix performance allocations, and move AD tools to test targets --- .github/workflows/test.yml | 59 ++++++++++++++++++++++++++++++++++++++ Project.toml | 13 +++++++-- src/SimpleFlows.jl | 1 + src/layers.jl | 25 +++++++++++++--- src/nsf.jl | 12 ++++---- 5 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..6f2656c --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,59 @@ +name: CI +on: + schedule: + - cron: 0 0 * * * + pull_request: + branches: + - main + - develop + push: + branches: + - "**" + tags: "*" +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1.11" + - "1.12" + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + + # Step to download the test data required for test_splines.jl + - name: Download test data + run: | + # This should ideally pull from a permanent remote URL. + # For now we create a dummy file to avoid test failures until the real file is placed somewhere. + # If a URL exists, replace this with: curl -o /tmp/nsf_test_data_f64.npz "URL" + echo "Placeholder for downloading /tmp/nsf_test_data_f64.npz" + + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + file: lcov.info \ No newline at end of file diff --git a/Project.toml b/Project.toml index 1496fb4..ef53b4b 100644 --- a/Project.toml +++ b/Project.toml @@ -7,23 +7,21 @@ authors = ["marcobonici "] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +BenchmarkTools = "1.6.3" Bijectors = "0.15" ChainRulesCore = "1.26" ConcreteStructs = "0.2" @@ -42,3 +40,12 @@ Random = "1.11" Statistics = "1.11" Test = "1.11" Zygote = "0.7" + +[extras] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test", "DifferentiationInterface", "Mooncake", "BenchmarkTools"] diff --git a/src/SimpleFlows.jl b/src/SimpleFlows.jl index 2135773..c7b9850 100644 --- a/src/SimpleFlows.jl +++ b/src/SimpleFlows.jl @@ -5,6 +5,7 @@ using Distributions using Bijectors using JSON, NPZ using Optimisers, Zygote +using ChainRulesCore include("layers.jl") include("realnvp.jl") diff --git a/src/layers.jl b/src/layers.jl index e5265d9..8e25d59 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -22,7 +22,7 @@ end function AffineBijector(params::AbstractArray) n = size(params, 1) ÷ 2 idx = ntuple(Returns(Colon()), ndims(params) - 1) - return AffineBijector(params[1:n, idx...], params[(n + 1):end, idx...]) + return @views AffineBijector(params[1:n, idx...], params[(n + 1):end, idx...]) end function forward_and_log_det(b::AffineBijector, x::AbstractArray) @@ -63,14 +63,31 @@ function _apply_mask(bj::MaskedCoupling, x::AbstractMatrix, transform_fn) y_tr, ld_tr = transform_fn(bj_inner, x_tr) # 3. Reconstruct full y - # We use a Zygote-friendly reconstruction - tr_idx = cumsum(m) - y = vcat([m[i] ? y_tr[tr_idx[i]:tr_idx[i], :] : x[i:i, :] for i in 1:D]...) + y = _reconstruct(m, x, y_tr) # sum log-dets over the transformed dims return y, dsum(ld_tr; dims=(1,)) end +function _reconstruct(m::AbstractArray{Bool}, x::AbstractMatrix, y_tr::AbstractMatrix) + y = similar(x) + y[.!m, :] .= x[.!m, :] + y[m, :] .= y_tr + return y +end + +function ChainRulesCore.rrule(::typeof(_reconstruct), m::AbstractArray{Bool}, x::AbstractMatrix, y_tr::AbstractMatrix) + y = _reconstruct(m, x, y_tr) + function _reconstruct_pullback(Δy) + Δx = similar(x) + Δx[.!m, :] .= Δy[.!m, :] + Δx[m, :] .= 0 + Δy_tr = Δy[m, :] + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), Δx, Δy_tr + end + return y, _reconstruct_pullback +end + function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray) _apply_mask(bj, x, forward_and_log_det) end diff --git a/src/nsf.jl b/src/nsf.jl index 53ae8c2..78f6b31 100644 --- a/src/nsf.jl +++ b/src/nsf.jl @@ -29,9 +29,9 @@ function forward_and_log_det(b::NSFSplineBijector, x::AbstractArray) params = reshape(params, D_tr, 3*K - 1, N) # Partition params - w_unnorm = params[:, 1:K, :] - h_unnorm = params[:, K+1:2*K, :] - dv_unnorm = params[:, 2*K+1:end, :] + w_unnorm = @view params[:, 1:K, :] + h_unnorm = @view params[:, K+1:2*K, :] + dv_unnorm = @view params[:, 2*K+1:end, :] # Flatten everything to call the spline function x_flat = vec(x) @@ -57,9 +57,9 @@ function inverse_and_log_det(b::NSFSplineBijector, y::AbstractArray) params = b.params params = reshape(params, D_tr, 3*K - 1, N) - w_unnorm = params[:, 1:K, :] - h_unnorm = params[:, K+1:2*K, :] - dv_unnorm = params[:, 2*K+1:end, :] + w_unnorm = @view params[:, 1:K, :] + h_unnorm = @view params[:, K+1:2*K, :] + dv_unnorm = @view params[:, 2*K+1:end, :] y_flat = vec(y) w_flat = reshape(permutedims(w_unnorm, (1, 3, 2)), D_tr * N, K) From 27818eb6614a189477357bdda97d0df9850c3ad2 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Mon, 16 Mar 2026 20:44:44 -0400 Subject: [PATCH 12/14] Update Project.toml --- Project.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ef53b4b..baf82e1 100644 --- a/Project.toml +++ b/Project.toml @@ -29,16 +29,16 @@ DifferentiationInterface = "0.7.16" Distributions = "0.25" ForwardDiff = "1.3" JSON = "0.21" -LinearAlgebra = "1.12" +LinearAlgebra = "1.10" Lux = "1.31" MLUtils = "0.4" Mooncake = "0.5.23" NNlib = "0.9" NPZ = "0.4" Optimisers = "0.4" -Random = "1.11" -Statistics = "1.11" -Test = "1.11" +Random = "1.10" +Statistics = "1.10" +Test = "1.10" Zygote = "0.7" [extras] From 159a9d43caa62a1610dd8fe97c1cc1c8ab92d54d Mon Sep 17 00:00:00 2001 From: marcobonici Date: Mon, 16 Mar 2026 21:28:58 -0400 Subject: [PATCH 13/14] Fix CI by adding missing test data for NSF and updating test paths --- .github/workflows/test.yml | 8 -------- test/data/nsf_test_data_f64.npz | Bin 0 -> 7024 bytes test/test_splines.jl | 2 +- 3 files changed, 1 insertion(+), 9 deletions(-) create mode 100644 test/data/nsf_test_data_f64.npz diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6f2656c..e1fdfd3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,14 +41,6 @@ jobs: ${{ runner.os }}-test-${{ env.cache-name }}- ${{ runner.os }}-test- ${{ runner.os }}- - - # Step to download the test data required for test_splines.jl - - name: Download test data - run: | - # This should ideally pull from a permanent remote URL. - # For now we create a dummy file to avoid test failures until the real file is placed somewhere. - # If a URL exists, replace this with: curl -o /tmp/nsf_test_data_f64.npz "URL" - echo "Placeholder for downloading /tmp/nsf_test_data_f64.npz" - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 diff --git a/test/data/nsf_test_data_f64.npz b/test/data/nsf_test_data_f64.npz new file mode 100644 index 0000000000000000000000000000000000000000..f538b56b08addbff6c31f8536d450515a719f5dc GIT binary patch literal 7024 zcmd5>c{tVI);AXoB1I(`iZX;!8fZzHBuY}|WQa;KhD?zmDpQIgl_-@)_L+TPLG(fo>;y`!bo&;3K@mz}H@_ML3ZuUakqZ&s2ORaDp_E9xcs?+zaJ zta9;DnOblQwLYZ6RRy14s}1sr*1+?koUWzB3glfpd2+Q!AGnGru&Z*A!PFv)b#2ld zZWGc}cA1{VzF1JEmTnT+RwF2UtIh=aWjH9sO*-ztc z6kL6}>)I)tg^UIcYAur~7)2Ys>yK1&Rq%jIXnyG`sx_Gc%7Dc!`rsk|5R3jryuD5k=LmN@Pk~2a?oZR_v zf9wnyEf;A@Mi;TEMMWRFzce7cIWnR=i55siQjBb)-%-l zWqZK|+q+!3V((~Ye%aQ;%F^tbt);WgKOX|uKOVwA*8lt=$cpax_um4~x2oa0Xs+ zW-XE%$aqINue}_~QSjvRIjUYaoH$R_x$jCr2Xa-tm30B4ZNj39%?-G_ztD8Q9IM*V z)pUnbGfY@DzNuW?ss`n~X4nJHk+JTSu7T|L0g&u{Hz=x1M)4A#@>b{9sA$eAs^de% zy*rMV8Cz0d-#beR&)p^@NhK&$7}SA!zEZqf@DP}Fzh9!TjExYXMzRevAz}neC-I^u z6MPH;^V8aBP^KEZ;)n$i(k@+VH+|;=Gt6Y6E|&>dwBqtt?E^HRoPHE|cXSM-zs1=V zZt6h?wbT^Ghp*7q&{c7zqY|!796B1QPlX|)?{?STj^b^-!q#gtBuw^x8z*B(0iDVO zCH5uN_+0Fn<=Z|YYIP;OH@Wi>+5@t^Iu`Xo`-h2+>)o9oyT@dueOy0SUhCv;;{T43 z(|7yk&oR*WYgpO_&jCo0`0+Ju4-<;1?*qAWr=YpyV?m-!G6XV~aJx11;)@53F3s!u zF{U$oli#uu7Y72W&>Qh%+KmM&>f4iDMA3WZ4&QJeN2by_aHpo0W0`m{Q zyqj(6K(`=HDTdWJ98TBSr*bV7o(9JXOWBt~fg-->m7T=q?f$EAaX&~lu29v@Y)5rd z4~`pywV*&X5Q)7o2F*_=?!1hpW1gg1siRvFqMxzlBJwzFd_6H|SKEzOk7pFM8Zsa~ zpd{Hbb_`bCvkkFeOoPo0VY|1-DlvFVZz#Vw5v-=nA1(bb4bgFaG;ku}z&Dxp)0<{- zom)m&>y2jIrbFFX&6|#VU!+HtU73KJVM4youp2}-dtXaVhMc|eYwrWnDNLyfFm{MX#r?PxlnobdtdM=#ElP5ufvZHkY zHxtw*majkaVr5X}gy23ZaJ}P zGDN1n($H~4B6EYn1gxTpFl%J;@RiGIW7=2`IP|@IKP$|lM%K@_X6J9grJa6lA~7AX zUf9BQEolUpJxOBu787VUs~6*TDFf?7=NyFA<^U!VZEQU&@cl%zZHm#avvKip3a>LT+C^(ZG$O(wm|&1$v&38&l0k zQ1F;!IR}vpzXRd3E=QkVKq&Ml5ZYMTUakN(t3}@Q!RN>wDFjjm#LvvEy3{!jlzjp1D<$huhHGgZ_iPgmeIOi4onZhJ zl~H30=U_drUU|%$3V>6(bfM@D6!ohp^;__85%QOEHt8X-FV|%zDc3{zq7c<>z61~- zF}{&`kA~A#JsMs!y=cuc3GsX@Cow(biPqXi3VaJ%*|+O73HOwJXFIh( zxvCitZcb&=z>@y@Q0Wmw?_svsBkPzbs-&;Uz$W)F*~14Mg=IB>+AT&{VNC!R?Dw^&crMk^XjifOla(j zG8XEW6Rma;B{q<~fu3R;Bde2LHn+(svGlRhdDM(UwtI zA{pGhB=tKizrSAb{?KFeThnxDd3QNJ8(uY=SXK&OgWm)TDRjYn$?Eas=v+`sZ%nER zm_l2@o$@&z<1k2nI>bco2FUW$(I9j8XH2q+!?6Im6;-RlZMLYYAZLW(vj8h zsUBZIFHlK$-|1}ogeLN0o~sYi(Mw_R+tLsH&{nunUrIOv`3)W)F#X&LbJI4u4b(e0 zK1lc3s@@N;-7jqH)9Hb|8iY14W3^7TVJzgZ zzTdE$2#$LrUapSHK?ggtgoJ4lZn9rqiB)~z%fXv)Hn#?8!F3N?=p%3-Oj7roWC8T% zlUhiV9SA$Nt|s@;&@CWWrO}NDUSoVLB@UlqrK6SOTeWfEX5Z$SSUC+a?%H-Hv=n`w zX|Eak&V=6rmh2UmH~CyOa0n0+W&tJ+Y5W9fr_o^Bnu z_qxvmryrAePhayYB#j{Z$(dp5NC9wHe^ggH?+G-&N4%q(>%o=XEj01$5ES#0nl7id zV)%mg>gVEaM>JY09~_#Ls~cIi%*|E;kOG)T6!zuO^V58%^km zB;$U4FY&&f61b0Bcs@=K!Qgm7dNO|nFy<8VHraKdjOv@GV^_NI@CCEo_VNteai4O@ z;$;uEI|Zw3bMV6S|S*SKQ zJ01M|Jj?7#vmw#KpBTTQ3cWScqjuTUV?9%?o`w|H=!0t^Xf+q0)bl7d~$A9~S-ekCFw3jA9 z8GsGff`tMsXdrMmTr2zZFc?r|mw2lZ@v7!<%#QOE99r2%4Lp4rmnsY$EIxM&8S>k2 z>rbb{TQUA+?`&&u(Q&u?5$dhbd*`;2Yt0PCyeKz~JlX)4hynZP+O4q0g@12Y)+9bi za@{TdeHwa=<41C=SqRk`533CBj6!g^gol;?EHH`PDY_Z$@OCKaskq=LbkLG~YiT!( z1Yhk}{L=jprc?N2P-F_ERAkutme%1?yN?dF0Tle|l>|aX)0i#b)SxHO2o(xBY!{zR zV)Vs)`@{urjXS@N+w0~qo=rLIESo=v>r>aLXs@e5wePM^*7ioG=k0AJF!Hfa`_9~JD-vFe-gqo&BLzH%f^9)_0wb0UZ_X+i#|D4a!JMuE;HcAb z3a^=h&+YP6uT&|J?p%C3ylNcC5r-dC1ya2^I28%==h$>M+p$OVMMOR=Dt}u)g_K3({Eo(QDr@Xh&?U3-~bu zLtYaK)8!?2#&5JqWJehk%58XeZ@dvaIs&TF9cz$p$Hx609!yv&z*_V?s|PJ4t{M|B zOkij1ca9fS1}c3ypu3p61HkxZD3pGIyfOumkMw#xYV!P@Mh+cRzmXLMpY>tRd!39( zml>E7cUEpR_LCy?)v;=wExF0~ZXB~jJKCP43O*ldIb4fm{R+z^o31=nNODN|0@!}1ViwPE`y zl*`C-wRu5>zKA8J_c*_S=8I+;$45j87TGk(vdU4MIa$yWa-9Ji%Y|5ZtYo{hGsfu zi$l(~gLo_5RglDhh|GX{pM7Rv+0H3HR)idydyF8$Gi-!p40<5c^kOL>5#fPyf4{j0?W1tWXAW`!C+9#nY$JgTyLWrK5%gkj2?VBn9Mzi4bw3! zTK*$w@^)&%yMF4A0-wvmwP)PM2ET)()bg!&mUlr+p+Lk3F$M~aOKp6^n+paeTtzE9I0<_=)DqQJG@y*+ zy`>&8bMVTae{WMQ72c7(gESsZLzPif^|p0vgyPg25m5s_&`;skYR+gTbd9|Hs@$Fj zk@v$UmsghJEA_nZVuNImV-aw>cb-M<%W(P0Z;lzbFJ#H|8`f5=9vsVIW$gy(6y|;B zydD%dy)uYtHVv<0S*@c&*a!_3-P2PKI0#QFK3_iZiIUA+i{k1R%rNWFnYogR>DwlF`qC%xu+klSZ|h}* zcyY-OqMkH7ev@?jeR>Z_txs<0%^8PwbB|J`>=}Idth#R_a~z*sw90R+i-goG&xQ@I zFCwhaF+4H9P-IUUSJdL$-i@xi$wt%aBbXy*R+;S7hAXbWlNoG%3b*6eJzsNw8VQNv z!p`nd(AK$4d`svQJjp(AGWSkBw#PnmbPlItNAlGw4Y_#?Db?Cc6CS|~g=>CQ_v>-$ zl#ch-mm^@tDsz-jM#d1^E`$2*?RZZsd7sXsGT@W>cw}6Nj=sKY2L>L{z%x)f`&1rt z!S`9&S;M^8aWO{Hf$; zN#&Okq2H8na{r;^UkWULYWYb=e^SezQ}{P6yet0D@&`HnQ`65-`0Fy3{HCdW70bT{ p#y Date: Mon, 16 Mar 2026 21:42:22 -0400 Subject: [PATCH 14/14] Migrate NSF test data from binary NPZ to git-friendly JSON --- test/data/nsf_test_data.json | 733 ++++++++++++++++++++++++++++++++ test/data/nsf_test_data_f64.npz | Bin 7024 -> 0 bytes test/test_splines.jl | 24 +- 3 files changed, 748 insertions(+), 9 deletions(-) create mode 100644 test/data/nsf_test_data.json delete mode 100644 test/data/nsf_test_data_f64.npz diff --git a/test/data/nsf_test_data.json b/test/data/nsf_test_data.json new file mode 100644 index 0000000..f1eeb64 --- /dev/null +++ b/test/data/nsf_test_data.json @@ -0,0 +1,733 @@ +{ + "inputs": { + "data": [ + 0.2995562877502637, + 0.4285173645433462, + -0.03720300727052493, + -0.9848309379659244, + 0.24237779080894106, + 0.2318702888580676, + 1.0352472972512925, + 0.4573409128131671, + 0.28320739441795323, + -1.0214571469180447, + -0.010085521702337353, + -1.990778610968323, + -0.23294220898382634, + -0.8272632292381021, + 0.08948016046560357, + -0.2515086919419446, + 0.6711532211923003, + 1.7342104609172728, + -1.2174058025747443, + -1.0221641302789066, + 0.7817673742546976, + -1.4920727283504345, + -0.44460513945120644, + 0.14417369704825783, + -1.7903080050220619, + 0.4212009947066779, + -0.26711077156188034, + 1.5695068008576476, + -1.748402801489203, + 1.504572275678796, + -0.8118406564041991, + -1.1983222673785627, + 0.17347640142361004, + 1.2715909667771088, + 2.3464758864728754, + 0.005247036919813949, + 0.26684314514213736, + 0.618669993885617, + -1.5168508562286191, + -0.004495386168535826 + ], + "shape": [ + 4, + 10 + ] + }, + "unnormalized_widths": { + "data": [ + -0.4419745502775611, + 0.2716846620032497, + 1.056891815220453, + -1.766660462065303, + 0.9862554128547407, + -0.5671097889774945, + -0.236684969212557, + 0.006353391113712342, + -0.9123003462772722, + -0.6303061763998286, + -0.39855088265346844, + 0.6913826362365914, + -0.2669938796469903, + 0.6572105842025696, + -1.6658941949653996, + -1.9141455779164431, + 0.9987723890648004, + 0.7339018277109903, + 0.29946905473477337, + -1.0776595577849437, + -0.03349665465339956, + -0.6908995915989536, + 1.2349001065104936, + 0.24146450013043255, + 0.5416192945872663, + -0.6946378083376783, + -0.7430063996503045, + -0.1178562026752313, + 1.5611147357425514, + -0.7144626307394635, + 0.4653856846834471, + 0.9340946495293597, + -0.0536259352810602, + 1.024893906047934, + -1.6038248406643751, + -1.8998197674219859, + 0.8604090757781153, + -0.4640699001641785, + -0.44543010608126465, + -0.1733949212480647, + -0.9017421148082815, + 0.007740374152002723, + -0.5570592823399676, + 1.4743616877386294, + -0.7171640736861331, + 0.20098777565954465, + 0.2677046171271731, + 2.5516359902155408, + -0.4752797645451217, + 0.8235071239868901, + -0.06274129239437512, + -0.6671725645531117, + -0.5656502203844694, + 0.4981374432781232, + -1.7052945505852435, + 0.2850774357972623, + 1.419073672127401, + 1.5143278017369346, + 0.1203101077666772, + 0.2689397009780952, + -1.296892295720691, + 0.18009974321625744, + -0.8593183061725375, + 1.0955977195397684, + -0.8269759713539973, + -1.679818997288417, + 0.13591860618347618, + 0.5917908476437124, + -0.28742284714076166, + 0.6587220194553163, + 0.2709020500524623, + -1.149354106366614, + -0.27339532971884056, + 0.05755031834255014, + 0.8500577134471289, + 0.9478581977287968, + -2.2564392962632804, + 0.3110659164171673, + 0.06194949970627075, + 1.2335858050609876, + -1.0639335147664135, + 1.1537781680086128, + 1.4596822474021245, + 1.5373855077150955, + 1.386220003799335, + 0.24116311123514478, + 1.6076780385623348, + 1.7935701786527096, + -0.13499022826985557, + -0.25067215098463164, + 1.1518687500322395, + -0.16649431697299996, + -0.3651179239279022, + 0.7065959337431358, + -1.567067835424215, + -0.3357557133568141, + 0.05654075352457806, + 1.4204192411699395, + 0.6755821599239246, + -0.9357700117370953, + 0.10049062711321136, + 0.6228199266745641, + -0.4495829648314594, + -0.914993910952663, + -0.14374993267103833, + -1.203808008890236, + 0.07285694107615862, + 2.0141519804280326, + -0.7930321767525238, + -0.477376582450768, + -0.7873047578273823, + -0.07195237845460918, + 0.8247512770846365, + 0.1908921136519597, + -1.4979151343522192, + -0.5774972459244567, + -0.27133546200730857, + 0.9813962548064747, + 0.7390035236473482, + 1.3917635799868133, + 0.6491852100764046, + 1.2039325588275476, + 0.9874814278993451, + 0.7542670407256892, + 0.5781475806410152, + 1.0836356031788037, + -0.6390685802515629, + -0.48981275707449656, + 1.2623218276299724, + -0.27964434459183907, + 0.55464291590441, + 0.6308280535150428, + -0.6187129493373275, + -1.81710056168182, + 0.12247935993756598, + 0.7788349259051482, + -0.17017167872263755, + 0.043301694413358134, + -2.089678394322878, + -0.9289772719574583, + -0.8074523880762338, + 0.39794575558975936, + 0.7354513501370189, + -1.236806983856194, + -0.35701990837483416, + -0.6836642340698702, + -0.17214669458374612, + -0.4073494437665049, + 0.7496549362718903, + 0.9152466532285749, + 1.292950920601979, + 0.09414818556627164, + -0.38366959466913275, + -1.2609592964755483, + 1.0775925750761937, + 0.6273242190434466, + 1.5303753088083207, + -1.3500142617068518, + 0.542614186666463, + -0.9181737147627111, + -0.06381394399314673, + 0.9978706959761301, + 1.622897663758754, + -0.16614605841589397, + 0.34666368467190317, + -0.9591633938060772, + -1.6966150893683734, + -1.2814351976260474, + -0.26992650119900263, + 0.10762582501581724, + -1.2859520744264117, + 1.5023826600840733, + 0.05731628365819024, + -1.4611815421227203, + 0.11844701357721124, + -1.7097420738232683, + -0.6481962873525597, + -0.6199302841848613, + -0.041677237819211586, + -1.912936627272994, + -0.22069702937409055, + -0.8373282020867259, + -1.2102204195821435, + 1.1367489812535863, + 2.147097339967129, + -0.39588656758208457, + 0.028256106155769695, + 0.4780347016744464, + -0.7879795594263654, + -0.4819936051666644, + 0.2945839162465962, + -0.23035924547300898, + 0.3179810499974549, + 0.6885676803300341, + -1.3749595715279772, + 0.08511336387296158, + 1.2413785514414426, + -0.16773829248295685, + -0.6272882310444161, + 0.8067649829193638 + ], + "shape": [ + 4, + 10, + 5 + ] + }, + "unnormalized_heights": { + "data": [ + -0.20314173587596263, + 0.41098923499226014, + 1.0423699879551078, + 1.0111021595209633, + -0.4005606761098726, + -1.5400573604406471, + -0.27680484002022715, + 0.22713168745929768, + 0.6772607869261693, + 0.17859923880951847, + 0.8068577169597195, + -0.007188030468140175, + 1.498700405481595, + -0.03019556689050345, + 0.36384746366328413, + 0.5946658094550371, + 0.8488782882116401, + -1.4076687315200758, + -0.6554716217755595, + -0.031424129829246916, + 0.002119877636176426, + -0.9320608828189856, + -0.4836971414215709, + -0.6652501624993471, + 0.6929251584566181, + 0.09402273928088646, + 1.6129488803273395, + -0.6274502774485388, + -0.6354031061312012, + -0.11014269025953972, + -0.5707540620409073, + 0.2782765045840637, + -0.131168025357405, + -0.014621424093204098, + -0.9264218514606051, + 0.4745604420052093, + -0.8170580592759659, + 1.090316288383937, + -0.4324600332331524, + 1.335598774236891, + 0.19522143581405244, + 0.7739177723108309, + 1.5545069571528942, + -1.1665892611773085, + -1.3514084083199862, + 1.548288167937, + 0.1962794539558957, + -0.5133236613036435, + 0.07380298739515019, + -0.6768899292657692, + 0.29925023340409146, + -0.296150919949663, + 1.797754234791071, + -0.2256181742270194, + 0.7611949933670529, + -0.4752871063125886, + 0.3315215039285197, + 1.029960808199997, + 0.6816100143946399, + 0.6027021077218985, + -1.0103642468106544, + 1.8110431275164358, + -1.242898069994524, + -0.018728729244763993, + 0.3435160724199721, + -1.4671607394596426, + -0.19175241722939693, + -0.6020537188199624, + -1.2173643223619721, + -1.3293671463788634, + 0.6536593231539053, + 1.3132412975399521, + -0.03123667988254042, + 0.6955788134794414, + 0.14788650580579352, + -0.7267532390860204, + 0.11041826907812055, + -0.9430074874135619, + -0.024493143878920487, + -0.48165316592257534, + 0.24917168366560993, + -1.2762292342683303, + 1.5849730687530486, + -0.7991803895400539, + -0.9017993731080081, + 0.9288465316923816, + 0.8731470189472956, + -0.5447919995948828, + 1.9920454844101763, + 0.8504258697467754, + 0.8239191817788585, + -0.7591698936771379, + 0.23182973839731683, + 0.19368409881631005, + -0.8829189545685341, + 1.1838674589073483, + 0.12891223912089034, + 2.177326020841736, + -0.29262366803238393, + -0.1269124796513161, + -1.1195878651441893, + -1.4858191479600487, + 0.2730005353024221, + 0.056387506785403815, + 0.23274595474814475, + -1.219750082862799, + 0.6164470190315824, + 1.313122245567395, + 1.231436269706581, + 0.7101787049919178, + -0.29432626362033404, + 0.09290205186381154, + 0.5124772757514626, + 0.5722528182048355, + 0.0069296369726892315, + 0.05265567712992814, + 0.024119499618272, + -0.1734020738525506, + 1.980986216379721, + -0.3625440782730144, + -0.46610241187034174, + -0.1434065696905791, + 0.7199388940089966, + 0.18661178764402744, + -1.6874596492942056, + 0.08390244767011035, + 0.010205327986369436, + -1.4490806809539707, + -0.33044907219753716, + 0.9425485271043043, + 1.0202696745617659, + 1.1809578134345797, + -0.4778501888006765, + 0.5058710736668895, + 0.26915516555170543, + -0.48408003859590065, + -1.2238040021173409, + 0.4365480813048655, + 0.9362614654567366, + -0.132580840204314, + -1.2632446956429335, + 0.7389689210845545, + 0.11501735450468505, + -0.48921141096064424, + -1.2546554715250677, + -1.3274229125145334, + -0.20570447804044192, + -0.14388808008330586, + -0.12220432019599553, + -0.01670896018294231, + -0.47610390457226875, + -0.2426945946706641, + 0.8173767150222387, + -1.481720319316519, + 0.3306231644074124, + -0.22060606606267064, + -0.9954168307888536, + 0.9803566659279874, + -0.2658724362020457, + 0.238746989349253, + -0.9029006394421663, + 0.199025684270468, + 2.6380786365181796, + 0.8742346952912949, + 0.8505308967944145, + 0.3572785103117973, + -0.8838416943783096, + -0.5830552203885363, + -0.39032216747009113, + 1.8547476587140206, + 1.3346481596559676, + -0.14668266051521825, + 0.29946965798993297, + 0.1038460098414191, + -0.09599439254200862, + -0.34832623473704294, + -0.6243273110799229, + -0.42153332375097163, + 0.1418089057088939, + 0.8622146933847438, + 1.452604901391212, + 0.6644577289945207, + 1.0157942265133257, + -0.7306000207552646, + 0.8387576576827396, + 0.6099936154303692, + 0.42155963063098567, + 0.9953749290985876, + -1.5066610314377438, + 0.43343311885236446, + -1.3121219556703199, + -0.7507398495238957, + -0.2772483620390129, + -0.15323533618009222, + 0.8545701717376761, + -0.952862648290038, + -0.7025528681323762, + 0.0627858851186502, + 0.011103897250099562, + 0.44470601078973404 + ], + "shape": [ + 4, + 10, + 5 + ] + }, + "unnormalized_derivatives": { + "data": [ + -0.5073466627133774, + 0.4640089663510305, + 0.42905396760192394, + -0.024940177091060257, + 0.2037444999220601, + 1.3825460700799999, + -0.18594279184343007, + -0.13199258968708819, + 0.95499847971227, + 0.36061593000733744, + 1.5299017471873564, + 0.1682232955566942, + -0.3435279013303228, + -1.7971615354714086, + 0.4693013003698667, + 1.493186440171208, + 0.12356938192583447, + -0.17905683036898526, + -0.6760364508503255, + 0.393978576087819, + 0.64949409980517, + -1.1431268808798611, + 0.40178181976552013, + 1.5351166370632565, + 0.8854633803752585, + 0.8200277372952126, + 0.04080196720634493, + 1.7491823889925064, + -0.012391613968610236, + 1.1152018043934078, + 0.6179595164022554, + -0.2567371630523861, + -1.676382260579765, + -0.23626544563697244, + 2.779642620636317, + 1.3711024513262018, + 1.6243002171198435, + 0.023223437469212883, + 0.3333688295519079, + 1.776099742658068, + 0.23740050210541427, + -0.10972801634299746, + -0.4166779124774877, + -0.5503437577515686, + 0.7109974741886405, + -0.3132870217495473, + 0.20474610893599518, + 0.529929231024398, + 0.7392324154597758, + -1.289651880858446, + -0.5846732948053854, + -0.2975140552555919, + -0.3602379545435858, + -0.28291660184068396, + 0.16236109077792268, + 1.6561866571093768, + -0.08969148970529098, + -0.1403733323733132, + 1.653704353836856, + -0.06677018843436126, + 0.7967438370130172, + -1.0487567518729046, + 0.38249765783308975, + -0.09510169198820755, + -0.26937819568016647, + 1.0506749629051164, + -0.8217549133103285, + -1.2605463844968352, + 0.06630459711885793, + 1.4068597142149075, + -0.23906347041924664, + 0.40581500848228097, + 0.34824046380765017, + 1.4576160043118096, + 1.1435094589159382, + -1.1561650245007897, + -2.0501574100159003, + 1.972996262759935, + -1.4023249157404538, + -1.252355835467333, + -0.5939367844331447, + -0.8883536777451707, + 0.8826283214585245, + -0.08537770872725445, + 1.2430570323434795, + -0.8288361760454861, + 0.3740588912011981, + 0.6908593029167234, + -0.237603292141705, + -1.4491304746103755, + -1.1503461466656624, + -0.06074698383716112, + -0.7611869714619356, + -0.21606476673148656, + -0.34312322567272335, + 1.4977421535617197, + -1.0246406937201316, + -0.6374111111312827, + -0.6874324035898947, + 0.4555428253347865, + 0.747331595480188, + 0.746181920181807, + 0.3114390139045884, + -0.593601279371061, + -0.8244884478356891, + -0.2384891022367264, + -0.6674385398325516, + 1.1105282112735941, + -0.28770161892364415, + -0.48758783299086855, + -0.8176805756822754, + -0.4529194605131728, + 0.4266942129682067, + 0.959445148595098, + -1.2042407824096886, + 0.32657568443451324, + 0.9547547392627056, + 1.111198856760241, + 1.1033711088631535, + -0.9222774516745698, + 0.015288963161081063, + 0.692635853414707, + 1.6871393477277539, + 1.3793012042095691, + -0.8531034123811346, + -1.6106188284100986, + -0.2882501627532076, + -0.07348821377104808, + -0.252825439264895, + -0.8911729028578352, + -1.4964971728273821, + 0.45852900331123253, + -1.2090009788666547, + -0.15934430798101568, + 1.357487397797586, + -0.6324727023902199, + 1.2747860118984116, + 0.8169733493456708, + -1.1168848140267789, + -0.9667484660578313, + 1.211796851951881, + 1.426407289877944, + -0.3775828028349923, + 0.5675809579603023, + -0.003458751262658035, + 0.20991311225897344, + -2.1396527931651366, + -0.18661113504227145, + 0.425938031142508, + 0.9491730513198636, + -1.1358577850612366, + 0.4790736156730508, + -1.2943260486129644, + -0.8749887305231243, + 0.5211816427115885, + -0.8246528349739011, + 1.7195791979873167, + -0.3173219909708357, + 0.21110777465868505, + -1.4271230211989725 + ], + "shape": [ + 4, + 10, + 4 + ] + }, + "outputs": { + "data": [ + 1.9339028943779266, + 1.11421988558889, + -0.42197727509315885, + -0.476029405566483, + -1.2483046495149022, + 1.6004071514895317, + 0.5313728916219365, + 0.10384965984925965, + 2.456998918588678, + -2.2632805151911053, + -0.5857489536913694, + -2.2219616276614382, + 1.869369003838893, + -1.93153721655368, + 1.5859303738009816, + 1.5589315450576469, + 0.5017763041989596, + 1.5363818054493383, + -2.3176639527548177, + -1.4333335079417524, + -1.445733610842785, + 0.12056301759087473, + -2.574201538280003, + -0.55456591149102, + 0.10180164225306532, + -0.19378840390634317, + 0.9227754990410411, + 0.011018458865574399, + -2.684602298027738, + 0.8095865207597701, + 0.8266911303353263, + -2.2455603887630513, + -0.34474197332753287, + 2.0481347447494245, + 0.4375288729644756, + 1.033164040845549, + 1.7872257701259382, + -0.07206412074167462, + -2.1356471141029694, + 0.48098264485491926 + ], + "shape": [ + 4, + 10 + ] + }, + "logabsdet": { + "data": [ + -1.4675210692989342, + -1.067310211976055, + 0.14419547484850925, + -1.9485519507529072, + -2.20764645599918, + -1.1681871821958048, + -0.5374838340935822, + -0.9097839152439218, + -3.3201129842532837, + -3.675045844551764, + -0.9144255442750324, + -0.3373297475889202, + -0.6997453365127888, + -1.2178046448117705, + -0.4445588126514184, + 0.9267799822797531, + -0.6963832281734724, + 0.5840714217957924, + 0.03488088162906966, + -0.805999892145376, + -1.1272514008047247, + 1.0451391736003843, + -1.5991455469573763, + -0.4562886598302148, + 1.2521448786399687, + -1.6697280722784094, + 0.019017996066540865, + 0.20786467131674569, + -2.580447328202009, + -1.0605820634029133, + 1.4089747749221635, + -1.3185885277804952, + -0.49619613672187163, + 0.018390055827107177, + -0.3367753785704455, + 1.224586394069803, + -2.354301504990093, + -2.935230110842744, + -1.0977673032699535, + -1.5464665560652113 + ], + "shape": [ + 4, + 10 + ] + } +} \ No newline at end of file diff --git a/test/data/nsf_test_data_f64.npz b/test/data/nsf_test_data_f64.npz deleted file mode 100644 index f538b56b08addbff6c31f8536d450515a719f5dc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7024 zcmd5>c{tVI);AXoB1I(`iZX;!8fZzHBuY}|WQa;KhD?zmDpQIgl_-@)_L+TPLG(fo>;y`!bo&;3K@mz}H@_ML3ZuUakqZ&s2ORaDp_E9xcs?+zaJ zta9;DnOblQwLYZ6RRy14s}1sr*1+?koUWzB3glfpd2+Q!AGnGru&Z*A!PFv)b#2ld zZWGc}cA1{VzF1JEmTnT+RwF2UtIh=aWjH9sO*-ztc z6kL6}>)I)tg^UIcYAur~7)2Ys>yK1&Rq%jIXnyG`sx_Gc%7Dc!`rsk|5R3jryuD5k=LmN@Pk~2a?oZR_v zf9wnyEf;A@Mi;TEMMWRFzce7cIWnR=i55siQjBb)-%-l zWqZK|+q+!3V((~Ye%aQ;%F^tbt);WgKOX|uKOVwA*8lt=$cpax_um4~x2oa0Xs+ zW-XE%$aqINue}_~QSjvRIjUYaoH$R_x$jCr2Xa-tm30B4ZNj39%?-G_ztD8Q9IM*V z)pUnbGfY@DzNuW?ss`n~X4nJHk+JTSu7T|L0g&u{Hz=x1M)4A#@>b{9sA$eAs^de% zy*rMV8Cz0d-#beR&)p^@NhK&$7}SA!zEZqf@DP}Fzh9!TjExYXMzRevAz}neC-I^u z6MPH;^V8aBP^KEZ;)n$i(k@+VH+|;=Gt6Y6E|&>dwBqtt?E^HRoPHE|cXSM-zs1=V zZt6h?wbT^Ghp*7q&{c7zqY|!796B1QPlX|)?{?STj^b^-!q#gtBuw^x8z*B(0iDVO zCH5uN_+0Fn<=Z|YYIP;OH@Wi>+5@t^Iu`Xo`-h2+>)o9oyT@dueOy0SUhCv;;{T43 z(|7yk&oR*WYgpO_&jCo0`0+Ju4-<;1?*qAWr=YpyV?m-!G6XV~aJx11;)@53F3s!u zF{U$oli#uu7Y72W&>Qh%+KmM&>f4iDMA3WZ4&QJeN2by_aHpo0W0`m{Q zyqj(6K(`=HDTdWJ98TBSr*bV7o(9JXOWBt~fg-->m7T=q?f$EAaX&~lu29v@Y)5rd z4~`pywV*&X5Q)7o2F*_=?!1hpW1gg1siRvFqMxzlBJwzFd_6H|SKEzOk7pFM8Zsa~ zpd{Hbb_`bCvkkFeOoPo0VY|1-DlvFVZz#Vw5v-=nA1(bb4bgFaG;ku}z&Dxp)0<{- zom)m&>y2jIrbFFX&6|#VU!+HtU73KJVM4youp2}-dtXaVhMc|eYwrWnDNLyfFm{MX#r?PxlnobdtdM=#ElP5ufvZHkY zHxtw*majkaVr5X}gy23ZaJ}P zGDN1n($H~4B6EYn1gxTpFl%J;@RiGIW7=2`IP|@IKP$|lM%K@_X6J9grJa6lA~7AX zUf9BQEolUpJxOBu787VUs~6*TDFf?7=NyFA<^U!VZEQU&@cl%zZHm#avvKip3a>LT+C^(ZG$O(wm|&1$v&38&l0k zQ1F;!IR}vpzXRd3E=QkVKq&Ml5ZYMTUakN(t3}@Q!RN>wDFjjm#LvvEy3{!jlzjp1D<$huhHGgZ_iPgmeIOi4onZhJ zl~H30=U_drUU|%$3V>6(bfM@D6!ohp^;__85%QOEHt8X-FV|%zDc3{zq7c<>z61~- zF}{&`kA~A#JsMs!y=cuc3GsX@Cow(biPqXi3VaJ%*|+O73HOwJXFIh( zxvCitZcb&=z>@y@Q0Wmw?_svsBkPzbs-&;Uz$W)F*~14Mg=IB>+AT&{VNC!R?Dw^&crMk^XjifOla(j zG8XEW6Rma;B{q<~fu3R;Bde2LHn+(svGlRhdDM(UwtI zA{pGhB=tKizrSAb{?KFeThnxDd3QNJ8(uY=SXK&OgWm)TDRjYn$?Eas=v+`sZ%nER zm_l2@o$@&z<1k2nI>bco2FUW$(I9j8XH2q+!?6Im6;-RlZMLYYAZLW(vj8h zsUBZIFHlK$-|1}ogeLN0o~sYi(Mw_R+tLsH&{nunUrIOv`3)W)F#X&LbJI4u4b(e0 zK1lc3s@@N;-7jqH)9Hb|8iY14W3^7TVJzgZ zzTdE$2#$LrUapSHK?ggtgoJ4lZn9rqiB)~z%fXv)Hn#?8!F3N?=p%3-Oj7roWC8T% zlUhiV9SA$Nt|s@;&@CWWrO}NDUSoVLB@UlqrK6SOTeWfEX5Z$SSUC+a?%H-Hv=n`w zX|Eak&V=6rmh2UmH~CyOa0n0+W&tJ+Y5W9fr_o^Bnu z_qxvmryrAePhayYB#j{Z$(dp5NC9wHe^ggH?+G-&N4%q(>%o=XEj01$5ES#0nl7id zV)%mg>gVEaM>JY09~_#Ls~cIi%*|E;kOG)T6!zuO^V58%^km zB;$U4FY&&f61b0Bcs@=K!Qgm7dNO|nFy<8VHraKdjOv@GV^_NI@CCEo_VNteai4O@ z;$;uEI|Zw3bMV6S|S*SKQ zJ01M|Jj?7#vmw#KpBTTQ3cWScqjuTUV?9%?o`w|H=!0t^Xf+q0)bl7d~$A9~S-ekCFw3jA9 z8GsGff`tMsXdrMmTr2zZFc?r|mw2lZ@v7!<%#QOE99r2%4Lp4rmnsY$EIxM&8S>k2 z>rbb{TQUA+?`&&u(Q&u?5$dhbd*`;2Yt0PCyeKz~JlX)4hynZP+O4q0g@12Y)+9bi za@{TdeHwa=<41C=SqRk`533CBj6!g^gol;?EHH`PDY_Z$@OCKaskq=LbkLG~YiT!( z1Yhk}{L=jprc?N2P-F_ERAkutme%1?yN?dF0Tle|l>|aX)0i#b)SxHO2o(xBY!{zR zV)Vs)`@{urjXS@N+w0~qo=rLIESo=v>r>aLXs@e5wePM^*7ioG=k0AJF!Hfa`_9~JD-vFe-gqo&BLzH%f^9)_0wb0UZ_X+i#|D4a!JMuE;HcAb z3a^=h&+YP6uT&|J?p%C3ylNcC5r-dC1ya2^I28%==h$>M+p$OVMMOR=Dt}u)g_K3({Eo(QDr@Xh&?U3-~bu zLtYaK)8!?2#&5JqWJehk%58XeZ@dvaIs&TF9cz$p$Hx609!yv&z*_V?s|PJ4t{M|B zOkij1ca9fS1}c3ypu3p61HkxZD3pGIyfOumkMw#xYV!P@Mh+cRzmXLMpY>tRd!39( zml>E7cUEpR_LCy?)v;=wExF0~ZXB~jJKCP43O*ldIb4fm{R+z^o31=nNODN|0@!}1ViwPE`y zl*`C-wRu5>zKA8J_c*_S=8I+;$45j87TGk(vdU4MIa$yWa-9Ji%Y|5ZtYo{hGsfu zi$l(~gLo_5RglDhh|GX{pM7Rv+0H3HR)idydyF8$Gi-!p40<5c^kOL>5#fPyf4{j0?W1tWXAW`!C+9#nY$JgTyLWrK5%gkj2?VBn9Mzi4bw3! zTK*$w@^)&%yMF4A0-wvmwP)PM2ET)()bg!&mUlr+p+Lk3F$M~aOKp6^n+paeTtzE9I0<_=)DqQJG@y*+ zy`>&8bMVTae{WMQ72c7(gESsZLzPif^|p0vgyPg25m5s_&`;skYR+gTbd9|Hs@$Fj zk@v$UmsghJEA_nZVuNImV-aw>cb-M<%W(P0Z;lzbFJ#H|8`f5=9vsVIW$gy(6y|;B zydD%dy)uYtHVv<0S*@c&*a!_3-P2PKI0#QFK3_iZiIUA+i{k1R%rNWFnYogR>DwlF`qC%xu+klSZ|h}* zcyY-OqMkH7ev@?jeR>Z_txs<0%^8PwbB|J`>=}Idth#R_a~z*sw90R+i-goG&xQ@I zFCwhaF+4H9P-IUUSJdL$-i@xi$wt%aBbXy*R+;S7hAXbWlNoG%3b*6eJzsNw8VQNv z!p`nd(AK$4d`svQJjp(AGWSkBw#PnmbPlItNAlGw4Y_#?Db?Cc6CS|~g=>CQ_v>-$ zl#ch-mm^@tDsz-jM#d1^E`$2*?RZZsd7sXsGT@W>cw}6Nj=sKY2L>L{z%x)f`&1rt z!S`9&S;M^8aWO{Hf$; zN#&Okq2H8na{r;^UkWULYWYb=e^SezQ}{P6yet0D@&`HnQ`65-`0Fy3{HCdW70bT{ p#y