Skip to content

Commit

Permalink
SDeMo bugfix and QOL (#306)
Browse files Browse the repository at this point in the history
* bug(demo): CV with a consensus argument

* semver(demo): v0.0.6

* feat(demo): validation measures work on vectors of CM

* doc(demo): show CM on vectors

* feat(demo): cv / train on bagged models

* doc(demo): fix docstrings
  • Loading branch information
tpoisot authored Oct 15, 2024
1 parent 6cb2461 commit 5abe9a6
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 48 deletions.
2 changes: 1 addition & 1 deletion SDeMo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SDeMo"
uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467"
authors = ["Timothée Poisot <timothee.poisot@umontreal.ca>"]
version = "0.0.5"
version = "0.0.6"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
54 changes: 31 additions & 23 deletions SDeMo/docs/src/demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ size(X)

# ## Setting up the model

## We will start with an initial model that uses a PCA to transform the data, and
# We will start with an initial model that uses a PCA to transform the data, and
# then a Naive Bayes Classifier for the classification. Note that this is the
# partial syntax where we use the default threshold, and all the variables:

Expand All @@ -53,11 +53,11 @@ cv = crossvalidate(sdm, folds);
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [mean(measure.(set)) for measure in measures, set in cv]
pretty_table(
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
)

# Assuming we want to get a simple idea of what the MCC is for the validation
Expand All @@ -69,6 +69,14 @@ mcc.(cv.validation)

ci(cv.validation, mcc)

# We can also get the same output by calling a function on a vector of `ConfusionMatrix`, *e.g.*

mcc(cv.validation)

# Adding the `true` argument returns a tuple with the 95% CI:

mcc(cv.validation, true)

# ## Variable selection

# We will now select variables using forward selection, but with the added
Expand All @@ -87,11 +95,11 @@ cv2 = crossvalidate(sdm, folds)
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [mean(measure.(set)) for measure in measures, set in cv2]
pretty_table(
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
)

# Quite clearly! Before thinking about the relative importance of variables, we
Expand Down Expand Up @@ -136,11 +144,11 @@ varimp = variableimportance(sdm, folds)
# In relative terms, this is:

pretty_table(
hcat(variables(sdm), varimp ./ sum(varimp));
alignment=[:l, :c],
backend=Val(:markdown),
header=["Variable", "Importance"],
formatters=(ft_printf("%5.3f", 2), ft_printf("%d", 1))
hcat(variables(sdm), varimp ./ sum(varimp));
alignment=[:l, :c],
backend=Val(:markdown),
header=["Variable", "Importance"],
formatters=(ft_printf("%5.3f", 2), ft_printf("%d", 1))
)

# ## Partial response curve
Expand Down Expand Up @@ -184,8 +192,8 @@ f = Figure()
ax = Axis(f[1, 1])
prx, pry = partialresponse(sdm, 1; inflated=false, threshold=false)
for i in 1:200
ix, iy = partialresponse(sdm, 1; inflated=true, threshold=false)
lines!(ax, ix, iy, color=(:grey, 0.5))
ix, iy = partialresponse(sdm, 1; inflated=true, threshold=false)
lines!(ax, ix, iy, color=(:grey, 0.5))
end
lines!(ax, prx, pry, color=:black, linewidth=4)
current_figure() #hide
Expand Down Expand Up @@ -279,11 +287,11 @@ cf = counterfactual(sdm, instance(sdm, inst; strict=false), target, 200.0; thres
# is:

pretty_table(
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm)]);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Variable", "Obs.", "Counterf."],
formatters=(ft_printf("%4.1f", [2, 3]), ft_printf("%d", 1))
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm)]);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Variable", "Obs.", "Counterf."],
formatters=(ft_printf("%4.1f", [2, 3]), ft_printf("%d", 1))
)

# We can check the prediction that would be made on the counterfactual:
Expand Down
59 changes: 38 additions & 21 deletions SDeMo/src/crossvalidation/crossvalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,15 @@ function montecarlo(y, X; n = 100, kwargs...)
return [holdout(y, X; kwargs...) for _ in 1:n]
end


@testitem "We can do montecarlo validation" begin
X, y = SDeMo.__demodata()
model = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
folds = montecarlo(model; n=10)
folds = montecarlo(model; n = 10)
cv = crossvalidate(model, folds)
@test eltype(cv.validation) <: ConfusionMatrix
@test length(cv.training) == 10
end


