Skip to content

Commit

Permalink
Switch the multiclass modules to be based on scores rather than proba…
Browse files Browse the repository at this point in the history
…bilities
  • Loading branch information
golmschenk committed Aug 5, 2024
1 parent 8dfc089 commit 468e1ee
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
6 changes: 4 additions & 2 deletions src/qusi/experimental/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Neural network model related public interface.
"""
from qusi.internal.hadryss_model import HadryssBinaryClassEndModule, HadryssMultiClassEndModule
from qusi.internal.hadryss_model import HadryssBinaryClassEndModule, \
HadryssMultiClassScoreEndModule, HadryssMultiClassProbabilityEndModule

__all__ = [
'HadryssBinaryClassEndModule',
'HadryssMultiClassEndModule',
'HadryssMultiClassScoreEndModule',
'HadryssMultiClassProbabilityEndModule',
]
4 changes: 2 additions & 2 deletions src/qusi/internal/hadryss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def new(cls):
return cls()


class HadryssMultiClassEndModule(Module):
class HadryssMultiClassProbabilityEndModule(Module):
"""
A module for the end of the Hadryss model designed for multi classification.
"""
Expand All @@ -260,7 +260,7 @@ def new(cls, number_of_classes: int):
return cls(number_of_classes)


class HadryssMultiClassEndModule2(Module): # TODO: Temporary test for Abhina.
class HadryssMultiClassScoreEndModule(Module):
"""
A module for the end of the Hadryss model designed for multi classification without softmax.
"""
Expand Down
27 changes: 9 additions & 18 deletions src/qusi/internal/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import Tensor
from torch.nn import NLLLoss, Module, CrossEntropyLoss
from torch.nn import Module, CrossEntropyLoss, Softmax
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy


Expand All @@ -9,21 +9,6 @@ class CrossEntropyAlt(Module):
def new(cls):
return cls()

def __init__(self):
super().__init__()
self.nll_loss = NLLLoss()

def __call__(self, preds: Tensor, target: Tensor):
predicted_log_probabilities = torch.log(preds)
target_int = target.to(torch.int64)
cross_entropy = self.nll_loss(predicted_log_probabilities, target_int)
return cross_entropy

class CrossEntropyAlt2(Module): # TODO: Temporary test for Abhina.
@classmethod
def new(cls):
return cls()

def __init__(self):
super().__init__()
self.cross_entropy = CrossEntropyLoss()
Expand All @@ -33,6 +18,7 @@ def __call__(self, preds: Tensor, target: Tensor):
cross_entropy = self.cross_entropy(preds, target_int)
return cross_entropy


class MulticlassAUROCAlt(Module):
@classmethod
def new(cls, number_of_classes: int):
Expand All @@ -41,12 +27,15 @@ def new(cls, number_of_classes: int):
def __init__(self, number_of_classes: int):
super().__init__()
self.multiclass_auroc = MulticlassAUROC(num_classes=number_of_classes)
self.softmax = Softmax()

def __call__(self, preds: Tensor, target: Tensor):
probabilities = self.softmax(preds)
target_int = target.to(torch.int64)
cross_entropy = self.multiclass_auroc(preds, target_int)
cross_entropy = self.multiclass_auroc(probabilities, target_int)
return cross_entropy


class MulticlassAccuracyAlt(Module):
@classmethod
def new(cls, number_of_classes: int):
Expand All @@ -55,8 +44,10 @@ def new(cls, number_of_classes: int):
def __init__(self, number_of_classes: int):
super().__init__()
self.multiclass_accuracy = MulticlassAccuracy(num_classes=number_of_classes)
self.softmax = Softmax()

def __call__(self, preds: Tensor, target: Tensor):
probabilities = self.softmax(preds)
target_int = target.to(torch.int64)
cross_entropy = self.multiclass_accuracy(preds, target_int)
cross_entropy = self.multiclass_accuracy(probabilities, target_int)
return cross_entropy

0 comments on commit 468e1ee

Please sign in to comment.