forked from etiennedupont/simple-object-detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
transforms.py
36 lines (28 loc) · 1.1 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from typing import Tuple, Dict, Optional
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F._get_image_size(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
return image, target
class ToTensor(nn.Module):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image)
return image, target