Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import numpy as np
import pandas as pd
import torch
from torch.nn.functional import softmax
from tqdm import tqdm

from aviary.core import Normalizer
from aviary.core import Normalizer, sampled_softmax
from aviary.utils import get_metrics, print_walltime

if TYPE_CHECKING:
Expand Down Expand Up @@ -103,19 +104,30 @@ def make_ensemble_predictions(
pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"

if model.robust:
preds, aleat_log_std = preds.T
ale_col = (
f"{target_col}_aleatoric_std_{idx}"
if target_col
else f"aleatoric_std_{idx}"
)
df[pred_col] = preds
df[ale_col] = aleatoric_std = np.exp(aleat_log_std)
if task_type == "regression":
preds, aleat_log_std = preds.T
ale_col = (
f"{target_col}_aleatoric_std_{idx}"
if target_col
else f"aleatoric_std_{idx}"
)
df[pred_col] = preds
df[ale_col] = aleatoric_std = np.exp(aleat_log_std)
elif task_type == "classification":
# need to convert to tensor to use `sampled_softmax`
preds = torch.from_numpy(preds).to(device)
pre_logits, log_std = preds.chunk(2, dim=1)
logits = sampled_softmax(pre_logits, log_std)
df[pred_col] = logits.argmax(dim=1).cpu().numpy()
else:
df[pred_col] = preds
if task_type == "regression":
df[pred_col] = preds
else:
logits = softmax(preds, dim=1)
df[pred_col] = logits.argmax(dim=1).cpu().numpy()

# denormalize predictions if a normalizer was used during training
if "normalizer_dict" in checkpoint:
if checkpoint["normalizer_dict"][target_name] is not None:
assert task_type == "regression", "Normalization only takes place for regression."
normalizer = Normalizer.from_state_dict(
checkpoint["normalizer_dict"][target_name]
Expand Down
Loading