From 468e1ee3d22069d23df99bab9e11945e5e72a15c Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 5 Aug 2024 13:31:50 -0400 Subject: [PATCH] Switch the multiclass modules to be based on scores rather than probabilities --- src/qusi/experimental/model.py | 6 ++++-- src/qusi/internal/hadryss_model.py | 4 ++-- src/qusi/internal/metric.py | 27 +++++++++------------------ 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/qusi/experimental/model.py b/src/qusi/experimental/model.py index 258be56..d4e2f19 100644 --- a/src/qusi/experimental/model.py +++ b/src/qusi/experimental/model.py @@ -1,9 +1,11 @@ """ Neural network model related public interface. """ -from qusi.internal.hadryss_model import HadryssBinaryClassEndModule, HadryssMultiClassEndModule +from qusi.internal.hadryss_model import HadryssBinaryClassEndModule, \ + HadryssMultiClassScoreEndModule, HadryssMultiClassProbabilityEndModule __all__ = [ 'HadryssBinaryClassEndModule', - 'HadryssMultiClassEndModule', + 'HadryssMultiClassScoreEndModule', + 'HadryssMultiClassProbabilityEndModule', ] diff --git a/src/qusi/internal/hadryss_model.py b/src/qusi/internal/hadryss_model.py index bab3717..00c77f2 100644 --- a/src/qusi/internal/hadryss_model.py +++ b/src/qusi/internal/hadryss_model.py @@ -239,7 +239,7 @@ def new(cls): return cls() -class HadryssMultiClassEndModule(Module): +class HadryssMultiClassProbabilityEndModule(Module): """ A module for the end of the Hadryss model designed for multi classification. """ @@ -260,7 +260,7 @@ def new(cls, number_of_classes: int): return cls(number_of_classes) -class HadryssMultiClassEndModule2(Module): # TODO: Temporary test for Abhina. +class HadryssMultiClassScoreEndModule(Module): """ A module for the end of the Hadryss model designed for multi classification without softmax. """ diff --git a/src/qusi/internal/metric.py b/src/qusi/internal/metric.py index d50a9ce..22860be 100644 --- a/src/qusi/internal/metric.py +++ b/src/qusi/internal/metric.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torch.nn import NLLLoss, Module, CrossEntropyLoss +from torch.nn import Module, CrossEntropyLoss, Softmax from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy @@ -9,21 +9,6 @@ class CrossEntropyAlt(Module): def new(cls): return cls() - def __init__(self): - super().__init__() - self.nll_loss = NLLLoss() - - def __call__(self, preds: Tensor, target: Tensor): - predicted_log_probabilities = torch.log(preds) - target_int = target.to(torch.int64) - cross_entropy = self.nll_loss(predicted_log_probabilities, target_int) - return cross_entropy - -class CrossEntropyAlt2(Module): # TODO: Temporary test for Abhina. - @classmethod - def new(cls): - return cls() - def __init__(self): super().__init__() self.cross_entropy = CrossEntropyLoss() @@ -33,6 +18,7 @@ def __call__(self, preds: Tensor, target: Tensor): cross_entropy = self.cross_entropy(preds, target_int) return cross_entropy + class MulticlassAUROCAlt(Module): @classmethod def new(cls, number_of_classes: int): @@ -41,12 +27,15 @@ def new(cls, number_of_classes: int): def __init__(self, number_of_classes: int): super().__init__() self.multiclass_auroc = MulticlassAUROC(num_classes=number_of_classes) + self.softmax = Softmax() def __call__(self, preds: Tensor, target: Tensor): + probabilities = self.softmax(preds) target_int = target.to(torch.int64) - cross_entropy = self.multiclass_auroc(preds, target_int) + cross_entropy = self.multiclass_auroc(probabilities, target_int) return cross_entropy + class MulticlassAccuracyAlt(Module): @classmethod def new(cls, number_of_classes: int): @@ -55,8 +44,10 @@ def new(cls, number_of_classes: int): def __init__(self, number_of_classes: int): super().__init__() self.multiclass_accuracy = MulticlassAccuracy(num_classes=number_of_classes) + self.softmax = Softmax() def __call__(self, preds: Tensor, target: Tensor): + probabilities = self.softmax(preds) target_int = target.to(torch.int64) - cross_entropy = self.multiclass_accuracy(preds, target_int) + cross_entropy = self.multiclass_accuracy(probabilities, target_int) return cross_entropy