Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def evaluate(
metrics_str = " ".join(
f"{key} {val:<9.2f}" for key, val in avrg_metrics[target].items()
)
print(f"{action:>9}: {target} {metrics_str}")
print(f"{action:>9}: {target} N {len(data_loader):,} {metrics_str}")

return avrg_metrics

Expand Down
18 changes: 8 additions & 10 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,13 @@ def train_ensemble(
when early stopping. Defaults to None.
verbose (bool, optional): Whether to show progress bars for each epoch.
"""
if isinstance(train_set, Subset):
train_set = train_set.dataset
if isinstance(val_set, Subset):
val_set = val_set.dataset

train_loader = DataLoader(train_set, **data_params)
print(f"Training on {len(train_set):,} samples")

if val_set is not None:
data_params.update({"batch_size": 16 * data_params["batch_size"]})
val_loader = DataLoader(val_set, **data_params)
print(f"Validating on {len(val_set):,} samples")
else:
val_loader = None

Expand Down Expand Up @@ -354,7 +350,13 @@ def train_ensemble(

for target, normalizer in normalizer_dict.items():
if normalizer is not None:
sample_target = Tensor(train_set.df[target].values)
if isinstance(train_set, Subset):
sample_target = Tensor(
train_set.dataset.df[target].iloc[train_set.indices].values
)
else:
sample_target = Tensor(train_set.df[target].values)

if not restart_params["resume"]:
normalizer.fit(sample_target)
print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}")
Expand Down Expand Up @@ -455,10 +457,6 @@ def results_multitask(
"------------Evaluate model on Test Set------------\n"
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
)

if isinstance(test_set, Subset):
test_set = test_set.dataset

test_loader = DataLoader(test_set, **data_params)
print(f"Testing on {len(test_set):,} samples")

Expand Down
1 change: 1 addition & 0 deletions tests/test_cgcnn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,6 @@ def test_cgcnn_clf(df_matbench_phonons):

ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert ens_acc > 0.85
assert ens_roc_auc > 0.9
1 change: 1 addition & 0 deletions tests/test_cgcnn_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_cgcnn_regression(df_matbench_phonons):

mae, rmse, r2 = get_metrics(targets, y_ens, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert r2 > 0.7
assert mae < 150
assert rmse < 300
2 changes: 2 additions & 0 deletions tests/test_roost_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,7 @@ def test_roost_clf(df_matbench_phonons):

ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()

assert len(logits) == ensemble
assert len(targets) == len(test_set) == len(test_idx)
assert ens_acc > 0.9
assert ens_roc_auc > 0.9
1 change: 1 addition & 0 deletions tests/test_roost_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_roost_regression(df_matbench_phonons):

mae, rmse, r2 = get_metrics(targets, y_ens, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert r2 > 0.7
assert mae < 150
assert rmse < 300
1 change: 1 addition & 0 deletions tests/test_wren_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,6 @@ def test_wren_clf(df_matbench_phonons_wyckoff):

ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert ens_acc > 0.85
assert ens_roc_auc > 0.9
1 change: 1 addition & 0 deletions tests/test_wren_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_wren_regression(df_matbench_phonons_wyckoff):

mae, rmse, r2 = get_metrics(targets, y_ens, task).values()

assert len(targets) == len(test_set) == len(test_idx)
assert r2 > 0.7
assert mae < 150
assert rmse < 300
Loading