diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..e1fdfd3 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,51 @@ +name: CI +on: + schedule: + - cron: 0 0 * * * + pull_request: + branches: + - main + - develop + push: + branches: + - "**" + tags: "*" +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1.11" + - "1.12" + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + file: lcov.info \ No newline at end of file diff --git a/Project.toml b/Project.toml index efcacd0..baf82e1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,30 +5,47 @@ authors = ["marcobonici "] [deps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Bijectors = "0.15.17" -ConcreteStructs = "0.2.3" -Distributions = "0.25.123" -JSON = "1.4.0" -LinearAlgebra = "1.12.0" -Lux = "1.31.3" -MLUtils = "0.4.8" -NPZ = "0.4.3" -Optimisers = "0.4.7" -Random = "1.11.0" -Statistics = "1.11.1" -Test = "1.11.0" -Zygote = "0.7.10" +BenchmarkTools = "1.6.3" +Bijectors = "0.15" +ChainRulesCore = "1.26" +ConcreteStructs = "0.2" +DifferentiationInterface = "0.7.16" +Distributions = "0.25" +ForwardDiff = "1.3" +JSON = "0.21" +LinearAlgebra = "1.10" +Lux = "1.31" +MLUtils = "0.4" +Mooncake = "0.5.23" +NNlib = "0.9" +NPZ = "0.4" +Optimisers = "0.4" +Random = "1.10" +Statistics = "1.10" +Test = "1.10" +Zygote = "0.7" + +[extras] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test", "DifferentiationInterface", "Mooncake", "BenchmarkTools"] diff --git a/README.md b/README.md index 771b521..a50a463 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,10 @@ distributions inside [`Turing.jl`](https://turinglang.org/) probabilistic progra ## Features -- **RealNVP** (Real-valued Non-Volume Preserving) architecture with affine coupling layers +- **Architectures**: + - **RealNVP** (Real-valued Non-Volume Preserving) + - **NSF** (Neural Spline Flow with Rational Quadratic Splines) + - **MAF** (Masked Autoregressive Flow with MADE) - Full `Distributions.jl` interface: `logpdf`, `rand`, `length` - `Bijectors.jl` compatible โ†’ plug directly into Turing models as a prior - Training via `Optimisers.Adam` + `Zygote.jl` autodiff, mini-batched with `MLUtils` @@ -15,10 +18,10 @@ distributions inside [`Turing.jl`](https://turinglang.org/) probabilistic progra ## Quick Start ```julia -using SimpleFlows, Distributions +using SimpleFlows, Distributions, LinearAlgebra -# 1. Build a 4-dim flow -flow = FlowDistribution(; n_transforms=6, dist_dims=4, hidden_dims=64, n_layers=3) +# 1. Build a 4-dim flow (options: :RealNVP, :NSF, :MAF) +flow = FlowDistribution(Float32; architecture=:RealNVP, n_transforms=6, dist_dims=4, hidden_layer_sizes=[64, 64, 64]) # 2. Sample training data from your target distribution data = Float32.(rand(MvNormal(zeros(4), I), 10_000)) @@ -60,8 +63,8 @@ my_flow/ | Architecture | Status | |---|---| | RealNVP | โœ… Done | -| MAF | ๐Ÿ“‹ Planned | -| NSF | ๐Ÿ“‹ Planned | +| MAF | โœ… Done | +| NSF | โœ… Done | ## Running Tests @@ -69,8 +72,10 @@ my_flow/ julia --project=. -e "using Pkg; Pkg.test()" ``` -## Running the Example +## Running the Examples ```bash julia --project=. examples/train_multinormal.jl +julia --project=. examples/train_multinormal_nsf.jl +julia --project=. examples/train_multinormal_maf.jl ``` diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 0aabf77..46edf98 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.12.1" manifest_format = "2.0" -project_hash = "64c6eed7b4f2bf47e0f4350e25ac6a3f80fb5a0d" +project_hash = "3d2728b3089e56da155cbe07b6f794382965b85e" [[deps.ADTypes]] git-tree-sha1 = "f7304359109c768cf32dc5fa2d371565bb63b68a" @@ -1753,10 +1753,8 @@ uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" version = "0.1.0" [[deps.SimpleFlows]] -deps = ["Bijectors", "ConcreteStructs", "Distributions", "JSON", "LinearAlgebra", "Lux", "MLUtils", "NPZ", "Optimisers", "Random", "Statistics", "Test", "Zygote"] -git-tree-sha1 = "9ee8f18c24029fe3ad3362abade4e005e283e4d1" -repo-rev = "main" -repo-url = "https://github.com/CosmologicalEmulators/SimpleFlows.jl" +deps = ["Bijectors", "ChainRulesCore", "ConcreteStructs", "Distributions", "ForwardDiff", "JSON", "LinearAlgebra", "Lux", "MLUtils", "NNlib", "NPZ", "Optimisers", "Random", "Statistics", "Test", "Zygote"] +path = ".." uuid = "7aff1418-a6e2-48c2-ba03-ae8e32e98757" version = "0.1.0" diff --git a/examples/Project.toml b/examples/Project.toml index 5225e9e..7618a38 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,5 +2,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" SimpleFlows = "7aff1418-a6e2-48c2-ba03-ae8e32e98757" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/examples/train_multinormal_maf.jl b/examples/train_multinormal_maf.jl new file mode 100644 index 0000000..5e6f85d --- /dev/null +++ b/examples/train_multinormal_maf.jl @@ -0,0 +1,111 @@ +# examples/train_multinormal_maf.jl +using SimpleFlows +using Random +using Statistics +using Distributions, LinearAlgebra, Random, Statistics, Printf +using Turing +using Optimisers + +function run_maf_example() + # โ”€โ”€ 1. Setup Data โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + rng = Random.default_rng() + Random.seed!(rng, 123) + + dist_dims = 4 + # Create a correlated 4D Gaussian + ฮผ_true = [1.0, -0.5, 2.0, 0.3] + ฮฃ_true = [1.0 0.5 0.2 0.1; + 0.5 1.0 0.4 0.2; + 0.2 0.4 1.0 0.5; + 0.1 0.2 0.5 1.0] + + true_dist = MvNormal(ฮผ_true, ฮฃ_true) + n_samples = 5000 + target_data = rand(rng, true_dist, n_samples) + + println("Target mean: ", round.(mean(target_data, dims=2)[:], digits=3)) + println("Target std: ", round.(std(target_data, dims=2)[:], digits=3)) + println() + + # โ”€โ”€ 2. Initialize MAF Flow โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # MAF is generally more expressive than RealNVP for same number of layers + n_transforms = 4 + hidden_layer_sizes = [32, 32] + + println("Training MAF ($n_transforms transforms, $(hidden_layer_sizes) units, 500 epochs)โ€ฆ") + + dist = FlowDistribution(Float32; + architecture=:MAF, + n_transforms=n_transforms, + dist_dims=dist_dims, + hidden_layer_sizes=hidden_layer_sizes, + rng=rng + ) + + # โ”€โ”€ 3. Train โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # We use a smaller learning rate for MAF as it can be more sensitive + train_flow!(dist, target_data; n_epochs=500, batch_size=200, opt=Optimisers.Adam(1e-3)) + + # โ”€โ”€ 4. Evaluate Density โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + test_data = rand(rng, true_dist, 1000) + lp_true = logpdf(true_dist, test_data) + lp_flow = logpdf(dist, test_data) + + println("\nโ”€โ”€ Density Fit โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") + println("Mean log-pdf (true distribution): ", round(mean(lp_true), digits=4)) + println("Mean log-pdf (trained MAF flow): ", round(mean(lp_flow), digits=4)) + println("Difference: ", round(abs(mean(lp_true) - mean(lp_flow)), digits=4)) + + samples = rand(rng, dist, 5000) + println("\nFlow sample mean: ", round.(mean(samples, dims=2)[:], digits=3)) + println("Flow sample std: ", round.(std(samples, dims=2)[:], digits=3)) + + # โ”€โ”€ 5. Serialization โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + save_path = joinpath(@__DIR__, "../trained_flows/maf_mvn_4d") + save_trained_flow(save_path, dist) + + # Reload and verify + dist_reloaded = load_trained_flow(save_path) + lp_reloaded = logpdf(dist_reloaded, test_data) + println("Mean log-pdf after reload: ", round(mean(lp_reloaded), digits=4)) + if isapprox(mean(lp_flow), mean(lp_reloaded), atol=1e-5) + println("Round-trip OK โœ“") + else + println("Round-trip FAILED โœ—") + end + + # โ”€โ”€ 6. Turing.jl Integration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + println("\nโ”€โ”€ Turing Inference โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") + # Observed data: one sample from ฮธ[1] with noise + ฮธ_true = ฮผ_true + y_obs = ฮธ_true[1] + 0.1 * randn(rng) + println("Observed y: ", round(y_obs, digits=4), " (true ฮธ[1] = ", ฮธ_true[1], ")") + + @model function linear_model(y, prior_dist) + ฮธ ~ prior_dist + y ~ Normal(ฮธ[1], 0.1) + end + + # Sampling with exact prior + println("\nSampling with exact MvNormal prior (4 chains ร— 1000 samples)โ€ฆ") + chain_exact = sample(linear_model(y_obs, true_dist), HMC(0.1, 10), MCMCThreads(), 1000, 4; progress=false) + + # Sampling with MAF prior + println("Sampling with trained MAF prior (4 chains ร— 1000 samples)โ€ฆ") + chain_maf = sample(linear_model(y_obs, dist), HMC(0.1, 10), MCMCThreads(), 1000, 4; progress=false) + + println("\nโ”€โ”€ Posterior comparison (ฮธ[1]) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") + exact_ฮธ1 = vec(chain_exact[:, "ฮธ[1]", :]) + maf_ฮธ1 = vec(chain_maf[:, "ฮธ[1]", :]) + + @printf " True ฮธ[1]: %.4f\n" ฮธ_true[1] + @printf " Observed y: %.4f\n" y_obs + @printf " Posterior mean (exact): %.4f ยฑ %.4f\n" mean(exact_ฮธ1) std(exact_ฮธ1) + @printf " Posterior mean (MAF): %.4f ยฑ %.4f\n" mean(maf_ฮธ1) std(maf_ฮธ1) + + println("\nโœจ End of script: Turing MCMC sampling with MAF finished successfully!") +end + +if abspath(PROGRAM_FILE) == joinpath(@__DIR__, "train_multinormal_maf.jl") + run_maf_example() +end diff --git a/examples/train_multinormal_nsf.jl b/examples/train_multinormal_nsf.jl new file mode 100644 index 0000000..6c31a90 --- /dev/null +++ b/examples/train_multinormal_nsf.jl @@ -0,0 +1,130 @@ +""" +Train a Neural Spline Flow (NSF) normalizing flow to approximate a 4-dimensional correlated Gaussian, +then use Turing.jl to run NUTS inference under both the exact prior and the NF prior. + +Usage: + julia --project=. -t 4 examples/train_multinormal_nsf.jl + +Note: Turing.jl must be installed in your global environment or activated environment. +""" + +using SimpleFlows +using Distributions, LinearAlgebra, Random, Statistics, Printf +using Turing + +rng = Random.MersenneTwister(42) + +# โ”€โ”€ 1. Define a non-trivial 4D target distribution โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +ฮผ = [1.0, -0.5, 2.0, 0.3] +ฮฃ = [1.00 0.50 0.10 0.00; + 0.50 1.00 0.30 0.20; + 0.10 0.30 1.00 0.40; + 0.00 0.20 0.40 1.00] +target = MvNormal(ฮผ, ฮฃ) + +# โ”€โ”€ 2. Draw training data โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +N_train = 10_000 +data = Float64.(rand(rng, target, N_train)) # 4 ร— 10_000 + +println("Target mean: ", round.(mean(data; dims=2)[:]; digits=3)) +println("Target std: ", round.(std(data; dims=2)[:]; digits=3)) + +# โ”€โ”€ 3. Build and train the flow โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +# Note: NSF typically requires fewer transforms but slightly larger hidden layers +# or just more expressive bins (K). +flow = FlowDistribution(Float64; + architecture = :NSF, + n_transforms = 4, + dist_dims = 4, + hidden_layer_sizes = [32, 32], + K = 8, + tail_bound = 3.0, + rng, +) + +println("\nTraining NSF (4 transforms, K=8, [32, 32] hidden units, 500 epochs)โ€ฆ") +train_flow!(flow, data; + n_epochs = 500, + lr = 1e-3, + batch_size = 512, + verbose = true, +) + +# โ”€โ”€ 4. Evaluate fit โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +N_test = 5_000 +x_test = Float64.(rand(rng, target, N_test)) + +lp_true = mean(logpdf(target, Float64.(x_test))) +lp_flow = mean(logpdf(flow, x_test)) + +println("\nโ”€โ”€ Density Fit โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") +println("Mean log-pdf (true distribution): ", round(lp_true; digits=4)) +println("Mean log-pdf (trained NSF flow): ", round(Float64(lp_flow); digits=4)) +println("Difference: ", round(abs(lp_true - lp_flow); digits=4)) + +samples = Distributions.rand(rng, flow, N_test) # 4 ร— N_test +println("\nFlow sample mean: ", round.(mean(samples; dims=2)[:]; digits=3)) +println("Flow sample std: ", round.(std(samples; dims=2)[:]; digits=3)) + +# โ”€โ”€ 5. Save the trained flow โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +save_dir = joinpath(@__DIR__, "..", "trained_flows", "nsf_mvn_4d") +save_trained_flow(save_dir, flow) +println("\nFlow saved to $save_dir") + +# โ”€โ”€ 6. Reload and verify round-trip โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +flow2 = load_trained_flow(save_dir; rng) +lp_reloaded = mean(logpdf(flow2, x_test)) +println("Mean log-pdf after reload: ", round(Float64(lp_reloaded); digits=4)) +@assert lp_reloaded โ‰ˆ lp_flow atol=1f-3 "Round-trip failed!" +println("Round-trip OK โœ“") + +# โ”€โ”€ 7. Turing.jl inference demo โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +println("\nโ”€โ”€ Turing Inference โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") + +ฮธ_true = [1.0, -0.5, 2.0, 0.3] +y_obs = ฮธ_true[1] + 0.5 * randn(rng) +println("Observed y: ", round(y_obs; digits=4), " (true ฮธ[1] = $(ฮธ_true[1]))") + +@model function inference_model(y_obs, prior) + ฮธ ~ prior + y_obs ~ Normal(ฮธ[1], 0.5) +end + +n_samples = 1000 +n_chains = 4 + +# โ”€โ”€ 7a. Exact MvNormal prior โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +println("\nSampling with exact MvNormal prior ($n_chains chains ร— $n_samples samples)โ€ฆ") +chain_exact = sample( + inference_model(y_obs, target), + NUTS(), MCMCThreads(), n_samples, n_chains; + progress = false, +) + +# โ”€โ”€ 7b. Trained NSF prior โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +println("Sampling with trained NSF prior ($n_chains chains ร— $n_samples samples)โ€ฆ") +chain_nf = sample( + inference_model(y_obs, flow2), + NUTS(), MCMCThreads(), n_samples, n_chains; + progress = false, +) + +# โ”€โ”€ 8. Compare posteriors โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +println("\nโ”€โ”€ Posterior comparison (ฮธ[1]) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") +exact_ฮธ1 = vec(chain_exact[:, "ฮธ[1]", :]) +nf_ฮธ1 = vec(chain_nf[:, "ฮธ[1]", :]) + +@printf " True ฮธ[1]: %.4f\n" ฮธ_true[1] +@printf " Observed y: %.4f\n" y_obs +@printf " Posterior mean (exact): %.4f ยฑ %.4f\n" mean(exact_ฮธ1) std(exact_ฮธ1) +@printf " Posterior mean (NSF): %.4f ยฑ %.4f\n" mean(nf_ฮธ1) std(nf_ฮธ1) + +println("\nโœจ End of script: Turing MCMC sampling with NSF finished successfully!") diff --git a/examples/train_multinormal.jl b/examples/train_multinormal_nvp.jl similarity index 100% rename from examples/train_multinormal.jl rename to examples/train_multinormal_nvp.jl diff --git a/examples/train_output.txt b/examples/train_output.txt deleted file mode 100644 index c76430e..0000000 --- a/examples/train_output.txt +++ /dev/null @@ -1,71 +0,0 @@ -Precompiling packages... -Info Given FlowMarginals was explicitly requested, output will be shown live  -ERROR: LoadError: UndefVarError: `linked_vec_length` not defined in `Bijectors` -Suggestion: check for spelling errors or missing imports. -Stacktrace: - [1] getproperty(x::Module, f::Symbol) - @ Base ./Base_compiler.jl:47 - [2] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 - [3] include(mapexpr::Function, mod::Module, _path::String) - @ Base ./Base.jl:307 - [4] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:12 - [5] include(mod::Module, _path::String) - @ Base ./Base.jl:306 - [6] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing) - @ Base ./loading.jl:2996 - [7] top-level scope - @ stdin:5 - [8] eval(m::Module, e::Any) - @ Core ./boot.jl:489 - [9] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) - @ Base ./loading.jl:2842 - [10] include_string - @ ./loading.jl:2852 [inlined] - [11] exec_options(opts::Base.JLOptions) - @ Base ./client.jl:315 - [12] _start() - @ Base ./client.jl:550 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:1 -in expression starting at stdin:5 - โœ— FlowMarginals - 0 dependencies successfully precompiled in 5 seconds. 254 already precompiled. - -ERROR: LoadError: The following 1 direct dependency failed to precompile: - -FlowMarginals - -Failed to precompile FlowMarginals [d4f3a2b1-9e8c-4d7f-b5a6-2c1e0f9d8e7a] to "/home/marcobonici/.julia/compiled/v1.12/FlowMarginals/jl_lWeZRJ". -ERROR: LoadError: UndefVarError: `linked_vec_length` not defined in `Bijectors` -Suggestion: check for spelling errors or missing imports. -Stacktrace: - [1] getproperty(x::Module, f::Symbol) - @ Base ./Base_compiler.jl:47 - [2] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 - [3] include(mapexpr::Function, mod::Module, _path::String) - @ Base ./Base.jl:307 - [4] top-level scope - @ ~/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:12 - [5] include(mod::Module, _path::String) - @ Base ./Base.jl:306 - [6] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing) - @ Base ./loading.jl:2996 - [7] top-level scope - @ stdin:5 - [8] eval(m::Module, e::Any) - @ Core ./boot.jl:489 - [9] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) - @ Base ./loading.jl:2842 - [10] include_string - @ ./loading.jl:2852 [inlined] - [11] exec_options(opts::Base.JLOptions) - @ Base ./client.jl:315 - [12] _start() - @ Base ./client.jl:550 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/distribution.jl:95 -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/src/FlowMarginals.jl:1 -in expression starting at stdin: -in expression starting at /home/marcobonici/Desktop/work/FlowMarginals.jl/examples/train_multinormal.jl:11 diff --git a/src/SimpleFlows.jl b/src/SimpleFlows.jl index 034d494..c7b9850 100644 --- a/src/SimpleFlows.jl +++ b/src/SimpleFlows.jl @@ -5,16 +5,23 @@ using Distributions using Bijectors using JSON, NPZ using Optimisers, Zygote +using ChainRulesCore include("layers.jl") include("realnvp.jl") include("normalizer.jl") +include("splines.jl") +include("nsf.jl") +include("made.jl") +include("maf.jl") +include("generic_ops.jl") include("distribution.jl") include("training.jl") include("io.jl") -export RealNVP, FlowDistribution +export RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow, FlowDistribution export MinMaxNormalizer -export train_flow!, save_trained_flow, load_trained_flow +export train_flow!, save_trained_flow, load_trained_flow, normalize, denormalize +export unconstrained_rational_quadratic_spline end diff --git a/src/distribution.jl b/src/distribution.jl index 5c1d648..5187731 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -11,8 +11,8 @@ A trained normalizing flow wrapped as a `Distributions.jl` - `n_dims`, `hidden_dims`, `n_layers`: architecture metadata for serialization - `normalizer`: fitted `MinMaxNormalizer` (always present after training) """ -mutable struct FlowDistribution{T<:Real} <: ContinuousMultivariateDistribution - model :: RealNVP +mutable struct FlowDistribution{T<:Real, M<:AbstractLuxLayer} <: ContinuousMultivariateDistribution + model :: M ps st n_dims :: Int @@ -21,31 +21,39 @@ mutable struct FlowDistribution{T<:Real} <: ContinuousMultivariateDistribution end """ - FlowDistribution([Type=Float32]; n_transforms, dist_dims, hidden_layer_sizes, - hidden_dims=64, n_layers=3, - activation=gelu, rng=Random.default_rng()) + FlowDistribution([Type=Float32]; architecture=:RealNVP, n_transforms, dist_dims, + hidden_layer_sizes, hidden_dims=64, n_layers=3, + activation=gelu, rng=Random.default_rng(), K=8, tail_bound=3.0) Construct and randomly initialise a `FlowDistribution`. - -`hidden_layer_sizes` sets the width of each hidden layer independently. -For convenience, you may pass `hidden_dims` and `n_layers` instead, which will -expand to `fill(hidden_dims, n_layers)`. -The `normalizer` field is `nothing` until `train_flow!` is called. +`architecture` can be `:RealNVP` or `:NSF`. """ function FlowDistribution(::Type{T}=Float32; + architecture=:RealNVP, n_transforms::Int, dist_dims::Int, hidden_layer_sizes::Vector{Int}=Int[], hidden_dims::Int=64, n_layers::Int=3, activation=gelu, + K=8, tail_bound=3.0, rng::AbstractRNG=Random.default_rng()) where {T<:Real} # If no vector given, fall back to the scalar convenience args if isempty(hidden_layer_sizes) hidden_layer_sizes = fill(hidden_dims, n_layers) end - model = RealNVP(; n_transforms, dist_dims, hidden_layer_sizes, activation) + + model = if architecture == :RealNVP + RealNVP(; n_transforms, dist_dims, hidden_layer_sizes, activation) + elseif architecture == :NSF + NeuralSplineFlow(; n_transforms, dist_dims, hidden_layer_sizes, K, tail_bound, activation) + elseif architecture == :MAF + MaskedAutoregressiveFlow(; n_transforms, dist_dims, hidden_layer_sizes, activation) + else + error("Unknown architecture: $architecture. Supported: :RealNVP, :NSF, :MAF") + end + ps, st = Lux.setup(rng, model) ps = Lux.fmap(x -> x isa AbstractArray ? T.(x) : x, ps) - return FlowDistribution{T}(model, ps, st, dist_dims, hidden_layer_sizes, nothing) + return FlowDistribution{T, typeof(model)}(model, ps, st, dist_dims, hidden_layer_sizes, nothing) end # โ”€โ”€ Distributions.jl interface โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @@ -54,12 +62,12 @@ Distributions.length(d::FlowDistribution) = d.n_dims function _apply_normalizer(d::FlowDistribution{T}, x::AbstractMatrix{<:Real}) where {T} isnothing(d.normalizer) && return x, zero(T) - return normalize(d.normalizer, x), d.normalizer.log_jac + return SimpleFlows.normalize(d.normalizer, x), d.normalizer.log_jac end function _apply_normalizer(d::FlowDistribution{T}, x::AbstractVector{<:Real}) where {T} isnothing(d.normalizer) && return x, zero(T) - return normalize(d.normalizer, x), d.normalizer.log_jac + return SimpleFlows.normalize(d.normalizer, x), d.normalizer.log_jac end function Distributions.logpdf(d::FlowDistribution, x::AbstractVector{<:Real}) @@ -77,14 +85,14 @@ end function Base.rand(rng::AbstractRNG, d::FlowDistribution{T}) where {T} z = draw_samples(rng, T, d.model, d.ps, d.st, 1) - x = isnothing(d.normalizer) ? z : denormalize(d.normalizer, z) + x = isnothing(d.normalizer) ? z : SimpleFlows.denormalize(d.normalizer, z) return T.(vec(x)) end function Distributions.rand(rng::AbstractRNG, d::FlowDistribution{T}, n::Int) where {T} z = draw_samples(rng, T, d.model, d.ps, d.st, n) isnothing(d.normalizer) && return z - return denormalize(d.normalizer, z) + return SimpleFlows.denormalize(d.normalizer, z) end # โ”€โ”€ Bijectors.jl interface โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @@ -93,10 +101,12 @@ end # therefore no parameter transformation is required for HMC/NUTS. Bijectors.bijector(::FlowDistribution) = identity -# Explicitly implement the VectorBijectors interface -Bijectors.VectorBijectors.vec_length(d::FlowDistribution) = d.n_dims -Bijectors.VectorBijectors.linked_vec_length(d::FlowDistribution) = d.n_dims -Bijectors.VectorBijectors.to_vec(d::FlowDistribution) = Base.identity -Bijectors.VectorBijectors.from_vec(d::FlowDistribution) = Base.identity -Bijectors.VectorBijectors.to_linked_vec(d::FlowDistribution) = Base.identity -Bijectors.VectorBijectors.from_linked_vec(d::FlowDistribution) = Base.identity +if isdefined(Bijectors, :VectorBijectors) + # Explicitly implement the VectorBijectors interface + Bijectors.VectorBijectors.vec_length(d::FlowDistribution) = d.n_dims + Bijectors.VectorBijectors.linked_vec_length(d::FlowDistribution) = d.n_dims + Bijectors.VectorBijectors.to_vec(d::FlowDistribution) = Base.identity + Bijectors.VectorBijectors.from_vec(d::FlowDistribution) = Base.identity + Bijectors.VectorBijectors.to_linked_vec(d::FlowDistribution) = Base.identity + Bijectors.VectorBijectors.from_linked_vec(d::FlowDistribution) = Base.identity +end diff --git a/src/generic_ops.jl b/src/generic_ops.jl new file mode 100644 index 0000000..097ea63 --- /dev/null +++ b/src/generic_ops.jl @@ -0,0 +1,84 @@ +# src/generic_ops.jl + +""" + log_prob(model, ps, st, x) -> Vector + +Compute per-sample log-probability of `x` (shape: dist_dims ร— batch) under the flow. +Pure functional โ€” no mutations, safe for Zygote. +Supports RealNVP, NeuralSplineFlow, and MaskedAutoregressiveFlow. +""" +function log_prob(model::Union{RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow}, ps, st, x::AbstractMatrix) + lp = nothing + for i in model.n_transforms:-1:1 + ks = model isa MaskedAutoregressiveFlow ? keys(model.mades) : keys(model.conditioners) + k = ks[i] + + if model isa MaskedAutoregressiveFlow + bj = MAFBijector(model.mades[k], ps.mades[k], st.mades[k]) + # Inverse is x -> u, which is what we need for log_prob + x, ld = Bijectors.with_logabsdet_jacobian(bj, x) + else + mask = st.mask_list[i] + cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], + s = st.conditioners[k] + x_cond -> Lux.apply(m, x_cond, p, s)[1] + end + + bj = if model isa RealNVP + MaskedCoupling(mask, cond_fn, AffineBijector) + else + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, model.K, model.tail_bound)) + end + x, ld = inverse_and_log_det(bj, x) + end + + lp = isnothing(lp) ? ld : lp .+ ld + + # Apply ReversePermute between MAF blocks + if model isa MaskedAutoregressiveFlow && i > 1 + x = x[end:-1:1, :] + end + end + base_lp = dsum(gaussian_logpdf.(x); dims=(1,)) + return isnothing(lp) ? base_lp : lp .+ base_lp +end + +""" + draw_samples(rng, T, model, ps, st, n_samples) -> Matrix + +Sample from the flow by pushing Gaussian noise through the forward transforms. +Supports RealNVP, NeuralSplineFlow, and MaskedAutoregressiveFlow. +""" +function draw_samples(rng::AbstractRNG, ::Type{T}, model::Union{RealNVP, NeuralSplineFlow, MaskedAutoregressiveFlow}, + ps, st, n_samples::Int) where T + x = randn(rng, T, model.dist_dims, n_samples) + for i in 1:(model.n_transforms) + ks = model isa MaskedAutoregressiveFlow ? keys(model.mades) : keys(model.conditioners) + k = ks[i] + + if model isa MaskedAutoregressiveFlow + bj = MAFBijector(model.mades[k], ps.mades[k], st.mades[k]) + x, _ = forward_and_log_det(bj, x) + else + mask = st.mask_list[i] + cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], + s = st.conditioners[k] + x_cond -> Lux.apply(m, x_cond, p, s)[1] + end + + bj = if model isa RealNVP + MaskedCoupling(mask, cond_fn, AffineBijector) + else + MaskedCoupling(mask, cond_fn, p -> NSFCouplingBijector_from_flat(p, model.K, model.tail_bound)) + end + + x, _ = forward_and_log_det(bj, x) + end + + # Apply ReversePermute between MAF blocks + if model isa MaskedAutoregressiveFlow && i < model.n_transforms + x = x[end:-1:1, :] + end + end + return x +end diff --git a/src/io.jl b/src/io.jl index 830930e..33a0d90 100644 --- a/src/io.jl +++ b/src/io.jl @@ -34,30 +34,63 @@ end # โ”€โ”€ Architecture dict โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ function _flow_to_dict(flow::FlowDistribution) - return Dict( - "architecture" => "RealNVP", + arch = if flow.model isa RealNVP + "RealNVP" + elseif flow.model isa NeuralSplineFlow + "NSF" + elseif flow.model isa MaskedAutoregressiveFlow + "MAF" + else + error("Unknown architecture type: $(typeof(flow.model))") + end + + d = Dict( + "architecture" => arch, "n_transforms" => flow.model.n_transforms, "dist_dims" => flow.n_dims, "hidden_layer_sizes" => flow.hidden_layer_sizes, "activation" => "gelu", ) + if flow.model isa NeuralSplineFlow + d["K"] = flow.model.K + d["tail_bound"] = flow.model.tail_bound + end + return d end function _build_flow_from_dict(d::AbstractDict, ::Type{T}=Float32, rng::AbstractRNG=Random.default_rng()) where {T<:Real} arch = d["architecture"] - arch == "RealNVP" || error("Unknown architecture: $arch") + arch_sym = if arch == "RealNVP" + :RealNVP + elseif arch == "NSF" + :NSF + elseif arch == "MAF" + :MAF + else + error("Unknown architecture: $arch") + end + # Support both new (hidden_layer_sizes) and legacy (hidden_dims + n_layers) formats hidden_layer_sizes = if haskey(d, "hidden_layer_sizes") Int.(d["hidden_layer_sizes"]) else fill(Int(d["hidden_dims"]), Int(d["n_layers"])) end - return FlowDistribution(T; - n_transforms = Int(d["n_transforms"]), - dist_dims = Int(d["dist_dims"]), - hidden_layer_sizes = hidden_layer_sizes, - rng, + + kwargs = Dict{Symbol, Any}( + :architecture => arch_sym, + :n_transforms => Int(d["n_transforms"]), + :dist_dims => Int(d["dist_dims"]), + :hidden_layer_sizes => hidden_layer_sizes, + :rng => rng, ) + + if arch == "NSF" + kwargs[:K] = Int(d["K"]) + kwargs[:tail_bound] = Float64(d["tail_bound"]) + end + + return FlowDistribution(T; kwargs...) end # โ”€โ”€ Public API โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @@ -118,7 +151,7 @@ function load_trained_flow(path::String; rng::AbstractRNG=Random.default_rng()) xmin = flat["normalizer_xmin"] xmax = flat["normalizer_xmax"] T_norm = eltype(xmin) - flow.normalizer = MinMaxNormalizer(xmin, xmax, sum(-log.(xmax .- xmin))) + flow.normalizer = MinMaxNormalizer{T_norm}(xmin, xmax, T_norm(sum(-log.(xmax .- xmin)))) delete!(flat, "normalizer_xmin") delete!(flat, "normalizer_xmax") end diff --git a/src/layers.jl b/src/layers.jl index 605052d..8e25d59 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -22,7 +22,7 @@ end function AffineBijector(params::AbstractArray) n = size(params, 1) รท 2 idx = ntuple(Returns(Colon()), ndims(params) - 1) - return AffineBijector(params[1:n, idx...], params[(n + 1):end, idx...]) + return @views AffineBijector(params[1:n, idx...], params[(n + 1):end, idx...]) end function forward_and_log_det(b::AffineBijector, x::AbstractArray) @@ -49,19 +49,49 @@ masked dimensions are transformed. bijector_constructor end -function _apply_mask(bj::MaskedCoupling, x::AbstractArray, transform_fn) - x_cond = x .* .!bj.mask # pass-through dims โ†’ conditioner input - params = bj.conditioner(x_cond) - y, log_det = transform_fn(params) - log_det = log_det .* bj.mask # only masked dims contribute to log|J| - y = ifelse.(bj.mask, y, x) # pass through unmasked dims unchanged - return y, dsum(log_det; dims=Tuple(1:(ndims(x) - 1))) +function _apply_mask(bj::MaskedCoupling, x::AbstractMatrix, transform_fn) + D, N = size(x) + m = bj.mask + + # 1. Conditioning + x_cond = x .* .!m + params = bj.conditioner(x_cond) + + # 2. Transform the active dims only + x_tr = x[m, :] + bj_inner = bj.bijector_constructor(params) + y_tr, ld_tr = transform_fn(bj_inner, x_tr) + + # 3. Reconstruct full y + y = _reconstruct(m, x, y_tr) + + # sum log-dets over the transformed dims + return y, dsum(ld_tr; dims=(1,)) +end + +function _reconstruct(m::AbstractArray{Bool}, x::AbstractMatrix, y_tr::AbstractMatrix) + y = similar(x) + y[.!m, :] .= x[.!m, :] + y[m, :] .= y_tr + return y +end + +function ChainRulesCore.rrule(::typeof(_reconstruct), m::AbstractArray{Bool}, x::AbstractMatrix, y_tr::AbstractMatrix) + y = _reconstruct(m, x, y_tr) + function _reconstruct_pullback(ฮ”y) + ฮ”x = similar(x) + ฮ”x[.!m, :] .= ฮ”y[.!m, :] + ฮ”x[m, :] .= 0 + ฮ”y_tr = ฮ”y[m, :] + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ฮ”x, ฮ”y_tr + end + return y, _reconstruct_pullback end function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray) - _apply_mask(bj, x, p -> forward_and_log_det(bj.bijector_constructor(p), x)) + _apply_mask(bj, x, forward_and_log_det) end function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray) - _apply_mask(bj, y, p -> inverse_and_log_det(bj.bijector_constructor(p), y)) + _apply_mask(bj, y, inverse_and_log_det) end diff --git a/src/made.jl b/src/made.jl new file mode 100644 index 0000000..0e82a54 --- /dev/null +++ b/src/made.jl @@ -0,0 +1,135 @@ +# src/made.jl +using Lux +using Random +using Statistics + +""" + create_degrees(input_dims::Int, hidden_dims::Vector{Int}, out_dims::Int) + +Generate degrees for each unit in a MADE network to ensure autoregressive property. +Returns a list of degree vectors (one for input, each hidden layer, and output). +""" +function create_degrees(input_dims::Int, hidden_dims::Vector{Int}, out_dims::Int; sequential::Bool=true) + # Degrees for input: 1 to D + degrees = [collect(1:input_dims)] + + # Degrees for hidden layers + for h_dim in hidden_dims + if sequential + # Use a deterministic pattern: repeat 1 to D-1 + d_hidden = [((i - 1) % (input_dims - 1)) + 1 for i in 1:h_dim] + push!(degrees, d_hidden) + else + # Sample hidden degrees in [1, D-1] + push!(degrees, rand(1:(input_dims-1), h_dim)) + end + end + + # Degrees for output: same as input (usually out_dims = input_dims or 2*input_dims) + # If out_dims is a multiple of input_dims, we repeat the degrees + if out_dims % input_dims == 0 + n_repeats = out_dims รท input_dims + push!(degrees, repeat(collect(1:input_dims), n_repeats)) + else + # Fallback for arbitrary output sizes (though usually it's D or 2D) + push!(degrees, collect(1:out_dims)) + end + + return degrees +end + +""" + create_masks(degrees::Vector{Vector{Int}}) + +Create binary masks from a list of degree vectors. +""" +function create_masks(degrees::Vector{Vector{Int}}) + masks = Matrix{Float32}[] + # Layer l goes from degrees[l] to degrees[l+1] + for i in 1:(length(degrees) - 1) + m_curr = degrees[i] + m_next = degrees[i+1] + + # Matrix shape is (length(m_next), length(m_curr)) + mask = zeros(Float32, length(m_next), length(m_curr)) + + for r in 1:length(m_next) + for c in 1:length(m_curr) + if i < length(degrees)-1 + # Hidden layers: weight exists if degree(prev) <= degree(curr) + if m_next[r] >= m_curr[c] + mask[r, c] = 1.0f0 + end + else + # Output layer: weight exists if degree(prev) < degree(curr) + # This ensures y_i depends only on x_{ m_curr[c] + mask[r, c] = 1.0f0 + end + end + end + end + push!(masks, mask) + end + return masks +end + +# โ”€โ”€ MaskedDense Layer โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +struct MaskedDense{F} <: Lux.AbstractLuxLayer + activation::F + in_dims::Int + out_dims::Int + mask::Matrix{Float32} +end + +function MaskedDense(in_dims::Int, out_dims::Int, mask::Matrix{Float32}, activation=identity) + return MaskedDense(activation, in_dims, out_dims, mask) +end + +function Lux.initialparameters(rng::AbstractRNG, l::MaskedDense) + return ( + weight = Lux.glorot_uniform(rng, l.out_dims, l.in_dims), + bias = zeros(Float32, l.out_dims) + ) +end + +function Lux.initialstates(rng::AbstractRNG, l::MaskedDense) + return NamedTuple() +end + +function (l::MaskedDense)(x::AbstractArray, ps, st) + # y = (W .* mask) * x + b + # We use ps.weight .* l.mask for differentiability + y = (ps.weight .* l.mask) * x .+ ps.bias + return l.activation.(y), st +end + +# โ”€โ”€ MADE Network โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +struct MADE <: Lux.AbstractLuxContainerLayer{(:layers,)} + layers::Lux.Chain + input_dims::Int + hidden_dims::Vector{Int} + out_dims::Int +end + +function MADE(input_dims::Int, hidden_dims::Vector{Int}, out_dims::Int; + activation=relu, sequential::Bool=true) + degrees = create_degrees(input_dims, hidden_dims, out_dims; sequential) + masks = create_masks(degrees) + + layers_list = [] + for i in 1:length(masks) + in_d = size(masks[i], 2) + out_d = size(masks[i], 1) + act = (i == length(masks)) ? identity : activation + push!(layers_list, MaskedDense(in_d, out_d, masks[i], act)) + end + + return MADE(Lux.Chain(layers_list...), input_dims, hidden_dims, out_dims) +end + +function (l::MADE)(x::AbstractArray, ps, st) + return Lux.apply(l.layers, x, ps.layers, st.layers) +end diff --git a/src/maf.jl b/src/maf.jl new file mode 100644 index 0000000..e4b67ca --- /dev/null +++ b/src/maf.jl @@ -0,0 +1,99 @@ +# src/maf.jl +using Lux +using Bijectors +using Random +using LinearAlgebra + +# โ”€โ”€ MAF Bijector โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +struct MAFBijector + made + ps + st +end + +""" + inverse_and_log_det(b::MAFBijector, x::AbstractArray) + +Density estimation pass (Inverse): u = (x - m) * exp(-alpha). Fast O(1). +""" +function Bijectors.with_logabsdet_jacobian(b::MAFBijector, x::AbstractMatrix) + # MADE(x) returns (m, log_alpha) concatenated + out, _ = Lux.apply(b.made, x, b.ps, b.st) + D = size(x, 1) + m = out[1:D, :] + log_alpha = out[D+1:end, :] + + # Inverse transform + u = (x .- m) .* exp.(-log_alpha) + + # Log-determinant: sum of -log_alpha across dimensions + lad = -sum(log_alpha; dims=1) + + # Return as (output, logabsdet) + return u, vec(lad) +end + +""" + forward_and_log_det(b::MAFBijector, u::AbstractArray) + +Sampling pass (Forward): x_i = u_i * exp(alpha_i) + m_i. Sequential O(D). +""" +function forward_and_log_det(b::MAFBijector, u::AbstractMatrix) + D, N = size(u) + T = eltype(u) + + # Initialize x with zeros. We will fill it dimension by dimension. + # To be Zygote friendly, we can use a loop and vcat/hcat or just use a copy if allowed. + # Actually, sampling is usually not the target of AD during training, but for completeness: + + x = zeros(T, D, N) + + for i in 1:D + # Compute parameters for the current state of x + out, _ = Lux.apply(b.made, x, b.ps, b.st) + m_i = out[i, :] + log_alpha_i = out[D+i, :] + + # Update row i of x + # x[i, :] = u[i, :] .* exp.(log_alpha_i) .+ m_i + # Non-mutating version for Zygote: + new_row = u[i:i, :] .* exp.(log_alpha_i') .+ m_i' + x = vcat(x[1:i-1, :], new_row, x[i+1:end, :]) + end + + # Compute final log_alpha for the fully constructed x to get log_det + out_final, _ = Lux.apply(b.made, x, b.ps, b.st) + log_alpha_final = out_final[D+1:end, :] + lad = sum(log_alpha_final; dims=1) + + return x, vec(lad) +end + +# โ”€โ”€ ReversePermute โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +struct ReversePermute <: Lux.AbstractLuxLayer end + +function (l::ReversePermute)(x::AbstractArray, ps, st) + return x[end:-1:1, :], st +end + +# โ”€โ”€ MaskedAutoregressiveFlow โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +@concrete struct MaskedAutoregressiveFlow <: Lux.AbstractLuxContainerLayer{(:mades,)} + mades + dist_dims::Int + n_transforms::Int + hidden_layer_sizes::Vector{Int} +end + +function MaskedAutoregressiveFlow(; n_transforms::Int, dist_dims::Int, hidden_layer_sizes::Vector{Int}, activation=relu) + mades_list = [MADE(dist_dims, hidden_layer_sizes, 2 * dist_dims; activation) for _ in 1:n_transforms] + keys_ = ntuple(i -> Symbol(:made_, i), n_transforms) + mades = NamedTuple{keys_}(Tuple(mades_list)) + return MaskedAutoregressiveFlow(mades, dist_dims, n_transforms, hidden_layer_sizes) +end + +function Lux.initialstates(rng::AbstractRNG, m::MaskedAutoregressiveFlow) + return (mades = Lux.initialstates(rng, m.mades),) +end diff --git a/src/normalizer.jl b/src/normalizer.jl index 43e1ba9..d34c56a 100644 --- a/src/normalizer.jl +++ b/src/normalizer.jl @@ -21,11 +21,17 @@ end Fit a `MinMaxNormalizer` from training data (shape `n_dims ร— n_samples`). """ -function MinMaxNormalizer(data::AbstractMatrix{T}) where T - xmin = vec(minimum(data; dims=2)) - xmax = vec(maximum(data; dims=2)) - log_jac = sum(-log.(xmax .- xmin)) - return MinMaxNormalizer{T}(xmin, xmax, log_jac) +function MinMaxNormalizer(x::AbstractMatrix{T}) where {T} + x_min = vec(minimum(x, dims=2)) + x_max = vec(maximum(x, dims=2)) + + if any(x_min .โ‰ˆ x_max) + throw(ArgumentError("Data has zero variance along one or more dimensions. Cannot initialize MinMaxNormalizer.")) + end + + # Base volume change: sum(-log(x_max - x_min)) + logabsdet = sum(-log.(x_max .- x_min)) + return MinMaxNormalizer{T}(x_min, x_max, logabsdet) end """ @@ -33,19 +39,23 @@ end Apply the forward transform: `z = (x - x_min) / (x_max - x_min)`. """ -normalize(n::MinMaxNormalizer, x::AbstractMatrix) = - (x .- n.x_min) ./ (n.x_max .- n.x_min) +function normalize(n::MinMaxNormalizer, x::AbstractMatrix) + return (x .- n.x_min) ./ (n.x_max .- n.x_min) +end -normalize(n::MinMaxNormalizer, x::AbstractVector) = - (x .- n.x_min) ./ (n.x_max .- n.x_min) +function normalize(n::MinMaxNormalizer, x::AbstractVector) + return (x .- n.x_min) ./ (n.x_max .- n.x_min) +end """ denormalize(n::MinMaxNormalizer, z) -> x Apply the inverse transform: `x = z * (x_max - x_min) + x_min`. """ -denormalize(n::MinMaxNormalizer, z::AbstractMatrix) = - z .* (n.x_max .- n.x_min) .+ n.x_min +function denormalize(n::MinMaxNormalizer, z::AbstractMatrix) + return z .* (n.x_max .- n.x_min) .+ n.x_min +end -denormalize(n::MinMaxNormalizer, z::AbstractVector) = - z .* (n.x_max .- n.x_min) .+ n.x_min +function denormalize(n::MinMaxNormalizer, z::AbstractVector) + return z .* (n.x_max .- n.x_min) .+ n.x_min +end diff --git a/src/nsf.jl b/src/nsf.jl new file mode 100644 index 0000000..78f6b31 --- /dev/null +++ b/src/nsf.jl @@ -0,0 +1,134 @@ +# src/nsf.jl +using Lux +using Bijectors +using Random +using LinearAlgebra + +""" + NSFSplineBijector(mask, params, K, tail_bound) + +A Neural Spline Flow (NSF) coupling layer using Rational Quadratic Splines. +`mask` is a binary vector (1 for variables to transform, 0 for variables that condition). +`params` is the output of the conditioner for the transformed dimensions. +""" +# Bijectors and Layers for NSF +struct NSFSplineBijector + params + K::Int + tail_bound::Float64 +end + +function forward_and_log_det(b::NSFSplineBijector, x::AbstractArray) + # x is (D_tr, N) + D_tr, N = size(x) + K = b.K + tail_bound = b.tail_bound + params = b.params + + # Reshape params to (D_tr, 3K-1, N) + params = reshape(params, D_tr, 3*K - 1, N) + + # Partition params + w_unnorm = @view params[:, 1:K, :] + h_unnorm = @view params[:, K+1:2*K, :] + dv_unnorm = @view params[:, 2*K+1:end, :] + + # Flatten everything to call the spline function + x_flat = vec(x) + w_flat = reshape(permutedims(w_unnorm, (1, 3, 2)), D_tr * N, K) + h_flat = reshape(permutedims(h_unnorm, (1, 3, 2)), D_tr * N, K) + dv_flat = reshape(permutedims(dv_unnorm, (1, 3, 2)), D_tr * N, K-1) + + y_flat, lad_flat = unconstrained_rational_quadratic_spline( + x_flat, w_flat, h_flat, dv_flat, eltype(x)(tail_bound) + ) + + y = reshape(y_flat, D_tr, N) + lad = reshape(lad_flat, D_tr, N) + + return y, lad +end + +function inverse_and_log_det(b::NSFSplineBijector, y::AbstractArray) + # y is (D_tr, N) + D_tr, N = size(y) + K = b.K + tail_bound = b.tail_bound + params = b.params + + params = reshape(params, D_tr, 3*K - 1, N) + w_unnorm = @view params[:, 1:K, :] + h_unnorm = @view params[:, K+1:2*K, :] + dv_unnorm = @view params[:, 2*K+1:end, :] + + y_flat = vec(y) + w_flat = reshape(permutedims(w_unnorm, (1, 3, 2)), D_tr * N, K) + h_flat = reshape(permutedims(h_unnorm, (1, 3, 2)), D_tr * N, K) + dv_flat = reshape(permutedims(dv_unnorm, (1, 3, 2)), D_tr * N, K-1) + + x_flat, lad_flat = unconstrained_rational_quadratic_spline( + y_flat, w_flat, h_flat, dv_flat, eltype(y)(tail_bound); + inverse=true + ) + + x = reshape(x_flat, D_tr, N) + lad = reshape(lad_flat, D_tr, N) + + return x, lad +end + +function NSFCouplingBijector_from_flat(params, K, tail_bound) + return NSFSplineBijector(params, K, tail_bound) +end + +""" + NeuralSplineFlow(; n_transforms, dist_dims, hidden_dims, n_layers, K=8, tail_bound=3.0, activation=gelu) + +Neural Spline Flow (NSF) with rational quadratic coupling layers. +""" +@concrete struct NeuralSplineFlow <: Lux.AbstractLuxContainerLayer{(:conditioners,)} + conditioners + mask_list :: Vector{BitVector} + dist_dims :: Int + n_transforms :: Int + hidden_layer_sizes :: Vector{Int} + K :: Int + tail_bound :: Float64 +end + +function NeuralSplineFlow(; n_transforms::Int, dist_dims::Int, + hidden_layer_sizes::Vector{Int}, K=8, tail_bound=3.0, activation=gelu) + D = dist_dims + + # Pre-generate masks to know out_dims per layer + mask_list = [BitVector(collect(1:D) .% 2 .== i % 2) + for i in 1:n_transforms] + + mlps = [] + for i in 1:n_transforms + m = mask_list[i] + D_tr = sum(m) + out_dims = D_tr * (3*K - 1) + push!(mlps, MLP(D, hidden_layer_sizes, out_dims; activation)) + end + + keys_ = ntuple(i -> Symbol(:conditioners_, i), n_transforms) + conditioners = NamedTuple{keys_}(Tuple(mlps)) + return NeuralSplineFlow(conditioners, mask_list, D, n_transforms, hidden_layer_sizes, K, Float64(tail_bound)) +end + +function Lux.initialstates(rng::AbstractRNG, m::NeuralSplineFlow) + return (; mask_list=m.mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) +end + + +# Helper to build the bijector from flat parameters +struct NSFSplineConstructor + mask + K::Int + tail_bound::Float64 +end + +function (c::NSFSplineConstructor)(params) + return NSFSplineBijector(c.mask, params, c.K, c.tail_bound) +end diff --git a/src/realnvp.jl b/src/realnvp.jl index 1c4fac1..96cca68 100644 --- a/src/realnvp.jl +++ b/src/realnvp.jl @@ -21,6 +21,7 @@ Masks alternate between even/odd dimensions. """ @concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)} conditioners + mask_list :: Vector{BitVector} dist_dims :: Int n_transforms :: Int hidden_layer_sizes :: Vector{Int} @@ -28,60 +29,24 @@ end function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_layer_sizes::Vector{Int}, activation=gelu) - mlps = [MLP(dist_dims, hidden_layer_sizes, 2 * dist_dims; activation) - for _ in 1:n_transforms] + D = dist_dims + + # Pre-generate masks + mask_list = [BitVector(collect(1:D) .% 2 .== i % 2) + for i in 1:n_transforms] + + mlps = [] + for i in 1:n_transforms + m = mask_list[i] + D_tr = sum(m) + push!(mlps, MLP(D, hidden_layer_sizes, 2 * D_tr; activation)) + end + keys_ = ntuple(i -> Symbol(:conditioners_, i), n_transforms) conditioners = NamedTuple{keys_}(Tuple(mlps)) - return RealNVP(conditioners, dist_dims, n_transforms, hidden_layer_sizes) + return RealNVP(conditioners, mask_list, D, n_transforms, hidden_layer_sizes) end function Lux.initialstates(rng::AbstractRNG, m::RealNVP) - mask_list = [Bool.(collect(1:(m.dist_dims)) .% 2 .== i % 2) - for i in 1:(m.n_transforms)] - return (; mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) -end - -# โ”€โ”€ Pure-functional log-prob and sample โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - -""" - log_prob(model, ps, st, x) -> Vector - -Compute per-sample log-probability of `x` (shape: dist_dims ร— batch) under the flow. -Pure functional โ€” no mutations, safe for Zygote. -""" -function log_prob(model::RealNVP, ps, st, x::AbstractMatrix) - lp = nothing - for i in model.n_transforms:-1:1 - k = keys(model.conditioners)[i] - mask = st.mask_list[i] - cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], - s = st.conditioners[k] - x_cond -> Lux.apply(m, x_cond, p, s)[1] - end - bj = MaskedCoupling(mask, cond_fn, AffineBijector) - x, ld = inverse_and_log_det(bj, x) - lp = isnothing(lp) ? ld : lp .+ ld - end - base_lp = dsum(gaussian_logpdf.(x); dims=(1,)) - return isnothing(lp) ? base_lp : lp .+ base_lp -end - -""" - draw_samples(rng, T, model, ps, st, n_samples) -> Matrix - -Sample from the flow by pushing Gaussian noise through the forward transforms. -""" -function draw_samples(rng::AbstractRNG, ::Type{T}, model::RealNVP, - ps, st, n_samples::Int) where T - x = randn(rng, T, model.dist_dims, n_samples) - for i in 1:(model.n_transforms) - k = keys(model.conditioners)[i] - cond_fn = let m = model.conditioners[k], p = ps.conditioners[k], - s = st.conditioners[k] - x_cond -> Lux.apply(m, x_cond, p, s)[1] - end - bj = MaskedCoupling(st.mask_list[i], cond_fn, AffineBijector) - x, _ = forward_and_log_det(bj, x) - end - return x + return (; mask_list=m.mask_list, conditioners=Lux.initialstates(rng, m.conditioners)) end diff --git a/src/splines.jl b/src/splines.jl new file mode 100644 index 0000000..b8c12e2 --- /dev/null +++ b/src/splines.jl @@ -0,0 +1,154 @@ +# src/splines.jl +using NNlib +using ChainRulesCore +using ForwardDiff + +# Helper to find bins. Zygote ignores gradients for indices. +function compute_bin_idx(cum_arrays::AbstractMatrix{T}, inputs::AbstractVector{T}, K::Int) where {T<:Real} + M = length(inputs) + bin_idx = zeros(Int, M) + for i in 1:M + # Handle ForwardDiff values properly if needed, but searchsortedlast supports Duals natively + idx = searchsortedlast(@view(cum_arrays[i, :]), inputs[i]) + bin_idx[i] = clamp(idx, 1, K) + end + return bin_idx +end +ChainRulesCore.@non_differentiable compute_bin_idx(Any...) + +""" + unconstrained_rational_quadratic_spline(inputs, widths, heights, derivs; kwargs...) + +Inputs: +- `inputs`: Vector of length M. +- `unnormalized_widths`: Matrix of shape (M, K). +- `unnormalized_heights`: Matrix of shape (M, K). +- `unnormalized_derivatives`: Matrix of shape (M, K-1). + +All inputs are flattened. Returns `(outputs, logabsdet)`. +""" +function unconstrained_rational_quadratic_spline( + inputs::AbstractVector{T_in}, + unnormalized_widths::AbstractMatrix{T_w}, + unnormalized_heights::AbstractMatrix{T_h}, + unnormalized_derivatives::AbstractMatrix{T_dv}, + tail_bound=3.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3; + inverse::Bool=false +) where {T_in<:Real, T_w<:Real, T_h<:Real, T_dv<:Real} + # Operations will be performed in the promoted type + T = promote_type(T_in, T_w, T_h, T_dv) + + # Cast everything to the consistency type T + inputs = T.(inputs) + unnormalized_widths = T.(unnormalized_widths) + unnormalized_heights = T.(unnormalized_heights) + unnormalized_derivatives = T.(unnormalized_derivatives) + + # Convert keyword arguments to T + tail_bound = T(tail_bound) + min_bin_width = T(min_bin_width) + min_bin_height = T(min_bin_height) + min_derivative = T(min_derivative) + + M, K = size(unnormalized_widths) + + # 1. Pad derivatives with constant for linear tails + constant = T(log(exp(1 - min_derivative) - 1)) + pad_c = fill(constant, M) + unnorm_derivs = hcat(pad_c, unnormalized_derivatives, pad_c) # (M, K+1) + + # 2. Extract valid interior regions (using Mask to avoid mutating variables for Zygote) + inside_mask = (inputs .>= -tail_bound) .& (inputs .<= tail_bound) + + # We will compute the spline for ALL points, but only conditionally select the output + # This avoids Zygote mutating array issues, relying on `ifelse`. + # To avoid NaNs or domain errors, we clamp the inputs outside the tail bound. + clamped_inputs = clamp.(inputs, -tail_bound, tail_bound) + + # 3. Compute normalize bin parameters + widths_raw = NNlib.softmax(unnormalized_widths; dims=2) + widths = min_bin_width .+ (1 - min_bin_width * K) .* widths_raw + + cumwidths_raw = cumsum(widths; dims=2) + # Scale to (0, 1) then to (-tail_bound, tail_bound) + # We use hcat to ensure exact boundaries + interior_cumwidths = (2 * tail_bound) .* (@view cumwidths_raw[:, 1:end-1]) .- tail_bound + cumwidths = hcat(fill(-tail_bound, M), interior_cumwidths, fill(tail_bound, M)) + widths = cumwidths[:, 2:end] .- cumwidths[:, 1:end-1] + + derivatives = min_derivative .+ NNlib.softplus.(unnorm_derivs) + + heights_raw = NNlib.softmax(unnormalized_heights; dims=2) + heights = min_bin_height .+ (1 - min_bin_height * K) .* heights_raw + + cumheights_raw = cumsum(heights; dims=2) + interior_cumheights = (2 * tail_bound) .* (@view cumheights_raw[:, 1:end-1]) .- tail_bound + cumheights = hcat(fill(-tail_bound, M), interior_cumheights, fill(tail_bound, M)) + heights = cumheights[:, 2:end] .- cumheights[:, 1:end-1] + + # 4. Find the appropriate bin + if inverse + bin_idx = compute_bin_idx(cumheights, clamped_inputs, K) + else + bin_idx = compute_bin_idx(cumwidths, clamped_inputs, K) + end + + # 5. Gather bin specific parameters + # Zygote friendly gather using linear indexing + # cumwidths is (M, K+1) + # widths is (M, K) + linear_indices_k_plus_1 = (bin_idx .- 1) .* M .+ (1:M) + linear_indices_k = (bin_idx .- 1) .* M .+ (1:M) + + input_cumwidths = cumwidths[linear_indices_k_plus_1] + input_bin_widths = widths[linear_indices_k] + + input_cumheights = cumheights[linear_indices_k_plus_1] + input_heights = heights[linear_indices_k] + + delta = heights ./ widths + input_delta = delta[linear_indices_k] + + input_derivatives = derivatives[linear_indices_k_plus_1] + input_derivatives_plus_one = derivatives[linear_indices_k_plus_1 .+ M] + + # 6. Evaluate spline equations + if inverse + a = (clamped_inputs .- input_cumheights) .* (input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) .+ input_heights .* (input_delta .- input_derivatives) + b = input_heights .* input_derivatives .- (clamped_inputs .- input_cumheights) .* (input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) + c = -input_delta .* (clamped_inputs .- input_cumheights) + + discriminant = b.^2 .- 4 .* a .* c + # Numerical stability: splines are monotonic so discriminant >= 0. + # We use abs and a tiny epsilon for square root gradients. + root = (2 .* c) ./ (-b .- sqrt.(abs.(discriminant) .+ T(1e-12))) + + rq_outputs = root .* input_bin_widths .+ input_cumwidths + + theta_one_minus_theta = root .* (1 .- root) + denominator = input_delta .+ ((input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) .* theta_one_minus_theta) + derivative_numerator = (input_delta.^2) .* (input_derivatives_plus_one .* root.^2 .+ 2 .* input_delta .* theta_one_minus_theta .+ input_derivatives .* (1 .- root).^2) + + rq_logabsdet = log.(abs.(derivative_numerator) .+ T(1e-12)) .- 2 .* log.(abs.(denominator) .+ T(1e-12)) + rq_logabsdet = -rq_logabsdet + else + theta = (clamped_inputs .- input_cumwidths) ./ input_bin_widths + theta_one_minus_theta = theta .* (1 .- theta) + + numerator = input_heights .* (input_delta .* theta.^2 .+ input_derivatives .* theta_one_minus_theta) + denominator = input_delta .+ ((input_derivatives .+ input_derivatives_plus_one .- 2 .* input_delta) .* theta_one_minus_theta) + rq_outputs = input_cumheights .+ numerator ./ denominator + + derivative_numerator = (input_delta.^2) .* (input_derivatives_plus_one .* theta.^2 .+ 2 .* input_delta .* theta_one_minus_theta .+ input_derivatives .* (1 .- theta).^2) + rq_logabsdet = log.(abs.(derivative_numerator) .+ T(1e-12)) .- 2 .* log.(abs.(denominator) .+ T(1e-12)) + end + + # 7. Apply identity outside bounds + outputs = ifelse.(inside_mask, rq_outputs, inputs) + logabsdet = ifelse.(inside_mask, rq_logabsdet, zero(T)) + + return outputs, logabsdet +end diff --git a/src/training.jl b/src/training.jl index c4b55f5..f32906d 100644 --- a/src/training.jl +++ b/src/training.jl @@ -16,17 +16,24 @@ negative log-likelihood on the normalised data. """ function train_flow!(flow::FlowDistribution{T}, data::AbstractMatrix; n_epochs::Int=1000, - lr::Real=T(1f-3), + lr::Union{Nothing, Real}=nothing, batch_size::Int=256, - verbose::Bool=true) where {T} - # Always fit and apply a min-max normalizer + verbose::Bool=true, + opt=nothing) where {T} + # 5. Fit & Attach Normalizer (on the first batch or full data) + # We fit it on the full training data for simplicity flow.normalizer = MinMaxNormalizer(T.(data)) - data_T = normalize(flow.normalizer, T.(data)) + data_norm = SimpleFlows.normalize(flow.normalizer, data) - opt = Optimisers.OptimiserChain(Optimisers.ClipGrad(T(1)), Optimisers.Adam(T(lr))) - opt_state = Optimisers.setup(opt, flow.ps) + actual_opt = if isnothing(opt) + actual_lr = isnothing(lr) ? T(1f-3) : T(lr) + Optimisers.OptimiserChain(Optimisers.ClipGrad(T(1)), Optimisers.Adam(actual_lr)) + else + opt + end + opt_state = Optimisers.setup(actual_opt, flow.ps) - loader = DataLoader(data_T; batchsize=batch_size, shuffle=true) + loader = DataLoader(data_norm; batchsize=batch_size, shuffle=true) for epoch in 1:n_epochs total_loss = zero(T) diff --git a/test/data/nsf_test_data.json b/test/data/nsf_test_data.json new file mode 100644 index 0000000..f1eeb64 --- /dev/null +++ b/test/data/nsf_test_data.json @@ -0,0 +1,733 @@ +{ + "inputs": { + "data": [ + 0.2995562877502637, + 0.4285173645433462, + -0.03720300727052493, + -0.9848309379659244, + 0.24237779080894106, + 0.2318702888580676, + 1.0352472972512925, + 0.4573409128131671, + 0.28320739441795323, + -1.0214571469180447, + -0.010085521702337353, + -1.990778610968323, + -0.23294220898382634, + -0.8272632292381021, + 0.08948016046560357, + -0.2515086919419446, + 0.6711532211923003, + 1.7342104609172728, + -1.2174058025747443, + -1.0221641302789066, + 0.7817673742546976, + -1.4920727283504345, + -0.44460513945120644, + 0.14417369704825783, + -1.7903080050220619, + 0.4212009947066779, + -0.26711077156188034, + 1.5695068008576476, + -1.748402801489203, + 1.504572275678796, + -0.8118406564041991, + -1.1983222673785627, + 0.17347640142361004, + 1.2715909667771088, + 2.3464758864728754, + 0.005247036919813949, + 0.26684314514213736, + 0.618669993885617, + -1.5168508562286191, + -0.004495386168535826 + ], + "shape": [ + 4, + 10 + ] + }, + "unnormalized_widths": { + "data": [ + -0.4419745502775611, + 0.2716846620032497, + 1.056891815220453, + -1.766660462065303, + 0.9862554128547407, + -0.5671097889774945, + -0.236684969212557, + 0.006353391113712342, + -0.9123003462772722, + -0.6303061763998286, + -0.39855088265346844, + 0.6913826362365914, + -0.2669938796469903, + 0.6572105842025696, + -1.6658941949653996, + -1.9141455779164431, + 0.9987723890648004, + 0.7339018277109903, + 0.29946905473477337, + -1.0776595577849437, + -0.03349665465339956, + -0.6908995915989536, + 1.2349001065104936, + 0.24146450013043255, + 0.5416192945872663, + -0.6946378083376783, + -0.7430063996503045, + -0.1178562026752313, + 1.5611147357425514, + -0.7144626307394635, + 0.4653856846834471, + 0.9340946495293597, + -0.0536259352810602, + 1.024893906047934, + -1.6038248406643751, + -1.8998197674219859, + 0.8604090757781153, + -0.4640699001641785, + -0.44543010608126465, + -0.1733949212480647, + -0.9017421148082815, + 0.007740374152002723, + -0.5570592823399676, + 1.4743616877386294, + -0.7171640736861331, + 0.20098777565954465, + 0.2677046171271731, + 2.5516359902155408, + -0.4752797645451217, + 0.8235071239868901, + -0.06274129239437512, + -0.6671725645531117, + -0.5656502203844694, + 0.4981374432781232, + -1.7052945505852435, + 0.2850774357972623, + 1.419073672127401, + 1.5143278017369346, + 0.1203101077666772, + 0.2689397009780952, + -1.296892295720691, + 0.18009974321625744, + -0.8593183061725375, + 1.0955977195397684, + -0.8269759713539973, + -1.679818997288417, + 0.13591860618347618, + 0.5917908476437124, + -0.28742284714076166, + 0.6587220194553163, + 0.2709020500524623, + -1.149354106366614, + -0.27339532971884056, + 0.05755031834255014, + 0.8500577134471289, + 0.9478581977287968, + -2.2564392962632804, + 0.3110659164171673, + 0.06194949970627075, + 1.2335858050609876, + -1.0639335147664135, + 1.1537781680086128, + 1.4596822474021245, + 1.5373855077150955, + 1.386220003799335, + 0.24116311123514478, + 1.6076780385623348, + 1.7935701786527096, + -0.13499022826985557, + -0.25067215098463164, + 1.1518687500322395, + -0.16649431697299996, + -0.3651179239279022, + 0.7065959337431358, + -1.567067835424215, + -0.3357557133568141, + 0.05654075352457806, + 1.4204192411699395, + 0.6755821599239246, + -0.9357700117370953, + 0.10049062711321136, + 0.6228199266745641, + -0.4495829648314594, + -0.914993910952663, + -0.14374993267103833, + -1.203808008890236, + 0.07285694107615862, + 2.0141519804280326, + -0.7930321767525238, + -0.477376582450768, + -0.7873047578273823, + -0.07195237845460918, + 0.8247512770846365, + 0.1908921136519597, + -1.4979151343522192, + -0.5774972459244567, + -0.27133546200730857, + 0.9813962548064747, + 0.7390035236473482, + 1.3917635799868133, + 0.6491852100764046, + 1.2039325588275476, + 0.9874814278993451, + 0.7542670407256892, + 0.5781475806410152, + 1.0836356031788037, + -0.6390685802515629, + -0.48981275707449656, + 1.2623218276299724, + -0.27964434459183907, + 0.55464291590441, + 0.6308280535150428, + -0.6187129493373275, + -1.81710056168182, + 0.12247935993756598, + 0.7788349259051482, + -0.17017167872263755, + 0.043301694413358134, + -2.089678394322878, + -0.9289772719574583, + -0.8074523880762338, + 0.39794575558975936, + 0.7354513501370189, + -1.236806983856194, + -0.35701990837483416, + -0.6836642340698702, + -0.17214669458374612, + -0.4073494437665049, + 0.7496549362718903, + 0.9152466532285749, + 1.292950920601979, + 0.09414818556627164, + -0.38366959466913275, + -1.2609592964755483, + 1.0775925750761937, + 0.6273242190434466, + 1.5303753088083207, + -1.3500142617068518, + 0.542614186666463, + -0.9181737147627111, + -0.06381394399314673, + 0.9978706959761301, + 1.622897663758754, + -0.16614605841589397, + 0.34666368467190317, + -0.9591633938060772, + -1.6966150893683734, + -1.2814351976260474, + -0.26992650119900263, + 0.10762582501581724, + -1.2859520744264117, + 1.5023826600840733, + 0.05731628365819024, + -1.4611815421227203, + 0.11844701357721124, + -1.7097420738232683, + -0.6481962873525597, + -0.6199302841848613, + -0.041677237819211586, + -1.912936627272994, + -0.22069702937409055, + -0.8373282020867259, + -1.2102204195821435, + 1.1367489812535863, + 2.147097339967129, + -0.39588656758208457, + 0.028256106155769695, + 0.4780347016744464, + -0.7879795594263654, + -0.4819936051666644, + 0.2945839162465962, + -0.23035924547300898, + 0.3179810499974549, + 0.6885676803300341, + -1.3749595715279772, + 0.08511336387296158, + 1.2413785514414426, + -0.16773829248295685, + -0.6272882310444161, + 0.8067649829193638 + ], + "shape": [ + 4, + 10, + 5 + ] + }, + "unnormalized_heights": { + "data": [ + -0.20314173587596263, + 0.41098923499226014, + 1.0423699879551078, + 1.0111021595209633, + -0.4005606761098726, + -1.5400573604406471, + -0.27680484002022715, + 0.22713168745929768, + 0.6772607869261693, + 0.17859923880951847, + 0.8068577169597195, + -0.007188030468140175, + 1.498700405481595, + -0.03019556689050345, + 0.36384746366328413, + 0.5946658094550371, + 0.8488782882116401, + -1.4076687315200758, + -0.6554716217755595, + -0.031424129829246916, + 0.002119877636176426, + -0.9320608828189856, + -0.4836971414215709, + -0.6652501624993471, + 0.6929251584566181, + 0.09402273928088646, + 1.6129488803273395, + -0.6274502774485388, + -0.6354031061312012, + -0.11014269025953972, + -0.5707540620409073, + 0.2782765045840637, + -0.131168025357405, + -0.014621424093204098, + -0.9264218514606051, + 0.4745604420052093, + -0.8170580592759659, + 1.090316288383937, + -0.4324600332331524, + 1.335598774236891, + 0.19522143581405244, + 0.7739177723108309, + 1.5545069571528942, + -1.1665892611773085, + -1.3514084083199862, + 1.548288167937, + 0.1962794539558957, + -0.5133236613036435, + 0.07380298739515019, + -0.6768899292657692, + 0.29925023340409146, + -0.296150919949663, + 1.797754234791071, + -0.2256181742270194, + 0.7611949933670529, + -0.4752871063125886, + 0.3315215039285197, + 1.029960808199997, + 0.6816100143946399, + 0.6027021077218985, + -1.0103642468106544, + 1.8110431275164358, + -1.242898069994524, + -0.018728729244763993, + 0.3435160724199721, + -1.4671607394596426, + -0.19175241722939693, + -0.6020537188199624, + -1.2173643223619721, + -1.3293671463788634, + 0.6536593231539053, + 1.3132412975399521, + -0.03123667988254042, + 0.6955788134794414, + 0.14788650580579352, + -0.7267532390860204, + 0.11041826907812055, + -0.9430074874135619, + -0.024493143878920487, + -0.48165316592257534, + 0.24917168366560993, + -1.2762292342683303, + 1.5849730687530486, + -0.7991803895400539, + -0.9017993731080081, + 0.9288465316923816, + 0.8731470189472956, + -0.5447919995948828, + 1.9920454844101763, + 0.8504258697467754, + 0.8239191817788585, + -0.7591698936771379, + 0.23182973839731683, + 0.19368409881631005, + -0.8829189545685341, + 1.1838674589073483, + 0.12891223912089034, + 2.177326020841736, + -0.29262366803238393, + -0.1269124796513161, + -1.1195878651441893, + -1.4858191479600487, + 0.2730005353024221, + 0.056387506785403815, + 0.23274595474814475, + -1.219750082862799, + 0.6164470190315824, + 1.313122245567395, + 1.231436269706581, + 0.7101787049919178, + -0.29432626362033404, + 0.09290205186381154, + 0.5124772757514626, + 0.5722528182048355, + 0.0069296369726892315, + 0.05265567712992814, + 0.024119499618272, + -0.1734020738525506, + 1.980986216379721, + -0.3625440782730144, + -0.46610241187034174, + -0.1434065696905791, + 0.7199388940089966, + 0.18661178764402744, + -1.6874596492942056, + 0.08390244767011035, + 0.010205327986369436, + -1.4490806809539707, + -0.33044907219753716, + 0.9425485271043043, + 1.0202696745617659, + 1.1809578134345797, + -0.4778501888006765, + 0.5058710736668895, + 0.26915516555170543, + -0.48408003859590065, + -1.2238040021173409, + 0.4365480813048655, + 0.9362614654567366, + -0.132580840204314, + -1.2632446956429335, + 0.7389689210845545, + 0.11501735450468505, + -0.48921141096064424, + -1.2546554715250677, + -1.3274229125145334, + -0.20570447804044192, + -0.14388808008330586, + -0.12220432019599553, + -0.01670896018294231, + -0.47610390457226875, + -0.2426945946706641, + 0.8173767150222387, + -1.481720319316519, + 0.3306231644074124, + -0.22060606606267064, + -0.9954168307888536, + 0.9803566659279874, + -0.2658724362020457, + 0.238746989349253, + -0.9029006394421663, + 0.199025684270468, + 2.6380786365181796, + 0.8742346952912949, + 0.8505308967944145, + 0.3572785103117973, + -0.8838416943783096, + -0.5830552203885363, + -0.39032216747009113, + 1.8547476587140206, + 1.3346481596559676, + -0.14668266051521825, + 0.29946965798993297, + 0.1038460098414191, + -0.09599439254200862, + -0.34832623473704294, + -0.6243273110799229, + -0.42153332375097163, + 0.1418089057088939, + 0.8622146933847438, + 1.452604901391212, + 0.6644577289945207, + 1.0157942265133257, + -0.7306000207552646, + 0.8387576576827396, + 0.6099936154303692, + 0.42155963063098567, + 0.9953749290985876, + -1.5066610314377438, + 0.43343311885236446, + -1.3121219556703199, + -0.7507398495238957, + -0.2772483620390129, + -0.15323533618009222, + 0.8545701717376761, + -0.952862648290038, + -0.7025528681323762, + 0.0627858851186502, + 0.011103897250099562, + 0.44470601078973404 + ], + "shape": [ + 4, + 10, + 5 + ] + }, + "unnormalized_derivatives": { + "data": [ + -0.5073466627133774, + 0.4640089663510305, + 0.42905396760192394, + -0.024940177091060257, + 0.2037444999220601, + 1.3825460700799999, + -0.18594279184343007, + -0.13199258968708819, + 0.95499847971227, + 0.36061593000733744, + 1.5299017471873564, + 0.1682232955566942, + -0.3435279013303228, + -1.7971615354714086, + 0.4693013003698667, + 1.493186440171208, + 0.12356938192583447, + -0.17905683036898526, + -0.6760364508503255, + 0.393978576087819, + 0.64949409980517, + -1.1431268808798611, + 0.40178181976552013, + 1.5351166370632565, + 0.8854633803752585, + 0.8200277372952126, + 0.04080196720634493, + 1.7491823889925064, + -0.012391613968610236, + 1.1152018043934078, + 0.6179595164022554, + -0.2567371630523861, + -1.676382260579765, + -0.23626544563697244, + 2.779642620636317, + 1.3711024513262018, + 1.6243002171198435, + 0.023223437469212883, + 0.3333688295519079, + 1.776099742658068, + 0.23740050210541427, + -0.10972801634299746, + -0.4166779124774877, + -0.5503437577515686, + 0.7109974741886405, + -0.3132870217495473, + 0.20474610893599518, + 0.529929231024398, + 0.7392324154597758, + -1.289651880858446, + -0.5846732948053854, + -0.2975140552555919, + -0.3602379545435858, + -0.28291660184068396, + 0.16236109077792268, + 1.6561866571093768, + -0.08969148970529098, + -0.1403733323733132, + 1.653704353836856, + -0.06677018843436126, + 0.7967438370130172, + -1.0487567518729046, + 0.38249765783308975, + -0.09510169198820755, + -0.26937819568016647, + 1.0506749629051164, + -0.8217549133103285, + -1.2605463844968352, + 0.06630459711885793, + 1.4068597142149075, + -0.23906347041924664, + 0.40581500848228097, + 0.34824046380765017, + 1.4576160043118096, + 1.1435094589159382, + -1.1561650245007897, + -2.0501574100159003, + 1.972996262759935, + -1.4023249157404538, + -1.252355835467333, + -0.5939367844331447, + -0.8883536777451707, + 0.8826283214585245, + -0.08537770872725445, + 1.2430570323434795, + -0.8288361760454861, + 0.3740588912011981, + 0.6908593029167234, + -0.237603292141705, + -1.4491304746103755, + -1.1503461466656624, + -0.06074698383716112, + -0.7611869714619356, + -0.21606476673148656, + -0.34312322567272335, + 1.4977421535617197, + -1.0246406937201316, + -0.6374111111312827, + -0.6874324035898947, + 0.4555428253347865, + 0.747331595480188, + 0.746181920181807, + 0.3114390139045884, + -0.593601279371061, + -0.8244884478356891, + -0.2384891022367264, + -0.6674385398325516, + 1.1105282112735941, + -0.28770161892364415, + -0.48758783299086855, + -0.8176805756822754, + -0.4529194605131728, + 0.4266942129682067, + 0.959445148595098, + -1.2042407824096886, + 0.32657568443451324, + 0.9547547392627056, + 1.111198856760241, + 1.1033711088631535, + -0.9222774516745698, + 0.015288963161081063, + 0.692635853414707, + 1.6871393477277539, + 1.3793012042095691, + -0.8531034123811346, + -1.6106188284100986, + -0.2882501627532076, + -0.07348821377104808, + -0.252825439264895, + -0.8911729028578352, + -1.4964971728273821, + 0.45852900331123253, + -1.2090009788666547, + -0.15934430798101568, + 1.357487397797586, + -0.6324727023902199, + 1.2747860118984116, + 0.8169733493456708, + -1.1168848140267789, + -0.9667484660578313, + 1.211796851951881, + 1.426407289877944, + -0.3775828028349923, + 0.5675809579603023, + -0.003458751262658035, + 0.20991311225897344, + -2.1396527931651366, + -0.18661113504227145, + 0.425938031142508, + 0.9491730513198636, + -1.1358577850612366, + 0.4790736156730508, + -1.2943260486129644, + -0.8749887305231243, + 0.5211816427115885, + -0.8246528349739011, + 1.7195791979873167, + -0.3173219909708357, + 0.21110777465868505, + -1.4271230211989725 + ], + "shape": [ + 4, + 10, + 4 + ] + }, + "outputs": { + "data": [ + 1.9339028943779266, + 1.11421988558889, + -0.42197727509315885, + -0.476029405566483, + -1.2483046495149022, + 1.6004071514895317, + 0.5313728916219365, + 0.10384965984925965, + 2.456998918588678, + -2.2632805151911053, + -0.5857489536913694, + -2.2219616276614382, + 1.869369003838893, + -1.93153721655368, + 1.5859303738009816, + 1.5589315450576469, + 0.5017763041989596, + 1.5363818054493383, + -2.3176639527548177, + -1.4333335079417524, + -1.445733610842785, + 0.12056301759087473, + -2.574201538280003, + -0.55456591149102, + 0.10180164225306532, + -0.19378840390634317, + 0.9227754990410411, + 0.011018458865574399, + -2.684602298027738, + 0.8095865207597701, + 0.8266911303353263, + -2.2455603887630513, + -0.34474197332753287, + 2.0481347447494245, + 0.4375288729644756, + 1.033164040845549, + 1.7872257701259382, + -0.07206412074167462, + -2.1356471141029694, + 0.48098264485491926 + ], + "shape": [ + 4, + 10 + ] + }, + "logabsdet": { + "data": [ + -1.4675210692989342, + -1.067310211976055, + 0.14419547484850925, + -1.9485519507529072, + -2.20764645599918, + -1.1681871821958048, + -0.5374838340935822, + -0.9097839152439218, + -3.3201129842532837, + -3.675045844551764, + -0.9144255442750324, + -0.3373297475889202, + -0.6997453365127888, + -1.2178046448117705, + -0.4445588126514184, + 0.9267799822797531, + -0.6963832281734724, + 0.5840714217957924, + 0.03488088162906966, + -0.805999892145376, + -1.1272514008047247, + 1.0451391736003843, + -1.5991455469573763, + -0.4562886598302148, + 1.2521448786399687, + -1.6697280722784094, + 0.019017996066540865, + 0.20786467131674569, + -2.580447328202009, + -1.0605820634029133, + 1.4089747749221635, + -1.3185885277804952, + -0.49619613672187163, + 0.018390055827107177, + -0.3367753785704455, + 1.224586394069803, + -2.354301504990093, + -2.935230110842744, + -1.0977673032699535, + -1.5464665560652113 + ], + "shape": [ + 4, + 10 + ] + } +} \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 675c57d..8887791 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,18 @@ using Test using SimpleFlows using Random, Distributions, LinearAlgebra +using Lux, Zygote, Bijectors, ForwardDiff, JSON rng = Random.MersenneTwister(42) @testset "SimpleFlows.jl" begin include("test_layers.jl") include("test_realnvp.jl") + include("test_training.jl") include("test_io.jl") include("test_normalizer.jl") + include("test_splines.jl") + include("test_nsf.jl") + include("test_maf.jl") + include("test_ad.jl") end diff --git a/test/test_ad.jl b/test/test_ad.jl new file mode 100644 index 0000000..76faebe --- /dev/null +++ b/test/test_ad.jl @@ -0,0 +1,47 @@ +using Test +using SimpleFlows +using Distributions +using Random +using DifferentiationInterface +using ForwardDiff +using Zygote +using Mooncake + +println("Starting AD tests...") + +@testset "Automatic Differentiation" begin + rng = Random.MersenneTwister(42) + n_dims = 2 + architectures = [:RealNVP, :NSF, :MAF] + + for arch in architectures + println("Testing architecture: ", arch) + @testset "$arch" begin + d = FlowDistribution(Float64; architecture=arch, n_transforms=2, dist_dims=n_dims, hidden_layer_sizes=[16, 16], n_layers=2) + x = rand(rng, n_dims) + + f(x) = logpdf(d, x) + val = f(x) + @test isfinite(val) + + for (name, backend) in [ + ("ForwardDiff", AutoForwardDiff()), + ("Zygote", AutoZygote()), + ("Mooncake", AutoMooncake(config=nothing)) + ] + println(" Testing backend: ", name) + @testset "$name" begin + g = try + DifferentiationInterface.gradient(f, backend, x) + catch e + @error "Failed to differentiate $arch with $name" exception=(e, catch_backtrace()) + nothing + end + @test g !== nothing + @test length(g) == n_dims + @test all(isfinite, g) + end + end + end + end +end diff --git a/test/test_io.jl b/test/test_io.jl index e56f140..ab2d2ca 100644 --- a/test/test_io.jl +++ b/test/test_io.jl @@ -2,39 +2,55 @@ rng = Random.MersenneTwister(4) dims = 3 - # Test backward-compat scalar API - flow = FlowDistribution(; n_transforms=3, dist_dims=dims, - hidden_dims=16, n_layers=2, rng) - - # We don't train here. Training stability is tested in test_realnvp.jl. - # The IO test only verifies that model serialization survives round-trip. - - # Save - tmpdir = mktempdir() - save_trained_flow(tmpdir, flow) - - @test isfile(joinpath(tmpdir, "flow_setup.json")) - @test isfile(joinpath(tmpdir, "weights.npz")) - - # Load - flow2 = load_trained_flow(tmpdir; rng) - - # Dimensions must match - @test Distributions.length(flow2) == dims - - # logpdf must be identical at test points - x_test = randn(rng, Float32, dims, 20) - lp1 = logpdf(flow, x_test) - lp2 = logpdf(flow2, x_test) - @test lp1 โ‰ˆ lp2 atol=1f-4 - - # Also test the new per-layer vector API round-trips correctly - flow3 = FlowDistribution(; n_transforms=2, dist_dims=dims, - hidden_layer_sizes=[32, 64, 32], rng) + for arch in [:RealNVP, :NSF, :MAF] + flow = if arch == :NSF + FlowDistribution(Float32; architecture=arch, n_transforms=3, dist_dims=dims, + hidden_layer_sizes=[16, 16], K=4, tail_bound=2.0, rng=rng) + else + FlowDistribution(Float32; architecture=arch, n_transforms=3, dist_dims=dims, + hidden_layer_sizes=[16, 16], rng=rng) + end + + # Save + tmpdir = mktempdir() + save_trained_flow(tmpdir, flow) + + @test isfile(joinpath(tmpdir, "flow_setup.json")) + @test isfile(joinpath(tmpdir, "weights.npz")) + + # Load + flow2 = load_trained_flow(tmpdir; rng) + + # Dimensions must match + @test Distributions.length(flow2) == dims + @test typeof(flow2.model) == typeof(flow.model) + + # logpdf must be identical at test points + x_test = randn(rng, Float32, dims, 20) + lp1 = logpdf(flow, x_test) + lp2 = logpdf(flow2, x_test) + @test lp1 โ‰ˆ lp2 atol=1f-4 + end + + # Test backward-compat scalar API and legacy JSON load + flow3 = FlowDistribution(; architecture=:RealNVP, n_transforms=2, dist_dims=dims, + hidden_dims=16, n_layers=2, rng=rng) tmpdir2 = mktempdir() save_trained_flow(tmpdir2, flow3) + + # Manually modify JSON to simulate old format + setup_file = joinpath(tmpdir2, "flow_setup.json") + setup_dict = JSON.parsefile(setup_file) + delete!(setup_dict, "hidden_layer_sizes") + setup_dict["hidden_dims"] = 16 + setup_dict["n_layers"] = 2 + open(setup_file, "w") do io + JSON.print(io, setup_dict, 4) + end + flow4 = load_trained_flow(tmpdir2; rng) - @test flow4.hidden_layer_sizes == [32, 64, 32] + @test flow4.hidden_layer_sizes == [16, 16] + x_test = randn(rng, Float32, dims, 20) lp3 = logpdf(flow3, x_test) lp4 = logpdf(flow4, x_test) @test lp3 โ‰ˆ lp4 atol=1f-4 diff --git a/test/test_layers.jl b/test/test_layers.jl index 5770e7c..9ae7d8b 100644 --- a/test/test_layers.jl +++ b/test/test_layers.jl @@ -28,7 +28,7 @@ end mask = Bool.(collect(1:n) .% 2 .== 0) # [false, true, false, true, false, true] # Trivial conditioner: returns zeros (identity bijector) - conditioner = x_cond -> zeros(Float32, 2n, B) + conditioner = x_cond -> zeros(Float32, 2*sum(mask), B) bj = SimpleFlows.MaskedCoupling(mask, conditioner, SimpleFlows.AffineBijector) x = randn(rng, Float32, n, B) @@ -43,3 +43,45 @@ end x_rec, _ = SimpleFlows.inverse_and_log_det(bj, y) @test x_rec โ‰ˆ x atol=1f-5 end + +@testset "Explicit Dense Jacobian Inverses" begin + rng = Random.MersenneTwister(3) + n = 4 + + # 1. AffineBijector + params = randn(rng, Float32, 2n, 1) + b_aff = SimpleFlows.AffineBijector(params) + + x_single = randn(rng, Float32, n) + f_aff(x) = SimpleFlows.forward_and_log_det(b_aff, reshape(x, n, 1))[1][:] + finv_aff(y) = SimpleFlows.inverse_and_log_det(b_aff, reshape(y, n, 1))[1][:] + + y_aff = f_aff(x_single) + J_aff = ForwardDiff.jacobian(f_aff, x_single) + Jinv_aff = ForwardDiff.jacobian(finv_aff, y_aff) + + # Check J * J_inv โ‰ˆ I + @test J_aff * Jinv_aff โ‰ˆ I(n) atol=1f-4 + # Check determinant matches (Affine returns per-dimension ld) + ld_fwd = SimpleFlows.forward_and_log_det(b_aff, reshape(x_single, n, 1))[2][:] + @test log(abs(det(J_aff))) โ‰ˆ sum(ld_fwd) atol=1f-4 + + # 2. MaskedCoupling + mask = [false, true, false, true] + # Trivial deterministic conditioner for testing structural zeros. MUST be deterministic for ForwardDiff! + fixed_params = randn(rng, Float32, 2*sum(mask), 1) + conditioner = x_cond -> fixed_params .+ sum(x_cond) * 0.0f0 # Ensure dual numbers propagate if needed + bj_m = SimpleFlows.MaskedCoupling(mask, conditioner, SimpleFlows.AffineBijector) + + f_m(x) = SimpleFlows.forward_and_log_det(bj_m, reshape(x, n, 1))[1][:] + finv_m(y) = SimpleFlows.inverse_and_log_det(bj_m, reshape(y, n, 1))[1][:] + + x_m = randn(rng, Float32, n) + y_m = f_m(x_m) + + J_m = ForwardDiff.jacobian(f_m, x_m) + Jinv_m = ForwardDiff.jacobian(finv_m, y_m) + + @test J_m * Jinv_m โ‰ˆ I(n) atol=1f-4 +end + diff --git a/test/test_maf.jl b/test/test_maf.jl new file mode 100644 index 0000000..95f24db --- /dev/null +++ b/test/test_maf.jl @@ -0,0 +1,121 @@ +using SimpleFlows +using Test +using Random +using Lux +using ForwardDiff +using Zygote +using Bijectors +using Statistics +using LinearAlgebra + +@testset "MAF Components" begin + rng = Random.default_rng() + Random.seed!(rng, 123) + + D = 4 + hidden = [16, 16] + + @testset "MADE Autoregressive Property" begin + # Create a MADE network + made = SimpleFlows.MADE(D, hidden, 2*D) + ps, st = Lux.setup(rng, made) + + # Explicit Mask Multiplication (nflows inspired) + # Extract masks from MaskedDense layers to ensure their product is strictly lower triangular. + m1 = made.layers.layer_1.mask + m2 = made.layers.layer_2.mask + m3 = made.layers.layer_3.mask + total_mask = m3 * m2 * m1 + + # Block 1: m (1:D, 1:D) + M_m = total_mask[1:D, :] + @test isapprox(tril(M_m, -1), M_m, atol=1e-7) + @test count(>(0), M_m) > 0 # make sure it's not identically zero + + # Block 2: log_alpha (D+1:2D, 1:D) + M_la = total_mask[D+1:end, :] + @test isapprox(tril(M_la, -1), M_la, atol=1e-7) + @test count(>(0), M_la) > 0 + + # We check the Jacobian of the outputs with respect to the inputs. + + function f(x) + out, _ = Lux.apply(made, x, ps, st) + return out + end + + x0 = randn(Float32, D) + J = ForwardDiff.jacobian(f, x0) # (2*D, D) + + # Block 1: m (1:D, 1:D) + J_m = J[1:D, :] + @test isapprox(tril(J_m, -1), J_m, atol=1e-7) + + # Block 2: log_alpha (D+1:2D, 1:D) + J_la = J[D+1:end, :] + @test isapprox(tril(J_la, -1), J_la, atol=1e-7) + end + + @testset "Differentiability" begin + made = SimpleFlows.MADE(D, hidden, 2*D) + ps, st = Lux.setup(rng, made) + x = randn(Float32, D, 5) + + # Zygote test + gs = Zygote.gradient(ps) do p + out, _ = Lux.apply(made, x, p, st) + sum(out) + end + @test gs[1] !== nothing + + # ForwardDiff test + function g(p_vec) + # Reconstruct NamedTuple structure if possible, but easier to test w.r.t x + return nothing + end + + gx = ForwardDiff.gradient(x -> sum(Lux.apply(made, x, ps, st)[1]), x[:, 1]) + @test all(isfinite, gx) + end + + @testset "MAFBijector Invertibility" begin + made = SimpleFlows.MADE(D, hidden, 2*D) + ps, st = Lux.setup(rng, made) + bj = SimpleFlows.MAFBijector(made, ps, st) + + x = randn(Float32, D, 4) + + # Inverse (Density pass) + u, lad_inv = Bijectors.with_logabsdet_jacobian(bj, x) + + # Forward (Sampling pass) + x_rec, lad_fwd = SimpleFlows.forward_and_log_det(bj, u) + + @test x โ‰ˆ x_rec atol=1e-4 + @test lad_inv โ‰ˆ -lad_fwd atol=1e-4 + end + + @testset "Integrated MAF Model" begin + dist = FlowDistribution(Float32; + architecture=:MAF, + n_transforms=2, + dist_dims=D, + hidden_dims=16, + n_layers=2 + ) + + x = randn(Float32, D, 10) + lp = logpdf(dist, x) + @test length(lp) == 10 + @test all(isfinite, lp) + + samples = rand(rng, dist, 10) + @test size(samples) == (D, 10) + @test all(isfinite, samples) + + # Training smoke test + data = randn(Float32, D, 100) + train_flow!(dist, data; n_epochs=1, batch_size=50) + @test !isnothing(dist.normalizer) + end +end diff --git a/test/test_normalizer.jl b/test/test_normalizer.jl index a491a0a..6c81749 100644 --- a/test/test_normalizer.jl +++ b/test/test_normalizer.jl @@ -1,4 +1,23 @@ @testset "MinMaxNormalizer" begin + # 1. Normal functioning data + x = Float32[1 2 3; 4 5 6] + norm = MinMaxNormalizer(x) + + @test norm.x_min == Float32[1, 4] + @test norm.x_max == Float32[3, 6] + + x_norm = SimpleFlows.normalize(norm, x) + @test size(x_norm) == size(x) + @test all(x_norm .>= 0) + @test all(x_norm .<= 1) + + x_unnorm = SimpleFlows.denormalize(norm, x_norm) + @test x_unnorm โ‰ˆ x + + # 2. Zero variance edge case + x_zero_var = Float32[1 1 1; 4 5 6] + @test_throws ArgumentError MinMaxNormalizer(x_zero_var) + rng = Random.MersenneTwister(7) dims = 4 diff --git a/test/test_nsf.jl b/test/test_nsf.jl new file mode 100644 index 0000000..d9f06d4 --- /dev/null +++ b/test/test_nsf.jl @@ -0,0 +1,52 @@ +using SimpleFlows +using Test +using Random +using Statistics +using Distributions + +@testset "NSF Integrated Model" begin + rng = Random.default_rng() + Random.seed!(rng, 123) + + dist_dims = 2 + n_transforms = 2 + K = 4 + + # 1. Initialization + dist = FlowDistribution(Float32; + architecture=:NSF, + n_transforms=n_transforms, + dist_dims=dist_dims, + hidden_dims=16, + n_layers=2, + K=K + ) + + @test dist.model isa NeuralSplineFlow + @test dist.model.K == K + + # 2. Forward pass (logpdf) + x = randn(Float32, dist_dims, 10) + lp = logpdf(dist, x) + @test length(lp) == 10 + @test all(isfinite, lp) + + # 3. Sampling + samples = rand(rng, dist, 100) + @test size(samples) == (dist_dims, 100) + @test all(isfinite, samples) + + # 4. Training (Smoke test) + # Simple data: Gaussian + target_data = randn(Float32, dist_dims, 200) + + # Initialize with normalizer + train_flow!(dist, target_data; n_epochs=1, batch_size=100) + + @test !isnothing(dist.normalizer) + + # Verify logpdf still works after training + lp_after = logpdf(dist, target_data[:, 1:10]) + @test length(lp_after) == 10 + @test all(isfinite, lp_after) +end diff --git a/test/test_splines.jl b/test/test_splines.jl new file mode 100644 index 0000000..a00f39d --- /dev/null +++ b/test/test_splines.jl @@ -0,0 +1,122 @@ +using SimpleFlows +using Test +using JSON +using ForwardDiff +using Zygote +using LinearAlgebra + +@testset "Rational Quadratic Splines" begin + # 1. Load reference data + data = JSON.parsefile(joinpath(@__DIR__, "data", "nsf_test_data.json")) + + # Helper to reconstruct array from JSON dict + function get_array(T, dict, key) + d = dict[key] + reshape(T.(d["data"]), d["shape"]...) + end + + for T in [Float32, Float64] + @testset "Precision: $T" begin + inputs = get_array(T, data, "inputs") + unnormalized_widths = get_array(T, data, "unnormalized_widths") + unnormalized_heights = get_array(T, data, "unnormalized_heights") + unnormalized_derivatives = get_array(T, data, "unnormalized_derivatives") + + ref_outputs = get_array(T, data, "outputs") + ref_logabsdet = get_array(T, data, "logabsdet") + + D, N = size(inputs) + + @testset "Numerical Parity with PyTorch" begin + for d in 1:D + out, lad = unconstrained_rational_quadratic_spline( + inputs[d, :], + unnormalized_widths[d, :, :], + unnormalized_heights[d, :, :], + unnormalized_derivatives[d, :, :], + T(3.0) # tail_bound + ) + + @test out โ‰ˆ ref_outputs[d, :] atol=(T == Float32 ? 1e-4 : 1e-6) + @test lad โ‰ˆ ref_logabsdet[d, :] atol=(T == Float32 ? 1e-4 : 1e-6) + end + end + + @testset "Invertibility" begin + for d in 1:D + out, _ = unconstrained_rational_quadratic_spline( + inputs[d, :], + unnormalized_widths[d, :, :], + unnormalized_heights[d, :, :], + unnormalized_derivatives[d, :, :], + T(3.0) + ) + + back, _ = unconstrained_rational_quadratic_spline( + out, + unnormalized_widths[d, :, :], + unnormalized_heights[d, :, :], + unnormalized_derivatives[d, :, :], + T(3.0); + inverse=true + ) + + @test back โ‰ˆ inputs[d, :] atol=(T == Float32 ? 1e-4 : 1e-5) + end + end + + @testset "ForwardDiff Compatibility" begin + d = 1 + x = inputs[d, 1:5] + w = unnormalized_widths[d, 1:5, :] + h = unnormalized_heights[d, 1:5, :] + dv = unnormalized_derivatives[d, 1:5, :] + + f(x_in) = sum(unconstrained_rational_quadratic_spline(x_in, w, h, dv, T(3.0))[1]) + + g = ForwardDiff.gradient(f, x) + @test all(isfinite, g) + @test length(g) == 5 + end + + @testset "Zygote Compatibility" begin + d = 1 + x = inputs[d, 1:5] + w = unnormalized_widths[d, 1:5, :] + h = unnormalized_heights[d, 1:5, :] + dv = unnormalized_derivatives[d, 1:5, :] + + f(x_in) = sum(unconstrained_rational_quadratic_spline(x_in, w, h, dv, T(3.0))[1]) + + gz = Zygote.gradient(f, x)[1] + @test all(isfinite, gz) + + gf = ForwardDiff.gradient(f, x) + @test gz โ‰ˆ gf atol=T == Float32 ? 1e-4 : 1e-10 + end + + @testset "Boundary & Tail Consistency" begin + # x explicitly outside the tail_bound + x_out = T.([-4.0, 4.0, -10.0, 10.0]) + N = length(x_out) + K = 5 + w = randn(T, N, K) + h = randn(T, N, K) + dv = randn(T, N, K-1) + + # Forward should be strict linear identity outside bounds + y_out, lad_out = unconstrained_rational_quadratic_spline(x_out, w, h, dv, T(3.0)) + + @test y_out โ‰ˆ x_out atol=1e-6 + @test lad_out โ‰ˆ zeros(T, N) atol=1e-6 + + # Inverse should also be strict identity + x_rec, lad_inv = unconstrained_rational_quadratic_spline(x_out, w, h, dv, T(3.0); inverse=true) + + @test x_rec โ‰ˆ x_out atol=1e-6 + @test lad_inv โ‰ˆ zeros(T, N) atol=1e-6 + end + end + + end +end diff --git a/test/test_training.jl b/test/test_training.jl new file mode 100644 index 0000000..2c9ce8f --- /dev/null +++ b/test/test_training.jl @@ -0,0 +1,85 @@ +using SimpleFlows +using Test +using Random +using Optimisers + +@testset "Custom Optimizer & Training" begin + rng = Random.default_rng() + Random.seed!(rng, 10) + + dims = 2 + N = 100 + data = randn(Float32, dims, N) + + flow = FlowDistribution(Float32; architecture=:RealNVP, n_transforms=2, dist_dims=dims, + hidden_layer_sizes=[16, 16], rng) + + # Initial NLL + lp_init = -mean(logpdf(flow, data)) + + # Custom optimizer + custom_opt = Optimisers.Adam(1f-4) + train_flow!(flow, data; n_epochs=5, batch_size=50, opt=custom_opt, verbose=false) + + # Check that training occurred without error and actually evaluated + lp_trained = -mean(logpdf(flow, data)) + @test isfinite(lp_trained) + @test !isnothing(flow.normalizer) +end + +@testset "Automatic Type Conversion" begin + rng = Random.default_rng() + Random.seed!(rng, 42) + + dims = 2 + N = 100 + # User passes Float64 data accidentally + data_f64 = randn(Float64, dims, N) + + # Model is strictly Float32 + flow = FlowDistribution(Float32; architecture=:RealNVP, n_transforms=2, dist_dims=dims, + hidden_layer_sizes=[16, 16], rng) + + # Should not throw MethodError + train_flow!(flow, data_f64; n_epochs=2, batch_size=50, verbose=false) + + # Normalizer should now be correctly typed as MinMaxNormalizer{Float32} + @test flow.normalizer isa MinMaxNormalizer{Float32} + + # logpdf on Float64 data should evaluate correctly without errors. + # The output type may promote to Float64 depending on input type. + lp = logpdf(flow, data_f64) + @test all(isfinite, lp) +end + +@testset "Reparameterization Trick Zygote Gradients" begin + rng = Random.default_rng() + Random.seed!(rng, 100) + + dims = 2 + + # We test that we can differentiate a loss function evaluated on the output of draw_samples. + # This proves that our flows support the reparameterization trick natively. + for arch in [:RealNVP, :NSF, :MAF] + flow = FlowDistribution(Float32; architecture=arch, n_transforms=1, dist_dims=dims, + hidden_layer_sizes=[16], K=4, tail_bound=2.0, rng) + + # We need a stable test_rng since draw_samples is stochastic. + # Zygote can sometimes struggle with capturing global RNGs during AD, so we pass one explicitly + # and ignore the derivative with respect to it (which is 0). + test_rng = Random.default_rng() + Random.seed!(test_rng, 42) + + function loss_fn(ps) + x_samples = SimpleFlows.draw_samples(test_rng, Float32, flow.model, ps, flow.st, 10) + return sum(x_samples.^2) + end + + loss, (dps,) = Zygote.withgradient(loss_fn, flow.ps) + + @test isfinite(loss) + @test !isnothing(dps) + end +end + +