Skip to content

Commit

Permalink
Merge pull request #50 from CompRhys/wrenformer-ensemble-preds
Browse files Browse the repository at this point in the history
Wrenformer ensemble predictions
  • Loading branch information
janosh authored Jun 24, 2022
2 parents 79417d5 + 4d0470f commit 5d012d7
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 22 deletions.
14 changes: 7 additions & 7 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,8 @@ def save_results_dict(


def get_metrics(
targets: np.ndarray,
predictions: np.ndarray,
targets: np.ndarray | pd.Series,
predictions: np.ndarray | pd.Series,
type: Literal["regression", "classification"],
prec: int = 4,
) -> dict:
Expand All @@ -791,17 +791,17 @@ def get_metrics(
metrics = {}

if type == "regression":
metrics["mae"] = np.abs(targets - predictions).mean()
metrics["rmse"] = ((targets - predictions) ** 2).mean() ** 0.5
metrics["r2"] = r2_score(targets, predictions)
metrics["MAE"] = np.abs(targets - predictions).mean()
metrics["RMSE"] = ((targets - predictions) ** 2).mean() ** 0.5
metrics["R2"] = r2_score(targets, predictions)
elif type == "classification":
pred_labels = predictions.argmax(axis=1)

metrics["accuracy"] = accuracy_score(targets, pred_labels)
metrics["balanced_accuracy"] = balanced_accuracy_score(targets, pred_labels)
metrics["f1"] = f1_score(targets, pred_labels)
metrics["F1"] = f1_score(targets, pred_labels)
class1_probas = predictions[:, 1]
metrics["rocauc"] = roc_auc_score(targets, class1_probas)
metrics["ROCAUC"] = roc_auc_score(targets, class1_probas)

metrics = {key: round(float(val), prec) for key, val in metrics.items()}

Expand Down
18 changes: 12 additions & 6 deletions aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def get_composition_embedding(formula: str) -> Tensor:

def df_to_in_mem_dataloader(
df: pd.DataFrame,
target_col: str,
input_col: str = "wyckoff",
id_col: str = "material_id",
target_col: str = None,
id_col: str = None,
embedding_type: Literal["wyckoff", "composition"] = "wyckoff",
device: str = None,
**kwargs,
Expand All @@ -137,10 +137,12 @@ def df_to_in_mem_dataloader(
Args:
df (pd.DataFrame): Expected to have columns input_col, target_col, id_col.
target_col (str): Column name holding the target values.
input_col (str): Column name holding the input values (Aflow Wyckoff labels or composition
strings) from which initial embeddings will be constructed. Defaults to "wyckoff".
id_col (str): Column name holding material identifiers. Defaults to "material_id".
target_col (str): Column name holding the target values. Defaults to None. Only leave this
empty if making predictions since target tensor will be set to list of Nones.
id_col (str): Column name holding sample IDs. Defaults to None. If None, IDs will be
the dataframe index.
embedding_type ('wyckoff' | 'composition'): Defaults to "wyckoff".
device (str): torch.device to load tensors onto. Defaults to
"cuda" if torch.cuda.is_available() else "cpu".
Expand All @@ -162,14 +164,18 @@ def df_to_in_mem_dataloader(
if embedding_type == "wyckoff"
else get_composition_embedding
)
targets = torch.tensor(df[target_col], device=device)
targets = (
torch.tensor(df[target_col].to_numpy(), device=device)
if target_col in df
else np.empty(len(df))
)
if targets.dtype == torch.bool:
targets = targets.long() # convert binary classification targets to 0 and 1
inputs = np.empty(len(initial_embeddings), dtype=object)
for idx, tensor in enumerate(initial_embeddings):
inputs[idx] = tensor.to(device)

ids = df[id_col].to_numpy()
ids = (df[id_col] if id_col in df else df.index).to_numpy()
data_loader = InMemoryDataLoader(
[inputs, targets, ids], collate_fn=collate_batch, **kwargs
)
Expand Down
125 changes: 123 additions & 2 deletions aviary/wrenformer/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from __future__ import annotations

import json
import time
from contextlib import contextmanager
from typing import Generator, Literal

import pandas as pd
import torch
from tqdm import tqdm

from aviary.core import BaseModelClass
from aviary.utils import get_metrics
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer

__author__ = "Janosh Riebesell"
__date__ = "2022-05-10"


def _int_keys(dct: dict) -> dict:
# JSON stringifies all dict keys during serialization and does not revert
Expand Down Expand Up @@ -45,14 +59,14 @@ def merge_json_on_disk(
pass

def non_serializable_handler(obj: object) -> str:
# replace functions and classes in dct with string indicating a non-serializable type
# replace functions and classes in dct with string indicating it's a non-serializable type
return f"<not serializable: {type(obj).__qualname__}>"

with open(file_path, "w") as file:
default = (
non_serializable_handler if on_non_serializable == "annotate" else None
)
json.dump(dct, file, default=default)
json.dump(dct, file, default=default, indent=2)


@contextmanager
Expand All @@ -78,3 +92,110 @@ def print_walltime(
finally:
run_time = time.perf_counter() - start_time
print(f"{end_desc} took {run_time:.2f} sec")


def make_ensemble_predictions(
checkpoint_paths: list[str],
df: pd.DataFrame,
target_col: str = None,
input_col: str = "wyckoff",
model_class: type[BaseModelClass] = Wrenformer,
device: str = None,
print_metrics: bool = True,
warn_target_mismatch: bool = False,
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
"""Make predictions using an ensemble of Wrenformer models.
Args:
checkpoint_paths (list[str]): File paths to model checkpoints created with torch.save().
df (pd.DataFrame): Dataframe to make predictions on. Will be returned with additional
columns holding model predictions (and uncertainties for robust models) for each
model checkpoint.
target_col (str): Column holding target values. Defaults to None. If None, will not print
performance metrics.
input_col (str, optional): Column holding input values. Defaults to 'wyckoff'.
device (str, optional): torch.device. Defaults to "cuda" if torch.cuda.is_available()
else "cpu".
print_metrics (bool, optional): Whether to print performance metrics. Defaults to True
if target_col is not None.
warn_target_mismatch (bool, optional): Whether to warn if target_col != target_name from
model checkpoint. Defaults to False.
Returns:
pd.DataFrame: Input dataframe with added columns for model and ensemble predictions. If
target_col is not None, returns a 2nd dataframe containing model and ensemble metrics.
"""
# TODO: Add support for predicting all tasks a multi-task models was trained on. Currently only
# handles single targets.
device = device or ("cuda" if torch.cuda.is_available() else "cpu")

data_loader = df_to_in_mem_dataloader(
df=df,
target_col=target_col,
input_col=input_col,
batch_size=512,
embedding_type="wyckoff",
)

print(f"Predicting with {len(checkpoint_paths):,} model checkpoints(s)")

for idx, checkpoint_path in enumerate(tqdm(checkpoint_paths), 1):
checkpoint = torch.load(checkpoint_path, map_location=device)

model_params = checkpoint["model_params"]
target_name, task_type = list(model_params["task_dict"].items())[0]
assert task_type in ("regression", "classification"), f"invalid {task_type = }"
if target_name != target_col and warn_target_mismatch:
print(
f"Warning: {target_col = } does not match {target_name = } in checkpoint. "
"If this is not by accident, disable this warning by passing warn_target=False."
)
model = model_class(**model_params)
model.to(device)

model.load_state_dict(checkpoint["model_state"])

with torch.no_grad():
predictions = torch.cat([model(*inputs)[0] for inputs, *_ in data_loader])

if model.robust:
predictions, aleat_log_std = predictions.chunk(2, dim=1)
aleat_std = aleat_log_std.exp().cpu().numpy().squeeze()
df[f"aleatoric_std_{idx}"] = aleat_std.tolist()

predictions = predictions.cpu().numpy().squeeze()
pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"
df[pred_col] = predictions.tolist()

df_preds = df.filter(regex=r"_pred_\d")
df[f"{target_col}_pred_ens"] = ensemble_preds = df_preds.mean(axis=1)
df[f"{target_col}_epistemic_std_ens"] = epistemic_std = df_preds.std(axis=1)

if df.columns.str.startswith("aleatoric_std_").sum() > 0:
aleatoric_std = df.filter(regex=r"aleatoric_std_\d").mean(axis=1)
df[f"{target_col}_aleatoric_std_ens"] = aleatoric_std
df[f"{target_col}_total_std_ens"] = (
epistemic_std**2 + aleatoric_std**2
) ** 0.5

if target_col and print_metrics:
targets = df[target_col]
all_model_metrics = pd.DataFrame(
[
get_metrics(targets, df_preds[pred_col], task_type)
for pred_col in df_preds
],
index=df_preds.columns,
)

print("\nSingle model performance:")
print(all_model_metrics.describe().round(4).loc[["mean", "std"]])

ensemble_metrics = get_metrics(targets, ensemble_preds, task_type)

print("\nEnsemble performance:")
for key, val in ensemble_metrics.items():
print(f"{key:<8} {val:.3}")
return df, all_model_metrics

return df
81 changes: 81 additions & 0 deletions examples/mp_wbm/use_trained_wrenformer_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import os
from glob import glob

import pandas as pd
import wandb

from aviary import ROOT
from aviary.wrenformer.utils import make_ensemble_predictions

__author__ = "Janosh Riebesell"
__date__ = "2022-06-23"

"""Script that downloads checkpoints for an ensemble of Wrenformer models trained on
the MP+WBM dataset and makes predictions on the test set, then prints ensemble metrics.
"""


data_path = f"{ROOT}/datasets/2022-06-09-mp+wbm.json.gz"
target_col = "e_form"
test_size = 0.05
df = pd.read_json(data_path)
# shuffle with same random seed as in run_wrenformer() to get identical train/test split
df = df.sample(frac=1, random_state=0)
train_df = df.sample(frac=1 - test_size, random_state=0)
test_df = df.drop(train_df.index)


load_checkpoints_from_wandb = True

if load_checkpoints_from_wandb:
wandb.login()
wandb_api = wandb.Api()

runs = wandb_api.runs("aviary/mp-wbm", filters={"tags": {"$in": ["ensemble-id-2"]}})

print(
f"Loading checkpoints for the following run IDs:\n{', '.join(run.id for run in runs)}\n"
)

checkpoint_paths: list[str] = []
for run in runs:
run_path = "/".join(run.path)
checkpoint_dir = f"{ROOT}/.wandb_checkpoints/{run_path}"
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_path = f"{checkpoint_dir}/checkpoint.pth"
checkpoint_paths.append(checkpoint_path)

# download checkpoint from wandb if not already present
if os.path.isfile(checkpoint_path):
continue
wandb.restore("checkpoint.pth", root=checkpoint_dir, run_path=run_path)
else:
# load checkpoints from local run dirs
checkpoint_paths = glob(
f"{ROOT}/examples/mp_wbm/job-logs/wandb/run-20220621_13*/files/checkpoint.pth"
)

print(f"Predicting with {len(checkpoint_paths):,} model checkpoints(s)")

test_df, ensemble_metrics = make_ensemble_predictions(
checkpoint_paths, df=test_df, target_col=target_col
)

test_df.to_csv(f"{ROOT}/examples/mp_wbm/ensemble-predictions.csv")


# print output:
# Predicting with 10 model checkpoints(s)
#
# Single model performance:
# MAE RMSE R2
# mean 0.0369 0.1218 0.9864
# std 0.0005 0.0014 0.0003
#
# Ensemble performance:
# MAE 0.0308
# RMSE 0.118
# R2 0.987
7 changes: 4 additions & 3 deletions examples/wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def run_wrenformer(

# save model checkpoint
if checkpoint is not None:
state_dict = {
checkpoint_dict = {
"model_params": model_params,
"model_state": inference_model.state_dict(),
"optimizer_state": optimizer_instance.state_dict(),
Expand All @@ -327,16 +327,17 @@ def run_wrenformer(
"metrics": test_metrics,
"run_name": run_name,
"normalizer_dict": normalizer_dict,
"run_params": run_params,
}
if checkpoint == "local":
os.makedirs(f"{ROOT}/models", exist_ok=True)
checkpoint_path = f"{ROOT}/models/{timestamp}-{run_name}.pth"
torch.save(state_dict, checkpoint_path)
torch.save(checkpoint_dict, checkpoint_path)
if checkpoint == "wandb":
assert (
wandb_project and wandb.run is not None
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
torch.save(state_dict, f"{wandb.run.dir}/checkpoint.pth")
torch.save(checkpoint_dict, f"{wandb.run.dir}/checkpoint.pth")

# record test set metrics and scatter/ROC plots to wandb
if wandb_project:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def test_wrenformer_regression(df_matbench_phonons_wyckoff):
epochs=30,
)

assert test_metrics["mae"] < 260, test_metrics
assert test_metrics["rmse"] < 420, test_metrics
assert test_metrics["r2"] > 0.1, test_metrics
assert test_metrics["MAE"] < 260, test_metrics
assert test_metrics["RMSE"] < 420, test_metrics
assert test_metrics["R2"] > 0.1, test_metrics


def test_wrenformer_classification(df_matbench_phonons_wyckoff):
Expand All @@ -36,4 +36,4 @@ def test_wrenformer_classification(df_matbench_phonons_wyckoff):
)

assert test_metrics["accuracy"] > 0.7, test_metrics
assert test_metrics["rocauc"] > 0.8, test_metrics
assert test_metrics["ROCAUC"] > 0.8, test_metrics

0 comments on commit 5d012d7

Please sign in to comment.