diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ea42e1a..6bdd908 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: version: - - '1.5' + - '1.7' - '1.6' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 3ffcbf2..93d50af 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "Lighthouse" uuid = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59" authors = ["Beacon Biosignals, Inc."] -version = "0.13.4" +version = "0.14.0" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" @@ -16,15 +17,18 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" [compat] CairoMakie = "0.7" +Legolas = "0.3" Makie = "0.16.5" StatsBase = "0.33" +Tables = "1.7" TensorBoardLogger = "0.1" -julia = "1.5" +julia = "1.6" [extras] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "CairoMakie", "StableRNGs"] +test = ["Test", "CairoMakie", "StableRNGs", "Tables"] diff --git a/docs/src/index.md b/docs/src/index.md index 36e6b76..3ccffbf 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -38,7 +38,10 @@ accuracy binary_statistics cohens_kappa calibration_curve +EvaluationRow Lighthouse.evaluation_metrics +Lighthouse._evaluation_row_dict +Lighthouse.evaluation_metrics_row ``` ## Utilities diff --git a/src/Lighthouse.jl b/src/Lighthouse.jl index 1ffacb1..6ed6c3a 100644 --- a/src/Lighthouse.jl +++ b/src/Lighthouse.jl @@ -6,6 +6,7 @@ using StatsBase: StatsBase using TensorBoardLogger using Makie using Printf +using Legolas include("plotting.jl") @@ -18,6 +19,9 @@ export confusion_matrix, accuracy, binary_statistics, cohens_kappa, calibration_ include("classifier.jl") export AbstractClassifier +include("row.jl") +export EvaluationRow + include("learn.jl") export LearnLogger, learn!, upon, evaluate!, predict! diff --git a/src/learn.jl b/src/learn.jl index 8beb8d6..6b539ef 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -211,8 +211,7 @@ function evaluate!(predicted_hard_labels::AbstractVector, _validate_threshold_class(optimal_threshold_class, classes) log_resource_info!(logger, logger_prefix; suffix=logger_suffix) do - plot_data = evaluation_metrics(predicted_hard_labels, - predicted_soft_labels, + plot_data = evaluation_metrics(predicted_hard_labels, predicted_soft_labels, elected_hard_labels, classes, thresholds; votes=votes, optimal_threshold_class=optimal_threshold_class) @@ -226,8 +225,6 @@ function evaluate!(predicted_hard_labels::AbstractVector, return nothing end - - function _calculate_stratified_ea_kappas(predicted_hard_labels, elected_hard_labels, class_count, strata) groups = reduce(∪, strata) @@ -238,17 +235,19 @@ function _calculate_stratified_ea_kappas(predicted_hard_labels, elected_hard_lab elected = elected_hard_labels[index] k = _calculate_ea_kappas(predicted, elected, class_count) push!(kappas, - group => (per_class=k.per_class, multiclass=k.multiclass, n=sum(index))) + group => (per_class=k.per_class_kappas, multiclass=k.multiclass_kappa, + n=sum(index))) end - return sort(kappas; by=p -> last(p).multiclass) + kappas = sort(kappas; by=p -> last(p).multiclass) + return [k = v for (k, v) in kappas] end """ _calculate_ea_kappas(predicted_hard_labels, elected_hard_labels, classes) -Return `NamedTuple` with keys `:per_class`, `:multiclass` containing the Cohen's +Return `NamedTuple` with keys `:per_class_kappas`, `:multiclass_kappa` containing the Cohen's Kappa per-class and over all classes, respectively. The value of output key -`:per_class` is an `Array` such that item `i` is the Cohen's kappa calculated +`:per_class_kappas` is an `Array` such that item `i` is the Cohen's kappa calculated for class `i`. Where... @@ -272,15 +271,15 @@ function _calculate_ea_kappas(predicted_hard_labels, elected_hard_labels, class_ elected = ((label == class_index) + 1 for label in elected_hard_labels) return first(cohens_kappa(CLASS_VS_ALL_CLASS_COUNT, zip(predicted, elected))) end - return (per_class=per_class, multiclass=multiclass) + return (per_class_kappas=per_class, multiclass_kappa=multiclass) end """ _calculate_ira_kappas(votes, classes) -Return `NamedTuple` with keys `:per_class`, `:multiclass` containing the Cohen's +Return `NamedTuple` with keys `:per_class_IRA_kappas`, `:multiclass_IRA_kappas` containing the Cohen's Kappa for inter-rater agreement (IRA) per-class and over all classes, respectively. -The value of output key `:per_class` is an `Array` such that item `i` is the +The value of output key `:per_class_IRA_kappas` is an `Array` such that item `i` is the IRA kappa calculated for class `i`. Where... @@ -292,12 +291,14 @@ Where... - `classes` all possible classes voted on. -Returns `nothing` if `votes` has only a single voter (i.e., a single column) or if +Returns `(per_class_IRA_kappas=missing, multiclass_IRA_kappas=missing)` if `votes` has only a single voter (i.e., a single column) or if no two voters rated the same sample. Note that vote entries of `0` are taken to mean that the voter did not rate that sample. """ function _calculate_ira_kappas(votes, classes) - (isnothing(votes) || size(votes, 2) < 2) && return nothing # no votes given or only one expert + # no votes given or only one expert: + (isnothing(votes) || size(votes, 2) < 2) && + return (; per_class_IRA_kappas=missing, multiclass_IRA_kappas=missing) all_hard_label_pairs = Array{Int}(undef, 0, 2) num_voters = size(votes, 2) @@ -307,7 +308,8 @@ function _calculate_ira_kappas(votes, classes) end end hard_label_pairs = filter(row -> all(row .!= 0), collect(eachrow(all_hard_label_pairs))) - length(hard_label_pairs) > 0 || return nothing # No common observations voted on + length(hard_label_pairs) > 0 || + return (; per_class_IRA_kappas=missing, multiclass_IRA_kappas=missing) # No common observations voted on length(hard_label_pairs) < 10 && @warn "...only $(length(hard_label_pairs)) in common, potentially questionable IRA results" @@ -319,7 +321,7 @@ function _calculate_ira_kappas(votes, classes) hard_label_pairs) return first(cohens_kappa(CLASS_VS_ALL_CLASS_COUNT, class_v_other_hard_label_pair)) end - return (per_class=per_class_ira, multiclass=multiclass_ira) + return (; per_class_IRA_kappas=per_class_ira, multiclass_IRA_kappas=multiclass_ira) end function _spearman_corr(predicted_soft_labels, elected_soft_labels) @@ -368,7 +370,6 @@ Where... function _calculate_spearman_correlation(predicted_soft_labels, votes, classes) length(classes) > 2 && throw(ArgumentError("Only valid for 2-class problems")) if !all(x -> x ≈ 1, sum(predicted_soft_labels; dims=2)) - @info predicted_soft_labels throw(ArgumentError("Input probabiliities fail softmax assumption")) end @@ -429,7 +430,7 @@ function _get_optimal_threshold_from_ROC(per_class_roc_curves; thresholds, opt_point = nothing threshold_idx = 1 for point in zip(per_class_roc_curves[class_of_interest_index][1], - per_class_roc_curves[class_of_interest_index][2]) + per_class_roc_curves[class_of_interest_index][2]) d = dist((0, 1), point) if d < min min = d @@ -442,7 +443,9 @@ function _get_optimal_threshold_from_ROC(per_class_roc_curves; thresholds, end function _validate_threshold_class(optimal_threshold_class, classes) - isnothing(optimal_threshold_class) && return nothing + if ismissing(optimal_threshold_class) || isnothing(optimal_threshold_class) + return nothing + end length(classes) == 2 || throw(ArgumentError("Only valid for binary classification problems")) optimal_threshold_class in Set([1, 2]) || @@ -451,16 +454,16 @@ function _validate_threshold_class(optimal_threshold_class, classes) end """ - evaluation_metrics(predicted_hard_labels::AbstractVector, - predicted_soft_labels::AbstractMatrix, - elected_hard_labels::AbstractVector, - classes, - thresholds=0.0:0.01:1.0; - votes::Union{Nothing,AbstractMatrix}=nothing, - strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, - optimal_threshold_class::Union{Nothing,Integer}=nothing) - -Returns dictionary containing a battery of classifier performance + evaluation_metrics_row(predicted_hard_labels::AbstractVector, + predicted_soft_labels::AbstractMatrix, + elected_hard_labels::AbstractVector, + classes, + thresholds=0.0:0.01:1.0; + votes::Union{Nothing,AbstractMatrix}=nothing, + strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, + optimal_threshold_class::Union{Nothing,Integer}=nothing) + +Returns `EvaluationRow` containing a battery of classifier performance metrics that each compare `predicted_soft_labels` and/or `predicted_hard_labels` agaist `elected_hard_labels`. @@ -495,12 +498,12 @@ Where... See also [`evaluation_metrics_plot`](@ref). """ -function evaluation_metrics(predicted_hard_labels::AbstractVector, - predicted_soft_labels::AbstractMatrix, - elected_hard_labels::AbstractVector, classes, thresholds; - votes::Union{Nothing,AbstractMatrix}=nothing, - strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, - optimal_threshold_class::Union{Nothing,Integer}=nothing) +function evaluation_metrics_row(predicted_hard_labels::AbstractVector, + predicted_soft_labels::AbstractMatrix, + elected_hard_labels::AbstractVector, classes, thresholds; + votes::Union{Nothing,AbstractMatrix}=nothing, + strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, + optimal_threshold_class::Union{Missing,Integer}=missing) _validate_threshold_class(optimal_threshold_class, classes) class_count = length(classes) @@ -508,22 +511,15 @@ function evaluation_metrics(predicted_hard_labels::AbstractVector, class_labels = string.(class_vector) per_class_stats = per_class_confusion_statistics(predicted_soft_labels, elected_hard_labels, thresholds) - plot_dict = Dict() - plot_dict["class_labels"] = class_labels - plot_dict["thresholds"] = thresholds # ROC curves - plot_dict["per_class_roc_curves"] = [(map(t -> t.false_positive_rate, stats), - map(t -> t.true_positive_rate, stats)) - for stats in per_class_stats] - plot_dict["per_class_roc_aucs"] = [area_under_curve(x, y) - for (x, y) in plot_dict["per_class_roc_curves"]] + per_class_roc_curves = [(map(t -> t.false_positive_rate, stats), + map(t -> t.true_positive_rate, stats)) + for stats in per_class_stats] + per_class_roc_aucs = [area_under_curve(x, y) for (x, y) in per_class_roc_curves] # Optionally calculate optimal threshold - if !isnothing(optimal_threshold_class) - plot_dict["optimal_threshold_class"] = optimal_threshold_class - threshold = nothing - + if !ismissing(optimal_threshold_class) # If votes exist, calculate the threshold based on comparing against # vote probabilities. Otherwise, use the ROC curve. if !isnothing(votes) @@ -531,75 +527,97 @@ function evaluation_metrics(predicted_hard_labels::AbstractVector, votes; thresholds=thresholds, class_of_interest_index=optimal_threshold_class) - threshold = c.threshold - plot_dict["discrimination_calibration_curve"] = c.plot_curve_data - plot_dict["discrimination_calibration_score"] = c.mse + optimal_threshold = c.threshold + discrimination_calibration_curve = c.plot_curve_data + discrimination_calibration_score = c.mse expert_cal = _calculate_voter_discrimination_calibration(votes; class_of_interest_index=optimal_threshold_class) - plot_dict["per_expert_discrimination_calibration_curves"] = expert_cal.plot_curve_data - plot_dict["per_expert_discrimination_calibration_scores"] = expert_cal.mse + per_expert_discrimination_calibration_curves = expert_cal.plot_curve_data + per_expert_discrimination_calibration_scores = expert_cal.mse else + discrimination_calibration_curve = missing + discrimination_calibration_score = missing + per_expert_discrimination_calibration_curves = missing + per_expert_discrimination_calibration_scores = missing # ...based on ROC curve otherwise - threshold = _get_optimal_threshold_from_ROC(plot_dict["per_class_roc_curves"]; - thresholds=thresholds, - class_of_interest_index=optimal_threshold_class) + optimal_threshold = _get_optimal_threshold_from_ROC(per_class_roc_curves; + thresholds=thresholds, + class_of_interest_index=optimal_threshold_class) end - plot_dict["optimal_threshold"] = threshold # Recalculate `predicted_hard_labels` with this new threshold other_class = optimal_threshold_class == 1 ? 2 : 1 for (i, row) in enumerate(eachrow(predicted_soft_labels)) - predicted_hard_labels[i] = row[optimal_threshold_class] .>= threshold ? + predicted_hard_labels[i] = row[optimal_threshold_class] .>= optimal_threshold ? optimal_threshold_class : other_class end + else + discrimination_calibration_curve = missing + discrimination_calibration_score = missing + per_expert_discrimination_calibration_curves = missing + per_expert_discrimination_calibration_scores = missing + optimal_threshold = missing end # PR curves - plot_dict["per_class_pr_curves"] = [(map(t -> t.true_positive_rate, stats), - map(t -> t.precision, stats)) - for stats in per_class_stats] - - # Cohen's kappa - kappas = _calculate_ea_kappas(predicted_hard_labels, elected_hard_labels, class_count) - plot_dict["per_class_kappas"] = kappas.per_class - plot_dict["multiclass_kappa"] = kappas.multiclass - ira = _calculate_ira_kappas(votes, classes) - if !isnothing(ira) - plot_dict["per_class_IRA_kappas"] = ira.per_class - plot_dict["multiclass_IRA_kappas"] = ira.multiclass - end + per_class_pr_curves = [(map(t -> t.true_positive_rate, stats), + map(t -> t.precision, stats)) for stats in per_class_stats] # Stratified kappas - if !isnothing(strata) - plot_dict["stratified_kappas"] = _calculate_stratified_ea_kappas(predicted_hard_labels, - elected_hard_labels, - class_count, - strata) + if isnothing(strata) + stratified_kappas = missing + else + stratified_kappas = _calculate_stratified_ea_kappas(predicted_hard_labels, + elected_hard_labels, + class_count, strata) end # Reliability calibration curves - per_class_reliability_calibration_curves = map(1:class_count) do class_index + per_class_reliability_calibration = map(1:class_count) do class_index class_probabilities = view(predicted_soft_labels, :, class_index) return calibration_curve(class_probabilities, elected_hard_labels .== class_index) end - plot_dict["per_class_reliability_calibration_curves"] = map(x -> (mean.(x.bins), - x.fractions), - per_class_reliability_calibration_curves) - plot_dict["per_class_reliability_calibration_scores"] = map(x -> x.mean_squared_error, - per_class_reliability_calibration_curves) - - # Confusion matrix - plot_dict["confusion_matrix"] = confusion_matrix(class_count, - zip(predicted_hard_labels, - elected_hard_labels)) + per_class_reliability_calibration_curves = map(x -> (mean.(x.bins), x.fractions), + per_class_reliability_calibration) + per_class_reliability_calibration_scores = map(x -> x.mean_squared_error, + per_class_reliability_calibration) # Log Spearman correlation, iff this is a binary classification problem if length(classes) == 2 && !isnothing(votes) - plot_dict["spearman_correlation"] = _calculate_spearman_correlation(predicted_soft_labels, - votes, classes) + spearman_correlation = _calculate_spearman_correlation(predicted_soft_labels, votes, + classes) + else + spearman_correlation = missing end - return plot_dict + + return EvaluationRow(; class_labels, + confusion_matrix=confusion_matrix(class_count, + zip(predicted_hard_labels, + elected_hard_labels)), + spearman_correlation, per_class_reliability_calibration_curves, + per_class_reliability_calibration_scores, + _calculate_ira_kappas(votes, classes)..., + _calculate_ea_kappas(predicted_hard_labels, elected_hard_labels, + class_count)..., stratified_kappas, + per_class_pr_curves, per_class_roc_curves, per_class_roc_aucs, + discrimination_calibration_curve, discrimination_calibration_score, + per_expert_discrimination_calibration_curves, + per_expert_discrimination_calibration_scores, optimal_threshold, + optimal_threshold_class, thresholds) +end + +""" + evaluation_metrics(args...; optimal_threshold_class=nothing, kwargs...) + +Return [`evaluation_metrics_row`](@ref) after converting output `EvaluationRow` +into a `Dict`. For argument details, see [`evaluation_metrics_row`](@ref). +""" +function evaluation_metrics(args...; optimal_threshold_class=nothing, kwargs...) + row = evaluation_metrics_row(args...; + optimal_threshold_class=something(optimal_threshold_class, + missing), kwargs...) + return _evaluation_row_dict(row) end """ @@ -627,7 +645,6 @@ function evaluation_metrics_plot(predicted_hard_labels::AbstractVector, votes::Union{Nothing,AbstractMatrix}=nothing, strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, optimal_threshold_class::Union{Nothing,Integer}=nothing) - Base.depwarn(""" ``` evaluation_metrics_plot(predicted_hard_labels::AbstractVector, @@ -646,8 +663,8 @@ function evaluation_metrics_plot(predicted_hard_labels::AbstractVector, ``` """, :evaluation_metrics_plot) plot_dict = evaluation_metrics(predicted_hard_labels, predicted_soft_labels, - elected_hard_labels, classes, thresholds; - votes, strata, optimal_threshold_class) + elected_hard_labels, classes, thresholds; votes, strata, + optimal_threshold_class) return evaluation_metrics_plot(plot_dict), plot_dict end @@ -763,7 +780,8 @@ function learn!(model::AbstractClassifier, logger, get_train_batches, get_test_b predict!(model, predicted, get_test_batches(), logger; logger_prefix="$(test_set_logger_prefix)_prediction") evaluate!(map(label -> onecold(model, label), eachrow(predicted)), predicted, - elected, classes(model), logger; logger_prefix="$(test_set_logger_prefix)_evaluation", + elected, classes(model), logger; + logger_prefix="$(test_set_logger_prefix)_evaluation", logger_suffix="_per_epoch", votes=votes, optimal_threshold_class=optimal_threshold_class) post_epoch_callback(current_epoch) diff --git a/src/metrics.jl b/src/metrics.jl index d0bc98c..71f0ab8 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -24,14 +24,14 @@ end accuracy(confusion::AbstractMatrix) Returns the percentage of matching classifications out of total classifications, -or `missing` if `all(iszero, confusion)`. +or `NaN` if `all(iszero, confusion)`. Note that `accuracy(confusion)` is equivalent to overall percent agreement between `confusion`'s row classifier and column classifier. """ function accuracy(confusion::AbstractMatrix) total = sum(confusion) - total == 0 && return missing + total == 0 && return NaN return tr(confusion) / total end @@ -78,15 +78,12 @@ function binary_statistics(confusion::AbstractMatrix, class_index::Integer) false_negative_rate = (false_negatives == 0 && actual_positives == 0) ? (zero(false_negatives) / one(actual_positives)) : (false_negatives / actual_positives) - precision = (true_positives == 0 && predicted_positives == 0) ? missing : + precision = (true_positives == 0 && predicted_positives == 0) ? NaN : (true_positives / predicted_positives) - return (predicted_positives=predicted_positives, - predicted_negatives=predicted_negatives, actual_positives=actual_positives, - actual_negatives=actual_negatives, true_positives=true_positives, - true_negatives=true_negatives, false_positives=false_positives, - false_negatives=false_negatives, true_positive_rate=true_positive_rate, - true_negative_rate=true_negative_rate, false_positive_rate=false_positive_rate, - false_negative_rate=false_negative_rate, precision=precision) + return (; predicted_positives, predicted_negatives, actual_positives, actual_negatives, + true_positives, true_negatives, false_positives, false_negatives, + true_positive_rate, true_negative_rate, false_positive_rate, + false_negative_rate, precision) end function binary_statistics(confusion::AbstractMatrix) @@ -105,7 +102,8 @@ Return `(κ, p₀)` where `κ` is Cohen's kappa and `p₀` percent agreement giv their equivalents in [`confusion_matrix`](@ref)). """ function cohens_kappa(class_count, hard_label_pairs) - all(issubset(pair, 1:class_count) for pair in hard_label_pairs) || throw(ArgumentError("Unexpected class in `hard_label_pairs`.")) + all(issubset(pair, 1:class_count) for pair in hard_label_pairs) || + throw(ArgumentError("Unexpected class in `hard_label_pairs`.")) p₀ = accuracy(confusion_matrix(class_count, hard_label_pairs)) pₑ = _probability_of_chance_agreement(class_count, hard_label_pairs) return _cohens_kappa(p₀, pₑ), p₀ @@ -137,7 +135,7 @@ where: - `bins` a vector with `bin_count` `Pairs` specifying the calibration curve's probability bins - `fractions`: a vector where `fractions[i]` is the number of values in `probabilities` - that falls within `bin[i]` over the total number of values within `bin[i]`, or `missing` + that falls within `bin[i]` over the total number of values within `bin[i]`, or `NaN` if the total number of values in `bin[i]` is zero. - `totals`: a vector where `totals[i]` the total number of values within `bin[i]`. - `mean_squared_error`: The mean squared error of `fractions` vs. an ideal calibration curve. @@ -150,12 +148,12 @@ function calibration_curve(probabilities, bitmask; bin_count=10) bins = probability_bins(bin_count) per_bin = [fraction_within(probabilities, bitmask, bin...) for bin in bins] fractions, totals = first.(per_bin), last.(per_bin) - nonempty_indices = findall(!ismissing, fractions) + nonempty_indices = findall(!isnan, fractions) if !isempty(nonempty_indices) ideal = range(mean(first(bins)), mean(last(bins)); length=length(bins)) mean_squared_error = mse(fractions[nonempty_indices], ideal[nonempty_indices]) else - mean_squared_error = missing + mean_squared_error = NaN end return (bins=bins, fractions=fractions, totals=totals, mean_squared_error=mean_squared_error) @@ -179,6 +177,6 @@ function fraction_within(values, bitmask, start, stop) total += 1 end end - fraction = iszero(total) ? missing : (count / total) + fraction = iszero(total) ? NaN : (count / total) return (fraction=fraction, total=total) end diff --git a/src/row.jl b/src/row.jl new file mode 100644 index 0000000..f036f41 --- /dev/null +++ b/src/row.jl @@ -0,0 +1,133 @@ +# Arrow can't handle matrices---so when we write/read matrices, we have to pack and unpack them o_O +# https://github.com/apache/arrow-julia/issues/125 +vec_to_mat(mat::AbstractMatrix) = mat + +function vec_to_mat(vec::AbstractVector) + n = isqrt(length(vec)) + return reshape(vec, n, n) +end + +vec_to_mat(x::Missing) = return missing + +# Redefinition is workaround for https://github.com/beacon-biosignals/Legolas.jl/issues/9 +const EVALUATION_ROW_SCHEMA = Legolas.Schema("lighthouse.evaluation@1") + +""" + const EvaluationRow = Legolas.@row("lighthouse.evaluation@1", + class_labels::Union{Missing,Vector{String}}, + confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix), + discrimination_calibration_curve::Union{Missing, + Tuple{Vector{Float64}, + Vector{Float64}}}, + discrimination_calibration_score::Union{Missing,Float64}, + multiclass_IRA_kappas::Union{Missing,Float64}, + multiclass_kappa::Union{Missing,Float64}, + optimal_threshold::Union{Missing,Float64}, + optimal_threshold_class::Union{Missing,Int64}, + per_class_IRA_kappas::Union{Missing,Vector{Float64}}, + per_class_kappas::Union{Missing,Vector{Float64}}, + stratified_kappas::Union{Missing, + Vector{NamedTuple{(:per_class, + :multiclass, + :n), + Tuple{Vector{Float64}, + Float64, + Int64}}}}, + per_class_pr_curves::Union{Missing, + Vector{Tuple{Vector{Float64}, + Vector{Float64}}}}, + per_class_reliability_calibration_curves::Union{Missing, + Vector{Tuple{Vector{Float64}, + Vector{Float64}}}}, + per_class_reliability_calibration_scores::Union{Missing, + Vector{Float64}}, + per_class_roc_aucs::Union{Missing,Vector{Float64}}, + per_class_roc_curves::Union{Missing, + Vector{Tuple{Vector{Float64}, + Vector{Float64}}}}, + per_expert_discrimination_calibration_curves::Union{Missing, + Vector{Tuple{Vector{Float64}, + Vector{Float64}}}}, + per_expert_discrimination_calibration_scores::Union{Missing, + Vector{Float64}}, + spearman_correlation::Union{Missing, + NamedTuple{(:ρ, :n, + :ci_lower, + :ci_upper), + Tuple{Float64, + Int64, + Float64, + Float64}}}, + thresholds::Union{Missing,Vector{Float64}}) + EvaluationRow(evaluation_row_dict::Dict{String, Any}) -> EvaluationRow + +A type alias for [`Legolas.Row{typeof(Legolas.Schema("lighthouse.evaluation@1@1"))}`](https://beacon-biosignals.github.io/Legolas.jl/stable/#Legolas.@row) +representing the output metrics computed by [`evaluation_metrics_row`](@ref) and +[`evaluation_metrics`](@ref). + +Constructor that takes `evaluation_row_dict` converts [`evaluation_metrics`](@ref) +`Dict` of metrics results (e.g. from Lighthouse v for (k, v) in pairs(evaluation_row_dict))...) + return EvaluationRow(row) +end + +""" + _evaluation_row_dict(row::EvaluationRow) -> Dict{String,Any} + +Convert [`EvaluationRow`](@ref) into `::Dict{String, Any}` results, as are +output by `[`evaluation_metrics`](@ref)` (and predated use of `EvaluationRow` in +Lighthouse v for (k, v) in pairs(NamedTuple(row)) if !ismissing(v)) +end diff --git a/test/learn.jl b/test/learn.jl index 1d1b6aa..fc46e54 100644 --- a/test/learn.jl +++ b/test/learn.jl @@ -44,24 +44,25 @@ end @info counted n end end - elected = majority.((rng,), eachrow(votes), (1:length(Lighthouse.classes(model)),)) - Lighthouse.learn!(model, logger, () -> train_batches, () -> test_batches, votes, elected; - epoch_limit=limit, post_epoch_callback=callback) + elected = majority.((rng,), eachrow(votes), + (1:length(Lighthouse.classes(model)),)) + Lighthouse.learn!(model, logger, () -> train_batches, () -> test_batches, votes, + elected; epoch_limit=limit, post_epoch_callback=callback) @test counted == sum(1:limit) end @test length(logger.logged["train/loss_per_batch"]) == length(train_batches) * limit for key in ["test_set_prediction/loss_per_batch", - "test_set_prediction/time_in_seconds_per_batch", - "test_set_prediction/gc_time_in_seconds_per_batch", - "test_set_prediction/allocations_per_batch", - "test_set_prediction/memory_in_mb_per_batch"] + "test_set_prediction/time_in_seconds_per_batch", + "test_set_prediction/gc_time_in_seconds_per_batch", + "test_set_prediction/allocations_per_batch", + "test_set_prediction/memory_in_mb_per_batch"] @test length(logger.logged[key]) == length(test_batches) * limit end for key in ["test_set_prediction/mean_loss_per_epoch", - "test_set_evaluation/time_in_seconds_per_epoch", - "test_set_evaluation/gc_time_in_seconds_per_epoch", - "test_set_evaluation/allocations_per_epoch", - "test_set_evaluation/memory_in_mb_per_epoch"] + "test_set_evaluation/time_in_seconds_per_epoch", + "test_set_evaluation/gc_time_in_seconds_per_epoch", + "test_set_evaluation/allocations_per_epoch", + "test_set_evaluation/memory_in_mb_per_epoch"] @test length(logger.logged[key]) == limit end @test length(logger.logged["test_set_evaluation/metrics_per_epoch"]) == limit @@ -98,24 +99,36 @@ end @test length(logger.logged["wheeeeeee/time_in_seconds_for_all_time"]) == 1 @test length(logger.logged["wheeeeeee/metrics_for_all_time"]) == 1 + # Round-trip `onehot` for codecov + onehot_hard = map(h -> vec(Lighthouse.onehot(model, h)), predicted_hard) + @test map(h -> findfirst(h), onehot_hard) == predicted_hard + # Test startified eval strata = [Set("group $(j % Int(ceil(sqrt(j))))" for j in 1:(i - 1)) for i in 1:size(votes, 1)] - plot_data = evaluation_metrics(predicted_hard, predicted_soft, - elected_hard, model.classes, 0.0:0.01:1.0; - votes=votes, strata=strata) + plot_data = evaluation_metrics(predicted_hard, predicted_soft, elected_hard, + model.classes, 0.0:0.01:1.0; votes, strata) @test haskey(plot_data, "stratified_kappas") plot = evaluation_metrics_plot(plot_data) - plot2, plot_data2 = @test_deprecated evaluation_metrics_plot(predicted_hard, predicted_soft, - elected_hard, model.classes, 0.0:0.01:1.0; - votes=votes, strata=strata) + test_evaluation_metrics_roundtrip(plot_data) + + plot2, plot_data2 = @test_deprecated evaluation_metrics_plot(predicted_hard, + predicted_soft, + elected_hard, + model.classes, + 0.0:0.01:1.0; + votes=votes, + strata=strata) @test isequal(plot_data, plot_data2) # check these are the same + test_evaluation_metrics_roundtrip(plot_data2) # Test plotting plot_data = last(logger.logged["test_set_evaluation/metrics_per_epoch"]) @test isa(plot_data["thresholds"], AbstractVector) + @test isa(last(plot_data["per_class_pr_curves"]), + Tuple{Vector{Float64},Vector{Float64}}) pr = plot_pr_curves(plot_data["per_class_pr_curves"], plot_data["class_labels"]) @testplot pr @@ -124,15 +137,17 @@ end @testplot roc # Kappa no IRA - kappas_no_ira = plot_kappas(vcat(plot_data["multiclass_kappa"], plot_data["per_class_kappas"]), - vcat("Multiclass", plot_data["class_labels"])) + kappas_no_ira = plot_kappas(vcat(plot_data["multiclass_kappa"], + plot_data["per_class_kappas"]), + vcat("Multiclass", plot_data["class_labels"])) @testplot kappas_no_ira # Kappa with IRA - kappas_ira = plot_kappas(vcat(plot_data["multiclass_kappa"], plot_data["per_class_kappas"]), - vcat("Multiclass", plot_data["class_labels"]), - vcat(plot_data["multiclass_IRA_kappas"], - plot_data["per_class_IRA_kappas"])) + kappas_ira = plot_kappas(vcat(plot_data["multiclass_kappa"], + plot_data["per_class_kappas"]), + vcat("Multiclass", plot_data["class_labels"]), + vcat(plot_data["multiclass_IRA_kappas"], + plot_data["per_class_IRA_kappas"])) @testplot kappas_ira reliability_calibration = plot_reliability_calibration_curves(plot_data["per_class_reliability_calibration_curves"], @@ -177,9 +192,10 @@ end @info counted n end end - elected = majority.((rng,), eachrow(votes), (1:length(Lighthouse.classes(model)),)) - Lighthouse.learn!(model, logger, () -> train_batches, () -> test_batches, votes, elected; - epoch_limit=limit, post_epoch_callback=callback) + elected = majority.((rng,), eachrow(votes), + (1:length(Lighthouse.classes(model)),)) + Lighthouse.learn!(model, logger, () -> train_batches, () -> test_batches, votes, + elected; epoch_limit=limit, post_epoch_callback=callback) @test counted == sum(1:limit) end # Binary classification logs some additional metrics @@ -187,6 +203,7 @@ end limit plot_data = last(logger.logged["test_set_evaluation/metrics_per_epoch"]) @test haskey(plot_data, "spearman_correlation") + test_evaluation_metrics_roundtrip(plot_data) # No `optimal_threshold_class` during learning... @test !haskey(plot_data, "optimal_threshold") @@ -194,13 +211,14 @@ end # And now, `optimal_threshold_class` during learning elected = majority.((rng,), eachrow(votes), (1:length(Lighthouse.classes(model)),)) - Lighthouse.learn!(model, logger, () -> train_batches, () -> test_batches, votes, elected; - epoch_limit=limit, optimal_threshold_class=2, + Lighthouse.learn!(model, logger, () -> train_batches, () -> test_batches, votes, + elected; epoch_limit=limit, optimal_threshold_class=2, test_set_logger_prefix="validation_set") plot_data = last(logger.logged["validation_set_evaluation/metrics_per_epoch"]) @test haskey(plot_data, "optimal_threshold") @test haskey(plot_data, "optimal_threshold_class") @test plot_data["optimal_threshold_class"] == 2 + test_evaluation_metrics_roundtrip(plot_data) # `optimal_threshold_class` param invalid @test_throws ArgumentError Lighthouse.learn!(model, logger, () -> train_batches, @@ -226,12 +244,14 @@ end plot_data = last(logger.logged["wheeeeeee/metrics_for_all_time"]) @test !haskey(plot_data, "per_class_IRA_kappas") @test !haskey(plot_data, "multiclass_IRA_kappas") + test_evaluation_metrics_roundtrip(plot_data) evaluate!(predicted_hard, predicted_soft, elected_hard, model.classes, logger; logger_prefix="wheeeeeee", logger_suffix="_for_all_time", votes=votes) plot_data = last(logger.logged["wheeeeeee/metrics_for_all_time"]) @test haskey(plot_data, "per_class_IRA_kappas") @test haskey(plot_data, "multiclass_IRA_kappas") + test_evaluation_metrics_roundtrip(plot_data) # Test `evaluate` for different optimal_threshold classes evaluate!(predicted_hard, predicted_soft, elected_hard, model.classes, logger; @@ -242,6 +262,7 @@ end logger_prefix="wheeeeeee", logger_suffix="_for_all_time", votes=votes, optimal_threshold_class=2) plot_data_2 = last(logger.logged["wheeeeeee/metrics_for_all_time"]) + test_evaluation_metrics_roundtrip(plot_data_2) # The thresholds should not be identical (since they are *inclusive* when applied: # values greater than _or equal to_ the threshold are given the class value) @@ -266,10 +287,12 @@ end end end -@testset "`_calculate_ira_kappas`" begin +@testset "Invalid `_calculate_ira_kappas`" begin classes = ["roy", "gee", "biv"] - @test isnothing(Lighthouse._calculate_ira_kappas([1; 1; 1; 1], classes)) # Only one voter... - @test isnothing(Lighthouse._calculate_ira_kappas([1 0; 1 0; 0 1], classes)) # No observations in common... + @test isequal(Lighthouse._calculate_ira_kappas([1; 1; 1; 1], classes), + (; per_class_IRA_kappas=missing, multiclass_IRA_kappas=missing)) # Only one voter... + @test isequal(Lighthouse._calculate_ira_kappas([1 0; 1 0; 0 1], classes), + (; per_class_IRA_kappas=missing, multiclass_IRA_kappas=missing)) # No observations in common... end @testset "Calculate `_spearman_corr`" begin @@ -303,11 +326,9 @@ end # Test NaN spearman due to unranked input votes = [1; 2; 2] - predicted_soft = [ - 0.3 0.7 - 0.3 0.7 - 0.3 0.7 - ] + predicted_soft = [0.3 0.7 + 0.3 0.7 + 0.3 0.7] sp = Lighthouse._calculate_spearman_correlation(predicted_soft, votes, ["oh" "em"]) @test isnan(sp.ρ) @@ -329,11 +350,9 @@ end @test length(single_voter_calibration.mse) == 1 # Test multi-voter voter discrimination calibration - votes = [ - 0 1 1 1 - 1 2 0 0 - 2 1 2 2 - ] # Note: voters 3 and 4 have voted identically + votes = [0 1 1 1 + 1 2 0 0 + 2 1 2 2] # Note: voters 3 and 4 have voted identically voter_calibration = Lighthouse._calculate_voter_discrimination_calibration(votes; class_of_interest_index=1) @test length(voter_calibration.mse) == size(votes, 2) @@ -360,13 +379,11 @@ end end @testset "2-class per_class_confusion_statistics" begin - predicted_soft_labels = [ - 0.51 0.49 - 0.49 0.51 - 0.1 0.9 - 0.9 0.1 - 0.0 1.0 - ] + predicted_soft_labels = [0.51 0.49 + 0.49 0.51 + 0.1 0.9 + 0.9 0.1 + 0.0 1.0] elected_hard_labels = [1, 2, 2, 2, 1] thresholds = [0.25, 0.5, 0.75] class_1, class_2 = Lighthouse.per_class_confusion_statistics(predicted_soft_labels, @@ -441,15 +458,13 @@ end end @testset "3-class per_class_confusion_statistics" begin - predicted_soft_labels = [ - 1/3 1/3 1/3 - 0.1 0.7 0.2 - 0.25 0.25 0.5 - 0.4 0.5 0.1 - 0.0 0.0 1.0 - 0.2 0.5 0.3 - 0.5 0.4 0.1 - ] + predicted_soft_labels = [1/3 1/3 1/3 + 0.1 0.7 0.2 + 0.25 0.25 0.5 + 0.4 0.5 0.1 + 0.0 0.0 1.0 + 0.2 0.5 0.3 + 0.5 0.4 0.1] elected_hard_labels = [1, 2, 2, 1, 3, 3, 1] # TODO would be more robust to have multiple thresholds, but our naive tests # here will have to be refactored to avoid becoming a nightmare if we do that diff --git a/test/metrics.jl b/test/metrics.jl index 2254afe..eca833d 100644 --- a/test/metrics.jl +++ b/test/metrics.jl @@ -65,8 +65,8 @@ @test isapprox(stats.precision, 0.5; atol=0.02) @test confusion_matrix(10, ()) == zeros(10, 10) - @test all(ismissing, cohens_kappa(10, ())) - @test ismissing(accuracy(zeros(10, 10))) + @test all(isnan, cohens_kappa(10, ())) + @test isnan(accuracy(zeros(10, 10))) stats = binary_statistics(zeros(10, 10), 1) @test stats.predicted_positives == 0 @test stats.predicted_negatives == 0 @@ -80,7 +80,7 @@ @test stats.true_negative_rate == 1 @test stats.false_positive_rate == 0 @test stats.false_negative_rate == 0 - @test ismissing(stats.precision) + @test isnan(stats.precision) for p in 0:0.1:1 @test Lighthouse._cohens_kappa(p, p) == 0 @@ -103,6 +103,7 @@ end @test bin_count == length(bins) @test first(first(bins)) == 0.0 && last(last(bins)) == 1.0 @test all(!ismissing, fractions) + @test all(!isnan, fractions) @test all(!iszero, totals) @test all(isapprox.(fractions, 0.5; atol=0.02)) @test all(isapprox.(totals, length(probs) / bin_count; atol=1000)) @@ -120,6 +121,7 @@ end @test bin_count == length(bins) @test first(first(bins)) == 0.0 && last(last(bins)) == 1.0 @test all(!ismissing, fractions) + @test all(!isnan, fractions) @test all(!iszero, totals) @test all(isapprox.(fractions, ideal; atol=0.01)) @test all(totals .== 1_000_000 / bin_count) @@ -132,9 +134,21 @@ end @test bin_count == length(bins) @test first(first(bins)) == 0.0 && last(last(bins)) == 1.0 @test all(!ismissing, fractions) + @test all(!isnan, fractions) @test all(!iszero, totals) @test all(isapprox.(fractions, reverse(ideal); atol=0.01)) @test all(totals .== 1_000_000 / bin_count) @test isapprox(ceil(mean(fractions) * length(bitmask)), count(bitmask); atol=1) @test isapprox(mean_squared_error, 1 / 3; atol=0.01) + + # Handle garbage input---ensure non-existant results are NaN + probs = fill(-1, 40) + bitmask = zeros(Bool, 40) + bins, fractions, totals, mean_squared_error = calibration_curve(probs, bitmask; + bin_count) + @test bin_count == length(bins) + @test first(first(bins)) == 0.0 && last(last(bins)) == 1.0 + @test all(isnan, fractions) + @test all(iszero, totals) + @test isnan(mean_squared_error) end diff --git a/test/row.jl b/test/row.jl new file mode 100644 index 0000000..9f836db --- /dev/null +++ b/test/row.jl @@ -0,0 +1,20 @@ +@testset "`vec_to_mat`" begin + mat = [3 5 6; 6 7 8; 9 10 11] + @test Lighthouse.vec_to_mat(vec(mat)) == mat + @test Lighthouse.vec_to_mat(mat) == mat + @test ismissing(Lighthouse.vec_to_mat(missing)) + @test_throws DimensionMismatch Lighthouse.vec_to_mat(collect(1:6)) # Invalid dimensions +end + +@testset "`EvaluationRow` basics" begin + # Most EvaluationRow testing happens via the `test_evaluation_metrics_roundtrip` + # in test/learn.jl + + # Roundtrip from dict + dict = Dict("class_labels" => ["foo", "bar"], "multiclass_kappa" => 3) + test_evaluation_metrics_roundtrip(dict) + + # Handle fun case + mat_dict = Dict("confusion_matrix" => [3 5 6; 6 7 8; 9 10 11]) + test_evaluation_metrics_roundtrip(mat_dict) +end diff --git a/test/runtests.jl b/test/runtests.jl index f15aef5..c92ff9f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ using Lighthouse: plot_reliability_calibration_curves, plot_pr_curves, evaluation_metrics_plot, evaluation_metrics using Base.Threads using CairoMakie +using Legolas, Tables # Needs to be set for figures # returning true for showable("image/png", obj) @@ -25,8 +26,49 @@ macro testplot(fig_name) end end +const EVALUATION_ROW_KEYS = string.(keys(EvaluationRow())) + +function test_evaluation_metrics_roundtrip(row_dict::Dict{String,S}) where {S} + # Make sure we're capturing all metrics keys in our Schema + keys_not_in_schema = setdiff(keys(row_dict), EVALUATION_ROW_KEYS) + @test isempty(keys_not_in_schema) + + # Do the roundtripping (will fail if schema types do not validate after roundtrip) + row = EvaluationRow(row_dict) + rt_row = roundtrip_row(row) + + # Make sure full row roundtrips correctly + @test issetequal(keys(row), keys(rt_row)) + for (k, v) in pairs(row) + if ismissing(v) + @test ismissing(rt_row[k]) + else + @test issetequal(v, rt_row[k]) + end + end + + # Make sure originating metrics dictionary roundtrips correctly + rt_dict = Lighthouse._evaluation_row_dict(rt_row) + for (k, v) in pairs(row_dict) + if ismissing(v) + @test ismissing(rt_dict[k]) + else + @test issetequal(v, rt_dict[k]) + end + end + return nothing +end + +function roundtrip_row(row::EvaluationRow) + p = mktempdir() * "rt_test.arrow" + tbl = [row] + Legolas.write(p, tbl, Lighthouse.EVALUATION_ROW_SCHEMA) + return EvaluationRow(only(Tables.rows(Legolas.read(p)))) +end + include("plotting.jl") include("metrics.jl") include("learn.jl") include("utilities.jl") include("logger.jl") +include("row.jl")