diff --git a/aviary/predict.py b/aviary/predict.py index cce51734..d56c33ab 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -7,9 +7,10 @@ import numpy as np import pandas as pd import torch +from torch.nn.functional import softmax from tqdm import tqdm -from aviary.core import Normalizer +from aviary.core import Normalizer, sampled_softmax from aviary.utils import get_metrics, print_walltime if TYPE_CHECKING: @@ -103,19 +104,30 @@ def make_ensemble_predictions( pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}" if model.robust: - preds, aleat_log_std = preds.T - ale_col = ( - f"{target_col}_aleatoric_std_{idx}" - if target_col - else f"aleatoric_std_{idx}" - ) - df[pred_col] = preds - df[ale_col] = aleatoric_std = np.exp(aleat_log_std) + if task_type == "regression": + preds, aleat_log_std = preds.T + ale_col = ( + f"{target_col}_aleatoric_std_{idx}" + if target_col + else f"aleatoric_std_{idx}" + ) + df[pred_col] = preds + df[ale_col] = aleatoric_std = np.exp(aleat_log_std) + elif task_type == "classification": + # need to convert to tensor to use `sampled_softmax` + preds = torch.from_numpy(preds).to(device) + pre_logits, log_std = preds.chunk(2, dim=1) + logits = sampled_softmax(pre_logits, log_std) + df[pred_col] = logits.argmax(dim=1).cpu().numpy() else: - df[pred_col] = preds + if task_type == "regression": + df[pred_col] = preds + else: + logits = softmax(preds, dim=1) + df[pred_col] = logits.argmax(dim=1).cpu().numpy() # denormalize predictions if a normalizer was used during training - if "normalizer_dict" in checkpoint: + if checkpoint["normalizer_dict"][target_name] is not None: assert task_type == "regression", "Normalization only takes place for regression." normalizer = Normalizer.from_state_dict( checkpoint["normalizer_dict"][target_name]