Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from model.normalizer import NormMode
from model.alexnet import AlexNet
from model.alexnet import AlexNet, AlexNetConfig
154 changes: 96 additions & 58 deletions model/alexnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass
from typing import Mapping, Optional, Union

import torch
from torch import nn
from torch.nn.modules.module import T
from torchvision.models.alexnet import AlexNet_Weights

from model.normalizer import Normalizer
from model.normalizer import Normalizer, NormMode
from utils import Config, get_device


Expand All @@ -14,15 +16,13 @@ class MaxPoolingResults:
feature_size: torch.Size


@dataclass
class AlexNetConfig:
# pylint: disable=too-few-public-methods

def __init__(self, config: Config):
self.in_channels = config.in_channels
self.num_classes = config.num_classes
self.dropout = config.dropout
self.normalization_method = config.normalization_method
self.local_size = config.local_size
in_channels: int
num_classes: int
dropout: float
normalization_method: NormMode
local_size: int


class AlexNet(nn.Module):
Expand All @@ -31,31 +31,32 @@ class AlexNet(nn.Module):
def __init__(self, config: AlexNetConfig):
super().__init__()
self._deconv_eval = False
self.pooling_indicies = []
self.pooling_indices = []
self.conv_activations = []
self.relu = nn.ReLU(inplace=True)
self.norm = Normalizer(config.normalization_method, config.local_size)

self.conv1 = nn.Conv2d(config.in_channels, 96, kernel_size=7, stride=2, padding=2)
self.conv1 = nn.Conv2d(config.in_channels, 64, kernel_size=11, stride=4, padding=2)
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)

self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=2, padding=2)
self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2)
self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)

self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)

self.unpool5 = nn.MaxUnpool2d(kernel_size=3, stride=2)
self.deconv5 = nn.ConvTranspose2d(256, 384, kernel_size=3, stride=1, padding=1)
self.deconv4 = nn.ConvTranspose2d(384, 384, kernel_size=3, stride=1, padding=1)
self.deconv3 = nn.ConvTranspose2d(384, 256, kernel_size=3, stride=1, padding=1)
self.deconv5 = nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1)
self.deconv4 = nn.ConvTranspose2d(256, 384, kernel_size=3, padding=1)
self.deconv3 = nn.ConvTranspose2d(384, 192, kernel_size=3, padding=1)

self.unpool2 = nn.MaxUnpool2d(kernel_size=3, stride=2)
self.deconv2 = nn.ConvTranspose2d(256, 96, kernel_size=5, stride=2, padding=2)
self.deconv2 = nn.ConvTranspose2d(192, 64, kernel_size=5, padding=2)

self.unpool1 = nn.MaxUnpool2d(kernel_size=3, stride=2)
self.deconv1 = nn.ConvTranspose2d(96, config.in_channels, kernel_size=7, stride=2, padding=2)
self.deconv1 = nn.ConvTranspose2d(64, config.in_channels, kernel_size=11, stride=4, padding=2)

self.classifier = nn.Sequential(
nn.Dropout(p=config.dropout),
Expand Down Expand Up @@ -83,41 +84,51 @@ def __initialize_deconv_layer(deconv: nn.Module, conv: nn.Module):
deconv.weight = conv.weight

def __clear_forward_cache(self):
self.pooling_indicies.clear()
self.conv_activations.clear()
self.pooling_indices.clear()

def __add_forward_cache_entry(self, idx, size):
def __add_forward_cache_entry(self, activations: torch.Tensor, idx: Optional[torch.Tensor] = None,
size: Optional[torch.Tensor] = None):
if self._deconv_eval:
self.pooling_indicies.append(MaxPoolingResults(idx, size))
self.conv_activations.append(activations)
if idx is not None and size is not None:
self.pooling_indices.append(MaxPoolingResults(idx, size))

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, list[torch.Tensor]]:
self.__clear_forward_cache()

