Skip to content

Commit

Permalink
add class_labels to ClassRow (#78)
Browse files Browse the repository at this point in the history
* update rows

* add class_labels to classrow

* fix tests

* this is a breaking version right?

* Update Project.toml

Co-authored-by: Eric Hanson <5846501+ericphanson@users.noreply.github.com>

* unbump schema

* formatting

* Update test/row.jl

Co-authored-by: Eric Hanson <5846501+ericphanson@users.noreply.github.com>

* Update src/row.jl

Co-authored-by: Eric Hanson <5846501+ericphanson@users.noreply.github.com>

* coalesce class labels

* Update test/row.jl

Co-authored-by: Eric Hanson <5846501+ericphanson@users.noreply.github.com>

* test for tradeoff metrics

* start adding in tests

Co-authored-by: Eric Hanson <5846501+ericphanson@users.noreply.github.com>
  • Loading branch information
josephsdavid and ericphanson authored Jun 3, 2022
1 parent fc48b0d commit 0540cdd
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 97 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lighthouse"
uuid = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59"
authors = ["Beacon Biosignals, Inc."]
version = "0.14.9"
version = "0.14.10"

[deps]
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
Expand Down
113 changes: 50 additions & 63 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,22 +196,21 @@ end

"""
get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_index;
thresholds, binarize=binarize_by_threshold)
thresholds, binarize=binarize_by_threshold, class_labels=missing)
Return [`TradeoffMetricsRow`] calculated for the given `class_index`, with the following
fields guaranteed to be non-missing: `roc_curve`, `roc_auc`, pr_curve`,
`reliability_calibration_curve`, `reliability_calibration_score`.` $(BINARIZE_NOTE)
(`class_index`).
"""
function get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_index;
thresholds, binarize=binarize_by_threshold)
stats = per_threshold_confusion_statistics(predicted_soft_labels,
elected_hard_labels, thresholds,
class_index; binarize)
thresholds, binarize=binarize_by_threshold,
class_labels=missing)
stats = per_threshold_confusion_statistics(predicted_soft_labels, elected_hard_labels,
thresholds, class_index; binarize)
roc_curve = (map(t -> t.false_positive_rate, stats),
map(t -> t.true_positive_rate, stats))
pr_curve = (map(t -> t.true_positive_rate, stats),
map(t -> t.precision, stats))
pr_curve = (map(t -> t.true_positive_rate, stats), map(t -> t.precision, stats))

class_probabilities = view(predicted_soft_labels, :, class_index)
reliability_calibration = calibration_curve(class_probabilities,
Expand All @@ -220,118 +219,117 @@ function get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_
reliability_calibration.fractions)
reliability_calibration_score = reliability_calibration.mean_squared_error

return TradeoffMetricsRow(; class_index, roc_curve,
roc_auc=area_under_curve(roc_curve...),
pr_curve, reliability_calibration_curve,
reliability_calibration_score)
return TradeoffMetricsRow(; class_index, class_labels, roc_curve,
roc_auc=area_under_curve(roc_curve...), pr_curve,
reliability_calibration_curve, reliability_calibration_score)
end

"""
get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels, class_index;
thresholds, binarize=binarize_by_threshold)
thresholds, binarize=binarize_by_threshold, class_labels=missing)
Return [`TradeoffMetricsRow`] calculated for the given `class_index`. In addition
to metrics calculated by [`get_tradeoff_metrics`](@ref), additionally calculates
`spearman_correlation`-based metrics. $(BINARIZE_NOTE) (`class_index`).
"""
function get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels,
votes, class_index; thresholds,
binarize=binarize_by_threshold)
binarize=binarize_by_threshold,
class_labels=missing)
basic_row = get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels,
class_index; thresholds, binarize)
class_index; thresholds, binarize, class_labels)
corr = _calculate_spearman_correlation(predicted_soft_labels, votes)
row = Tables.rowmerge(basic_row,
(;
spearman_correlation=corr.ρ,
(; spearman_correlation=corr.ρ,
spearman_correlation_ci_upper=corr.ci_upper,
spearman_correlation_ci_lower=corr.ci_lower,
n_samples=corr.n))
spearman_correlation_ci_lower=corr.ci_lower, n_samples=corr.n))
return TradeoffMetricsRow(; row...)
end

