Skip to content

Commit

Permalink
Switch to slower but more accurate cdf-based handling of one-sided co…
Browse files Browse the repository at this point in the history
…nstraints
  • Loading branch information
brenhinkeller committed Jul 4, 2024
1 parent bd16159 commit 93c71ad
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 91 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Chron"
uuid = "68885b1f-77b5-52a7-b2e7-6a8014c56b98"
authors = ["C. Brenhin Keller <cbkeller@dartmouth.edu>"]
version = "0.5.5"
version = "0.6.0"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
103 changes: 27 additions & 76 deletions src/StratMetropolis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,17 +438,11 @@
npoints = length(model_heights)

# Calculate log likelihood of initial proposal
# Proposals younger than age constraint are given a pass if Age_Sidedness is -1 (maximum age)
# proposals older than age constraint are given a pass if Age_Sidedness is +1 (minimum age)

sample_height = copy(Height)
closest = findclosest(sample_height, model_heights)
closest_model_ages = model_ages[closest]
@inbounds for i eachindex(ages,closest_model_ages)
if Age_Sidedness[i] == sign(closest_model_ages[i] - ages[i].μ)
closest_model_ages[i] = ages[i].μ
end
end
ll = strat_ll(closest_model_ages, ages)
ll = strat_ll(closest_model_ages, ages, Age_Sidedness)
ll += normpdf_ll(Height, Height_sigma, sample_height)

# Preallocate variables for MCMC proposals
Expand All @@ -473,14 +467,10 @@

if rand() < 0.1
# Adjust heights
@inbounds for i eachindex(sample_heightₚ)
@inbounds for i eachindex(sample_heightₚ, closestₚ)
sample_heightₚ[i] += randn() * Height_sigma[i]
closestₚ[i] = round(Int,(sample_heightₚ[i] - model_heights[1])/resolution)+1
if closestₚ[i] < 1 # Check we're still within bounds
closestₚ[i] = 1
elseif closestₚ[i] > npoints
closestₚ[i] = npoints
end
closestₚ[i] = round(Int,(sample_heightₚ[i] - first(model_heights))/resolution)+1
closestₚ[i] = max(min(closestₚ[i], lastindex(model_agesₚ)), firstindex(model_agesₚ))
end
else
# Adjust one point at a time then resolve conflicts
Expand All @@ -505,16 +495,11 @@


# Calculate log likelihood of proposal
# Proposals younger than age constraint are given a pass if Age_Sidedness is -1 (maximum age)
# proposal older than age constraint are given a pass if Age_Sidedness is +1 (minimum age)
@inbounds for i eachindex(ages, closest_model_agesₚ)
adjust!(agesₚ, Chronometer, systematic)
@inbounds for i eachindex(closest_model_agesₚ, closestₚ)
closest_model_agesₚ[i] = model_agesₚ[closestₚ[i]]
if Age_Sidedness[i] == sign(closest_model_agesₚ[i] - ages[i].μ)
closest_model_agesₚ[i] = ages[i].μ
end
end
adjust!(agesₚ, Chronometer, systematic)
llₚ = strat_ll(closest_model_agesₚ, agesₚ)
llₚ = strat_ll(closest_model_agesₚ, agesₚ, Age_Sidedness)
llₚ += normpdf_ll(Height, Height_sigma, sample_heightₚ)

# Accept or reject proposal based on likelihood
Expand Down Expand Up @@ -549,14 +534,10 @@

if rand() < 0.1
# Adjust heights
@inbounds for i eachindex(sample_heightₚ)
@inbounds for i eachindex(sample_heightₚ, closestₚ)
sample_heightₚ[i] += randn() * Height_sigma[i]
closestₚ[i] = round(Int,(sample_heightₚ[i] - model_heights[1])/resolution)+1
if closestₚ[i] < 1 # Check we're still within bounds
closestₚ[i] = 1
elseif closestₚ[i] > npoints
closestₚ[i] = npoints
end
closestₚ[i] = round(Int,(sample_heightₚ[i] - first(model_heights))/resolution)+1
closestₚ[i] = max(min(closestₚ[i], lastindex(model_agesₚ)), firstindex(model_agesₚ))
end
else
# Adjust one point at a time then resolve conflicts
Expand All @@ -580,16 +561,11 @@
end

# Calculate log likelihood of proposal
# Proposals younger than age constraint are given a pass if Age_Sidedness is -1 (maximum age)
# proposal older than age constraint are given a pass if Age_Sidedness is +1 (minimum age)
@inbounds for i eachindex(ages, closest_model_agesₚ)
adjust!(agesₚ, Chronometer, systematic)
@inbounds for i eachindex(closest_model_agesₚ, closestₚ)
closest_model_agesₚ[i] = model_agesₚ[closestₚ[i]]
if Age_Sidedness[i] == sign(closest_model_agesₚ[i] - ages[i].μ)
closest_model_agesₚ[i] = ages[i].μ
end
end
adjust!(agesₚ, Chronometer, systematic)
llₚ = strat_ll(closest_model_agesₚ, agesₚ)
llₚ = strat_ll(closest_model_agesₚ, agesₚ, Age_Sidedness)
llₚ += normpdf_ll(Height, Height_sigma, sample_heightₚ)

