Skip to content

Commit

Permalink
Merge pull request #61 from LovePelmeni/schedulers
Browse files Browse the repository at this point in the history
Schedulers
  • Loading branch information
LovePelmeni authored Jun 13, 2024
2 parents 6b37dd1 + 3e724bd commit 932fc65
Show file tree
Hide file tree
Showing 6 changed files with 517 additions and 23 deletions.
55 changes: 51 additions & 4 deletions src/preprocessing/crop_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

cv2.setNumThreads(0)


def get_train_crop_augmentations(CROP_IMAGE_SIZE: int) -> albumentations.Compose:
def get_train_crop_augmentations(
crop_image_size: int,
normalization_means: typing.Tuple[float, float, float],
normalization_stds: typing.Tuple[float, float, float]) -> albumentations.Compose:
"""
Returns augmentations for training data
Expand Down Expand Up @@ -51,11 +53,18 @@ def get_train_crop_augmentations(CROP_IMAGE_SIZE: int) -> albumentations.Compose
target_size=CROP_IMAGE_SIZE
),
], p=1),
albumentations.Normalize(
mean=normalization_means,
std=normalization_stds
)
]
)


def get_validation_crop_augmentations(CROP_IMAGE_SIZE: int) -> albumentations.Compose:
def get_validation_crop_augmentations(
crop_image_size: int,
normalization_means: typing.Tuple[float, float, float],
normalization_stds: typing.Tuple[float, float, float]) -> albumentations.Compose:
"""
Returns augmentations for training data
Expand All @@ -76,6 +85,44 @@ def get_validation_crop_augmentations(CROP_IMAGE_SIZE: int) -> albumentations.Co
interpolation_up=cv2.INTER_LINEAR,
target_size=CROP_IMAGE_SIZE
),
albumentations.HorizontalFlip(p=0.5)
albumentations.HorizontalFlip(p=0.5),
albumentations.Normalize(
mean=normalization_means,
std=normalization_stds
)
]
)

def get_inference_augmentations(
crop_image_size: int,
normalization_means: typing.Tuple[float, float, float],
normalization_stds: typing.Tuple[float, float, float]) -> albumentations.Compose:
"""
Set of inference preprocessing transformations.
Parameters:
-----------
crop_image_size (int) - size of the cropped image
"""
return albumentations.OneOf(
transforms=[
resize.IsotropicResize(
target_size=crop_image_size,
interpolation_up=cv2.INTER_LINEAR,
interpolation_down=cv2.INTER_CUBIC
),
resize.IsotropicResize(
target_size=crop_image_size,
interpolation_up=cv2.INTER_LINEAR,
interpolation_down=cv2.INTER_NEAREST
),
resize.IsotropicResize(
target_size=crop_image_size,
interpolation_up=cv2.INTER_CUBIC,
interpolation_down=cv2.INTER_NEAREST
),
albumentations.Normalize(
mean=normalization_means,
std=normalization_stds
)
]
)
147 changes: 128 additions & 19 deletions src/training/classifiers/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
tf_efficientnetv2_b2,
_cfg
)
from torchvision.models import (
resnet50,
resnet101,
ResNet50_Weights,
ResNet101_Weights
)

from functools import partial
import typing

Expand All @@ -21,6 +28,17 @@
}
}

resnet_encoder_params = {
"resnet_101": {
"features": 2048,
"encoder": partial(resnet101, weights=ResNet101_Weights),
},
"resnet_50": {
"features": 2048,
"encoder": partial(resnet50, weights=ResNet50_Weights)
}
}