"""
get_hardened_metrics(predicted_hard_labels, elected_hard_labels, class_index;
thresholds)
class_labels=missing)
Return [`HardenedMetricsRow`] calculated for the given `class_index`, with the following
field guaranteed to be non-missing: expert-algorithm agreement (`ea_kappa`).
"""
function get_hardened_metrics(predicted_hard_labels, elected_hard_labels, class_index)
return HardenedMetricsRow(; class_index,
function get_hardened_metrics(predicted_hard_labels, elected_hard_labels, class_index;
class_labels=missing)
return HardenedMetricsRow(; class_index, class_labels,
ea_kappa=_calculate_ea_kappa(predicted_hard_labels,
elected_hard_labels,
class_index))
end

"""
get_hardened_metrics_multirater(predicted_hard_labels, elected_hard_labels, class_index;
thresholds)
class_labels=missing)
Return [`HardenedMetricsRow`] calculated for the given `class_index`. In addition
to metrics calculated by [`get_hardened_metrics`](@ref), additionally calculates
`discrimination_calibration_curve` and `discrimination_calibration_score`.
"""
function get_hardened_metrics_multirater(predicted_hard_labels, elected_hard_labels,
votes, class_index)
function get_hardened_metrics_multirater(predicted_hard_labels, elected_hard_labels, votes,
class_index; class_labels=missing)
basic_row = get_hardened_metrics(predicted_hard_labels, elected_hard_labels,
class_index)
class_index; class_labels)
cal = _calculate_discrimination_calibration(predicted_hard_labels, votes;
class_of_interest_index=class_index)
row = Tables.rowmerge(basic_row,
(;
discrimination_calibration_curve=cal.plot_curve_data,
(; discrimination_calibration_curve=cal.plot_curve_data,
discrimination_calibration_score=cal.mse))
return HardenedMetricsRow(; row...)
end

"""
get_hardened_metrics_multiclass(predicted_hard_labels, elected_hard_labels,
class_count)
class_count; class_labels=missing)
Return [`HardenedMetricsRow`] calculated over all `class_count` classes. Calculates
expert-algorithm agreement (`ea_kappa`) over all classes, as well as the multiclass
`confusion_matrix`.
"""
function get_hardened_metrics_multiclass(predicted_hard_labels, elected_hard_labels,
class_count)
class_count; class_labels=missing)
ea_kappa = first(cohens_kappa(class_count,
zip(predicted_hard_labels, elected_hard_labels)))
return HardenedMetricsRow(; class_index=:multiclass,
return HardenedMetricsRow(; class_index=:multiclass, class_labels,
confusion_matrix=confusion_matrix(class_count,
zip(predicted_hard_labels,
elected_hard_labels)),
ea_kappa)
end

"""
get_label_metrics_multirater(votes, class_index)
get_label_metrics_multirater(votes, class_index; class_labels=missing)
Return [`LabelMetricsRow`] calculated for the given `class_index`, with the following
field guaranteed to be non-missing: `per_expert_discrimination_calibration_curves`,
`per_expert_discrimination_calibration_scores`, interrater-agreement (`ira_kappa`).
"""
function get_label_metrics_multirater(votes, class_index)
function get_label_metrics_multirater(votes, class_index; class_labels=missing)
size(votes, 2) > 1 ||
throw(ArgumentError("Input `votes` is not multirater (`size(votes) == $(size(votes))`)"))
expert_cal = _calculate_voter_discrimination_calibration(votes;
class_of_interest_index=class_index)
per_expert_discrimination_calibration_curves = expert_cal.plot_curve_data
per_expert_discrimination_calibration_scores = expert_cal.mse
return LabelMetricsRow(; class_index, per_expert_discrimination_calibration_curves,
return LabelMetricsRow(; class_index, class_labels,
per_expert_discrimination_calibration_curves,
per_expert_discrimination_calibration_scores,
ira_kappa=_calculate_ira_kappa(votes, class_index))
end

"""
get_label_metrics_multirater_multiclass(votes, class_count)
get_label_metrics_multirater_multiclass(votes, class_count; class_labels=missing)
Return [`LabelMetricsRow`] calculated over all `class_count` classes. Calculates
the multiclass interrater agreement (`ira_kappa`).
"""
function get_label_metrics_multirater_multiclass(votes, class_count)
function get_label_metrics_multirater_multiclass(votes, class_count; class_labels=missing)
size(votes, 2) > 1 ||
throw(ArgumentError("Input `votes` is not multirater (`size(votes) == $(size(votes))`)"))
return LabelMetricsRow(; class_index=:multiclass,
return LabelMetricsRow(; class_index=:multiclass, class_labels,
ira_kappa=_calculate_ira_kappa_multiclass(votes, class_count))
end

Expand Down Expand Up @@ -420,10 +418,8 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
elected_hard_labels::AbstractVector, classes,
thresholds=0.0:0.01:1.0;
votes::Union{Nothing,Missing,AbstractMatrix}=nothing,
strata::Union{Nothing,
AbstractVector{Set{T}} where T}=nothing,
optimal_threshold_class::Union{Missing,Nothing,
Integer}=missing,
strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing,
optimal_threshold_class::Union{Missing,Nothing,Integer}=missing,
binarize=binarize_by_threshold)
class_labels = string.(collect(classes)) # Plots.jl expects this to be an `AbstractVector`
class_indices = 1:length(classes)
Expand All @@ -433,15 +429,12 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
# so do that here as well.
tradeoff_metrics_rows = if length(classes) == 2 && has_value(votes)
map(ic -> get_tradeoff_metrics_binary_multirater(predicted_soft_labels,
elected_hard_labels, votes,
ic;
elected_hard_labels, votes, ic;
thresholds, binarize),
class_indices)
else
map(ic -> get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels,
ic;
thresholds, binarize),
class_indices)
map(ic -> get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, ic;
thresholds, binarize), class_indices)
end