# Accept or reject proposal based on likelihood
Expand Down Expand Up @@ -619,17 +595,10 @@
npoints = length(model_heights)

# Calculate log likelihood of initial proposal
# Proposals younger than age constraint are given a pass if Age_Sidedness is -1 (maximum age)
# proposals older than age constraint are given a pass if Age_Sidedness is +1 (minimum age)
sample_height = copy(Height)
closest = findclosest(sample_height, model_heights)
closest_model_ages = model_ages[closest]
@inbounds for i eachindex(ages)
if Age_Sidedness[i] == sign(closest_model_ages[i] - ages[i].μ)
closest_model_ages[i] = ages[i].μ
end
end
ll = strat_ll(closest_model_ages, ages)
ll = strat_ll(closest_model_ages, ages, Age_Sidedness)
ll += normpdf_ll(Height, Height_sigma, sample_height)

# Ensure there is only one effective hiatus at most for each height node
Expand Down Expand Up @@ -673,14 +642,10 @@

if rand() < 0.1
# Adjust heights
@inbounds for i eachindex(sample_heightₚ)
@inbounds for i eachindex(sample_heightₚ, closestₚ)
sample_heightₚ[i] += randn() * Height_sigma[i]
closestₚ[i] = round(Int,(sample_heightₚ[i] - model_heights[1])/resolution)+1
if closestₚ[i] < 1 # Check we're still within bounds
closestₚ[i] = 1
elseif closestₚ[i] > npoints
closestₚ[i] = npoints
end
closestₚ[i] = round(Int,(sample_heightₚ[i] - first(model_heights))/resolution)+1
closestₚ[i] = max(min(closestₚ[i], lastindex(model_agesₚ)), firstindex(model_agesₚ))
end
else
# Adjust one point at a time then resolve conflicts
Expand Down Expand Up @@ -718,16 +683,11 @@


# Calculate log likelihood of proposal
# Proposals younger than age constraint are given a pass if Age_Sidedness is -1 (maximum age)
# proposal older than age constraint are given a pass if Age_Sidedness is +1 (minimum age)
@inbounds for i eachindex(ages)
adjust!(agesₚ, Chronometer, systematic)
@inbounds for i eachindex(closest_model_agesₚ, closestₚ)
closest_model_agesₚ[i] = model_agesₚ[closestₚ[i]]
if Age_Sidedness[i] == sign(closest_model_agesₚ[i] - ages[i].μ)
closest_model_agesₚ[i] = ages[i].μ
end
end
adjust!(agesₚ, Chronometer, systematic)
llₚ = strat_ll(closest_model_agesₚ, agesₚ)
llₚ = strat_ll(closest_model_agesₚ, agesₚ, Age_Sidedness)
llₚ += normpdf_ll(Height, Height_sigma, sample_heightₚ)

# Add log likelihood for hiatus duration
Expand Down Expand Up @@ -767,14 +727,10 @@

if rand() < 0.1
# Adjust heights
@inbounds for i eachindex(sample_heightₚ)
@inbounds for i eachindex(sample_heightₚ, closestₚ)
sample_heightₚ[i] += randn() * Height_sigma[i]
closestₚ[i] = round(Int,(sample_heightₚ[i] - model_heights[1])/resolution)+1
if closestₚ[i] < 1 # Check we're still within bounds
closestₚ[i] = 1
elseif closestₚ[i] > npoints
closestₚ[i] = npoints
end
closestₚ[i] = round(Int,(sample_heightₚ[i] - first(model_heights))/resolution)+1
closestₚ[i] = max(min(closestₚ[i], lastindex(model_agesₚ)), firstindex(model_agesₚ))
end
else
# Adjust one point at a time then resolve conflicts
Expand Down Expand Up @@ -811,16 +767,11 @@
end

# Calculate log likelihood of proposal
# Proposals younger than age constraint are given a pass if Age_Sidedness is -1 (maximum age)
# proposal older than age constraint are given a pass if Age_Sidedness is +1 (minimum age)
@inbounds for i eachindex(ages)
adjust!(agesₚ, Chronometer, systematic)
@inbounds for i eachindex(closest_model_agesₚ, closestₚ)
closest_model_agesₚ[i] = model_agesₚ[closestₚ[i]]
if Age_Sidedness[i] == sign(closest_model_agesₚ[i] - ages[i].μ)
closest_model_agesₚ[i] = ages[i].μ
end
end
adjust!(agesₚ, Chronometer, systematic)
llₚ = strat_ll(closest_model_agesₚ, agesₚ)
llₚ = strat_ll(closest_model_agesₚ, agesₚ, Age_Sidedness)
llₚ += normpdf_ll(Height, Height_sigma, sample_heightₚ)

# Add log likelihood for hiatus duration
Expand Down
36 changes: 33 additions & 3 deletions src/Utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,44 @@
## --- log likelihood functions allowing for arbitrary Distributions

# Use dispatch to let us reduce duplication
strat_ll(x, ages::AbstractVector{<:Normal}) = normpdf_ll(x, ages)
strat_ll(x::Real, age::Distribution) = logpdf(age, x)
function strat_ll(x, ages)
ll = zero(float(eltype(x)))
@inbounds for i in eachindex(x, ages)
ll += logpdf(ages[i], x[i])
ll += fastlogpdf(ages[i], x[i])
end
return ll
end
function strat_ll(x, ages, sidedness)
ll = zero(float(eltype(x)))
@inbounds for i in eachindex(x, ages, sidedness)
ll += if sidedness[i] > 0 # Minimum age
logcdf(ages[i], x[i])
elseif sidedness[i] < 0 # Maximum age
logccdf(ages[i], x[i])
else
fastlogpdf(ages[i], x[i])
end
end
return ll
end

# function strat_ll(x, ages, sidedness)
# ll = zero(float(eltype(x)))
# @inbounds for i in eachindex(x, ages, sidedness)
# ll += if sidedness[i] == sign(x[i] - mean(ages[i]))
# fastlogpdf(ages[i], mean(ages[i])))
# else
# fastlogpdf(ages[i], x[i]))
# end
# end
# return ll
# end

fastlogpdf(d, x::Real) = logpdf(d, x)
function fastlogpdf(d::Normal, x::Real)
δ, σ = (x - d.μ), d.σ
- δ*δ/(2*σ*σ)
end


## --- End of File
12 changes: 6 additions & 6 deletions test/testStratOnly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ config.sieve = round(Int,npoints_approx) # Record one out of every nsieve steps

# Test that results match expectation, within some tolerance
@test mdl.Age isa Vector{Float64}
@test mdl.Age [775.639803231674, 764.7522135364757, 753.6908761271258, 743.7955507171365, 734.1450626270426, 724.4181380914805, 719.9344128500521, 715.730084005377, 711.5837749049695, 708.9681700152063, 706.4443998136725, 703.9187181054767, 701.2964403515165, 698.4067159228537, 694.0564170150618] atol=1
@test mdl.Age_025CI [749.86998905391, 746.6946942599805, 744.0594596745302, 723.3563855942004, 718.369656399334, 715.5038496825758, 706.6306024739025, 703.0289240992544, 700.7652433271606, 697.635218301761, 695.8957733017851, 694.6899457591284, 693.7027382576433, 692.8123972703291, 681.3406483559285] atol=3
@test mdl.Age_975CI [811.2671899519365, 797.4625865888173, 763.3573868897032, 760.0445173042854, 754.7002448815252, 733.5275337761803, 731.2848593336894, 728.2781265900221, 722.32529630248, 720.7715897667372, 718.9216447983222, 716.5564071732707, 713.0417479766365, 704.0850806207267, 702.9561071444978] atol=3
@test mdl.Age [774.446455402488, 764.138871634567, 753.6234690338521, 743.7736679540459, 734.1288211670727, 724.4222568319159, 719.9641790169545, 715.7853469234702, 711.659901307201, 709.0927902799843, 706.6073112860937, 704.1349331543528, 701.576806557881, 698.7594637725156, 695.9077192905107] atol=1
@test mdl.Age_025CI [749.7477463948217, 746.630273097855, 743.9690750302638, 723.4084389968576, 718.3588586689344, 715.5410003896017, 706.7023568283762, 703.1283100699234, 700.9213523438647, 697.904849247825, 696.2105445428034, 695.0433927602111, 694.1016935971145, 693.2625238063798, 684.8895558316834] atol=3
@test mdl.Age_975CI [806.7021880750214, 794.9197023904139, 763.2743319392429, 760.0094083209071, 754.7169964175238, 733.5482560401001, 731.2806438300145, 728.3050781269916, 722.3491038457139, 720.830748668773, 718.9846443130959, 716.6641608791327, 713.2397206761551, 704.3594827974073, 703.4763901911105] atol=3
# Test that all age-depth models are in stratigraphic order
@test all([issorted(x, rev=true) for x in eachcol(agedist)])
@test all(!isnan, agedist)
Expand All @@ -51,9 +51,9 @@ hiatus.Duration_sigma = [ 3.1, 2.0 ]

