Skip to content

Commit

Permalink
fix compress_memberships bug, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierBinette committed Nov 17, 2023
1 parent 79f1974 commit d190ec5
Show file tree
Hide file tree
Showing 20 changed files with 111 additions and 180 deletions.
24 changes: 5 additions & 19 deletions er_evaluation/data_structures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,11 @@
└─3─┘ 5
"""
from er_evaluation.data_structures._data_structures import (
MembershipVector,
clusters_to_graph,
clusters_to_membership,
clusters_to_pairs,
compress_memberships,
graph_to_clusters,
graph_to_membership,
graph_to_pairs,
isclusters,
isgraph,
ismembership,
ispairs,
membership_to_clusters,
membership_to_graph,
membership_to_pairs,
pairs_to_clusters,
pairs_to_graph,
pairs_to_membership,
)
MembershipVector, clusters_to_graph, clusters_to_membership,
clusters_to_pairs, compress_memberships, graph_to_clusters,
graph_to_membership, graph_to_pairs, isclusters, isgraph, ismembership,
ispairs, membership_to_clusters, membership_to_graph, membership_to_pairs,
pairs_to_clusters, pairs_to_graph, pairs_to_membership)

__all__ = [
"compress_memberships",
Expand Down
13 changes: 6 additions & 7 deletions er_evaluation/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
The :py:meth:`load_rldata10000_disambiguations` and :py:meth:`load_rldata10000` return ground truth disambiguation, toy predicted disambiguations, and the full RLdata1000 dataframe.
"""

from er_evaluation.datasets.patentsview import load_pv_data, load_pv_disambiguations
from er_evaluation.datasets.rldata import (
load_rldata500,
load_rldata500_disambiguations,
load_rldata10000,
load_rldata10000_disambiguations,
)
from er_evaluation.datasets.patentsview import (load_pv_data,
load_pv_disambiguations)
from er_evaluation.datasets.rldata import (load_rldata500,
load_rldata500_disambiguations,
load_rldata10000,
load_rldata10000_disambiguations)

__all__ = [
"load_pv_data",
Expand Down
30 changes: 8 additions & 22 deletions er_evaluation/error_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,15 @@
The key advantage of working with the record error table is that it allows sensitivity analyses to be performed. Since all cluster error metrics and representative performance estimators can be computed directly from the record error table, uncertainty regarding error rates can be propagated from the record error table into cluster error metrics and into performance estimates.
"""
from er_evaluation.error_analysis._cluster_error import (
count_extra,
count_missing,
error_indicator,
error_metrics,
expected_extra,
expected_missing,
expected_relative_extra,
expected_relative_missing,
expected_size_difference,
splitting_entropy,
)
count_extra, count_missing, error_indicator, error_metrics, expected_extra,
expected_missing, expected_relative_extra, expected_relative_missing,
expected_size_difference, splitting_entropy)
from er_evaluation.error_analysis._record_error import (
cluster_sizes_from_table,
error_indicator_from_table,
error_metrics_from_table,
expected_extra_from_table,
expected_missing_from_table,
expected_relative_extra_from_table,
expected_relative_missing_from_table,
expected_size_difference_from_table,
pred_cluster_sizes_from_table,
record_error_table,
)
cluster_sizes_from_table, error_indicator_from_table,
error_metrics_from_table, expected_extra_from_table,
expected_missing_from_table, expected_relative_extra_from_table,
expected_relative_missing_from_table, expected_size_difference_from_table,
pred_cluster_sizes_from_table, record_error_table)
from er_evaluation.error_analysis._subgroup_discovery import fit_dt_regressor

__all__ = [
Expand Down
13 changes: 4 additions & 9 deletions er_evaluation/error_analysis/_cluster_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,10 @@

from er_evaluation.data_structures import MembershipVector
from er_evaluation.error_analysis._record_error import (
error_indicator_from_table,
error_metrics_from_table,
expected_extra_from_table,
expected_missing_from_table,
expected_relative_extra_from_table,
expected_relative_missing_from_table,
expected_size_difference_from_table,
record_error_table,
)
error_indicator_from_table, error_metrics_from_table,
expected_extra_from_table, expected_missing_from_table,
expected_relative_extra_from_table, expected_relative_missing_from_table,
expected_size_difference_from_table, record_error_table)
from er_evaluation.utils import relevant_prediction_subset


Expand Down
28 changes: 11 additions & 17 deletions er_evaluation/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,18 @@
**Note:** In order to obtain representative performance estimators, the set of predicted clusters given as an argument to estimator functions should cover the entire population of interest. Typically, this set of predicted clusters will be much larger than the set of sampled clusters.
"""
from er_evaluation.estimators._estimators import (
b_cubed_precision_estimator,
b_cubed_recall_estimator,
cluster_f_estimator,
cluster_precision_estimator,
cluster_recall_estimator,
estimates_table,
pairwise_f_estimator,
pairwise_precision_estimator,
pairwise_recall_estimator,
)
from er_evaluation.estimators._estimators import (b_cubed_precision_estimator,
b_cubed_recall_estimator,
cluster_f_estimator,
cluster_precision_estimator,
cluster_recall_estimator,
estimates_table,
pairwise_f_estimator,
pairwise_precision_estimator,
pairwise_recall_estimator)
from er_evaluation.estimators._summary_estimators import (
avg_cluster_size_estimator,
homonymy_rate_estimator,
matching_rate_estimator,
name_variation_estimator,
summary_estimates_table,
)
avg_cluster_size_estimator, homonymy_rate_estimator,
matching_rate_estimator, name_variation_estimator, summary_estimates_table)

