diff --git a/src/qusi/experimental/model.py b/src/qusi/experimental/model.py index 258be56..d4e2f19 100644 --- a/src/qusi/experimental/model.py +++ b/src/qusi/experimental/model.py @@ -1,9 +1,11 @@ """ Neural network model related public interface. """ -from qusi.internal.hadryss_model import HadryssBinaryClassEndModule, HadryssMultiClassEndModule +from qusi.internal.hadryss_model import HadryssBinaryClassEndModule, \ + HadryssMultiClassScoreEndModule, HadryssMultiClassProbabilityEndModule __all__ = [ 'HadryssBinaryClassEndModule', - 'HadryssMultiClassEndModule', + 'HadryssMultiClassScoreEndModule', + 'HadryssMultiClassProbabilityEndModule', ] diff --git a/src/qusi/internal/hadryss_model.py b/src/qusi/internal/hadryss_model.py index bab3717..00c77f2 100644 --- a/src/qusi/internal/hadryss_model.py +++ b/src/qusi/internal/hadryss_model.py @@ -239,7 +239,7 @@ def new(cls): return cls() -class HadryssMultiClassEndModule(Module): +class HadryssMultiClassProbabilityEndModule(Module): """ A module for the end of the Hadryss model designed for multi classification. """ @@ -260,7 +260,7 @@ def new(cls, number_of_classes: int): return cls(number_of_classes) -class HadryssMultiClassEndModule2(Module): # TODO: Temporary test for Abhina. +class HadryssMultiClassScoreEndModule(Module): """ A module for the end of the Hadryss model designed for multi classification without softmax. """ diff --git a/src/qusi/internal/metric.py b/src/qusi/internal/metric.py index d50a9ce..22860be 100644 --- a/src/qusi/internal/metric.py +++ b/src/qusi/internal/metric.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torch.nn import NLLLoss, Module, CrossEntropyLoss +from torch.nn import Module, CrossEntropyLoss, Softmax from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy @@ -9,21 +9,6 @@ class CrossEntropyAlt(Module): def new(cls): return cls() - def __init__(self): - super().__init__() - self.nll_loss = NLLLoss() - - def __call__(self, preds: Tensor, target: Tensor): - predicted_log_probabilities = torch.log(preds) - target_int = target.to(torch.int64) - cross_entropy = self.nll_loss(predicted_log_probabilities, target_int) - return cross_entropy - -class CrossEntropyAlt2(Module): # TODO: Temporary test for Abhina. - @classmethod - def new(cls): - return cls() - def __init__(self): super().__init__() self.cross_entropy = CrossEntropyLoss() @@ -33,6 +18,7 @@ def __call__(self, preds: Tensor, target: Tensor): cross_entropy = self.cross_entropy(preds, target_int) return cross_entropy + class MulticlassAUROCAlt(Module): @classmethod def new(cls, number_of_classes: int): @@ -41,12 +27,15 @@ def new(cls, number_of_classes: int): def __init__(self, number_of_classes: int): super().__init__() self.multiclass_auroc = MulticlassAUROC(num_classes=number_of_classes) + self.softmax = Softmax() def __call__(self, preds: Tensor, target: Tensor): + probabilities = self.softmax(preds) target_int = target.to(torch.int64) - cross_entropy = self.multiclass_auroc(preds, target_int) + cross_entropy = self.multiclass_auroc(probabilities, target_int) return cross_entropy + class MulticlassAccuracyAlt(Module): @classmethod def new(cls, number_of_classes: int): @@ -55,8 +44,10 @@ def new(cls, number_of_classes: int): def __init__(self, number_of_classes: int): super().__init__() self.multiclass_accuracy = MulticlassAccuracy(num_classes=number_of_classes) + self.softmax = Softmax() def __call__(self, preds: Tensor, target: Tensor): + probabilities = self.softmax(preds) target_int = target.to(torch.int64) - cross_entropy = self.multiclass_accuracy(preds, target_int) + cross_entropy = self.multiclass_accuracy(probabilities, target_int) return cross_entropy