Skip to content

Commit

Permalink
Add experimental multiclass configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Jun 12, 2024
1 parent 3a88036 commit 50e3259
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 18 deletions.
11 changes: 11 additions & 0 deletions src/qusi/experimental/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
Metric related public interface.
"""
from qusi.internal.metric import CrossEntropyAlt, MulticlassAccuracyAlt, MulticlassAUROCAlt

__all__ = [
'CrossEntropyAlt',
'MulticlassAccuracyAlt',
'MulticlassAUROCAlt',
]

9 changes: 9 additions & 0 deletions src/qusi/experimental/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Neural network model related public interface.
"""
from qusi.internal.hadryss_model import HadryssBinaryClassEndModule, HadryssMultiClassEndModule

__all__ = [
'HadryssBinaryClassEndModule',
'HadryssMultiClassEndModule',
]
57 changes: 47 additions & 10 deletions src/qusi/internal/hadryss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MaxPool1d,
Module,
Sigmoid,
Softmax,
)
from typing_extensions import Self

Expand All @@ -22,11 +23,10 @@ class Hadryss(Module):
A 1D convolutional neural network model for light curve data that will auto-size itself for a given input light
curve length.
"""
def __init__(self, *, input_length: int):
def __init__(self, *, input_length: int, end_module: Module):
super().__init__()
self.input_length: int = input_length
pooling_sizes, dense_size = self.determine_block_pooling_sizes_and_dense_size()
self.sigmoid = Sigmoid()
self.block0 = LightCurveNetworkBlock(
input_channels=1,
output_channels=8,
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, *, input_length: int):
self.block10 = LightCurveNetworkBlock(
input_channels=20, output_channels=20, kernel_size=1, pooling_size=1
)
self.prediction_layer = Conv1d(in_channels=20, out_channels=1, kernel_size=1)
self.end_module = end_module

def forward(self, x: Tensor) -> Tensor:
x = x.reshape([-1, 1, self.input_length])
Expand All @@ -121,20 +121,21 @@ def forward(self, x: Tensor) -> Tensor:
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.prediction_layer(x)
x = self.sigmoid(x)
x = torch.reshape(x, (-1,))
x = self.end_module(x)
return x

@classmethod
def new(cls, input_length: int = 3500) -> Self:
def new(cls, input_length: int = 3500, end_module: Module | None = None) -> Self:
"""
Creates a new Hadryss model.
:param input_length: The length of the input to auto-size the network to.
:param end_module: The end module of the network. Defaults to a `HadryssBinaryClassEndModule`.
:return: The model.
"""
instance = cls(input_length=input_length)
if end_module is None:
end_module = HadryssBinaryClassEndModule.new()
instance = cls(input_length=input_length, end_module=end_module)
return instance

def determine_block_pooling_sizes_and_dense_size(self) -> (list[int], int):
Expand Down Expand Up @@ -211,7 +212,43 @@ def forward(self, x):
if not self.spatial:
old_shape = x.shape
x = torch.reshape(x, [-1, torch.prod(torch.tensor(old_shape[1:]))])
x = self.batch_normalization(x)
if not self.spatial:
x = self.batch_normalization(x)
x = torch.reshape(x, old_shape)
else:
x = self.batch_normalization(x)
return x


class HadryssBinaryClassEndModule(Module):
def __init__(self):
super().__init__()
self.prediction_layer = Conv1d(in_channels=20, out_channels=1, kernel_size=1)
self.sigmoid = Sigmoid()

def forward(self, x: Tensor) -> Tensor:
x = self.prediction_layer(x)
x = self.sigmoid(x)
x = torch.reshape(x, (-1,))
return x

@classmethod
def new(cls):
return cls()


class HadryssMultiClassEndModule(Module):
def __init__(self, number_of_classes: int):
super().__init__()
self.number_of_classes: int = number_of_classes
self.prediction_layer = Conv1d(in_channels=20, out_channels=self.number_of_classes, kernel_size=1)
self.soft_max = Softmax(dim=1)

def forward(self, x: Tensor) -> Tensor:
x = self.prediction_layer(x)
x = self.soft_max(x)
x = torch.reshape(x, (-1, self.number_of_classes))
return x

@classmethod
def new(cls, number_of_classes: int):
return cls(number_of_classes)
48 changes: 48 additions & 0 deletions src/qusi/internal/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from torch import Tensor
from torch.nn import NLLLoss, Module
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy


class CrossEntropyAlt(Module):
@classmethod
def new(cls):
return cls()

def __init__(self):
super().__init__()
self.nll_loss = NLLLoss()

def __call__(self, preds: Tensor, target: Tensor):
predicted_log_probabilities = torch.log(preds)
target_int = target.to(torch.int64)
cross_entropy = self.nll_loss(predicted_log_probabilities, target_int)
return cross_entropy

class MulticlassAUROCAlt(Module):
@classmethod
def new(cls, number_of_classes: int):
return cls(number_of_classes=number_of_classes)

def __init__(self, number_of_classes: int):
super().__init__()
self.multiclass_auroc = MulticlassAUROC(num_classes=number_of_classes)

def __call__(self, preds: Tensor, target: Tensor):
target_int = target.to(torch.int64)
cross_entropy = self.multiclass_auroc(preds, target_int)
return cross_entropy

class MulticlassAccuracyAlt(Module):
@classmethod
def new(cls, number_of_classes: int):
return cls(number_of_classes=number_of_classes)

def __init__(self, number_of_classes: int):
super().__init__()
self.multiclass_accuracy = MulticlassAccuracy(num_classes=number_of_classes)

def __call__(self, preds: Tensor, target: Tensor):
target_int = target.to(torch.int64)
cross_entropy = self.multiclass_accuracy(preds, target_int)
return cross_entropy
49 changes: 46 additions & 3 deletions src/qusi/internal/toy_light_curve_collection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import math

import random
from functools import partial

from pathlib import Path

import numpy as np
from scipy import signal

from qusi.internal.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset
from qusi.internal.light_curve import LightCurve
from qusi.internal.light_curve_collection import (
LightCurveObservationCollection,
create_constant_label_for_path_function, LightCurveCollection,
)
from qusi.internal.light_curve_dataset import LightCurveDataset
from qusi.internal.light_curve_dataset import LightCurveDataset, \
default_light_curve_observation_post_injection_transform


class ToyLightCurve:
Expand Down Expand Up @@ -61,7 +68,7 @@ def toy_flat_light_curve_load_times_and_fluxes(_path: Path) -> (np.ndarray, np.n


def toy_sine_wave_light_curve_load_times_and_fluxes(
_path: Path,
_path: Path,
) -> (np.ndarray, np.ndarray):
"""
Loads a sine wave toy light curve.
Expand Down Expand Up @@ -115,4 +122,40 @@ def get_toy_finite_light_curve_dataset() -> FiniteStandardLightCurveDataset:
get_toy_sine_wave_light_curve_collection(),
get_toy_flat_light_curve_collection(),
]
)
)


