Skip to content

Commit

Permalink
SDeMo QoL improvements (#289)
Browse files Browse the repository at this point in the history
* feat(demo)?: maxdepth

* feat: respects depth/length when training

* doc: max

* feat(demo): feature bagging

* feat: export bag features

* bug: correct number of features
  • Loading branch information
tpoisot authored Oct 1, 2024
1 parent b6bf5ee commit a8d9cd0
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 21 deletions.
2 changes: 1 addition & 1 deletion SDeMo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SDeMo"
uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467"
authors = ["Timothée Poisot <timothee.poisot@umontreal.ca>"]
version = "0.0.3"
version = "0.0.4"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
3 changes: 2 additions & 1 deletion SDeMo/src/SDeMo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ export BIOCLIM
# BIOCLIM
include("classifiers/decisiontree.jl")
export DecisionTree
export maxnodes!, maxdepth!

# Bagging and ensembles
include("ensembles/bagging.jl")
export Bagging, outofbag, bootstrap
export Bagging, outofbag, bootstrap, bagfeatures!

include("ensembles/ensemble.jl")
export Ensemble
Expand Down
41 changes: 38 additions & 3 deletions SDeMo/src/classifiers/decisiontree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,43 @@ end
"""
DecisionTree
TODO
The depth and number of nodes can be adjusted with `maxnodes!` and `maxdepth!`.
"""
Base.@kwdef mutable struct DecisionTree <: Classifier
root::DecisionNode = DecisionNode()
maxnodes::Integer = 12
maxdepth::Integer = 7
end

function maxnodes!(dt::DecisionTree, n::Integer)
dt.maxnodes = n
return dt
end
function maxdepth!(dt::DecisionTree, n::Integer)
dt.maxdepth = n
return dt
end

function maxnodes!(sdm::SDM, n)
if sdm.classifier isa DecisionTree
maxnodes!(sdm.classifier, n)
return sdm
end
return sdm
end
function maxdepth!(sdm::SDM, n)
if sdm.classifier isa DecisionTree
maxdepth!(sdm.classifier, n)
return sdm
end
return sdm
end

tips(::Nothing) = nothing
tips(dt::DecisionTree) = tips(dt.root)
function tips(dn::SDeMo.DecisionNode)
if iszero(dn.variable)
return dn
return [dn]
else
return vcat(tips(dn.left), tips(dn.right))
end
Expand Down Expand Up @@ -157,7 +182,7 @@ function train!(
root.prediction = mean(y)
dt.root = root
train!(dt.root, X, y)
for _ in 1:6
for _ in 1:(dt.maxdepth-2)
for tip in SDeMo.tips(dt)
p = SDeMo._pool(tip, X)
if !(tip.visited)
Expand Down Expand Up @@ -191,4 +216,14 @@ end

function StatsAPI.predict(dt::DecisionTree, X::Matrix{T}) where {T <: Number}
return vec(mapslices(x -> predict(dt, x), X; dims = 1))
end

@testitem "We can train a decison tree" begin
X, y = SDeMo.__demodata()
model = SDM(MultivariateTransform{PCA}, DecisionTree, X, y)
maxdepth!(model, 3)
@test model.classifier.maxdepth == 3
train!(model)
@test SDeMo.depth(model.classifier) <= 3
@test length(SDeMo.tips(model.classifier)) <= model.classifier.maxnodes
end
30 changes: 25 additions & 5 deletions SDeMo/src/ensembles/bagging.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
bootstrap(y, X; n = 50)
"""
function bootstrap(y, X; n=50)
function bootstrap(y, X; n = 50)
@assert size(y, 1) == size(X, 2)
bags = []
for _ in 1:n
inbag = sample(1:size(X, 2), size(X, 2); replace=true)
inbag = sample(1:size(X, 2), size(X, 2); replace = true)
outbag = setdiff(axes(X, 2), inbag)
push!(bags, (inbag, outbag))
end
Expand All @@ -24,7 +24,7 @@ end
"""
mutable struct Bagging <: AbstractEnsembleSDM
model::SDM
bags::Vector{Tuple{Vector{Int64},Vector{Int64}}}
bags::Vector{Tuple{Vector{Int64}, Vector{Int64}}}
models::Vector{SDM}
end

Expand All @@ -43,7 +43,7 @@ end
Creates a bag from SDM
"""
function Bagging(model::SDM, n::Integer)
bags = bootstrap(labels(model), features(model); n=n)
bags = bootstrap(labels(model), features(model); n = n)
return Bagging(model, bags, [deepcopy(model) for _ in eachindex(bags)])
end

Expand Down Expand Up @@ -77,7 +77,27 @@ function outofbag(ensemble::Bagging; kwargs...)
return ConfusionMatrix(outcomes, ensemble.model.y[done_instances])
end

bagfeatures!(ensemble::Bagging) =
bagfeatures!(ensemble, ceil(Int64, sqrt(length(variables(ensemble)))))

function bagfeatures!(ensemble::Bagging, n::Integer)
for model in ensemble.models
sampled_variables = StatsBase.sample(variables(model), n; replace = false)
variables!(model, sampled_variables)
end
return ensemble
end

@testitem "We can bag the features of an ensemble model" begin
X, y = SDeMo.__demodata()
model = SDM(MultivariateTransform{PCA}, DecisionTree, X, y)
ensemble = Bagging(model, 10)
bagfeatures!(ensemble)
for model in ensemble.models
@test length(variables(model)) == ceil(Int64, sqrt(size(X, 1)))
end
end

majority(pred::Vector{Bool}) = sum(pred) > length(pred) // 2
majority(pred::BitVector) = sum(pred) > length(pred) // 2
export majority
export majority
2 changes: 1 addition & 1 deletion SDeMo/src/ensembles/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ Base.deleteat!(ens::Ensemble, i) = deleteat!(ens.models, i)
outmod = popat!(ens, 1)
@test outmod isa SDM
@test length(ens.models) == 1
end
end
47 changes: 37 additions & 10 deletions docs/src/tutorials/sdemo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ hm = heatmap!(ax,
scatter!(ax, presencelayer; color = :black)
scatter!(ax, bgpoints; color = :red, markersize = 4)
lines!(ax, CHE.geometry[1]; color = :black)
Colorbar(f[1,2], hm)
Colorbar(f[1, 2], hm)
hidedecorations!(ax)
hidespines!(ax)
current_figure() #hide
Expand Down Expand Up @@ -144,9 +144,11 @@ ensemble = Bagging(sdm, 30)
# of the dataset, we can also bootstrap which variables are accessible to each
# model:

for model in ensemble.models
variables!(model, unique(rand(variables(model), length(variables(model)))))
end
bagfeatures!(ensemble)

# By default, the `bagfeatures!` function called on an ensemble will sample the variables
# forom the model, so that each model in the ensemble has the square root (rounded _up_) of
# the number of original variables.

# ::: info About this ensemble model
#
Expand Down Expand Up @@ -182,15 +184,25 @@ f = Figure(; size = (600, 600))
ax = Axis(f[1, 1]; aspect = DataAspect(), title = "Prediction")
hm = heatmap!(ax, prd; colormap = :linear_worb_100_25_c53_n256, colorrange = (0, 1))
Colorbar(f[1, 2], hm)
contour!(ax, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5)
contour!(
ax,
predict(ensemble, layers; consensus = majority);
color = :black,
linewidth = 0.5,
)
lines!(ax, CHE.geometry[1]; color = :black)
hidedecorations!(ax)
hidespines!(ax)
ax2 = Axis(f[2, 1]; aspect = DataAspect(), title = "Uncertainty")
hm =
heatmap!(ax2, quantize(unc); colormap = :linear_gow_60_85_c27_n256, colorrange = (0, 1))
Colorbar(f[2, 2], hm)
contour!(ax2, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5)
contour!(
ax2,
predict(ensemble, layers; consensus = majority);
color = :black,
linewidth = 0.5,
)
lines!(ax2, CHE.geometry[1]; color = :black)
hidedecorations!(ax2)
hidespines!(ax2)
Expand All @@ -217,14 +229,24 @@ hm = heatmap!(
colormap = :diverging_gwv_55_95_c39_n256,
colorrange = (-0.3, 0.3),
)
contour!(ax, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5)
contour!(
ax,
predict(ensemble, layers; consensus = majority);
color = :black,
linewidth = 0.5,
)
lines!(ax, CHE.geometry[1]; color = :black) #hide
hidedecorations!(ax)
hidespines!(ax)
Colorbar(f[1, 2], hm)
ax2 = Axis(f[2, 1]; aspect = DataAspect(), title = "Partial response")
hm = heatmap!(ax2, part_v1; colormap = :linear_gow_65_90_c35_n256, colorrange = (0, 1))
contour!(ax2, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5)
contour!(
ax2,
predict(ensemble, layers; consensus = majority);
color = :black,
linewidth = 0.5,
)
lines!(ax2, CHE.geometry[1]; color = :black)
Colorbar(f[2, 2], hm)
hidedecorations!(ax2)
Expand All @@ -251,8 +273,13 @@ heatmap!(
categorical = true,
),
)
contour!(ax, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5)
contour!(
ax,
predict(ensemble, layers; consensus = majority);
color = :black,
linewidth = 0.5,
)
lines!(ax, CHE.geometry[1]; color = :black)
hidedecorations!(ax)
hidespines!(ax)
current_figure() #hide
current_figure() #hide

2 comments on commit a8d9cd0

@tpoisot
Copy link
Member Author

@tpoisot tpoisot commented on a8d9cd0 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register() subdir=SDeMo

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116403

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a SDeMo-v0.0.4 -m "<description of version>" a8d9cd0c8bc58ee85ca82ec1302d18d213c35f86
git push origin SDeMo-v0.0.4

Please sign in to comment.