__all__ = [
"b_cubed_precision_estimator",
Expand Down
17 changes: 6 additions & 11 deletions er_evaluation/estimators/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,15 @@

from er_evaluation.data_structures import MembershipVector
from er_evaluation.error_analysis import record_error_table
from er_evaluation.estimators._utils import (
_parse_weights,
ratio_of_means_estimator,
validate_prediction_sample,
validate_weights,
)
from er_evaluation.estimators._utils import (_parse_weights,
ratio_of_means_estimator,
validate_prediction_sample,
validate_weights)
from er_evaluation.estimators.from_table import (
b_cubed_precision_estimator_from_table,
b_cubed_recall_estimator_from_table,
cluster_f_estimator_from_table,
b_cubed_recall_estimator_from_table, cluster_f_estimator_from_table,
cluster_precision_estimator_from_table,
cluster_recall_estimator_from_table,
pairwise_f_estimator_from_table,
)
cluster_recall_estimator_from_table, pairwise_f_estimator_from_table)
from er_evaluation.utils import expand_grid


Expand Down
10 changes: 4 additions & 6 deletions er_evaluation/estimators/_summary_estimators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pandas as pd

from er_evaluation.data_structures import MembershipVector
from er_evaluation.estimators._utils import (
_parse_weights,
ratio_of_means_estimator,
validate_prediction_sample,
validate_weights,
)
from er_evaluation.estimators._utils import (_parse_weights,
ratio_of_means_estimator,
validate_prediction_sample,
validate_weights)
from er_evaluation.summary import cluster_sizes
from er_evaluation.utils import expand_grid

Expand Down
14 changes: 6 additions & 8 deletions er_evaluation/estimators/from_table.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from er_evaluation.error_analysis import (
cluster_sizes_from_table,
error_indicator_from_table,
expected_missing_from_table,
expected_relative_extra_from_table,
expected_relative_missing_from_table,
expected_size_difference_from_table,
)
from er_evaluation.error_analysis import (cluster_sizes_from_table,
error_indicator_from_table,
expected_missing_from_table,
expected_relative_extra_from_table,
expected_relative_missing_from_table,
expected_size_difference_from_table)
from er_evaluation.estimators._utils import ratio_of_means_estimator


