Skip to content

Commit 92e7161

Browse files
committed
Add alternative cross entropy loss
1 parent 3dbc712 commit 92e7161

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

src/qusi/internal/hadryss_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,24 @@ def forward(self, x: Tensor) -> Tensor:
258258
@classmethod
259259
def new(cls, number_of_classes: int):
260260
return cls(number_of_classes)
261+
262+
263+
class HadryssMultiClassEndModule2(Module): # TODO: Temporary test for Abhina.
264+
"""
265+
A module for the end of the Hadryss model designed for multi classification without softmax.
266+
"""
267+
def __init__(self, number_of_classes: int):
268+
super().__init__()
269+
self.number_of_classes: int = number_of_classes
270+
self.prediction_layer = Conv1d(in_channels=20, out_channels=self.number_of_classes, kernel_size=1)
271+
self.soft_max = Softmax(dim=1)
272+
273+
def forward(self, x: Tensor) -> Tensor:
274+
x = self.prediction_layer(x)
275+
x = self.soft_max(x)
276+
x = torch.reshape(x, (-1, self.number_of_classes))
277+
return x
278+
279+
@classmethod
280+
def new(cls, number_of_classes: int):
281+
return cls(number_of_classes)

src/qusi/internal/metric.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from torch.nn import NLLLoss, Module
3+
from torch.nn import NLLLoss, Module, CrossEntropyLoss
44
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy
55

66

@@ -19,6 +19,20 @@ def __call__(self, preds: Tensor, target: Tensor):
1919
cross_entropy = self.nll_loss(predicted_log_probabilities, target_int)
2020
return cross_entropy
2121

22+
class CrossEntropyAlt2(Module): # TODO: Temporary test for Abhina.
23+
@classmethod
24+
def new(cls):
25+
return cls()
26+
27+
def __init__(self):
28+
super().__init__()
29+
self.cross_entropy = CrossEntropyLoss()
30+
31+
def __call__(self, preds: Tensor, target: Tensor):
32+
target_int = target.to(torch.int64)
33+
cross_entropy = self.cross_entropy(preds, target_int)
34+
return cross_entropy
35+
2236
class MulticlassAUROCAlt(Module):
2337
@classmethod
2438
def new(cls, number_of_classes: int):

0 commit comments

Comments
 (0)