-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SDeMo improvements: inflated response curves, better ensemble support…
…, decision trees (#284) * feat(demo): decision tree * feat: inflated partialresponse * perf: decision tree trains faster * perf: decision tree trains faster * feat: variable importance works with ensembles * feat: partialresponse works with an ensemble * feat: Shapley works with ensemble models * doc: decision tree in the tutorial * feat: ensemble models can be cross-validated * feat: utility functions for ensembles * test: partial resp for ensembles * perf(demo): limit tree depth to 7 always * bug(demo): fix inflated partial response * feat: add variables for the bagging * bug(demo): order of functions in partial response * doc: use decision tree for the demo * semver(demo): v0.0.3 * doc: clarify tutorial * doc: abs needed to map shapley values * bug(sdt): masking of occurrences in nested polygons * doc: fix polygon demo * perf(dem): decision tree training * perf(demo): decision tree is ~ 2x faster * feat(demo): outofbag gains thr keyword * feat(demo): majority function * doc: cleanup sdemo tutorial
- Loading branch information
Showing
15 changed files
with
418 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,4 +18,5 @@ MultivariateTransform | |
```@docs | ||
NaiveBayes | ||
BIOCLIM | ||
DecisionTree | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
Base.@kwdef mutable struct DecisionNode | ||
parent::Union{DecisionNode, Nothing} = nothing | ||
left::Union{DecisionNode, Nothing} = nothing | ||
right::Union{DecisionNode, Nothing} = nothing | ||
variable::Integer = 0 | ||
value::Float64 = 0.0 | ||
prediction::Float64 = 0.5 | ||
visited::Bool = false | ||
end | ||
|
||
function _is_in_node_parent(::Nothing, X) | ||
return [true for _ in axes(X, 2)] | ||
end | ||
|
||
function _is_in_node_parent(dn::DecisionNode, X) | ||
if isnothing(dn.parent) | ||
return [true for _ in axes(X, 2)] | ||
end | ||
to_the_left = X[dn.parent.variable, :] .< dn.parent.value | ||
if dn == dn.parent.left | ||
return to_the_left | ||
else | ||
return map(!, to_the_left) | ||
end | ||
end | ||
|
||
_pool(::Nothing, X) = _is_in_node_parent(nothing, X) | ||
|
||
function _pool(dn::DecisionNode, X) | ||
return _is_in_node_parent(dn, X) .& _pool(dn.parent, X) | ||
end | ||
|
||
function train!(dn::DecisionNode, X, y) | ||
v = collect(axes(X, 1)) | ||
if dn.visited | ||
return dn | ||
end | ||
dn.visited = true | ||
current_entropy = SDeMo._entropy(y) | ||
dn.prediction = mean(y) | ||
if current_entropy > 0.0 | ||
best_gain = -Inf | ||
best_split = (0, 0.0) | ||
found = false | ||
pl, pr = (0.0, 0.0) | ||
for i in eachindex(v) | ||
x = unique(X[v[i], :]) | ||
for j in eachindex(x) | ||
left = findall(X[v[i], :] .< x[j]) | ||
right = findall(X[v[i], :] .>= x[j]) | ||
left_p = length(left) / length(y) | ||
right_p = 1.0 - left_p | ||
left_e = SDeMo._entropy(y[left]) | ||
right_e = SDeMo._entropy(y[right]) | ||
IG = current_entropy - left_p * left_e - right_p * right_e | ||
if (IG > best_gain) & (IG > 0) | ||
best_gain = IG | ||
best_split = (v[i], x[j]) | ||
pl, pr = left_p, right_p | ||
found = true | ||
end | ||
end | ||
end | ||
if found | ||
dn.variable, dn.value = best_split | ||
# New node | ||
vl = isone(pl) .| iszero(pl) | ||
vr = isone(pr) .| iszero(pr) | ||
dn.left = SDeMo.DecisionNode(; parent = dn, prediction = pl, visited = vl) | ||
dn.right = SDeMo.DecisionNode(; parent = dn, prediction = pr, visited = vr) | ||
end | ||
end | ||
return dn | ||
end | ||
|
||
""" | ||
DecisionTree | ||
TODO | ||
""" | ||
Base.@kwdef mutable struct DecisionTree <: Classifier | ||
root::DecisionNode = DecisionNode() | ||
maxnodes::Integer = 12 | ||
end | ||
|
||
tips(::Nothing) = nothing | ||
tips(dt::DecisionTree) = tips(dt.root) | ||
function tips(dn::SDeMo.DecisionNode) | ||
if iszero(dn.variable) | ||
return dn | ||
else | ||
return vcat(tips(dn.left), tips(dn.right)) | ||
end | ||
end | ||
|
||
depth(dt::DecisionTree) = maximum(depth.(tips(dt))) | ||
depth(dn::DecisionNode) = 1 + depth(dn.parent) | ||
depth(::Nothing) = 0 | ||
|
||
function merge!(dn::DecisionNode) | ||
dn.variable = 0 | ||
dn.value = 0 | ||
dn.left = nothing | ||
dn.right = nothing | ||
return dn | ||
end | ||
|
||
function _entropy(x::Vector{Bool}) | ||
pᵢ = [sum(x), length(x) - sum(x)] ./ length(x) | ||
return -sum(pᵢ .* log2.(pᵢ)) | ||
end | ||
|
||
function twigs(dt::DecisionTree) | ||
leaves = SDeMo.tips(dt) | ||
leaf_parents = unique([leaf.parent for leaf in leaves]) | ||
twig_nodes = | ||
filter(p -> iszero(p.left.variable) & iszero(p.right.variable), leaf_parents) | ||
return twig_nodes | ||
end | ||
|
||
function _information_gain(dn::SDeMo.DecisionNode, X, y) | ||
p = findall(SDeMo._pool(dn, X)) | ||
pl = [i for i in p if X[dn.variable,i] < dn.value] | ||
pr = setdiff(p, pl) | ||
yl = y[pl] | ||
yr = y[pr] | ||
yt = y[p] | ||
e = SDeMo._entropy(yt) | ||
el = SDeMo._entropy(yl) | ||
er = SDeMo._entropy(yr) | ||
return e - mean(yl) * el - mean(yr) * er | ||
end | ||
|
||
function prune!(tree, X, y) | ||
tw = twigs(tree) | ||
wrst = Inf | ||
widx = 0 | ||
for i in eachindex(tw) | ||
ef = _information_gain(tw[i], X, y) | ||
if ef < wrst | ||
wrst = ef | ||
widx = i | ||
end | ||
end | ||
SDeMo.merge!(tw[widx]) | ||
return tree | ||
end | ||
|
||
Base.zero(::Type{DecisionTree}) = 0.5 | ||
|
||
function train!( | ||
dt::DecisionTree, | ||
y::Vector{Bool}, | ||
X::Matrix{T}, | ||
) where {T <: Number} | ||
root = SDeMo.DecisionNode() | ||
root.prediction = mean(y) | ||
dt.root = root | ||
train!(dt.root, X, y) | ||
for _ in 1:6 | ||
for tip in SDeMo.tips(dt) | ||
p = SDeMo._pool(tip, X) | ||
if !(tip.visited) | ||
train!(tip, X[:, findall(p)], y[findall(p)]) | ||
end | ||
end | ||
end | ||
|
||
while length(SDeMo.tips(dt)) > dt.maxnodes | ||
prune!(dt, X, y) | ||
end | ||
|
||
return dt | ||
end | ||
|
||
function StatsAPI.predict(dt::DecisionTree, x::Vector{T}) where {T <: Number} | ||
return predict(dt.root, x) | ||
end | ||
|
||
function StatsAPI.predict(dn::DecisionNode, x::Vector{T}) where {T <: Number} | ||
if iszero(dn.variable) | ||
return dn.prediction | ||
else | ||
if x[dn.variable] < dn.value | ||
return predict(dn.left, x) | ||
else | ||
return predict(dn.right, x) | ||
end | ||
end | ||
end | ||
|
||
function StatsAPI.predict(dt::DecisionTree, X::Matrix{T}) where {T <: Number} | ||
return vec(mapslices(x -> predict(dt, x), X; dims = 1)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
1759419
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register() subdir=SDeMo
1759419
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/116217
Tip: Release Notes
Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.
To add them here just re-invoke and the PR will be updated.
Tagging
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: