Skip to content

Commit

Permalink
Add minimal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jan 21, 2024
1 parent 47d4caf commit 0c03a6b
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 29 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Run tests

on:
push:
branches:
- main
pull_request:

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1']
julia-arch: [x64]
os: [ubuntu-latest]

steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}
arch: ${{ matrix.julia-arch }}
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
# with:
# annotate: true
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,10 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
QuadGK = "2.9"
SpecialFunctions = "2.3"
StaticArrays = "1.9"
StatsFuns = "1.3"
LinearAlgebra = "1"
Random = "1"
julia = "1.10"
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# BootstrapAsymptotics

## Algorithms

- Bootstrap
- Full resampling
- ?

## State evolutions

| | Ridge | Logistic |
| ------- | ----- | -------- |
| 1 algo | | |
| 2 algos | | |
| 3 algos | | |

## Notations

- `n`: number of samples
Expand Down
22 changes: 11 additions & 11 deletions src/ridge/state_evolution.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## Update overlaps

function update_overlaps(hatoverlaps::HatOverlaps; λ::Real)
(; m̂, Q̂, V̂) = hatoverlaps
R = inv* I + )
m = R *
Q = R * ( * ' + ) * R'
(; m_hat, Q_hat, V_hat) = hatoverlaps
R = inv* I + V_hat)
m = R * m_hat
Q = R * (m_hat * m_hat' + Q_hat) * R'
V = R
return Overlaps(m, Q, V)
end
Expand All @@ -17,22 +17,22 @@ function update_hatoverlaps(
(; m, Q, V) = overlaps

Q⁻¹ = inv(Q)
vstar = ρ - sum(m' * Q⁻¹ * m)
v_star = ρ - sum(m' * Q⁻¹ * m)
B = vcat(m', m') * Q⁻¹ - I

m̂, Q̂, V̂ = zero(m), zero(Q), zero(V)
m_hat, Q_hat, V_hat = zero(m), zero(Q), zero(V)

for p1 in 0:pmax, p2 in 0:pmax
P = Diagonal(SVector(p1, p2))
G = inv(I + P * V) * P
proba = weight_dist(p1, p2)

+=* proba) * (G * SVector(1, 1))
+=* proba) * (G * ((vstar + σ²) .+ B * Q * B') * G')
+=* proba) * G
m_hat +=* proba) * (G * SVector(1, 1))
Q_hat +=* proba) * (G * ((v_star + σ²) .+ B * Q * B') * G')
V_hat +=* proba) * G
end

return HatOverlaps(α * m̂, α * Q̂, α *)
return HatOverlaps(m_hat, Q_hat, V_hat)
end

## State evolution
Expand All @@ -49,7 +49,7 @@ function state_evolution(
max_iteration=100,
)
overlaps, hatoverlaps = Overlaps(), HatOverlaps()

for _ in 0:max_iteration
new_hatoverlaps = update_hatoverlaps(overlaps; weight_dist, α, σ², ρ, pmax)
new_overlaps = update_overlaps(new_hatoverlaps; λ)
Expand Down
14 changes: 7 additions & 7 deletions src/utils/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ struct Overlaps{Vec1<:AbstractVector,Mat1<:AbstractMatrix,Mat2<:AbstractMatrix}
end

struct HatOverlaps{Vec1<:AbstractVector,Mat1<:AbstractMatrix,Mat2<:AbstractMatrix}
::Vec1
::Mat1
::Mat2
m_hat::Vec1
Q_hat::Mat1
V_hat::Mat2
end

function Overlaps()
Expand All @@ -29,10 +29,10 @@ function Overlaps()
end

function HatOverlaps()
= SVector(0.0, 0.0)
= SMatrix{2,2}(1.0, 0.01, 0.01, 1.0)
= Diagonal(SVector(1.0, 1.0))
return HatOverlaps(m̂, Q̂, V̂)
m_hat = SVector(0.0, 0.0)
Q_hat = SMatrix{2,2}(1.0, 0.01, 0.01, 1.0)
V_hat = Diagonal(SVector(1.0, 1.0))
return HatOverlaps(m_hat, Q_hat, V_hat)
end

function relative_difference(overlaps::Overlaps, overlaps_ref::Overlaps)
Expand Down
24 changes: 20 additions & 4 deletions test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "fb067cfbb20385de2ccfb638ba57f270383524bb"
project_hash = "43e725ec011071cc34d587db20d7f6c50a6b5c58"

[[deps.Aqua]]
deps = ["Compat", "Pkg", "Test"]
Expand Down Expand Up @@ -37,13 +37,15 @@ deps = ["TOML", "UUIDs"]
git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.12.0"
weakdeps = ["Dates", "LinearAlgebra"]

[deps.Compat.extensions]
CompatLinearAlgebraExt = "LinearAlgebra"

[deps.Compat.weakdeps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+1"

[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
Expand Down Expand Up @@ -116,6 +118,10 @@ version = "1.11.0+1"
[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

Expand All @@ -139,6 +145,11 @@ version = "2023.1.10"
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+2"

[[deps.OrderedCollections]]
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand Down Expand Up @@ -231,6 +242,11 @@ deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
20 changes: 13 additions & 7 deletions test/ridge.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
using BootstrapAsymptotics
using LinearAlgebra
using Test

overlaps, hatoverlaps = state_evolution(Ridge(); weight_dist=indep_poisson, α=1.0, λ=1e-4, σ²=1.0);
overlaps.m
overlaps.Q
overlaps.V
hatoverlaps.
hatoverlaps.
hatoverlaps.
overlaps, hatoverlaps = state_evolution(
Ridge(); weight_dist=indep_poisson, α=1.0, λ=1e-4, σ²=1.0
);

@test overlaps.m [0.631987, 0.631987] rtol = 1e-3
@test overlaps.Q [2.34849 1.1545; 1.1545 2.34849] rtol = 1e-3
@test overlaps.V Diagonal([3680.13, 3680.13]) rtol = 1e-3

@test hatoverlaps.m_hat [0.00017173, 0.00017173] rtol = 1e-3
@test hatoverlaps.Q_hat [1.43915e-7 5.57537e-8; 5.57537e-8 1.43915e-7] rtol = 1e-3
@test hatoverlaps.V_hat Diagonal([0.00017173, 0.00017173]) rtol = 1e-3
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using JuliaFormatter
using Test

@testset verbose = true "BootstrapAsymptotics" begin
@testset "Code quality" begin
Aqua.test_all(BootstrapAsymptotics)
end
@testset "Ridge" begin
include("ridge.jl")
end
Expand Down

0 comments on commit 0c03a6b

Please sign in to comment.