def get_square_wave_light_curve_observation_collection() -> LightCurveObservationCollection:
return LightCurveObservationCollection.new(
get_paths_function=toy_light_curve_get_paths_function,
load_times_and_fluxes_from_path_function=square_wave_light_curve_load_times_and_fluxes,
load_label_from_path_function=create_constant_label_for_path_function(2),
)


square_wave_random_generator = random.Random()


def square_wave_light_curve_load_times_and_fluxes(_path: Path) -> (np.ndarray, np.ndarray):
"""
Loads a square wave light curve.
"""
length = 100
number_of_cycles = square_wave_random_generator.random() + 1 * 9
linear_space = np.linspace(0, 1, length, endpoint=False)
phases = math.tau * number_of_cycles * linear_space
times = np.arange(length, dtype=np.float32)
fluxes = signal.square(phases)
return times, fluxes


def get_toy_multi_class_light_curve_dataset() -> LightCurveDataset:
return LightCurveDataset.new(
standard_light_curve_collections=[
get_toy_flat_light_curve_observation_collection(),
get_toy_sine_wave_light_curve_observation_collection(),
get_square_wave_light_curve_observation_collection(),
],
post_injection_transform=partial(default_light_curve_observation_post_injection_transform,
length=100, number_of_classes=3, randomize=False)
)
27 changes: 22 additions & 5 deletions tests/unit_tests/test_hydryss_model.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,48 @@
import torch

from qusi.internal.hadryss_model import Hadryss
from qusi.internal.hadryss_model import Hadryss, HadryssBinaryClassEndModule, \
HadryssMultiClassEndModule


def test_lengths_give_correct_output_size():
hadryss50 = Hadryss(input_length=50)
hadryss50 = Hadryss.new(input_length=50)

output50 = hadryss50(torch.arange(50, dtype=torch.float32).reshape([1, 50]))

assert output50.shape == torch.Size([1])

hadryss1000 = Hadryss(input_length=1000)
hadryss1000 = Hadryss.new(input_length=1000)

output1000 = hadryss1000(torch.arange(1000, dtype=torch.float32).reshape([1, 1000]))

assert output1000.shape == torch.Size([1])

hadryss3673 = Hadryss(input_length=3673)
hadryss3673 = Hadryss.new(input_length=3673)

output3673 = hadryss3673(torch.arange(3673, dtype=torch.float32).reshape([1, 3673]))

assert output3673.shape == torch.Size([1])

hadryss100000 = Hadryss(input_length=100000)
hadryss100000 = Hadryss.new(input_length=100000)

output100000 = hadryss100000(
torch.arange(100000, dtype=torch.float32).reshape([1, 100000])
)

assert output100000.shape == torch.Size([1])


def test_binary_classification_end_module_produces_expected_shape():
model = Hadryss.new(input_length=100, end_module=HadryssBinaryClassEndModule.new())

output = model(torch.arange(7 * 100, dtype=torch.float32).reshape([7, 100]))

assert output.shape == torch.Size([7])


def test_multi_class_classification_end_module_produces_expected_shape():
model = Hadryss.new(input_length=100, end_module=HadryssMultiClassEndModule.new(number_of_classes=3))

output = model(torch.arange(7 * 100, dtype=torch.float32).reshape([7, 100]))

assert output.shape == torch.Size([7, 3])

0 comments on commit 50e3259

Please sign in to comment.