Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: CI
on:
schedule:
- cron: 0 0 * * *
pull_request:
branches:
- main
- develop
push:
branches:
- "**"
tags: "*"
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- "1.10"
- "1.11"
- "1.12"
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v4
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-

- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
file: lcov.info
45 changes: 31 additions & 14 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,47 @@ authors = ["marcobonici <bonici.marco@gmail.com>"]

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Bijectors = "0.15.17"
ConcreteStructs = "0.2.3"
Distributions = "0.25.123"
JSON = "1.4.0"
LinearAlgebra = "1.12.0"
Lux = "1.31.3"
MLUtils = "0.4.8"
NPZ = "0.4.3"
Optimisers = "0.4.7"
Random = "1.11.0"
Statistics = "1.11.1"
Test = "1.11.0"
Zygote = "0.7.10"
BenchmarkTools = "1.6.3"
Bijectors = "0.15"
ChainRulesCore = "1.26"
ConcreteStructs = "0.2"
DifferentiationInterface = "0.7.16"
Distributions = "0.25"
ForwardDiff = "1.3"
JSON = "0.21"
LinearAlgebra = "1.10"
Lux = "1.31"
MLUtils = "0.4"
Mooncake = "0.5.23"
NNlib = "0.9"
NPZ = "0.4"
Optimisers = "0.4"
Random = "1.10"
Statistics = "1.10"
Test = "1.10"
Zygote = "0.7"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "DifferentiationInterface", "Mooncake", "BenchmarkTools"]
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ distributions inside [`Turing.jl`](https://turinglang.org/) probabilistic progra

## Features

- **RealNVP** (Real-valued Non-Volume Preserving) architecture with affine coupling layers
- **Architectures**:
- **RealNVP** (Real-valued Non-Volume Preserving)
- **NSF** (Neural Spline Flow with Rational Quadratic Splines)
- **MAF** (Masked Autoregressive Flow with MADE)
- Full `Distributions.jl` interface: `logpdf`, `rand`, `length`
- `Bijectors.jl` compatible → plug directly into Turing models as a prior
- Training via `Optimisers.Adam` + `Zygote.jl` autodiff, mini-batched with `MLUtils`
Expand All @@ -15,10 +18,10 @@ distributions inside [`Turing.jl`](https://turinglang.org/) probabilistic progra
## Quick Start

```julia
using SimpleFlows, Distributions
using SimpleFlows, Distributions, LinearAlgebra

# 1. Build a 4-dim flow
flow = FlowDistribution(; n_transforms=6, dist_dims=4, hidden_dims=64, n_layers=3)
# 1. Build a 4-dim flow (options: :RealNVP, :NSF, :MAF)
flow = FlowDistribution(Float32; architecture=:RealNVP, n_transforms=6, dist_dims=4, hidden_layer_sizes=[64, 64, 64])

# 2. Sample training data from your target distribution
data = Float32.(rand(MvNormal(zeros(4), I), 10_000))
Expand Down Expand Up @@ -60,17 +63,19 @@ my_flow/
| Architecture | Status |
|---|---|
| RealNVP | ✅ Done |
| MAF | 📋 Planned |
| NSF | 📋 Planned |
| MAF | ✅ Done |
| NSF | ✅ Done |

## Running Tests

```bash
julia --project=. -e "using Pkg; Pkg.test()"
```

## Running the Example
## Running the Examples

```bash
julia --project=. examples/train_multinormal.jl
julia --project=. examples/train_multinormal_nsf.jl
julia --project=. examples/train_multinormal_maf.jl
```
8 changes: 3 additions & 5 deletions examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.12.1"
manifest_format = "2.0"
project_hash = "64c6eed7b4f2bf47e0f4350e25ac6a3f80fb5a0d"
project_hash = "3d2728b3089e56da155cbe07b6f794382965b85e"

[[deps.ADTypes]]
git-tree-sha1 = "f7304359109c768cf32dc5fa2d371565bb63b68a"
Expand Down Expand Up @@ -1753,10 +1753,8 @@ uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
version = "0.1.0"

[[deps.SimpleFlows]]
deps = ["Bijectors", "ConcreteStructs", "Distributions", "JSON", "LinearAlgebra", "Lux", "MLUtils", "NPZ", "Optimisers", "Random", "Statistics", "Test", "Zygote"]
git-tree-sha1 = "9ee8f18c24029fe3ad3362abade4e005e283e4d1"
repo-rev = "main"
repo-url = "https://github.com/CosmologicalEmulators/SimpleFlows.jl"
deps = ["Bijectors", "ChainRulesCore", "ConcreteStructs", "Distributions", "ForwardDiff", "JSON", "LinearAlgebra", "Lux", "MLUtils", "NNlib", "NPZ", "Optimisers", "Random", "Statistics", "Test", "Zygote"]
path = ".."
uuid = "7aff1418-a6e2-48c2-ba03-ae8e32e98757"
version = "0.1.0"

Expand Down
2 changes: 2 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
SimpleFlows = "7aff1418-a6e2-48c2-ba03-ae8e32e98757"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
111 changes: 111 additions & 0 deletions examples/train_multinormal_maf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# examples/train_multinormal_maf.jl
using SimpleFlows
using Random
using Statistics
using Distributions, LinearAlgebra, Random, Statistics, Printf
using Turing
using Optimisers

function run_maf_example()
# ── 1. Setup Data ──────────────────────────────────────────────────────────
rng = Random.default_rng()
Random.seed!(rng, 123)

dist_dims = 4
# Create a correlated 4D Gaussian
μ_true = [1.0, -0.5, 2.0, 0.3]
Σ_true = [1.0 0.5 0.2 0.1;
0.5 1.0 0.4 0.2;
0.2 0.4 1.0 0.5;
0.1 0.2 0.5 1.0]

true_dist = MvNormal(μ_true, Σ_true)
n_samples = 5000
target_data = rand(rng, true_dist, n_samples)

println("Target mean: ", round.(mean(target_data, dims=2)[:], digits=3))
println("Target std: ", round.(std(target_data, dims=2)[:], digits=3))
println()

# ── 2. Initialize MAF Flow ───────────────────────────────────────────────
# MAF is generally more expressive than RealNVP for same number of layers
n_transforms = 4
hidden_layer_sizes = [32, 32]

println("Training MAF ($n_transforms transforms, $(hidden_layer_sizes) units, 500 epochs)…")

dist = FlowDistribution(Float32;
architecture=:MAF,
n_transforms=n_transforms,
dist_dims=dist_dims,
hidden_layer_sizes=hidden_layer_sizes,
rng=rng
)

# ── 3. Train ──────────────────────────────────────────────────────────────
# We use a smaller learning rate for MAF as it can be more sensitive
train_flow!(dist, target_data; n_epochs=500, batch_size=200, opt=Optimisers.Adam(1e-3))

# ── 4. Evaluate Density ──────────────────────────────────────────────────
test_data = rand(rng, true_dist, 1000)
lp_true = logpdf(true_dist, test_data)
lp_flow = logpdf(dist, test_data)

println("\n── Density Fit ──────────────────────────────────────────")
println("Mean log-pdf (true distribution): ", round(mean(lp_true), digits=4))
println("Mean log-pdf (trained MAF flow): ", round(mean(lp_flow), digits=4))
println("Difference: ", round(abs(mean(lp_true) - mean(lp_flow)), digits=4))

samples = rand(rng, dist, 5000)
println("\nFlow sample mean: ", round.(mean(samples, dims=2)[:], digits=3))
println("Flow sample std: ", round.(std(samples, dims=2)[:], digits=3))

# ── 5. Serialization ─────────────────────────────────────────────────────
save_path = joinpath(@__DIR__, "../trained_flows/maf_mvn_4d")
save_trained_flow(save_path, dist)

# Reload and verify
dist_reloaded = load_trained_flow(save_path)
lp_reloaded = logpdf(dist_reloaded, test_data)
println("Mean log-pdf after reload: ", round(mean(lp_reloaded), digits=4))
if isapprox(mean(lp_flow), mean(lp_reloaded), atol=1e-5)
println("Round-trip OK ✓")
else
println("Round-trip FAILED ✗")
end

# ── 6. Turing.jl Integration ───────────────────────────────────────────
println("\n── Turing Inference ─────────────────────────────────────────")
# Observed data: one sample from θ[1] with noise
θ_true = μ_true
y_obs = θ_true[1] + 0.1 * randn(rng)
println("Observed y: ", round(y_obs, digits=4), " (true θ[1] = ", θ_true[1], ")")

@model function linear_model(y, prior_dist)
θ ~ prior_dist
y ~ Normal(θ[1], 0.1)
end

# Sampling with exact prior
println("\nSampling with exact MvNormal prior (4 chains × 1000 samples)…")
chain_exact = sample(linear_model(y_obs, true_dist), HMC(0.1, 10), MCMCThreads(), 1000, 4; progress=false)

# Sampling with MAF prior
println("Sampling with trained MAF prior (4 chains × 1000 samples)…")
chain_maf = sample(linear_model(y_obs, dist), HMC(0.1, 10), MCMCThreads(), 1000, 4; progress=false)

println("\n── Posterior comparison (θ[1]) ──────────────────────────────")
exact_θ1 = vec(chain_exact[:, "θ[1]", :])
maf_θ1 = vec(chain_maf[:, "θ[1]", :])

@printf " True θ[1]: %.4f\n" θ_true[1]
@printf " Observed y: %.4f\n" y_obs
@printf " Posterior mean (exact): %.4f ± %.4f\n" mean(exact_θ1) std(exact_θ1)
@printf " Posterior mean (MAF): %.4f ± %.4f\n" mean(maf_θ1) std(maf_θ1)

println("\n✨ End of script: Turing MCMC sampling with MAF finished successfully!")
end

if abspath(PROGRAM_FILE) == joinpath(@__DIR__, "train_multinormal_maf.jl")
run_maf_example()
end
Loading
Loading