# Step 2a: Choose optimal threshold and use it to harden predictions
Expand All @@ -455,8 +448,7 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
optimal_threshold = cal.threshold
elseif has_value(optimal_threshold_class)
roc_curve = tradeoff_metrics_rows[findfirst(==(optimal_threshold_class),
tradeoff_metrics_rows.classes),
:]
tradeoff_metrics_rows.classes), :]
optimal_threshold = _get_optimal_threshold_from_ROC(roc_curve, thresholds)
else
@warn "Not selecting and/or using optimal threshold; using `predicted_hard_labels` provided by default"
Expand All @@ -480,8 +472,7 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
elected_hard_labels, votes,
class_index), class_indices)
else
map(class_index -> get_hardened_metrics(predicted_hard_labels,
elected_hard_labels,
map(class_index -> get_hardened_metrics(predicted_hard_labels, elected_hard_labels,
class_index), class_indices)
end
hardened_metrics_table = vcat(hardened_metrics_table,
Expand All @@ -505,8 +496,7 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
stratified_kappas = has_value(strata) ?
_calculate_stratified_ea_kappas(predicted_hard_labels,
elected_hard_labels,
length(classes),
strata) : missing
length(classes), strata) : missing

return _evaluation_row(tradeoff_metrics_rows, hardened_metrics_table,
labels_metrics_table; optimal_threshold_class, class_labels,
Expand Down Expand Up @@ -551,9 +541,8 @@ to support [`evaluation_metrics_row`](@ref):
- `label_metrics_table`: table of [`LabelMetricsRow`](@ref)s
"""
function _evaluation_row(tradeoff_metrics_table, hardened_metrics_table,
label_metrics_table;
optimal_threshold_class=missing, class_labels, thresholds,
optimal_threshold, stratified_kappas=missing)
label_metrics_table; optimal_threshold_class=missing, class_labels,
thresholds, optimal_threshold, stratified_kappas=missing)
tradeoff_rows, _ = _split_classes_from_multiclass(tradeoff_metrics_table)
hardened_rows, hardened_multi = _split_classes_from_multiclass(hardened_metrics_table)
label_rows, labels_multi = _split_classes_from_multiclass(label_metrics_table)
Expand Down Expand Up @@ -614,8 +603,7 @@ function _evaluation_row(tradeoff_metrics_table, hardened_metrics_table,

# from label_metrics_table
per_expert_discrimination_calibration_curves,
multiclass_IRA_kappas,
per_class_IRA_kappas,
multiclass_IRA_kappas, per_class_IRA_kappas,
per_expert_discrimination_calibration_scores,

# from kwargs:
Expand Down Expand Up @@ -745,8 +733,7 @@ function _calculate_ira_kappa(votes, class_index)
hard_label_pairs = _prep_hard_label_pairs(votes)
length(hard_label_pairs) == 0 && return missing
CLASS_VS_ALL_CLASS_COUNT = 2
class_v_other_hard_label_pair = map(row -> 1 .+ (row .== class_index),
hard_label_pairs)
class_v_other_hard_label_pair = map(row -> 1 .+ (row .== class_index), hard_label_pairs)
return first(cohens_kappa(CLASS_VS_ALL_CLASS_COUNT, class_v_other_hard_label_pair))
end

Expand Down Expand Up @@ -818,8 +805,8 @@ function _calculate_optimal_threshold_from_discrimination_calibration(predicted_
bin_count = min(size(votes, 2) + 1, 10)
per_threshold_curves = map(thresholds) do thresh
pred_soft = view(predicted_soft_labels, :, class_of_interest_index)
return calibration_curve(elected_probabilities,
binarize.(pred_soft, thresh); bin_count=bin_count)
return calibration_curve(elected_probabilities, binarize.(pred_soft, thresh);
bin_count=bin_count)
end
i_min = argmin([c.mean_squared_error for c in per_threshold_curves])
curve = per_threshold_curves[i_min]
Expand Down Expand Up @@ -900,8 +887,8 @@ function per_class_confusion_statistics(predicted_soft_labels::AbstractMatrix,
class_count = size(predicted_soft_labels, 2)
return map(1:class_count) do i
return per_threshold_confusion_statistics(predicted_soft_labels,
elected_hard_labels,
thresholds, i; binarize)
elected_hard_labels, thresholds, i;
binarize)
end
end

Expand Down
Loading

2 comments on commit 0540cdd

@josephsdavid
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61680

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.10 -m "<description of version>" 0540cdd5fcfc7a95cdd0b8c080febc42da736dbe
git push origin v0.14.10

Please sign in to comment.