Expand Down
25 changes: 8 additions & 17 deletions er_evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,14 @@
- Records with NA cluster identifier in the reference or predicted clusterings are dropped.
- The metrics in this module do not provide representative performance estimates. They are only useful for comparing two clusterings, such as a. For representative performance estimates, see the :mod:`er_evaluation.estimators` module.
"""
from er_evaluation.metrics._metrics import (
adjusted_rand_score,
b_cubed_f,
b_cubed_precision,
b_cubed_recall,
cluster_completeness,
cluster_f,
cluster_homogeneity,
cluster_precision,
cluster_recall,
cluster_v_measure,
metrics_table,
pairwise_f,
pairwise_precision,
pairwise_recall,
rand_score,
)
from er_evaluation.metrics._metrics import (adjusted_rand_score, b_cubed_f,
b_cubed_precision, b_cubed_recall,
cluster_completeness, cluster_f,
cluster_homogeneity,
cluster_precision, cluster_recall,
cluster_v_measure, metrics_table,
pairwise_f, pairwise_precision,
pairwise_recall, rand_score)

__all__ = [
"adjusted_rand_score",
Expand Down
10 changes: 4 additions & 6 deletions er_evaluation/metrics/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from scipy.special import comb

from er_evaluation.data_structures import MembershipVector
from er_evaluation.error_analysis import (
error_indicator,
expected_relative_extra_from_table,
expected_relative_missing_from_table,
record_error_table,
)
from er_evaluation.error_analysis import (error_indicator,
expected_relative_extra_from_table,
expected_relative_missing_from_table,
record_error_table)
from er_evaluation.summary import number_of_links
from er_evaluation.utils import expand_grid

Expand Down
27 changes: 10 additions & 17 deletions er_evaluation/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
"""
Helper Plots and Visualizations
"""
from er_evaluation.plots._dtree_plots import (
make_dt_regressor_plot,
plot_dt_regressor_sunburst,
plot_dt_regressor_tree,
plot_dt_regressor_treemap,
)
from er_evaluation.plots._dtree_plots import (make_dt_regressor_plot,
plot_dt_regressor_sunburst,
plot_dt_regressor_tree,
plot_dt_regressor_treemap)
from er_evaluation.plots._fairness import plot_performance_disparities
from er_evaluation.plots._plots import (
add_ests_to_summaries,
compare_plots,
plot_cluster_errors,
plot_cluster_sizes_distribution,
plot_comparison,
plot_entropy_curve,
plot_estimates,
plot_metrics,
plot_summaries,
)
from er_evaluation.plots._plots import (add_ests_to_summaries, compare_plots,
plot_cluster_errors,
plot_cluster_sizes_distribution,
plot_comparison, plot_entropy_curve,
plot_estimates, plot_metrics,
plot_summaries)

__all__ = [
"add_ests_to_summaries",
Expand Down
3 changes: 2 additions & 1 deletion er_evaluation/plots/_dtree_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import plotly.graph_objects as go

from er_evaluation.error_analysis import fit_dt_regressor
from er_evaluation.plots._dtree_data import build_sunburst_data, create_igraph_tree
from er_evaluation.plots._dtree_data import (build_sunburst_data,
create_igraph_tree)


def make_dt_regressor_plot(
Expand Down
17 changes: 9 additions & 8 deletions er_evaluation/plots/_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from er_evaluation.data_structures import MembershipVector
from er_evaluation.error_analysis import error_metrics
from er_evaluation.estimators import (
estimates_table,
pairwise_precision_estimator,
pairwise_recall_estimator,
summary_estimates_table,
)
from er_evaluation.estimators import (estimates_table,
pairwise_precision_estimator,
pairwise_recall_estimator,
summary_estimates_table)
from er_evaluation.estimators._utils import _parse_weights
from er_evaluation.metrics import metrics_table, pairwise_f, pairwise_precision, pairwise_recall
from er_evaluation.summary import cluster_hill_number, cluster_sizes_distribution, summary_statistics
from er_evaluation.metrics import (metrics_table, pairwise_f,
pairwise_precision, pairwise_recall)
from er_evaluation.summary import (cluster_hill_number,
cluster_sizes_distribution,
summary_statistics)

DEFAULT_METRICS = {
"Pairwise precision": pairwise_precision,
Expand Down
20 changes: 8 additions & 12 deletions er_evaluation/summary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,14 @@
# 'homonymy_rate': 0.5,
# 'name_variation_rate': 0.5}
"""
from er_evaluation.summary._summary import (
average_cluster_size,
cluster_hill_number,
cluster_sizes,
cluster_sizes_distribution,
homonymy_rate,
matching_rate,
name_variation_rate,
number_of_clusters,
number_of_links,
summary_statistics,
)
from er_evaluation.summary._summary import (average_cluster_size,
cluster_hill_number, cluster_sizes,
cluster_sizes_distribution,
homonymy_rate, matching_rate,
name_variation_rate,
number_of_clusters,
number_of_links,
summary_statistics)

__all__ = [
"average_cluster_size",
Expand Down
11 changes: 4 additions & 7 deletions er_evaluation/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""
Utility Functions
"""
from er_evaluation.utils._utils import (
expand_grid,
load_module_parquet,
load_module_tsv,
relevant_prediction_subset,
sample_clusters,
)
from er_evaluation.utils._utils import (expand_grid, load_module_parquet,
load_module_tsv,
relevant_prediction_subset,
sample_clusters)

__all__ = [
"expand_grid",
Expand Down
7 changes: 5 additions & 2 deletions examples/subgroup_discovery/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import plotly.graph_objects as go
import streamlit as st
from data_prep import categorical_features, features_df, numerical_features, pred, reference
from data_prep import (categorical_features, features_df, numerical_features,
pred, reference)

from er_evaluation.error_analysis import error_indicator, expected_relative_extra, expected_relative_missing
from er_evaluation.error_analysis import (error_indicator,
expected_relative_extra,
expected_relative_missing)
from er_evaluation.estimators._utils import _parse_weights
from er_evaluation.plots import make_dt_regressor_plot

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_structures/test_compress_memberships.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from er_evaluation.data_structures import compress_memberships


def test_keep_na_values_in_index():
series1 = pd.Series(index=[-1, 0, 4, 7], data=[pd.NA, 1, 2, 3])
series2 = pd.Series(index=[1, 0, 4, 8], data=[1, pd.NA, 2, 3])
cs1, cs2 = compress_memberships(series1, series2)

assert cs1.isna().sum() == 3
assert cs2.isna().sum() == 3

Loading

0 comments on commit d190ec5

Please sign in to comment.