feat = self.conv1(x)
feat = self.relu(feat)
size1 = feat.size()
feat, idx1 = self.pool1(feat)
self.__add_forward_cache_entry(feat, idx1, size1)
feat = self.norm(feat)
self.__add_forward_cache_entry(idx1, size1)

feat = self.conv2(feat)
feat = self.relu(feat)
size2 = feat.size()
feat, idx2 = self.pool2(feat)
self.__add_forward_cache_entry(feat, idx2, size2)
feat = self.norm(feat)
self.__add_forward_cache_entry(idx2, size2)

feat = self.conv3(feat)
feat = self.relu(feat)
self.__add_forward_cache_entry(feat)

feat = self.conv4(feat)
feat = self.relu(feat)
self.__add_forward_cache_entry(feat)

feat = self.conv5(feat)
feat = self.relu(feat)
size5 = feat.size()
feat, idx5 = self.pool5(feat)
self.__add_forward_cache_entry(feat, idx5, size5)

if self._deconv_eval:
return self.conv_activations

feat = self.norm(feat)
self.__add_forward_cache_entry(idx5, size5)

feat = torch.flatten(feat, 1)
feat = self.classifier(feat)
Expand All @@ -129,42 +140,47 @@ def log_results(feat: torch.Tensor, deconv_layer: nn.ConvTranspose2d) -> dict:
layer_weights = deconv_layer.weight.clone()
return {"activation_maps": feat_maps.detach().cpu(), "kernel_weights": layer_weights.detach().cpu()}

def deconv_forward(self, x: torch.Tensor) -> dict:
if not self._deconv_eval or len(self.pooling_indicies) == 0:
def deconv_forward(self, x: torch.Tensor, layer_level: int = 5) -> dict:
if not self._deconv_eval or len(self.pooling_indices) == 0:
raise ValueError("Model not in deconv mode")

deconv_results = {}
with torch.no_grad():
feat = x

results5 = self.pooling_indicies[2]
idx5, size5 = results5.pooling_indices, results5.feature_size
feat = self.unpool5(feat, idx5, output_size=size5)
feat = self.relu(feat)
feat = self.deconv5(feat)
deconv_results["DeConv5"] = self.log_results(feat, self.deconv5)

feat = self.relu(feat)
feat = self.deconv4(feat)
deconv_results["DeConv4"] = self.log_results(feat, self.deconv4)

feat = self.relu(feat)
feat = self.deconv3(feat)
deconv_results["DeConv3"] = self.log_results(feat, self.deconv3)

results2 = self.pooling_indicies[1]
idx2, size2 = results2.pooling_indices, results2.feature_size
feat = self.unpool2(feat, idx2, output_size=size2)
feat = self.relu(feat)
feat = self.deconv2(feat)
deconv_results["DeConv2"] = self.log_results(feat, self.deconv2)

results1 = self.pooling_indicies[0]
idx1, size1 = results1.pooling_indices, results1.feature_size
feat = self.unpool2(feat, idx1, output_size=size1)
feat = self.relu(feat)
feat = self.deconv1(feat)
deconv_results["DeConv1"] = self.log_results(feat, self.deconv1)
if layer_level > 4:
results5 = self.pooling_indices[2]
idx5, size5 = results5.pooling_indices, results5.feature_size
feat = self.unpool5(feat, idx5, output_size=size5)
feat = self.relu(feat)
feat = self.deconv5(feat)
deconv_results["DeConv5"] = self.log_results(feat, self.deconv5)

if layer_level > 3:
feat = self.relu(feat)
feat = self.deconv4(feat)
deconv_results["DeConv4"] = self.log_results(feat, self.deconv4)

if layer_level > 2:
feat = self.relu(feat)
feat = self.deconv3(feat)
deconv_results["DeConv3"] = self.log_results(feat, self.deconv3)

