From f5bfcc503d96ae4e21114d46c97c42196ff7608f Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 19 Jan 2025 19:11:25 -0500 Subject: [PATCH 1/2] fix: address subset issue highlighted in #95 --- aviary/utils.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/aviary/utils.py b/aviary/utils.py index bbe35667..e32d23cb 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -316,11 +316,6 @@ def train_ensemble( when early stopping. Defaults to None. verbose (bool, optional): Whether to show progress bars for each epoch. """ - if isinstance(train_set, Subset): - train_set = train_set.dataset - if isinstance(val_set, Subset): - val_set = val_set.dataset - train_loader = DataLoader(train_set, **data_params) print(f"Training on {len(train_set):,} samples") @@ -354,7 +349,13 @@ def train_ensemble( for target, normalizer in normalizer_dict.items(): if normalizer is not None: - sample_target = Tensor(train_set.df[target].values) + if isinstance(train_set, Subset): + sample_target = Tensor( + train_set.dataset.df[target].iloc[train_set.indices].values + ) + else: + sample_target = Tensor(train_set.df[target].values) + if not restart_params["resume"]: normalizer.fit(sample_target) print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}") @@ -455,10 +456,6 @@ def results_multitask( "------------Evaluate model on Test Set------------\n" "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n" ) - - if isinstance(test_set, Subset): - test_set = test_set.dataset - test_loader = DataLoader(test_set, **data_params) print(f"Testing on {len(test_set):,} samples") From 645e69469e34eac83da495e82030efd782142664 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 19 Jan 2025 19:44:13 -0500 Subject: [PATCH 2/2] test: add checks that would have caught the test set subset issue --- aviary/core.py | 2 +- aviary/utils.py | 1 + tests/test_cgcnn_classification.py | 1 + tests/test_cgcnn_regression.py | 1 + tests/test_roost_classification.py | 2 ++ tests/test_roost_regression.py | 1 + tests/test_wren_classification.py | 1 + tests/test_wren_regression.py | 1 + 8 files changed, 9 insertions(+), 1 deletion(-) diff --git a/aviary/core.py b/aviary/core.py index f8a3ee7c..7425c98c 100644 --- a/aviary/core.py +++ b/aviary/core.py @@ -346,7 +346,7 @@ def evaluate( metrics_str = " ".join( f"{key} {val:<9.2f}" for key, val in avrg_metrics[target].items() ) - print(f"{action:>9}: {target} {metrics_str}") + print(f"{action:>9}: {target} N {len(data_loader):,} {metrics_str}") return avrg_metrics diff --git a/aviary/utils.py b/aviary/utils.py index e32d23cb..6da411b9 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -322,6 +322,7 @@ def train_ensemble( if val_set is not None: data_params.update({"batch_size": 16 * data_params["batch_size"]}) val_loader = DataLoader(val_set, **data_params) + print(f"Validating on {len(val_set):,} samples") else: val_loader = None diff --git a/tests/test_cgcnn_classification.py b/tests/test_cgcnn_classification.py index e75d0e5c..0c8d10cb 100644 --- a/tests/test_cgcnn_classification.py +++ b/tests/test_cgcnn_classification.py @@ -136,5 +136,6 @@ def test_cgcnn_clf(df_matbench_phonons): ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert ens_acc > 0.85 assert ens_roc_auc > 0.9 diff --git a/tests/test_cgcnn_regression.py b/tests/test_cgcnn_regression.py index 51d14b92..f6c908c6 100644 --- a/tests/test_cgcnn_regression.py +++ b/tests/test_cgcnn_regression.py @@ -135,6 +135,7 @@ def test_cgcnn_regression(df_matbench_phonons): mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert r2 > 0.7 assert mae < 150 assert rmse < 300 diff --git a/tests/test_roost_classification.py b/tests/test_roost_classification.py index deb500d7..86322bf1 100644 --- a/tests/test_roost_classification.py +++ b/tests/test_roost_classification.py @@ -138,5 +138,7 @@ def test_roost_clf(df_matbench_phonons): ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + assert len(logits) == ensemble + assert len(targets) == len(test_set) == len(test_idx) assert ens_acc > 0.9 assert ens_roc_auc > 0.9 diff --git a/tests/test_roost_regression.py b/tests/test_roost_regression.py index b24f2311..125b0031 100644 --- a/tests/test_roost_regression.py +++ b/tests/test_roost_regression.py @@ -137,6 +137,7 @@ def test_roost_regression(df_matbench_phonons): mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert r2 > 0.7 assert mae < 150 assert rmse < 300 diff --git a/tests/test_wren_classification.py b/tests/test_wren_classification.py index d6f01f9c..c7db96e0 100644 --- a/tests/test_wren_classification.py +++ b/tests/test_wren_classification.py @@ -146,5 +146,6 @@ def test_wren_clf(df_matbench_phonons_wyckoff): ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert ens_acc > 0.85 assert ens_roc_auc > 0.9 diff --git a/tests/test_wren_regression.py b/tests/test_wren_regression.py index c080dd2a..a8ad73db 100644 --- a/tests/test_wren_regression.py +++ b/tests/test_wren_regression.py @@ -145,6 +145,7 @@ def test_wren_regression(df_matbench_phonons_wyckoff): mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + assert len(targets) == len(test_set) == len(test_idx) assert r2 > 0.7 assert mae < 150 assert rmse < 300