"""
kfold(y, X; k = 10, permute = true)
Expand Down Expand Up @@ -113,32 +111,39 @@ function kfold(y, X; k = 10, permute = true)
return folds
end


@testitem "We can do kfold validation" begin
X, y = SDeMo.__demodata()
model = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
folds = kfold(model; k=12)
folds = kfold(model; k = 12)
cv = crossvalidate(model, folds)
@test eltype(cv.validation) <: ConfusionMatrix
@test length(cv.training) == 12
end


for op in (:leaveoneout, :holdout, :montecarlo, :kfold)
eval(quote
"""
$($op)(sdm::SDM)
Version of `$($op)` using the instances and labels of an SDM.
"""
$op(sdm::SDM, args...; kwargs...) = $op(labels(sdm), features(sdm), args...; kwargs...)
end)
eval(
quote
"""
$($op)(sdm::SDM)
Version of `$($op)` using the instances and labels of an SDM.
"""
$op(sdm::SDM, args...; kwargs...) =
$op(labels(sdm), features(sdm), args...; kwargs...)
"""
$($op)(sdm::Bagging)
Version of `$($op)` using the instances and labels of a bagged SDM. In this case, the instances of the model used as a reference to build the bagged model are used.
"""
$op(sdm::Bagging, args...; kwargs...) = $op(sdm.model, args...; kwargs...)
end,
)
end

@testitem "We can split data in an SDM" begin
X, y = SDeMo.__demodata()
sdm = SDM(MultivariateTransform{PCA}(), BIOCLIM(), 0.01, X, y, 1:size(X, 1))
folds = montecarlo(sdm; n=10)
folds = montecarlo(sdm; n = 10)
@test length(folds) == 10
end

Expand All @@ -153,27 +158,39 @@ This method returns two vectors of `ConfusionMatrix`, with the confusion matrix
for each set of validation data first, and the confusion matrix for the training
data second.
"""
function crossvalidate(sdm, folds; thr = nothing, kwargs...)
function crossvalidate(sdm::T, folds; thr = nothing, kwargs...) where {T <: AbstractSDM}
Cv = zeros(ConfusionMatrix, length(folds))
Ct = zeros(ConfusionMatrix, length(folds))
models = [deepcopy(sdm) for _ in Base.OneTo(Threads.nthreads())]
Threads.@threads for i in eachindex(folds)
trn, val = folds[i]
train!(models[Threads.threadid()]; training = trn, kwargs...)
pred = predict(models[Threads.threadid()], features(sdm)[:, val]; threshold = false)
ontrn = predict(models[Threads.threadid()], features(sdm)[:, trn]; threshold = false)
ontrn =
predict(models[Threads.threadid()], features(sdm)[:, trn]; threshold = false)
thr = isnothing(thr) ? threshold(sdm) : thr
Cv[i] = ConfusionMatrix(pred, labels(sdm)[val], thr)
Ct[i] = ConfusionMatrix(ontrn, labels(sdm)[trn], thr)
end
return (validation = Cv, training = Ct)
end

@testitem "We can crossvalidate an SDM" begin
@testitem "We can cross-validate an SDM" begin
X, y = SDeMo.__demodata()
sdm = SDM(MultivariateTransform{PCA}(), BIOCLIM(), 0.5, X, y, [1,2,12])
sdm = SDM(MultivariateTransform{PCA}(), BIOCLIM(), 0.5, X, y, [1, 2, 12])
train!(sdm)
cv = crossvalidate(sdm, kfold(sdm; k=15))
cv = crossvalidate(sdm, kfold(sdm; k = 15))
@test eltype(cv.validation) <: ConfusionMatrix
@test eltype(cv.training) <: ConfusionMatrix
end

@testitem "We can cross-validate an ensemble model using the consensus keyword" begin
using Statistics
X, y = SDeMo.__demodata()
sdm = SDM(MultivariateTransform{PCA}(), NaiveBayes(), 0.5, X, y, [1, 2, 12])
ens = Bagging(sdm, 10)
train!(ens)
cv = crossvalidate(ens, kfold(ens; k = 15); consensus = median)
@test eltype(cv.validation) <: ConfusionMatrix
@test eltype(cv.training) <: ConfusionMatrix
end
end
51 changes: 50 additions & 1 deletion SDeMo/src/crossvalidation/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,53 @@ specificity(M::ConfusionMatrix) = tnr(M)
Alias for `ppv`, the positive predictive value
"""
precision(M::ConfusionMatrix) = ppv(M)
precision(M::ConfusionMatrix) = ppv(M)

for op in (
:tpr,
:tnr,
:fpr,
:fnr,
:ppv,
:npv,
:fdir,
:fomr,
:plr,
:nlr,
:accuracy,
:balancedaccuracy,
:f1,
:fscore,
:trueskill,
:markedness,
:dor,
,
:mcc,
:specificity,
:sensitivity,
:recall,
:precision,
)
eval(
quote
"""
$($op)(C::Vector{ConfusionMatrix}, full::Bool=false)
Version of `$($op)` using a vector of confusion matrices. Returns the mean, and when the second argument is `true`, returns a tuple where the second argument is the CI.
"""
function $op(
C::Vector{ConfusionMatrix},
full::Bool = false,
args...;
kwargs...,
)
m = $op.(C, args...; kwargs...)
if full
return (mean(m), ci(C, $op))
else
return mean(m)
end
end
end,
)
end
7 changes: 5 additions & 2 deletions SDeMo/src/ensembles/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ Trains all the model in an ensemble model - the keyword arguments are passed to
includes the transformers.
"""
function train!(ensemble::Bagging; kwargs...)
# The ensemble model can be given a consensus argument, in which can we drop it for
# training as it's relevant for prediction only
trainargs = filter(kw -> kw.first != :consensus, kwargs)
Threads.@threads for m in eachindex(ensemble.models)
train!(ensemble.models[m]; training = ensemble.bags[m][1], kwargs...)
train!(ensemble.models[m]; training = ensemble.bags[m][1], trainargs...)
end
train!(ensemble.model; kwargs...)
train!(ensemble.model; trainargs...)
return ensemble
end

Expand Down

2 comments on commit 5abe9a6

@tpoisot
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register() subdir=SDeMo

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117316

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a SDeMo-v0.0.6 -m "<description of version>" 5abe9a6b037d9cfc47f7bf9ee9f8b436606739cd
git push origin SDeMo-v0.0.6

Please sign in to comment.