diff --git a/minerva/models/nets/cnn_ha_etal.py b/minerva/models/nets/cnn_ha_etal.py deleted file mode 100644 index 9e6d88c..0000000 --- a/minerva/models/nets/cnn_ha_etal.py +++ /dev/null @@ -1,300 +0,0 @@ -from typing import List, Tuple -import torch -from torchmetrics import Accuracy - -from minerva.models.nets.base import SimpleSupervisedModel - - -# Implementation of Multi-modal Convolutional Neural Networks for Activity -# Recognition, from Ha, Yu, and Choi. -# https://ieeexplore.ieee.org/document/7379657 - - -class ZeroPadder2D(torch.nn.Module): - def __init__(self, pad_at: List[int], padding_size: int): - super().__init__() - self.pad_at = pad_at - self.padding_size = padding_size - - def forward(self, x): - # X = (Batch, channels, H, W) - # X = (8, 1, 6, 60) - - for i in self.pad_at: - left = x[:, :, :i, :] - zeros = torch.zeros( - x.shape[0], x.shape[1], self.padding_size, x.shape[3] - ) - right = x[:, :, i:, :] - - x = torch.cat([left, zeros, right], dim=2) - # print(f"-- Left.shape: {left.shape}") - # print(f"-- Zeros.shape: {zeros.shape}") - # print(f"-- Right.shape: {right.shape}") - # print(f"-- X.shape: {x.shape}") - - return x - - def __str__(self) -> str: - return f"ZeroPadder2D(pad_at={self.pad_at}, padding_size={self.padding_size})" - - def __repr__(self) -> str: - return str(self) - - -class CNN_HaEtAl_1D(SimpleSupervisedModel): - def __init__( - self, - input_shape: Tuple[int, int, int] = (1, 6, 60), - num_classes: int = 6, - learning_rate: float = 1e-3, - ): - self.input_shape = input_shape - self.num_classes = num_classes - - backbone = self._create_backbone(input_shape=input_shape) - self.fc_input_channels = self._calculate_fc_input_features( - backbone, input_shape - ) - fc = self._create_fc(self.fc_input_channels, num_classes) - super().__init__( - backbone=backbone, - fc=fc, - learning_rate=learning_rate, - flatten=True, - loss_fn=torch.nn.CrossEntropyLoss(), - val_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - test_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - ) - - def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module: - return torch.nn.Sequential( - # Add padding - # ZeroPadder2D( - # pad_at=self.pad_at, - # padding_size=4 - 1, # kernel size - 1 - # ), - # First 2D convolutional layer - torch.nn.Conv2d( - in_channels=input_shape[0], - out_channels=32, - kernel_size=(1, 4), - stride=(1, 1), - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(1, 3), - stride=(1, 3), - ), - - # Second 2D convolutional layer - torch.nn.Conv2d( - in_channels=32, - out_channels=64, - kernel_size=(1, 5), - stride=(1, 1), - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(1, 3), - stride=(1, 3), - ), - ) - - def _calculate_fc_input_features( - self, backbone: torch.nn.Module, input_shape: Tuple[int, int, int] - ) -> int: - """Run a single forward pass with a random input to get the number of - features after the convolutional layers. - - Parameters - ---------- - backbone : torch.nn.Module - The backbone of the network - input_shape : Tuple[int, int, int] - The input shape of the network. - - Returns - ------- - int - The number of features after the convolutional layers. - """ - random_input = torch.randn(1, *input_shape) - with torch.no_grad(): - out = backbone(random_input) - return out.view(out.size(0), -1).size(1) - - def _create_fc( - self, input_features: int, num_classes: int - ) -> torch.nn.Module: - return torch.nn.Sequential( - torch.nn.Linear(in_features=input_features, out_features=128), - torch.nn.ReLU(), - torch.nn.Dropout(0.5), - torch.nn.Linear(in_features=128, out_features=num_classes), - # torch.nn.Softmax(dim=1), - ) - - -class CNN_HaEtAl_2D(SimpleSupervisedModel): - def __init__( - self, - pad_at: List[int]= (3, ), - input_shape: Tuple[int, int, int] = (1, 6, 60), - num_classes: int = 6, - learning_rate: float = 1e-3, - ): - self.pad_at = pad_at - self.input_shape = input_shape - self.num_classes = num_classes - - backbone = self._create_backbone(input_shape=input_shape) - self.fc_input_channels = self._calculate_fc_input_features( - backbone, input_shape - ) - fc = self._create_fc(self.fc_input_channels, num_classes) - super().__init__( - backbone=backbone, - fc=fc, - learning_rate=learning_rate, - flatten=True, - loss_fn=torch.nn.CrossEntropyLoss(), - val_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - test_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - ) - - def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module: - first_kernel_size = 4 - return torch.nn.Sequential( - # Add padding - ZeroPadder2D( - pad_at=self.pad_at, - padding_size=first_kernel_size - 1, # kernel size - 1 - ), - # First 2D convolutional layer - torch.nn.Conv2d( - in_channels=input_shape[0], - out_channels=32, - kernel_size=(first_kernel_size, first_kernel_size), - stride=(1, 1), - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(3, 3), - stride=(3, 3), - padding=1, - ), - - # Second 2D convolutional layer - torch.nn.Conv2d( - in_channels=32, - out_channels=64, - kernel_size=(5, 5), - stride=(1, 1), - padding=2, - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(3, 3), - stride=(3, 3), - padding=1, - ), - ) - - def _calculate_fc_input_features( - self, backbone: torch.nn.Module, input_shape: Tuple[int, int, int] - ) -> int: - """Run a single forward pass with a random input to get the number of - features after the convolutional layers. - - Parameters - ---------- - backbone : torch.nn.Module - The backbone of the network - input_shape : Tuple[int, int, int] - The input shape of the network. - - Returns - ------- - int - The number of features after the convolutional layers. - """ - random_input = torch.randn(1, *input_shape) - with torch.no_grad(): - out = backbone(random_input) - return out.view(out.size(0), -1).size(1) - - def _create_fc( - self, input_features: int, num_classes: int - ) -> torch.nn.Module: - return torch.nn.Sequential( - torch.nn.Linear(in_features=input_features, out_features=128), - torch.nn.ReLU(), - torch.nn.Dropout(0.5), - torch.nn.Linear(in_features=128, out_features=num_classes), - # torch.nn.Softmax(dim=1), - ) - - - -# def test_cnn_1d(): -# input_shape = (1, 6, 60) - -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) - -# model = CNN_HaEtAl_1D( -# input_shape=input_shape, num_classes=6, learning_rate=1e-3 -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) - -# trainer.fit(model, datamodule=data_module) - - -# def test_cnn_2d(): -# input_shape = (1, 6, 60) - -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) - -# model = CNN_HaEtAl_2D( -# pad_at=[3], input_shape=input_shape, num_classes=6, learning_rate=1e-3 -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) - -# trainer.fit(model, datamodule=data_module) - - -# if __name__ == "__main__": -# import logging -# logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) -# logging.getLogger("lightning").setLevel(logging.ERROR) -# logging.getLogger("lightning.pytorch.core").setLevel(logging.ERROR) - - -# test_cnn_1d() -# test_cnn_2d() diff --git a/minerva/models/nets/cnn_pf.py b/minerva/models/nets/cnn_pf.py deleted file mode 100644 index dffb638..0000000 --- a/minerva/models/nets/cnn_pf.py +++ /dev/null @@ -1,299 +0,0 @@ -from typing import List, Tuple -import torch -from torchmetrics import Accuracy - -from minerva.models.nets.base import SimpleSupervisedModel - -# Convolutional Neural Networks for Human Activity Recognition using Multiple -# Accelerometer and Gyroscope Sensors, from Ha, and Choi. -# https://ieeexplore.ieee.org/document/7727224 - -class ZeroPadder2D(torch.nn.Module): - def __init__(self, pad_at: List[int], padding_size: int): - super().__init__() - self.pad_at = pad_at - self.padding_size = padding_size - - def forward(self, x): - # X = (Batch, channels, H, W) - # X = (8, 1, 6, 60) - - for i in self.pad_at: - left = x[:, :, :i, :] - zeros = torch.zeros( - x.shape[0], x.shape[1], self.padding_size, x.shape[3] - ) - right = x[:, :, i:, :] - - x = torch.cat([left, zeros, right], dim=2) - # print(f"-- Left.shape: {left.shape}") - # print(f"-- Zeros.shape: {zeros.shape}") - # print(f"-- Right.shape: {right.shape}") - # print(f"-- X.shape: {x.shape}") - - return x - - def __str__(self) -> str: - return f"ZeroPadder2D(pad_at={self.pad_at}, padding_size={self.padding_size})" - - def __repr__(self) -> str: - return str(self) - - -class CNN_PF_Backbone(torch.nn.Module): - def __init__( - self, - pad_at: int, - input_shape: Tuple[int, int, int], - out_channels: int = 16, - include_middle: bool = False, - ): - super().__init__() - self.pad_at = pad_at - self.input_shape = input_shape - self.include_middle = include_middle - self.out_channels = out_channels - self.first_pad_size = 3 - 1 # kernel -1 - - self.first_padder = ZeroPadder2D( - pad_at=(pad_at,), - padding_size=self.first_pad_size, - ) - - self.upper_part = torch.nn.Sequential( - torch.nn.Conv2d( - in_channels=self.input_shape[0], - out_channels=self.out_channels, - kernel_size=(3, 3), - stride=(1, 1), - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(2, 3), - stride=(2, 3), - padding=1, - ), - ) - - self.lower_part = torch.nn.Sequential( - torch.nn.Conv2d( - in_channels=self.input_shape[0], - out_channels=self.out_channels, - kernel_size=(3, 3), - stride=(1, 1), - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(2, 3), - stride=(2, 3), - padding=1, - ), - ) - - if self.include_middle: - self.middle_part = torch.nn.Sequential( - torch.nn.Conv2d( - in_channels=self.input_shape[0], - out_channels=self.out_channels, - kernel_size=(3, 3), - stride=(1, 1), - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(2, 3), - stride=(2, 3), - padding=1, - ), - ) - - self.shared_part = torch.nn.Sequential( - torch.nn.Conv2d( - in_channels=( - self.out_channels * 3 - if self.include_middle - else self.out_channels * 2 - ), - out_channels=64, - kernel_size=(3, 5), - stride=(1, 1), - ), - torch.nn.ReLU(), - torch.nn.MaxPool2d( - kernel_size=(2, 3), - stride=(2, 3), - padding=1, - ), - ) - - def forward(self, x): - # X = (batch_size, channels, sensors, time_steps) - # X = (8, 1, 6, 60) - - # After pad: (8, 1, 8, 60) - x = self.first_padder(x) - - # upper slice (8, 1, 5, 60) - upper_x = x[:, :, : self.pad_at + self.first_pad_size, :] - upper_x = self.upper_part(upper_x) - zeros_1 = torch.zeros( - upper_x.size(0), upper_x.size(1), 3 - 1, upper_x.size(3) - ) - - upper_x = torch.cat( - [upper_x, zeros_1], - dim=2, - ) - - # lower slice (8, 1, 5, 60) - lower_x = x[:, :, self.pad_at :, :] - lower_x = self.lower_part(lower_x) - zeros_2 = torch.zeros( - lower_x.size(0), lower_x.size(1), 3 - 1, lower_x.size(3) - ) - - lower_x = torch.cat( - [zeros_2, lower_x], - dim=2, - ) - - if self.include_middle: - # x is already middle - middle_x = self.middle_part(x) - concatenated_x = torch.cat([upper_x, middle_x, lower_x], dim=1) - - else: - concatenated_x = torch.cat([upper_x, lower_x], dim=1) - - result_x = self.shared_part(concatenated_x) - return result_x - - -class CNN_PF_2D(SimpleSupervisedModel): - def __init__( - self, - pad_at: int, - input_shape: Tuple[int, int, int] = (1, 6, 60), - out_channels: int = 16, - num_classes: int = 6, - learning_rate: float = 1e-3, - include_middle: bool = False, - ): - self.pad_at = pad_at - self.input_shape = input_shape - self.out_channels = out_channels - self.num_classes = num_classes - - backbone = CNN_PF_Backbone( - pad_at=pad_at, - input_shape=input_shape, - out_channels=out_channels, - include_middle=include_middle, - ) - self.fc_input_channels = self._calculate_fc_input_features( - backbone, input_shape - ) - fc = self._create_fc(self.fc_input_channels, num_classes) - super().__init__( - backbone=backbone, - fc=fc, - learning_rate=learning_rate, - flatten=True, - loss_fn=torch.nn.CrossEntropyLoss(), - val_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - test_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - ) - - def _calculate_fc_input_features( - self, backbone: torch.nn.Module, input_shape: Tuple[int, int, int] - ) -> int: - """Run a single forward pass with a random input to get the number of - features after the convolutional layers. - - Parameters - ---------- - backbone : torch.nn.Module - The backbone of the network - input_shape : Tuple[int, int, int] - The input shape of the network. - - Returns - ------- - int - The number of features after the convolutional layers. - """ - random_input = torch.randn(1, *input_shape) - with torch.no_grad(): - out = backbone(random_input) - return out.view(out.size(0), -1).size(1) - - def _create_fc( - self, input_features: int, num_classes: int - ) -> torch.nn.Module: - return torch.nn.Sequential( - torch.nn.Linear(in_features=input_features, out_features=512), - torch.nn.ReLU(), - torch.nn.Dropout(0.5), - torch.nn.Linear(in_features=512, out_features=num_classes), - # torch.nn.Softmax(dim=1), - ) - - -class CNN_PFF_2D(CNN_PF_2D): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, include_middle=True) - - -# def test_cnn_pf_2d(): -# input_shape = (1, 6, 60) - -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) - -# model = CNN_PF_2D(pad_at=3, input_shape=input_shape) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) - -# trainer.fit(model, datamodule=data_module) - - -# def test_cnn_pff_2d(): -# input_shape = (1, 6, 60) - -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) - -# model = CNN_PFF_2D(pad_at=3, input_shape=input_shape) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) - -# trainer.fit(model, datamodule=data_module) - - -# if __name__ == "__main__": -# import logging - -# logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) -# logging.getLogger("lightning").setLevel(logging.ERROR) -# logging.getLogger("lightning.pytorch.core").setLevel(logging.ERROR) - -# # test_cnn_1d() -# test_cnn_pf_2d() -# test_cnn_pff_2d() diff --git a/minerva/models/nets/cnns.py b/minerva/models/nets/cnns.py new file mode 100644 index 0000000..15fb9af --- /dev/null +++ b/minerva/models/nets/cnns.py @@ -0,0 +1,410 @@ +from typing import List, Tuple + +import torch +from torchmetrics import Accuracy + +from minerva.models.nets.base import SimpleSupervisedModel + + +class ZeroPadder2D(torch.nn.Module): + def __init__(self, pad_at: List[int], padding_size: int): + super().__init__() + self.pad_at = pad_at + self.padding_size = padding_size + + def forward(self, x): + + for i in self.pad_at: + left = x[:, :, :i, :] + zeros = torch.zeros(x.shape[0], x.shape[1], self.padding_size, x.shape[3]) + right = x[:, :, i:, :] + + x = torch.cat([left, zeros, right], dim=2) + + return x + + def __str__(self) -> str: + return f"ZeroPadder2D(pad_at={self.pad_at}, padding_size={self.padding_size})" + + def __repr__(self) -> str: + return str(self) + + +class CNN_HaEtAl_1D(SimpleSupervisedModel): + + def __init__( + self, + input_shape: Tuple[int, int, int] = (1, 6, 60), + num_classes: int = 6, + learning_rate: float = 1e-3, + ): + self.input_shape = input_shape + self.num_classes = num_classes + + backbone = self._create_backbone(input_shape=input_shape) + self.fc_input_channels = self._calculate_fc_input_features( + backbone, input_shape + ) + fc = self._create_fc(self.fc_input_channels, num_classes) + super().__init__( + backbone=backbone, + fc=fc, + learning_rate=learning_rate, + flatten=True, + loss_fn=torch.nn.CrossEntropyLoss(), + val_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + ) + + def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module: + return torch.nn.Sequential( + # First 2D convolutional layer + torch.nn.Conv2d( + in_channels=input_shape[0], + out_channels=32, + kernel_size=(1, 4), + stride=(1, 1), + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(1, 3), + stride=(1, 3), + ), + # Second 2D convolutional layer + torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=(1, 5), + stride=(1, 1), + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(1, 3), + stride=(1, 3), + ), + ) + + def _calculate_fc_input_features( + self, backbone: torch.nn.Module, input_shape: Tuple[int, int, int] + ) -> int: + """Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + """ + random_input = torch.randn(1, *input_shape) + with torch.no_grad(): + out = backbone(random_input) + return out.view(out.size(0), -1).size(1) + + def _create_fc(self, input_features: int, num_classes: int) -> torch.nn.Module: + return torch.nn.Sequential( + torch.nn.Linear(in_features=input_features, out_features=128), + torch.nn.ReLU(), + torch.nn.Dropout(0.5), + torch.nn.Linear(in_features=128, out_features=num_classes), + # torch.nn.Softmax(dim=1), + ) + + +class CNN_HaEtAl_2D(SimpleSupervisedModel): + def __init__( + self, + pad_at: List[int] = (3,), + input_shape: Tuple[int, int, int] = (1, 6, 60), + num_classes: int = 6, + learning_rate: float = 1e-3, + ): + self.pad_at = pad_at + self.input_shape = input_shape + self.num_classes = num_classes + + backbone = self._create_backbone(input_shape=input_shape) + self.fc_input_channels = self._calculate_fc_input_features( + backbone, input_shape + ) + fc = self._create_fc(self.fc_input_channels, num_classes) + super().__init__( + backbone=backbone, + fc=fc, + learning_rate=learning_rate, + flatten=True, + loss_fn=torch.nn.CrossEntropyLoss(), + val_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + ) + + def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module: + first_kernel_size = 4 + return torch.nn.Sequential( + # Add padding + ZeroPadder2D( + pad_at=self.pad_at, + padding_size=first_kernel_size - 1, # kernel size - 1 + ), + # First 2D convolutional layer + torch.nn.Conv2d( + in_channels=input_shape[0], + out_channels=32, + kernel_size=(first_kernel_size, first_kernel_size), + stride=(1, 1), + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(3, 3), + stride=(3, 3), + padding=1, + ), + # Second 2D convolutional layer + torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=(5, 5), + stride=(1, 1), + padding=2, + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(3, 3), + stride=(3, 3), + padding=1, + ), + ) + + def _calculate_fc_input_features( + self, backbone: torch.nn.Module, input_shape: Tuple[int, int, int] + ) -> int: + """Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + """ + random_input = torch.randn(1, *input_shape) + with torch.no_grad(): + out = backbone(random_input) + return out.view(out.size(0), -1).size(1) + + def _create_fc(self, input_features: int, num_classes: int) -> torch.nn.Module: + return torch.nn.Sequential( + torch.nn.Linear(in_features=input_features, out_features=128), + torch.nn.ReLU(), + torch.nn.Dropout(0.5), + torch.nn.Linear(in_features=128, out_features=num_classes), + # torch.nn.Softmax(dim=1), + ) + + +class CNN_PF_Backbone(torch.nn.Module): + def __init__( + self, + pad_at: int, + input_shape: Tuple[int, int, int], + out_channels: int = 16, + include_middle: bool = False, + ): + super().__init__() + self.pad_at = pad_at + self.input_shape = input_shape + self.include_middle = include_middle + self.out_channels = out_channels + self.first_pad_size = 3 - 1 # kernel -1 + + self.first_padder = ZeroPadder2D( + pad_at=(pad_at,), + padding_size=self.first_pad_size, + ) + + self.upper_part = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=self.input_shape[0], + out_channels=self.out_channels, + kernel_size=(3, 3), + stride=(1, 1), + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(2, 3), + stride=(2, 3), + padding=1, + ), + ) + + self.lower_part = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=self.input_shape[0], + out_channels=self.out_channels, + kernel_size=(3, 3), + stride=(1, 1), + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(2, 3), + stride=(2, 3), + padding=1, + ), + ) + + if self.include_middle: + self.middle_part = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=self.input_shape[0], + out_channels=self.out_channels, + kernel_size=(3, 3), + stride=(1, 1), + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(2, 3), + stride=(2, 3), + padding=1, + ), + ) + + self.shared_part = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=( + self.out_channels * 3 + if self.include_middle + else self.out_channels * 2 + ), + out_channels=64, + kernel_size=(3, 5), + stride=(1, 1), + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=(2, 3), + stride=(2, 3), + padding=1, + ), + ) + + def forward(self, x): + # X = (batch_size, channels, sensors, time_steps) + # X = (8, 1, 6, 60) + + # After pad: (8, 1, 8, 60) + x = self.first_padder(x) + + # upper slice (8, 1, 5, 60) + upper_x = x[:, :, : self.pad_at + self.first_pad_size, :] + upper_x = self.upper_part(upper_x) + zeros_1 = torch.zeros(upper_x.size(0), upper_x.size(1), 3 - 1, upper_x.size(3)) + + upper_x = torch.cat( + [upper_x, zeros_1], + dim=2, + ) + + # lower slice (8, 1, 5, 60) + lower_x = x[:, :, self.pad_at :, :] + lower_x = self.lower_part(lower_x) + zeros_2 = torch.zeros(lower_x.size(0), lower_x.size(1), 3 - 1, lower_x.size(3)) + + lower_x = torch.cat( + [zeros_2, lower_x], + dim=2, + ) + + if self.include_middle: + # x is already middle + middle_x = self.middle_part(x) + concatenated_x = torch.cat([upper_x, middle_x, lower_x], dim=1) + + else: + concatenated_x = torch.cat([upper_x, lower_x], dim=1) + + result_x = self.shared_part(concatenated_x) + return result_x + + +class CNN_PF_2D(SimpleSupervisedModel): + def __init__( + self, + pad_at: int, + input_shape: Tuple[int, int, int] = (1, 6, 60), + out_channels: int = 16, + num_classes: int = 6, + learning_rate: float = 1e-3, + include_middle: bool = False, + ): + self.pad_at = pad_at + self.input_shape = input_shape + self.out_channels = out_channels + self.num_classes = num_classes + + backbone = CNN_PF_Backbone( + pad_at=pad_at, + input_shape=input_shape, + out_channels=out_channels, + include_middle=include_middle, + ) + self.fc_input_channels = self._calculate_fc_input_features( + backbone, input_shape + ) + fc = self._create_fc(self.fc_input_channels, num_classes) + super().__init__( + backbone=backbone, + fc=fc, + learning_rate=learning_rate, + flatten=True, + loss_fn=torch.nn.CrossEntropyLoss(), + val_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + ) + + def _calculate_fc_input_features( + self, backbone: torch.nn.Module, input_shape: Tuple[int, int, int] + ) -> int: + """Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + """ + random_input = torch.randn(1, *input_shape) + with torch.no_grad(): + out = backbone(random_input) + return out.view(out.size(0), -1).size(1) + + def _create_fc(self, input_features: int, num_classes: int) -> torch.nn.Module: + return torch.nn.Sequential( + torch.nn.Linear(in_features=input_features, out_features=512), + torch.nn.ReLU(), + torch.nn.Dropout(0.5), + torch.nn.Linear(in_features=512, out_features=num_classes), + ) + + +class CNN_PFF_2D(CNN_PF_2D): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, include_middle=True) diff --git a/minerva/models/nets/imu_transformer.py b/minerva/models/nets/imu_transformer.py index cd84dc9..4dddad9 100644 --- a/minerva/models/nets/imu_transformer.py +++ b/minerva/models/nets/imu_transformer.py @@ -1,10 +1,11 @@ from typing import Tuple + +import lightning as L import torch from torch import nn from torch.nn import TransformerEncoder, TransformerEncoderLayer -from minerva.models.nets.base import SimpleSupervisedModel -import lightning as L +from minerva.models.nets.base import SimpleSupervisedModel """ IMUTransformerEncoder model @@ -197,7 +198,7 @@ def __init__( dropout_factor=dropout_factor, ) self.fc_input_channels = self._calculate_fc_input_features( - backbone, input_shape + backbone, input_shape ) fc = self._create_fc(self.fc_input_channels, hidden_dim, num_classes) @@ -206,7 +207,7 @@ def __init__( fc=fc, learning_rate=learning_rate, loss_fn=torch.nn.CrossEntropyLoss(), - flatten=True + flatten=True, ) def _create_backbone(self, input_shape, hidden_dim, dropout_factor): @@ -219,7 +220,6 @@ def _create_backbone(self, input_shape, hidden_dim, dropout_factor): torch.nn.MaxPool1d(kernel_size=2), ) - def _calculate_fc_input_features( self, backbone: torch.nn.Module, input_shape: Tuple[int, int] ) -> int: @@ -228,64 +228,9 @@ def _calculate_fc_input_features( out = backbone(random_input) return out.view(out.size(0), -1).size(1) - def _create_fc(self, input_features, hidden_dim, num_classes): return torch.nn.Sequential( torch.nn.Linear(input_features, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, num_classes), ) - - -# def test_imu_transformer(): -# input_shape = (6, 60) - -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) - -# model = IMUTransformerEncoder( -# input_shape=input_shape, num_classes=6, learning_rate=1e-3 -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) - -# trainer.fit(model, datamodule=data_module) - - -# def test_imu_cnn(): -# input_shape = (6, 60) - -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) - -# model = IMUCNN( -# input_shape=input_shape, num_classes=6, learning_rate=1e-3 -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) - -# trainer.fit(model, datamodule=data_module) - - -# if __name__ == "__main__": -# import logging - -# logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) -# logging.getLogger("lightning").setLevel(logging.ERROR) -# logging.getLogger("lightning.pytorch.core").setLevel(logging.ERROR) -# test_imu_transformer() -# test_imu_cnn() \ No newline at end of file diff --git a/minerva/models/nets/inception_time.py b/minerva/models/nets/inception_time.py index e6e3bbb..91a6a0c 100644 --- a/minerva/models/nets/inception_time.py +++ b/minerva/models/nets/inception_time.py @@ -1,14 +1,14 @@ -import numpy as np import time - - from typing import Tuple + +import lightning as L +import numpy as np import torch from torch import nn from torch.nn import TransformerEncoder, TransformerEncoderLayer from torchmetrics import Accuracy + from minerva.models.nets.base import SimpleSupervisedModel -import lightning as L class InceptionModule(torch.nn.Module): @@ -229,12 +229,8 @@ def __init__( learning_rate=learning_rate, flatten=True, loss_fn=torch.nn.CrossEntropyLoss(), - val_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - test_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, + val_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, ) def _calculate_fc_input_features( @@ -260,55 +256,8 @@ def _calculate_fc_input_features( out = backbone(random_input) return out.view(out.size(0), -1).size(1) - def _create_fc( - self, input_features: int, num_classes: int - ) -> torch.nn.Module: + def _create_fc(self, input_features: int, num_classes: int) -> torch.nn.Module: return torch.nn.Sequential( torch.nn.Linear(in_features=input_features, out_features=num_classes), # torch.nn.Softmax(dim=1), ) - - -# def test_inception_time(): -# input_shape = (6, 60) - -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) - -# model = InceptionTime( -# input_shape=input_shape, num_classes=6, learning_rate=1e-3 -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) - -# trainer.fit(model, datamodule=data_module) - - - - # from torchview import draw_graph - - # model = InceptionTime() - # result = model(torch.rand(1, 6, 60)) - # print(f"Result.shape: {result.shape}") - # model_graph = draw_graph( - # model, - # input_size=(64, 6, 60), - # device="cpu", - # expand_nested=True, - # show_shapes=True, - # save_graph=True, - # filename="inception_graph", - # ) - # model_graph.visual_graph.render("inception_graph.png", format="png") - # print(f"Graph saved to `inception_graph.png`") - - -# if __name__ == "__main__": -# test_inception_time() diff --git a/minerva/models/nets/resnet_1d.py b/minerva/models/nets/resnet.py similarity index 75% rename from minerva/models/nets/resnet_1d.py rename to minerva/models/nets/resnet.py index 80d4a49..bc5b71a 100644 --- a/minerva/models/nets/resnet_1d.py +++ b/minerva/models/nets/resnet.py @@ -1,29 +1,25 @@ -import numpy as np import time - - from functools import partial from typing import Literal, Tuple + +import lightning as L +import numpy as np import torch from torch import nn from torch.nn import TransformerEncoder, TransformerEncoderLayer from torchmetrics import Accuracy + from minerva.models.nets.base import SimpleSupervisedModel -import lightning as L class ConvolutionalBlock(torch.nn.Module): - def __init__( - self, in_channels: int, activation_cls: torch.nn.Module = None - ): + def __init__(self, in_channels: int, activation_cls: torch.nn.Module = None): super().__init__() self.in_channels = in_channels self.activation_cls = activation_cls self.block = torch.nn.Sequential( - torch.nn.Conv1d( - in_channels, out_channels=64, kernel_size=5, stride=1 - ), + torch.nn.Conv1d(in_channels, out_channels=64, kernel_size=5, stride=1), torch.nn.BatchNorm1d(64), activation_cls(), torch.nn.MaxPool1d(2), @@ -120,9 +116,7 @@ def __init__( ) self.residual_blocks = torch.nn.Sequential( *[ - residual_block_cls( - in_channels=64, activation_cls=activation_cls - ) + residual_block_cls(in_channels=64, activation_cls=activation_cls) for _ in range(num_residual_blocks) ] ) @@ -166,12 +160,8 @@ def __init__( learning_rate=learning_rate, flatten=True, loss_fn=torch.nn.CrossEntropyLoss(), - val_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, - test_metrics={ - "acc": Accuracy(task="multiclass", num_classes=num_classes) - }, + val_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, + test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, ) def _calculate_fc_input_features( @@ -208,8 +198,8 @@ def __init__(self, *args, **kwargs): activation_cls=torch.nn.ELU, num_residual_blocks=8, ) - - + + # Deep Residual Network for Smartwatch-Based User Identification through Complex Hand Movements (ResNetSE1D) class ResNetSE1D_8(ResNet1DBase): def __init__(self, *args, **kwargs): @@ -221,6 +211,7 @@ def __init__(self, *args, **kwargs): num_residual_blocks=8, ) + # resnet-se: Channel Attention-Based Deep Residual Network for Complex Activity Recognition Using Wrist-Worn Wearable Sensors # Changes the activation function to ReLU and the number of residual blocks to 5 (compared to ResNetSE1D_8) class ResNetSE1D_5(ResNet1DBase): @@ -232,72 +223,3 @@ def __init__(self, *args, **kwargs): activation_cls=torch.nn.ReLU, num_residual_blocks=5, ) - - -# def test_resnet_8(): -# input_shape = (6, 60) -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) -# model = ResNet1D_8( -# input_shape=input_shape, -# num_classes=6, -# learning_rate=1e-3, -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) -# trainer.fit(model, datamodule=data_module) - - -# def test_resnet_se_8(): -# input_shape = (6, 60) -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) -# model = ResNetSE1D_8( -# input_shape=input_shape, -# num_classes=6, -# learning_rate=1e-3, -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) -# trainer.fit(model, datamodule=data_module) - - -# def test_resnet_se_5(): -# input_shape = (6, 60) -# data_module = RandomDataModule( -# num_samples=8, -# num_classes=6, -# input_shape=input_shape, -# batch_size=8, -# ) -# model = ResNetSE1D_5( -# input_shape=input_shape, -# num_classes=6, -# learning_rate=1e-3, -# ) -# print(model) - -# trainer = L.Trainer( -# max_epochs=1, logger=False, devices=1, accelerator="cpu" -# ) -# trainer.fit(model, datamodule=data_module) - - -# if __name__ == "__main__": -# test_resnet_8() -# test_resnet_se_8() -# test_resnet_se_5() diff --git a/minerva/models/nets/sfm.py b/minerva/models/nets/sfm.py deleted file mode 100644 index 8e89644..0000000 --- a/minerva/models/nets/sfm.py +++ /dev/null @@ -1,493 +0,0 @@ -from functools import partial - -import lightning as L -import numpy as np -import torch -import torch.nn as nn -from timm.models.vision_transformer import Block, PatchEmbed - -from minerva.utils.position_embedding import get_2d_sincos_pos_embed - - -class MaskedAutoencoderViT(L.LightningModule): - """ - Masked Autoencoder with VisionTransformer backbone. - - Args: - img_size (int): Size of input image. - patch_size (int): Size of image patch. - in_chans (int): Number of input channels. - embed_dim (int): Dimension of token embeddings. - depth (int): Number of transformer blocks. - num_heads (int): Number of attention heads. - decoder_embed_dim (int): Dimension of decoder embeddings. - decoder_depth (int): Number of decoder transformer blocks. - decoder_num_heads (int): Number of decoder attention heads. - mlp_ratio (float): Ratio of MLP hidden layer size to embedding size. - norm_layer (torch.nn.LayerNorm): Normalization layer. - norm_pix_loss (bool): Whether to normalize pixel loss. - - References: - - timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm - - DeiT: https://github.com/facebookresearch/deit - """ - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=1, - embed_dim=1024, - depth=24, - num_heads=16, - decoder_embed_dim=512, - decoder_depth=8, - decoder_num_heads=16, - mlp_ratio=4.0, - norm_layer=nn.LayerNorm, - norm_pix_loss=False, - ): - super().__init__() - - # MAE encoder specifics - self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False - ) # fixed sin-cos embedding - self.in_chans = in_chans - self.blocks = nn.ModuleList( - [ - Block( - embed_dim, - num_heads, - mlp_ratio, - qkv_bias=True, - norm_layer=norm_layer, - ) - for _ in range(depth) - ] - ) - self.norm = norm_layer(embed_dim) - - # MAE decoder specifics - self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) - - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) - - self.decoder_pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, decoder_embed_dim), - requires_grad=False, - ) # fixed sin-cos embedding - - self.decoder_blocks = nn.ModuleList( - [ - Block( - decoder_embed_dim, - decoder_num_heads, - mlp_ratio, - qkv_bias=True, - norm_layer=norm_layer, - ) - for _ in range(decoder_depth) - ] - ) - - self.decoder_norm = norm_layer(decoder_embed_dim) - self.decoder_pred = nn.Linear( - decoder_embed_dim, patch_size**2 * in_chans, bias=True - ) # decoder to patch - - self.norm_pix_loss = norm_pix_loss - - self.initialize_weights() - - def initialize_weights(self): - # Initialization - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.patch_embed.num_patches**0.5), - cls_token=True, - ) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) - - decoder_pos_embed = get_2d_sincos_pos_embed( - self.decoder_pos_embed.shape[-1], - int(self.patch_embed.num_patches**0.5), - cls_token=True, - ) - self.decoder_pos_embed.data.copy_( - torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) - ) - - w = self.patch_embed.proj.weight.data - torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - torch.nn.init.normal_(self.cls_token, std=0.02) - torch.nn.init.normal_(self.mask_token, std=0.02) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - torch.nn.init.xavier_uniform_(m.weight) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def patchify(self, imgs): # input: (32, 1, 224, 224) - """ - Extract patches from input images. - - Args: - imgs (torch.Tensor): Input images of shape (N, C, H, W). - - Returns: - torch.Tensor: Patches of shape (N, num_patches, patch_size^2 * in_chans). - """ - p = self.patch_embed.patch_size[0] - assert ( - imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 - ) # only square images are supported, and the size must be divisible by the patch size - - h = w = imgs.shape[2] // p - x = imgs.reshape( - (imgs.shape[0], self.in_chans, h, p, w, p) - ) # Transform images into (32, 1, 14, 16, 14, 16) - x = torch.einsum("nchpwq->nhwpqc", x) # reshape into (32, 14, 14, 16, 16, 1) - x = x.reshape( - (imgs.shape[0], h * w, p**2 * self.in_chans) - ) # Transform into (32, 196, 256) - return x - - def unpatchify(self, x): - """ - Reconstruct images from patches. - - Args: - x (torch.Tensor): Patches of shape (N, L, patch_size^2 * in_chans). - - Returns: - torch.Tensor: Reconstructed images of shape (N, C, H, W). - """ - p = self.patch_embed.patch_size[0] - h = w = int(x.shape[1] ** 0.5) - assert h * w == x.shape[1] - - x = x.reshape((x.shape[0], h, w, p, p, 3)) - x = torch.einsum("nhwpqc->nchpwq", x) - imgs = x.reshape((x.shape[0], 3, h * p, h * p)) - return imgs - - def random_masking(self, x, mask_ratio): - """ - Perform per-sample random masking by per-sample shuffling. - - Args: - x (torch.Tensor): Input tensor of shape (N, L, D). - mask_ratio (float): Ratio of values to mask. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Masked input, binary mask, shuffled indices. - """ - N, L, D = x.shape - len_keep = int(L * (1 - mask_ratio)) - - noise = torch.rand(N, L, device=x.device) - - ids_shuffle = torch.argsort(noise, dim=1) - ids_restore = torch.argsort(ids_shuffle, dim=1) - - ids_keep = ids_shuffle[:, :len_keep] - x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) - - mask = torch.ones(N, L, device=x.device) - mask[:, :len_keep] = 0 - mask = torch.gather(mask, dim=1, index=ids_restore) - - return x_masked, mask, ids_restore - - def forward_encoder(self, x, mask_ratio): - """ - Forward pass through the encoder. - - Args: - x (torch.Tensor): Input tensor of shape (N, C, H, W). - mask_ratio (float): Ratio of values to mask. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Encoded representation, binary mask, shuffled indices. - """ - x = self.patch_embed(x) - x = x + self.pos_embed[:, 1:, :] - - x, mask, ids_restore = self.random_masking(x, mask_ratio) - - cls_token = self.cls_token + self.pos_embed[:, :1, :] - cls_tokens = cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - - for blk in self.blocks: - x = blk(x) - x = self.norm(x) - - return x, mask, ids_restore - - def forward_decoder(self, x, ids_restore): - """ - Forward pass through the decoder. - - Args: - x (torch.Tensor): Input tensor of shape (N, L, D). - ids_restore (torch.Tensor): Indices to restore the original order of patches. - - Returns: - torch.Tensor: Decoded output tensor of shape (N, L, patch_size^2 * in_chans). - """ - x = self.decoder_embed(x) - - mask_tokens = self.mask_token.repeat( - x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 - ) - x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) - x_ = torch.gather( - x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) - ) - x = torch.cat([x[:, :1, :], x_], dim=1) - - x = x + self.decoder_pos_embed - - for blk in self.decoder_blocks: - x = blk(x) - x = self.decoder_norm(x) - - x = self.decoder_pred(x) - x = x[:, 1:, :] - - return x - - def forward_loss(self, imgs, pred, mask): - """ - Calculate the loss. - - Args: - imgs (torch.Tensor): Input images of shape (N, C, H, W). - pred (torch.Tensor): Predicted output of shape (N, L, patch_size^2 * in_chans). - mask (torch.Tensor): Binary mask of shape (N, L). - - Returns: - torch.Tensor: Computed loss value. - """ - target = self.patchify(imgs) - if self.norm_pix_loss: - mean = target.mean(dim=-1, keepdim=True) - var = target.var(dim=-1, keepdim=True) - target = (target - mean) / (var + 1.0e-6) ** 0.5 - - loss = (pred - target) ** 2 - loss = loss.mean(dim=-1) - loss = (loss * mask).sum() / mask.sum() - return loss - - def forward(self, imgs, mask_ratio=0.75): - """ - Forward pass. - - Args: - imgs (torch.Tensor): Input images of shape (N, C, H, W). - mask_ratio (float): Ratio of values to mask. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Loss value, predicted output, binary mask. - """ - latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) - pred = self.forward_decoder(latent, ids_restore) - loss = self.forward_loss(imgs, pred, mask) - return loss, pred, mask - - def training_step(self, batch, batch_idx): - """ - Training step. - - Args: - batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. - batch_idx (int): Index of the current batch. - - Returns: - Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step. - """ - imgs, _ = batch - loss, _, _ = self(imgs) - self.log("train_loss", loss) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - """ - Validation step. - - Args: - batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. - batch_idx (int): Index of the current batch. - - Returns: - Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step. - """ - imgs, _ = batch - loss, _, _ = self(imgs) - self.log("val_loss", loss) - return {"val_loss": loss} - - def configure_optimizers(self): - """ - Configure optimizer. - - Returns: - torch.optim.Optimizer: Optimizer. - """ - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) - return optimizer - - -# Define model architectures - -# mae_vit_small_patch16_dec512d8b -# decoder: 512 dim, 8 blocks, depth: 6 -mae_vit_small_patch16 = partial( - MaskedAutoencoderViT, - patch_size=16, - embed_dim=768, - depth=6, - num_heads=12, - decoder_embed_dim=512, - decoder_depth=4, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), -) - -# mae_vit_base_patch16_dec512d8b -# decoder: 512 dim, 8 blocks, -mae_vit_base_patch16 = partial( - MaskedAutoencoderViT, - patch_size=16, - embed_dim=768, - depth=12, - num_heads=12, - decoder_embed_dim=512, - decoder_depth=8, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), -) - -# mae_vit_large_patch16_dec512d8b -# decoder: 512 dim, 8 blocks -mae_vit_large_patch16 = partial( - MaskedAutoencoderViT, - patch_size=16, - embed_dim=1024, - depth=24, - num_heads=16, - decoder_embed_dim=512, - decoder_depth=8, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), -) - -# mae_vit_huge_patch14_dec512d8b -# decoder: 512 dim, 8 blocks -mae_vit_huge_patch14 = partial( - MaskedAutoencoderViT, - patch_size=14, - embed_dim=1280, - depth=32, - num_heads=16, - decoder_embed_dim=512, - decoder_depth=8, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), -) - -# mae_vit_large_patch16_dec256d4b -# decoder: 256 dim, 8 blocks -mae_vit_large_patch16D4d256 = partial( - MaskedAutoencoderViT, - patch_size=16, - embed_dim=1024, - depth=24, - num_heads=16, - decoder_embed_dim=256, - decoder_depth=4, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), -) - - -# mae_vit_base_patch16_dec256d4b -mae_vit_base_patch16D4d256 = partial( - MaskedAutoencoderViT, - patch_size=16, - embed_dim=768, - depth=12, - num_heads=12, - decoder_embed_dim=256, - decoder_depth=4, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), -) - -# import torch -# from torch.utils.data import DataLoader, TensorDataset -# import pytorch_lightning as pl -# import numpy as np - - -# def main(): -# # Create random data -# N = 32 # Batch size -# C, H, W = 1, 224, 224 # Image dimensions -# img_data = np.random.rand(N, C, H, W).astype(np.float32) -# target_data = np.random.randint(0, 10, size=N) # Random labels -# imgs = torch.tensor(img_data) -# targets = torch.tensor(target_data) - -# # Create a Lightning DataModule -# class RandomDataModule(L.LightningDataModule): -# def __init__(self, imgs, targets, batch_size=32): -# super().__init__() -# self.imgs = imgs -# self.targets = targets -# self.batch_size = batch_size - -# def train_dataloader(self): -# dataset = TensorDataset(self.imgs, self.targets) -# return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) - -# def val_dataloader(self): -# dataset = TensorDataset(self.imgs, self.targets) -# return DataLoader(dataset, batch_size=self.batch_size) - -# # Instantiate Lightning DataModule -# data_module = RandomDataModule(imgs, targets) - -# # Instantiate the model -# model = MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=1, embed_dim=256, depth=24, num_heads=16) - -# # Instantiate the Lightning Trainer -# trainer = L.Trainer(max_epochs=5, devices=1, accelerator="cpu") - -# print("Forward pass...") -# # Perform a forward pass -# trainer.fit(model, datamodule=data_module) - - -# if __name__ == "__main__": -# main() diff --git a/minerva/models/nets/vit.py b/minerva/models/nets/vit.py index 5dfaf5e..d87465a 100644 --- a/minerva/models/nets/vit.py +++ b/minerva/models/nets/vit.py @@ -3,7 +3,10 @@ from functools import partial from typing import Callable, List, Optional +import lightning as L import torch +import torch.nn as nn +from timm.models.vision_transformer import Block, PatchEmbed from torch import nn from torchvision.models.vision_transformer import ( Conv2dNormActivation, @@ -12,6 +15,8 @@ _log_api_usage_once, ) +from minerva.utils.position_embedding import get_2d_sincos_pos_embed + class _Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" @@ -322,3 +327,439 @@ def forward(self, x: torch.Tensor): x = x.reshape(B, n_h, n_w, C).permute(0, 3, 1, 2).contiguous() return x + + +class MaskedAutoencoderViT(L.LightningModule): + """ + Masked Autoencoder with VisionTransformer backbone. + + Args: + img_size (int): Size of input image. + patch_size (int): Size of image patch. + in_chans (int): Number of input channels. + embed_dim (int): Dimension of token embeddings. + depth (int): Number of transformer blocks. + num_heads (int): Number of attention heads. + decoder_embed_dim (int): Dimension of decoder embeddings. + decoder_depth (int): Number of decoder transformer blocks. + decoder_num_heads (int): Number of decoder attention heads. + mlp_ratio (float): Ratio of MLP hidden layer size to embedding size. + norm_layer (torch.nn.LayerNorm): Normalization layer. + norm_pix_loss (bool): Whether to normalize pixel loss. + + References: + - timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm + - DeiT: https://github.com/facebookresearch/deit + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=1, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=nn.LayerNorm, + norm_pix_loss=False, + ): + super().__init__() + + # MAE encoder specifics + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False + ) # fixed sin-cos embedding + self.in_chans = in_chans + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) + for _ in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, decoder_embed_dim), + requires_grad=False, + ) # fixed sin-cos embedding + + self.decoder_blocks = nn.ModuleList( + [ + Block( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) + for _ in range(decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear( + decoder_embed_dim, patch_size**2 * in_chans, bias=True + ) # decoder to patch + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self): + # Initialization + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) + ) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + torch.nn.init.normal_(self.cls_token, std=0.02) + torch.nn.init.normal_(self.mask_token, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): # input: (32, 1, 224, 224) + """ + Extract patches from input images. + + Args: + imgs (torch.Tensor): Input images of shape (N, C, H, W). + + Returns: + torch.Tensor: Patches of shape (N, num_patches, patch_size^2 * in_chans). + """ + p = self.patch_embed.patch_size[0] + assert ( + imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + ) # only square images are supported, and the size must be divisible by the patch size + + h = w = imgs.shape[2] // p + x = imgs.reshape( + (imgs.shape[0], self.in_chans, h, p, w, p) + ) # Transform images into (32, 1, 14, 16, 14, 16) + x = torch.einsum("nchpwq->nhwpqc", x) # reshape into (32, 14, 14, 16, 16, 1) + x = x.reshape( + (imgs.shape[0], h * w, p**2 * self.in_chans) + ) # Transform into (32, 196, 256) + return x + + def unpatchify(self, x): + """ + Reconstruct images from patches. + + Args: + x (torch.Tensor): Patches of shape (N, L, patch_size^2 * in_chans). + + Returns: + torch.Tensor: Reconstructed images of shape (N, C, H, W). + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape((x.shape[0], h, w, p, p, 3)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape((x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + + Args: + x (torch.Tensor): Input tensor of shape (N, L, D). + mask_ratio (float): Ratio of values to mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Masked input, binary mask, shuffled indices. + """ + N, L, D = x.shape + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) + + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + mask = torch.ones(N, L, device=x.device) + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, mask_ratio): + """ + Forward pass through the encoder. + + Args: + x (torch.Tensor): Input tensor of shape (N, C, H, W). + mask_ratio (float): Ratio of values to mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Encoded representation, binary mask, shuffled indices. + """ + x = self.patch_embed(x) + x = x + self.pos_embed[:, 1:, :] + + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore): + """ + Forward pass through the decoder. + + Args: + x (torch.Tensor): Input tensor of shape (N, L, D). + ids_restore (torch.Tensor): Indices to restore the original order of patches. + + Returns: + torch.Tensor: Decoded output tensor of shape (N, L, patch_size^2 * in_chans). + """ + x = self.decoder_embed(x) + + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 + ) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) + x_ = torch.gather( + x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) + x = torch.cat([x[:, :1, :], x_], dim=1) + + x = x + self.decoder_pos_embed + + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + x = self.decoder_pred(x) + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + Calculate the loss. + + Args: + imgs (torch.Tensor): Input images of shape (N, C, H, W). + pred (torch.Tensor): Predicted output of shape (N, L, patch_size^2 * in_chans). + mask (torch.Tensor): Binary mask of shape (N, L). + + Returns: + torch.Tensor: Computed loss value. + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) + loss = (loss * mask).sum() / mask.sum() + return loss + + def forward(self, imgs, mask_ratio=0.75): + """ + Forward pass. + + Args: + imgs (torch.Tensor): Input images of shape (N, C, H, W). + mask_ratio (float): Ratio of values to mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Loss value, predicted output, binary mask. + """ + latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) + pred = self.forward_decoder(latent, ids_restore) + loss = self.forward_loss(imgs, pred, mask) + return loss, pred, mask + + def training_step(self, batch, batch_idx): + """ + Training step. + + Args: + batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. + batch_idx (int): Index of the current batch. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step. + """ + imgs, _ = batch + loss, _, _ = self(imgs) + self.log("train_loss", loss) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + """ + Validation step. + + Args: + batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. + batch_idx (int): Index of the current batch. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step. + """ + imgs, _ = batch + loss, _, _ = self(imgs) + self.log("val_loss", loss) + return {"val_loss": loss} + + def configure_optimizers(self): + """ + Configure optimizer. + + Returns: + torch.optim.Optimizer: Optimizer. + """ + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + + +# Define model architectures + +# mae_vit_small_patch16_dec512d8b +# decoder: 512 dim, 8 blocks, depth: 6 +mae_vit_small_patch16 = partial( + MaskedAutoencoderViT, + patch_size=16, + embed_dim=768, + depth=6, + num_heads=12, + decoder_embed_dim=512, + decoder_depth=4, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), +) + +# mae_vit_base_patch16_dec512d8b +# decoder: 512 dim, 8 blocks, +mae_vit_base_patch16 = partial( + MaskedAutoencoderViT, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), +) + +# mae_vit_large_patch16_dec512d8b +# decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = partial( + MaskedAutoencoderViT, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), +) + +# mae_vit_huge_patch14_dec512d8b +# decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = partial( + MaskedAutoencoderViT, + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), +) + +# mae_vit_large_patch16_dec256d4b +# decoder: 256 dim, 8 blocks +mae_vit_large_patch16D4d256 = partial( + MaskedAutoencoderViT, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=256, + decoder_depth=4, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), +) + + +# mae_vit_base_patch16_dec256d4b +mae_vit_base_patch16D4d256 = partial( + MaskedAutoencoderViT, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=256, + decoder_depth=4, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), +) diff --git a/tests/models/nets/test_cnn_ha_etal.py b/tests/models/nets/test_cnn_ha_etal.py index 8678624..c37d166 100644 --- a/tests/models/nets/test_cnn_ha_etal.py +++ b/tests/models/nets/test_cnn_ha_etal.py @@ -1,5 +1,7 @@ import torch -from minerva.models.nets.cnn_ha_etal import CNN_HaEtAl_1D, CNN_HaEtAl_2D + +from minerva.models.nets.cnns import CNN_HaEtAl_1D, CNN_HaEtAl_2D + def test_cnn_ha_etal_1d_forward(): input_shape = (1, 6, 60) @@ -18,4 +20,4 @@ def test_cnn_ha_etal_2d_forward(): x = torch.rand(1, *input_shape) y = model(x) - assert y is not None \ No newline at end of file + assert y is not None diff --git a/tests/models/nets/test_cnn_pf.py b/tests/models/nets/test_cnn_pf.py index f8656e9..04602eb 100644 --- a/tests/models/nets/test_cnn_pf.py +++ b/tests/models/nets/test_cnn_pf.py @@ -1,5 +1,7 @@ import torch -from minerva.models.nets.cnn_pf import CNN_PF_2D, CNN_PFF_2D + +from minerva.models.nets.cnns import CNN_PF_2D, CNN_PFF_2D + def test_cnn_pf_forward(): input_shape = (1, 6, 60) @@ -18,4 +20,4 @@ def test_cnn_ha_pff_forward(): x = torch.rand(1, *input_shape) y = model(x) - assert y is not None \ No newline at end of file + assert y is not None diff --git a/tests/models/nets/test_resnet_1d.py b/tests/models/nets/test_resnet_1d.py index 885262c..86c23aa 100644 --- a/tests/models/nets/test_resnet_1d.py +++ b/tests/models/nets/test_resnet_1d.py @@ -1,5 +1,7 @@ import torch -from minerva.models.nets.resnet_1d import ResNet1D_8, ResNetSE1D_8, ResNetSE1D_5 + +from minerva.models.nets.resnet import ResNet1D_8, ResNetSE1D_5, ResNetSE1D_8 + def test_resnet_1d_8_forward(): input_shape = (6, 60) @@ -13,8 +15,8 @@ def test_resnet_1d_8_forward(): x = torch.rand(1, *input_shape) y = model(x) assert y is not None - - + + def test_resnet_se_1d_8_forward(): input_shape = (6, 60) model = ResNetSE1D_8( @@ -27,7 +29,8 @@ def test_resnet_se_1d_8_forward(): x = torch.rand(1, *input_shape) y = model(x) assert y is not None - + + def test_resnet_se_1d_5_forward(): input_shape = (6, 60) model = ResNetSE1D_5( @@ -39,4 +42,4 @@ def test_resnet_se_1d_5_forward(): x = torch.rand(1, *input_shape) y = model(x) - assert y is not None \ No newline at end of file + assert y is not None diff --git a/tests/models/nets/test_sfm.py b/tests/models/nets/test_sfm.py index 46f6ed2..f9cb5d0 100644 --- a/tests/models/nets/test_sfm.py +++ b/tests/models/nets/test_sfm.py @@ -1,7 +1,7 @@ import pytest import torch -from minerva.models.nets.sfm import ( +from minerva.models.nets.vit import ( mae_vit_base_patch16, mae_vit_base_patch16D4d256, mae_vit_huge_patch14,