if layer_level > 1:
results2 = self.pooling_indices[1]
idx2, size2 = results2.pooling_indices, results2.feature_size
feat = self.unpool2(feat, idx2, output_size=size2)
feat = self.relu(feat)
feat = self.deconv2(feat)
deconv_results["DeConv2"] = self.log_results(feat, self.deconv2)

if layer_level > 0:
results1 = self.pooling_indices[0]
idx1, size1 = results1.pooling_indices, results1.feature_size
feat = self.unpool2(feat, idx1, output_size=size1)
feat = self.relu(feat)
feat = self.deconv1(feat)
deconv_results["DeConv1"] = self.log_results(feat, self.deconv1)

return deconv_results

Expand All @@ -191,9 +207,31 @@ def save(self, path: str):
def load(self, path: str):
self.load_state_dict(torch.load(path))

def initialize_from_pretrained(self):
pretrained_weights = AlexNet_Weights.DEFAULT.get_state_dict()
self.__initialize_layer(self.conv1, pretrained_weights, "features.0")
self.__initialize_layer(self.conv2, pretrained_weights, "features.3")
self.__initialize_layer(self.conv3, pretrained_weights, "features.6")
self.__initialize_layer(self.conv4, pretrained_weights, "features.8")
self.__initialize_layer(self.conv5, pretrained_weights, "features.10")
self.__initialize_layer(self.classifier[1], pretrained_weights, "classifier.1")
self.__initialize_layer(self.classifier[4], pretrained_weights, "classifier.4")
self.__initialize_layer(self.classifier[6], pretrained_weights, "classifier.6")

@staticmethod
def __initialize_layer(layer: nn.Module, pretrained_state_dict: Mapping[str, any], pretrained_key: str):
layer.weight = pretrained_state_dict[pretrained_key + ".weight"]
layer.bias = pretrained_state_dict[pretrained_key + ".bias"]

@staticmethod
def get_model_from_config(config: Config) -> T:
alexnet_config = AlexNetConfig(config)
alexnet_config = AlexNetConfig(
in_channels=config.in_channels,
num_classes=config.num_classes,
dropout=config.dropout,
normalization_method=config.normalization_method,
local_size=config.local_size
)
model = AlexNet(alexnet_config)

if config.model_file:
Expand Down
4 changes: 4 additions & 0 deletions train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from model import AlexNet
from train.eval import test_model
from utils import Config
from torch.optim.lr_scheduler import StepLR


def train(config: Config):
Expand All @@ -19,6 +20,7 @@ def train(config: Config):
writer = SummaryWriter(config.result_path)
model = AlexNet.get_model_from_config(config)
optimizer = Adam(model.parameters())
lr_scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
device = model.device

train_loader = get_data_loader(
Expand All @@ -44,6 +46,8 @@ def train(config: Config):

total_loss += loss.item()

lr_scheduler.step()

avg_loss = total_loss / len(train_loader.dataset)
writer.add_scalar("Loss/train", avg_loss, epoch)

Expand Down
4 changes: 2 additions & 2 deletions utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parse_args() -> Config:

parser.add_argument("--data", type=int, default=1, help="Dataset (1) IMAGENET, (2) CIFAR10, (3) CIFAR100")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size to use for training and testing")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for dataloader")
parser.add_argument("--num_workers", type=int, default=6, help="Number of workers for dataloader")
parser.add_argument("--epochs", type=int, default=30, help="Number of epochs for training")
parser.add_argument("--seed", type=int, default=0, help="Random seed used for reproducibility")
parser.add_argument("--train", type=bool, default=False, help="Train or test the model")
Expand All @@ -50,7 +50,7 @@ def parse_args() -> Config:
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability used for training")
parser.add_argument("--normalization_method", type=int, default=1,
help="Normalization method (0) Contrast, (1) Local")
parser.add_argument("--local_size", type=int, default=2, help="Local size for local response normalization")
parser.add_argument("--local_size", type=int, default=5, help="Local size for local response normalization")

return init_config(parser.parse_args())

Expand Down
Loading