# Test that results match expectation, within some tolerance
@test mdl.Age isa Vector{Float64}
@test mdl.Age [776.1645401715448, 765.3913415310423, 754.4578701374423, 749.1402317775008, 729.1734073022601, 724.1504941010227, 720.8064612101196, 717.6325237531024, 714.3911204374407, 712.9981743647901, 701.8475241735739, 700.4609977090519, 699.1253113675604, 697.6951927842167, 693.5283072170754] atol=1
@test mdl.Age_025CI [750.7546389575148, 747.721013702703, 745.2360260172313, 735.8600442968074, 717.8362687469073, 715.6571312160221, 709.9467580290137, 707.3352431871043, 705.6012864195188, 704.1631160813711, 693.9692554028434, 693.2871054139605, 692.6791996279173, 692.0920645898433, 681.1134640520288] atol=3
@test mdl.Age_975CI [811.4534097071294, 797.7943496043165, 763.8616858646202, 761.1731148399618, 742.6187531312476, 732.7757692120042, 730.8844293232036, 728.3748870708009, 723.607690594645, 722.5139094919781, 711.2193585524507, 709.48421725086, 707.154594077965, 703.3249584517113, 702.2469721879497] atol=3
@test mdl.Age [774.9621590393864, 764.7483478838244, 754.3928287806855, 749.1279876514806, 729.1819943547974, 724.1793043771663, 720.8913053718023, 717.7627759721249, 714.5509500060313, 713.2002472528285, 702.1296086101366, 700.7704707147774, 699.4643412956174, 698.0733121394825, 695.3696935693185] atol=1
@test mdl.Age_025CI [750.6328462444853, 747.6305055671014, 745.2219453018333, 735.8795895625685, 717.8606462193696, 715.6849749469166, 710.1425439567424, 707.5591665793866, 705.808337538727, 704.4299789235162, 694.3410849948168, 693.699614909472, 693.1108571564564, 692.5485716537205, 684.6340487749085] atol=3
@test mdl.Age_975CI [806.6715104282963, 794.8753648212958, 763.7844862197752, 761.1148678026058, 742.5015528417246, 732.771524797932, 730.9079222710706, 728.4137307378472, 723.6929471023955, 722.6480626953821, 711.3953411778708, 709.6936433474348, 707.3968615013218, 703.6205967015544, 702.7858695444451] atol=3
# Test that all age-depth models are in stratigraphic order
@test all([issorted(x, rev=true) for x in eachcol(agedist)])
@test all(!isnan, agedist)
Expand Down
10 changes: 5 additions & 5 deletions test/testUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@

## --- Other utility functions for log likelihoods

@test strat_ll(66.02, BilinearExponential(-5.9345103368858085, 66.00606672179812, 0.17739474265630253, 70.57882331291309, 0.6017142541555505)) (-3.251727600957773 -5.9345103368858085)
@test strat_ll(950, Radiocarbon(1000, 10, intcal13)) -7.342817323513984
@test strat_ll([66.02,], [BilinearExponential(-5.9345103368858085, 66.00606672179812, 0.17739474265630253, 70.57882331291309, 0.6017142541555505)]) (-3.251727600957773 -5.9345103368858085)
@test strat_ll([950,], [Radiocarbon(1000, 10, intcal13)]) -7.342817323513984

@test strat_ll([0.0, 0.0], [Normal(0,1), Normal(0,1)]) 0
@test strat_ll([0.0, 0.5], [Normal(0,1), Uniform(0,1)]) -0.9189385332046728
@test strat_ll([0.0, 0.5, 66.02], [Normal(0,1), Uniform(0,1), BilinearExponential(-5.9345103368858085, 66.00606672179812, 0.17739474265630253, 70.57882331291309, 0.6017142541555505)]) (-0.9189385332046728 -3.251727600957773 -5.9345103368858085)
@test strat_ll([0.0, 0.5, 66.02, 900], [Normal(0,1), Uniform(0,1), BilinearExponential(-5.9345103368858085, 66.00606672179812, 0.17739474265630253, 70.57882331291309, 0.6017142541555505), Radiocarbon(1000, 10, intcal13)]) -26.680063245418374
@test strat_ll([0.0, 0.5], [Normal(0,1), Uniform(0,1)]) 0
@test strat_ll([0.0, 0.5, 66.02], [Normal(0,1), Uniform(0,1), BilinearExponential(-5.9345103368858085, 66.00606672179812, 0.17739474265630253, 70.57882331291309, 0.6017142541555505)]) (-3.251727600957773 -5.9345103368858085)
@test strat_ll([0.0, 0.5, 66.02, 900], [Normal(0,1), Uniform(0,1), BilinearExponential(-5.9345103368858085, 66.00606672179812, 0.17739474265630253, 70.57882331291309, 0.6017142541555505), Radiocarbon(1000, 10, intcal13)]) -25.761124712213704

## ---

0 comments on commit 93c71ad

Please sign in to comment.