Skip to content

Commit

Permalink
add linear eval
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-santiago committed Aug 30, 2023
1 parent b9dbb3d commit d140f83
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions autoencoders/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ class Constants:
SRC = HERE.parents[0]
REPO = HERE.parents[1]
DATA = REPO.joinpath("data")
OUTPUTS = REPO.joinpath("outputs")
SEED = 43
43 changes: 43 additions & 0 deletions autoencoders/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import pathlib
from typing import Dict, Union

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from torchmetrics.classification import AUROC, MulticlassAccuracy

from autoencoders.data import get_mnist_dataset


def evaluate_linear(
module: pl.LightningModule, trainer: pl.Trainer, train_length: int = 8000, n_classes: int = 10
):
ckpt_path = trainer.checkpoint_callback.best_model_path
encoder = module.load_from_checkpoint(ckpt_path, map_location=torch.device("cpu"))

ds = get_mnist_dataset(train=False)
x_train = encoder.encode(ds.data[:train_length].unsqueeze(1) / 255).numpy()
y_train = ds.targets[:train_length].numpy()
x_test = encoder.encode(ds.data[train_length:].unsqueeze(1) / 255).numpy()
y_test = ds.targets[train_length:]

lr = LogisticRegression(max_iter=300)
lr.fit(x_train, y_train)
labels = lr.predict(x_test)
labels_ohe = F.one_hot(torch.tensor(labels)).float()

acc = MulticlassAccuracy(num_classes=n_classes)
auc = AUROC(task="multiclass", num_classes=n_classes)

# pull out .item() for metrics tensors as tensors are not json serializable
return {
"acc": round(acc(torch.tensor(labels), y_test).item(), 4),
"auc": round(auc(labels_ohe, y_test).item(), 4),
}


def to_json(results: Dict, filepath: Union[pathlib.Path, str]):
with open(filepath, "a") as fp:
fp.write(json.dumps(results) + "\n")
9 changes: 9 additions & 0 deletions autoencoders/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import hydra
import omegaconf

import autoencoders.constants
import autoencoders.eval
import autoencoders.utils

constants = autoencoders.constants.Constants()


@hydra.main(config_path="conf", config_name="config", version_base="1.3")
def main(cfg):
Expand All @@ -22,6 +26,11 @@ def main(cfg):
trainer.fit(model=model, train_dataloaders=train_dl)
trainer.checkpoint_callback.to_yaml()

results = autoencoders.eval.evaluate_linear(module=model, trainer=trainer)
autoencoders.eval.to_json(
results={cfg.model.name: results}, filepath=constants.OUTPUTS.joinpath("results.json")
)


if __name__ == "__main__":
main()

0 comments on commit d140f83

Please sign in to comment.