Skip to content

Commit

Permalink
SDeMo improvements: inflated response curves, better ensemble support…
Browse files Browse the repository at this point in the history
…, decision trees (#284)

* feat(demo): decision tree

* feat: inflated partialresponse

* perf: decision tree trains faster

* perf: decision tree trains faster

* feat: variable importance works with ensembles

* feat: partialresponse works with an ensemble

* feat: Shapley works with ensemble models

* doc: decision tree in the tutorial

* feat: ensemble models can be cross-validated

* feat: utility functions for ensembles

* test: partial resp for ensembles

* perf(demo): limit tree depth to 7 always

* bug(demo): fix inflated partial response

* feat: add variables for the bagging

* bug(demo): order of functions in partial response

* doc: use decision tree for the demo

* semver(demo): v0.0.3

* doc: clarify tutorial

* doc: abs needed to map shapley values

* bug(sdt): masking of occurrences in nested polygons

* doc: fix polygon demo

* perf(dem): decision tree training

* perf(demo): decision tree is ~ 2x faster

* feat(demo): outofbag gains thr keyword

* feat(demo): majority function

* doc: cleanup sdemo tutorial
  • Loading branch information
tpoisot authored Sep 28, 2024
1 parent 7848077 commit 1759419
Show file tree
Hide file tree
Showing 15 changed files with 418 additions and 115 deletions.
2 changes: 1 addition & 1 deletion SDeMo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SDeMo"
uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467"
authors = ["Timothée Poisot <timothee.poisot@umontreal.ca>"]
version = "0.0.2"
version = "0.0.3"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
19 changes: 18 additions & 1 deletion SDeMo/docs/src/demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ ci(cv.validation, mcc)

forwardselection!(sdm, folds, [1])

# This operation *will retrain* the model. We can now look at the list of selected variables:
# This operation *will retrain* the model. We can now look at the list of
# selected variables:

variables(sdm)

Expand Down Expand Up @@ -170,6 +171,22 @@ cm = heatmap!(prx, pry, prz, colormap=:Oranges)
Colorbar(f[1,2], cm)
current_figure() #hide

# ## Inflated partial responses

# Inflated partial responses replace the average value by other summary
# statistics, here defined as (randomly) the mean, median, maximum, minimum, and
# a random observed value:

f = Figure()
ax = Axis(f[1,1])
prx, pry = partialresponse(sdm, 1; inflated=false, threshold=false)
for i in 1:200
ix, iy = partialresponse(sdm, 1; inflated=true, threshold=false)
lines!(ax, ix, iy, color=(:grey, 0.5))
end
lines!(ax, prx, pry, color=:black, linewidth=4)
current_figure() #hide

# ## Measuring uncertainty with bagging

# We can wrap our model into an homogeneous ensemble:
Expand Down
1 change: 1 addition & 0 deletions SDeMo/docs/src/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ MultivariateTransform
```@docs
NaiveBayes
BIOCLIM
DecisionTree
```
4 changes: 4 additions & 0 deletions SDeMo/src/SDeMo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ export NaiveBayes
include("classifiers/bioclim.jl")
export BIOCLIM

# BIOCLIM
include("classifiers/decisiontree.jl")
export DecisionTree

# Bagging and ensembles
include("ensembles/bagging.jl")
export Bagging, outofbag, bootstrap
Expand Down
194 changes: 194 additions & 0 deletions SDeMo/src/classifiers/decisiontree.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
Base.@kwdef mutable struct DecisionNode
parent::Union{DecisionNode, Nothing} = nothing
left::Union{DecisionNode, Nothing} = nothing
right::Union{DecisionNode, Nothing} = nothing
variable::Integer = 0
value::Float64 = 0.0
prediction::Float64 = 0.5
visited::Bool = false
end

function _is_in_node_parent(::Nothing, X)
return [true for _ in axes(X, 2)]
end

function _is_in_node_parent(dn::DecisionNode, X)
if isnothing(dn.parent)
return [true for _ in axes(X, 2)]
end
to_the_left = X[dn.parent.variable, :] .< dn.parent.value
if dn == dn.parent.left
return to_the_left
else
return map(!, to_the_left)
end
end

_pool(::Nothing, X) = _is_in_node_parent(nothing, X)

function _pool(dn::DecisionNode, X)
return _is_in_node_parent(dn, X) .& _pool(dn.parent, X)
end

function train!(dn::DecisionNode, X, y)
v = collect(axes(X, 1))
if dn.visited
return dn
end
dn.visited = true
current_entropy = SDeMo._entropy(y)
dn.prediction = mean(y)
if current_entropy > 0.0
best_gain = -Inf
best_split = (0, 0.0)
found = false
pl, pr = (0.0, 0.0)
for i in eachindex(v)
x = unique(X[v[i], :])
for j in eachindex(x)
left = findall(X[v[i], :] .< x[j])
right = findall(X[v[i], :] .>= x[j])
left_p = length(left) / length(y)
right_p = 1.0 - left_p
left_e = SDeMo._entropy(y[left])
right_e = SDeMo._entropy(y[right])
IG = current_entropy - left_p * left_e - right_p * right_e
if (IG > best_gain) & (IG > 0)
best_gain = IG
best_split = (v[i], x[j])
pl, pr = left_p, right_p
found = true
end
end
end
if found
dn.variable, dn.value = best_split
# New node
vl = isone(pl) .| iszero(pl)
vr = isone(pr) .| iszero(pr)
dn.left = SDeMo.DecisionNode(; parent = dn, prediction = pl, visited = vl)
dn.right = SDeMo.DecisionNode(; parent = dn, prediction = pr, visited = vr)
end
end
return dn
end

"""
DecisionTree
TODO
"""
Base.@kwdef mutable struct DecisionTree <: Classifier
root::DecisionNode = DecisionNode()
maxnodes::Integer = 12
end

tips(::Nothing) = nothing
tips(dt::DecisionTree) = tips(dt.root)
function tips(dn::SDeMo.DecisionNode)
if iszero(dn.variable)
return dn
else
return vcat(tips(dn.left), tips(dn.right))
end
end

depth(dt::DecisionTree) = maximum(depth.(tips(dt)))
depth(dn::DecisionNode) = 1 + depth(dn.parent)
depth(::Nothing) = 0

function merge!(dn::DecisionNode)
dn.variable = 0
dn.value = 0
dn.left = nothing
dn.right = nothing
return dn
end

function _entropy(x::Vector{Bool})
pᵢ = [sum(x), length(x) - sum(x)] ./ length(x)
return -sum(pᵢ .* log2.(pᵢ))
end

function twigs(dt::DecisionTree)
leaves = SDeMo.tips(dt)
leaf_parents = unique([leaf.parent for leaf in leaves])
twig_nodes =
filter(p -> iszero(p.left.variable) & iszero(p.right.variable), leaf_parents)
return twig_nodes
end

function _information_gain(dn::SDeMo.DecisionNode, X, y)
p = findall(SDeMo._pool(dn, X))
pl = [i for i in p if X[dn.variable,i] < dn.value]
pr = setdiff(p, pl)
yl = y[pl]
yr = y[pr]
yt = y[p]
e = SDeMo._entropy(yt)
el = SDeMo._entropy(yl)
er = SDeMo._entropy(yr)
return e - mean(yl) * el - mean(yr) * er
end

function prune!(tree, X, y)
tw = twigs(tree)
wrst = Inf
widx = 0
for i in eachindex(tw)
ef = _information_gain(tw[i], X, y)
if ef < wrst
wrst = ef
widx = i
end
end
SDeMo.merge!(tw[widx])
return tree
end

Base.zero(::Type{DecisionTree}) = 0.5

function train!(
dt::DecisionTree,
y::Vector{Bool},
X::Matrix{T},
) where {T <: Number}
root = SDeMo.DecisionNode()
root.prediction = mean(y)
dt.root = root
train!(dt.root, X, y)
for _ in 1:6
for tip in SDeMo.tips(dt)
p = SDeMo._pool(tip, X)
if !(tip.visited)
train!(tip, X[:, findall(p)], y[findall(p)])
end
end
end

while length(SDeMo.tips(dt)) > dt.maxnodes
prune!(dt, X, y)
end

return dt
end

function StatsAPI.predict(dt::DecisionTree, x::Vector{T}) where {T <: Number}
return predict(dt.root, x)
end

function StatsAPI.predict(dn::DecisionNode, x::Vector{T}) where {T <: Number}
if iszero(dn.variable)
return dn.prediction
else
if x[dn.variable] < dn.value
return predict(dn.left, x)
else
return predict(dn.right, x)
end
end
end

function StatsAPI.predict(dt::DecisionTree, X::Matrix{T}) where {T <: Number}
return vec(mapslices(x -> predict(dt, x), X; dims = 1))
end
7 changes: 6 additions & 1 deletion SDeMo/src/ensembles/bagging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,14 @@ function outofbag(ensemble::Bagging; kwargs...)
for
i in valid_models
]
push!(outcomes, count(pred) > count(pred) // 2)
push!(outcomes, majority(pred))
end
end

return ConfusionMatrix(outcomes, ensemble.model.y[done_instances])
end


majority(pred::Vector{Bool}) = sum(pred) > length(pred) // 2
majority(pred::BitVector) = sum(pred) > length(pred) // 2
export majority
11 changes: 11 additions & 0 deletions SDeMo/src/ensembles/pipeline.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
features(model::Bagging) = features(model.model)
features(model::Ensemble) = features(first(model.models))
features(model::Bagging, n) = features(model.model, n)
features(model::Ensemble, n) = features(first(model.models), n)
variables(model::Bagging) = variables(model.model)
labels(model::Bagging) = labels(model.model)
labels(model::Ensemble) = labels(first(model.models))
threshold(model::Bagging) = 0.5
threshold(model::Ensemble) = 0.5


"""
train!(ensemble::Bagging; kwargs...)
Expand Down
Loading

2 comments on commit 1759419

@tpoisot
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register() subdir=SDeMo

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116217

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a SDeMo-v0.0.3 -m "<description of version>" 1759419a1241bf861b7597090dc40d88a6495aaf
git push origin SDeMo-v0.0.3

Please sign in to comment.