From a8d9cd0c8bc58ee85ca82ec1302d18d213c35f86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Poisot?= Date: Tue, 1 Oct 2024 12:00:48 -0400 Subject: [PATCH] SDeMo QoL improvements (#289) * feat(demo)?: maxdepth * feat: respects depth/length when training * doc: max * feat(demo): feature bagging * feat: export bag features * bug: correct number of features --- SDeMo/Project.toml | 2 +- SDeMo/src/SDeMo.jl | 3 +- SDeMo/src/classifiers/decisiontree.jl | 41 +++++++++++++++++++++-- SDeMo/src/ensembles/bagging.jl | 30 ++++++++++++++--- SDeMo/src/ensembles/ensemble.jl | 2 +- docs/src/tutorials/sdemo.jl | 47 +++++++++++++++++++++------ 6 files changed, 104 insertions(+), 21 deletions(-) diff --git a/SDeMo/Project.toml b/SDeMo/Project.toml index 4a881efe9..cddabba56 100644 --- a/SDeMo/Project.toml +++ b/SDeMo/Project.toml @@ -1,7 +1,7 @@ name = "SDeMo" uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467" authors = ["Timothée Poisot "] -version = "0.0.3" +version = "0.0.4" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/SDeMo/src/SDeMo.jl b/SDeMo/src/SDeMo.jl index ea25b26b9..7d65192a0 100644 --- a/SDeMo/src/SDeMo.jl +++ b/SDeMo/src/SDeMo.jl @@ -42,10 +42,11 @@ export BIOCLIM # BIOCLIM include("classifiers/decisiontree.jl") export DecisionTree +export maxnodes!, maxdepth! # Bagging and ensembles include("ensembles/bagging.jl") -export Bagging, outofbag, bootstrap +export Bagging, outofbag, bootstrap, bagfeatures! include("ensembles/ensemble.jl") export Ensemble diff --git a/SDeMo/src/classifiers/decisiontree.jl b/SDeMo/src/classifiers/decisiontree.jl index f376b5b6a..73330dcfd 100644 --- a/SDeMo/src/classifiers/decisiontree.jl +++ b/SDeMo/src/classifiers/decisiontree.jl @@ -76,18 +76,43 @@ end """ DecisionTree -TODO +The depth and number of nodes can be adjusted with `maxnodes!` and `maxdepth!`. """ Base.@kwdef mutable struct DecisionTree <: Classifier root::DecisionNode = DecisionNode() maxnodes::Integer = 12 + maxdepth::Integer = 7 +end + +function maxnodes!(dt::DecisionTree, n::Integer) + dt.maxnodes = n + return dt +end +function maxdepth!(dt::DecisionTree, n::Integer) + dt.maxdepth = n + return dt +end + +function maxnodes!(sdm::SDM, n) + if sdm.classifier isa DecisionTree + maxnodes!(sdm.classifier, n) + return sdm + end + return sdm +end +function maxdepth!(sdm::SDM, n) + if sdm.classifier isa DecisionTree + maxdepth!(sdm.classifier, n) + return sdm + end + return sdm end tips(::Nothing) = nothing tips(dt::DecisionTree) = tips(dt.root) function tips(dn::SDeMo.DecisionNode) if iszero(dn.variable) - return dn + return [dn] else return vcat(tips(dn.left), tips(dn.right)) end @@ -157,7 +182,7 @@ function train!( root.prediction = mean(y) dt.root = root train!(dt.root, X, y) - for _ in 1:6 + for _ in 1:(dt.maxdepth-2) for tip in SDeMo.tips(dt) p = SDeMo._pool(tip, X) if !(tip.visited) @@ -191,4 +216,14 @@ end function StatsAPI.predict(dt::DecisionTree, X::Matrix{T}) where {T <: Number} return vec(mapslices(x -> predict(dt, x), X; dims = 1)) +end + +@testitem "We can train a decison tree" begin + X, y = SDeMo.__demodata() + model = SDM(MultivariateTransform{PCA}, DecisionTree, X, y) + maxdepth!(model, 3) + @test model.classifier.maxdepth == 3 + train!(model) + @test SDeMo.depth(model.classifier) <= 3 + @test length(SDeMo.tips(model.classifier)) <= model.classifier.maxnodes end \ No newline at end of file diff --git a/SDeMo/src/ensembles/bagging.jl b/SDeMo/src/ensembles/bagging.jl index 502a38c11..0e24c0c61 100644 --- a/SDeMo/src/ensembles/bagging.jl +++ b/SDeMo/src/ensembles/bagging.jl @@ -1,11 +1,11 @@ """ bootstrap(y, X; n = 50) """ -function bootstrap(y, X; n=50) +function bootstrap(y, X; n = 50) @assert size(y, 1) == size(X, 2) bags = [] for _ in 1:n - inbag = sample(1:size(X, 2), size(X, 2); replace=true) + inbag = sample(1:size(X, 2), size(X, 2); replace = true) outbag = setdiff(axes(X, 2), inbag) push!(bags, (inbag, outbag)) end @@ -24,7 +24,7 @@ end """ mutable struct Bagging <: AbstractEnsembleSDM model::SDM - bags::Vector{Tuple{Vector{Int64},Vector{Int64}}} + bags::Vector{Tuple{Vector{Int64}, Vector{Int64}}} models::Vector{SDM} end @@ -43,7 +43,7 @@ end Creates a bag from SDM """ function Bagging(model::SDM, n::Integer) - bags = bootstrap(labels(model), features(model); n=n) + bags = bootstrap(labels(model), features(model); n = n) return Bagging(model, bags, [deepcopy(model) for _ in eachindex(bags)]) end @@ -77,7 +77,27 @@ function outofbag(ensemble::Bagging; kwargs...) return ConfusionMatrix(outcomes, ensemble.model.y[done_instances]) end +bagfeatures!(ensemble::Bagging) = + bagfeatures!(ensemble, ceil(Int64, sqrt(length(variables(ensemble))))) + +function bagfeatures!(ensemble::Bagging, n::Integer) + for model in ensemble.models + sampled_variables = StatsBase.sample(variables(model), n; replace = false) + variables!(model, sampled_variables) + end + return ensemble +end + +@testitem "We can bag the features of an ensemble model" begin + X, y = SDeMo.__demodata() + model = SDM(MultivariateTransform{PCA}, DecisionTree, X, y) + ensemble = Bagging(model, 10) + bagfeatures!(ensemble) + for model in ensemble.models + @test length(variables(model)) == ceil(Int64, sqrt(size(X, 1))) + end +end majority(pred::Vector{Bool}) = sum(pred) > length(pred) // 2 majority(pred::BitVector) = sum(pred) > length(pred) // 2 -export majority \ No newline at end of file +export majority diff --git a/SDeMo/src/ensembles/ensemble.jl b/SDeMo/src/ensembles/ensemble.jl index 4009b07d4..57bb2e797 100644 --- a/SDeMo/src/ensembles/ensemble.jl +++ b/SDeMo/src/ensembles/ensemble.jl @@ -41,4 +41,4 @@ Base.deleteat!(ens::Ensemble, i) = deleteat!(ens.models, i) outmod = popat!(ens, 1) @test outmod isa SDM @test length(ens.models) == 1 -end \ No newline at end of file +end diff --git a/docs/src/tutorials/sdemo.jl b/docs/src/tutorials/sdemo.jl index 4bf6aff99..c6de24936 100644 --- a/docs/src/tutorials/sdemo.jl +++ b/docs/src/tutorials/sdemo.jl @@ -85,7 +85,7 @@ hm = heatmap!(ax, scatter!(ax, presencelayer; color = :black) scatter!(ax, bgpoints; color = :red, markersize = 4) lines!(ax, CHE.geometry[1]; color = :black) -Colorbar(f[1,2], hm) +Colorbar(f[1, 2], hm) hidedecorations!(ax) hidespines!(ax) current_figure() #hide @@ -144,9 +144,11 @@ ensemble = Bagging(sdm, 30) # of the dataset, we can also bootstrap which variables are accessible to each # model: -for model in ensemble.models - variables!(model, unique(rand(variables(model), length(variables(model))))) -end +bagfeatures!(ensemble) + +# By default, the `bagfeatures!` function called on an ensemble will sample the variables +# forom the model, so that each model in the ensemble has the square root (rounded _up_) of +# the number of original variables. # ::: info About this ensemble model # @@ -182,7 +184,12 @@ f = Figure(; size = (600, 600)) ax = Axis(f[1, 1]; aspect = DataAspect(), title = "Prediction") hm = heatmap!(ax, prd; colormap = :linear_worb_100_25_c53_n256, colorrange = (0, 1)) Colorbar(f[1, 2], hm) -contour!(ax, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5) +contour!( + ax, + predict(ensemble, layers; consensus = majority); + color = :black, + linewidth = 0.5, +) lines!(ax, CHE.geometry[1]; color = :black) hidedecorations!(ax) hidespines!(ax) @@ -190,7 +197,12 @@ ax2 = Axis(f[2, 1]; aspect = DataAspect(), title = "Uncertainty") hm = heatmap!(ax2, quantize(unc); colormap = :linear_gow_60_85_c27_n256, colorrange = (0, 1)) Colorbar(f[2, 2], hm) -contour!(ax2, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5) +contour!( + ax2, + predict(ensemble, layers; consensus = majority); + color = :black, + linewidth = 0.5, +) lines!(ax2, CHE.geometry[1]; color = :black) hidedecorations!(ax2) hidespines!(ax2) @@ -217,14 +229,24 @@ hm = heatmap!( colormap = :diverging_gwv_55_95_c39_n256, colorrange = (-0.3, 0.3), ) -contour!(ax, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5) +contour!( + ax, + predict(ensemble, layers; consensus = majority); + color = :black, + linewidth = 0.5, +) lines!(ax, CHE.geometry[1]; color = :black) #hide hidedecorations!(ax) hidespines!(ax) Colorbar(f[1, 2], hm) ax2 = Axis(f[2, 1]; aspect = DataAspect(), title = "Partial response") hm = heatmap!(ax2, part_v1; colormap = :linear_gow_65_90_c35_n256, colorrange = (0, 1)) -contour!(ax2, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5) +contour!( + ax2, + predict(ensemble, layers; consensus = majority); + color = :black, + linewidth = 0.5, +) lines!(ax2, CHE.geometry[1]; color = :black) Colorbar(f[2, 2], hm) hidedecorations!(ax2) @@ -251,8 +273,13 @@ heatmap!( categorical = true, ), ) -contour!(ax, predict(ensemble, layers; consensus=majority); color = :black, linewidth = 0.5) +contour!( + ax, + predict(ensemble, layers; consensus = majority); + color = :black, + linewidth = 0.5, +) lines!(ax, CHE.geometry[1]; color = :black) hidedecorations!(ax) hidespines!(ax) -current_figure() #hide \ No newline at end of file +current_figure() #hide