class DeepfakeClassifier(nn.Module):
"""
Expand All @@ -41,35 +59,39 @@ class DeepfakeClassifier(nn.Module):

def __init__(self, # 256x256
input_channels: int,
encoder_name: str
encoder_name: str,
device: str = 'cpu'
):
super(DeepfakeClassifier, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=input_channels,
out_channels=input_channels,
bias=False
)
).to(device)
self.encoder = encoder_params[encoder_name]['encoder']()
self.avgpool1 = nn.AdaptiveAvgPool2d((1, 1))
self.dropout1 = nn.Dropout()
self.avgpool1 = nn.AdaptiveAvgPool2d((1, 1)).to(device)
self.dropout1 = nn.Dropout().to(device)
self.dense1 = nn.Linear(
in_features=encoder_params[encoder_name]['features'],
out_features=128,
bias=True
bias=True,
device=device
)
self.relu1 = nn.ReLU()
self.relu1 = nn.ReLU().to(device)
self.dense2 = nn.Linear(
in_features=128,
out_features=64,
bias=True
bias=True,
device=device
)
self.relu2 = nn.ReLU()
self.relu2 = nn.ReLU().to(device)
self.dense3 = nn.Linear(
in_features=64,
out_features=1,
bias=True
bias=True,
device=device
)
self.sigmoid = nn.Sigmoid()
self.sigmoid = nn.Sigmoid().to(device)

def forward(self, input_map: torch.Tensor):
output = self.conv1(input_map)
Expand All @@ -96,7 +118,8 @@ def __init__(self,
encoder_name: str = list(encoder_params.keys())[-1],
num_classes: int = 2,
encoder_pretrained_config: typing.Dict = None,
dropout_rate: float = 0.5
dropout_rate: float = 0.5,
device: str = 'cpu'
):
super(DeepfakeClassifierSRM, self).__init__()

Expand All @@ -113,29 +136,32 @@ def __init__(self,
pretrained_cfg = None

self.encoder_name = encoder_name
self.srm_conv = srm_conv.SRMConv(in_channels=input_channels)
self.srm_conv = srm_conv.SRMConv(in_channels=input_channels).to(device)
self.encoder = encoder_params[encoder_name]['encoder'](
pretrained_cfg=_cfg(pretrained_cfg))
self.avgpool1 = nn.AdaptiveAvgPool2d((1, 1)) # x x 1 x 1
self.dropout1 = nn.Dropout(p=dropout_rate)
self.avgpool1 = nn.AdaptiveAvgPool2d((1, 1)).to(device) # x x 1 x 1
self.dropout1 = nn.Dropout(p=dropout_rate).to(device)
self.dense1 = nn.Linear(
in_features=encoder_params[encoder_name]['features'],
out_features=64,
bias=True,
device=device
)
self.relu1 = nn.ReLU()
self.relu1 = nn.ReLU().to(device)
self.dense2 = nn.Linear(
in_features=64,
out_features=32,
bias=True
bias=True,
device=device
)
self.relu2 = nn.ReLU()
self.relu2 = nn.ReLU().to(device)
self.dense3 = nn.Linear(
in_features=32,
out_features=num_classes,
bias=True
bias=True,
device=device
)
self.softmax = nn.Softmax(dim=1)
self.softmax = nn.Softmax(dim=1).to(device)

def forward(self, input_map: torch.Tensor):
noise = self.srm_conv(input_map)
Expand Down Expand Up @@ -248,3 +274,86 @@ def forward(self, image: torch.Tensor):
output = self.dense2(output)
probs = self.sigmoid(output)
return probs

class DistilDeepfakeClassifierSRM(nn.Module):
"""
Implementation of the light weight
encoder network with reduced number of parameters.
This architecture is designed to serve as a distilled
version of the original DeepfakeClassifierSRM.
Parameters:
-----------
encoder_name - name of the resnet encoder.
dropout_prob - dropout probability.
num_classes - number of output classes.
input_channels - number of channels in the input images.
either 1 (for grayscale or binary images) or 3 (for RGB images).
NOTE:
the output of this model contains raw logits.
If you want to convert them to probability distribution,
make sure to apply nn.Softmax(dim=-1).
"""
def __init__(self,
input_channels: int,
encoder_name: str,
num_classes: int,
dropout_prob: float = 0.5
):
super(DistilDeepfakeClassifierSRM, self).__init__()
self.encoder_name = encoder_name
self.encoder_out_features = resnet_encoder_params[encoder_name]['features']

self.srm_conv = srm_conv.SRMConv(in_channels=input_channels)
self.encoder = resnet_encoder_params[encoder_name]['encoder']()

if not hasattr(self.encoder, 'fc'):
raise ValueError(
"""invalid feature extraction model.
Should be torchvision.models
"""
)
else:
self.encoder.fc = nn.Identity()

self.avg_pool1 = nn.AdaptiveAvgPool2d((1, 1))
self.dense_block1 = nn.Sequential(
nn.Dropout(p=dropout_prob),
nn.Linear(
in_features=self.encoder_out_features,
out_features=self.encoder_out_features//2,
bias=True
),
nn.BatchNorm1d(num_features=self.encoder_out_features//2),
nn.ReLU()
)
self.dense_block2 = nn.Sequential(
nn.Dropout(p=dropout_prob),
nn.Linear(
in_features=self.encoder_out_features//2,
out_features=self.encoder_out_features//4,
bias=True
),
nn.BatchNorm1d(
num_features=self.encoder_out_features//4,
track_running_stats=True
),
nn.ReLU()
)
self.dropout3 = nn.Dropout(p=dropout_prob)
self.proj_head = nn.Linear(
in_features=self.encoder_out_features//4,
out_features=num_classes,
bias=False
)

def forward(self, input_images: torch.Tensor):
out_srm_features = self.srm_conv(input_images)
encoder_features = self.encoder(out_srm_features)
pooled_features = self.avg_pool1(encoder_features)
pooled_features = self.dense_block1(pooled_features)
pooled_features = self.dense_block2(pooled_features)
pooled_features = self.dropout3(pooled_features)
proj_features = self.proj_head(pooled_features)
return proj_features
7 changes: 7 additions & 0 deletions src/training/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import logging

from src.exceptions import exceptions

logger = logging.getLogger("dataset_logger")


Expand Down Expand Up @@ -46,6 +48,11 @@ def __init__(self,
self.image_labels = image_labels
self.data_type = data_type

if any([cv2.imread(img, cv2.IMREAD_UNCHANGED) is None for img in image_paths]):
raise exceptions.InvalidSourceError(
msg='some of the image paths provided are not valid'
)

def __len__(self):
return len(self.image_paths)

Expand Down
43 changes: 43 additions & 0 deletions src/training/schedulers/lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,46 @@ def get_lr(self):
return self.base_lrs

return [base_lr * torch.exp(-self.gamma * self.last_epoch) for base_lr in self.base_lrs]

# -------------------
# Warmup LR schedulers.

class LinearWarmup(_LRScheduler):
"""
Implementation of the Linear Warmup
Learning Rate Scheduler.
Parameters:
-----------
num_iters: int - total number of iterations.
min_lr: float - minimum learning rate.
t_mult: int - number of iterations * iters[-1], which determines
next time update is going to be made.
"""
def __init__(self,
optimizer: nn.Module,
num_iters: int,
init_lr: float,
target_lr: float,
):
self.num_iters: int = num_iters
self.init_lr: float = init_lr
self.target_lr: float = target_lr
self.curr_iter: int = 0
super(LinearWarmup, self).__init__(optimizer=optimizer)

def step(self, curr_step: int = 1):
self.curr_iter = curr_step
super(LinearWarmup, self).step()

def get_lr(self):
return [
(
base_lr +
self.curr_iter * (
(self.target_lr - self.init_lr
) / self.num_iters
)
)
for base_lr in self.base_lrs
]
Loading

0 comments on commit 932fc65

Please sign in to comment.