-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
91 lines (73 loc) · 3.66 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch import nn
from torchvision import models
import torch.nn.functional as F
# Read this paper? https://arxiv.org/pdf/1804.06208
class CardPoseCNN(nn.Module):
def __init__(self, filters=[64, 128, 256, 384], map_size=(15, 20), temperature=0.5, n_bins=52):
super(CardPoseCNN, self).__init__()
self.n_bins = n_bins
self.temperature = temperature
self.map_size = map_size
# Down path (normal convolutions)
self.conv1 = nn.Conv2d(3, filters[0], kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1)
# Up path (upsampling + convolutions)
self.upsample = nn.Upsample(mode='bilinear', scale_factor=2, align_corners=True)
self.up_conv1 = nn.Conv2d(filters[3] + filters[1], filters[2], kernel_size=1)
self.up_conv2 = nn.Conv2d(filters[2] + filters[0], filters[1], kernel_size=1)
self.up_conv3 = nn.Conv2d(filters[1], 8, kernel_size=1)
# Pooling and activations
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.adaptAvgPool = nn.AdaptiveAvgPool2d(1)
self.adaptMaxPool = nn.AdaptiveMaxPool2d(self.map_size)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(2)
self.flat = nn.Flatten()
self.fc1 = nn.Linear((map_size[0]*map_size[1]*8) + filters[3], 512)
self.fc2 = nn.Linear(512, 256)
self.fc_out_detect = nn.Linear(256, 8)
self.fc_out_ncards = nn.Linear(256, self.n_bins) if n_bins > 0 else nn.Linear(256, 1)
def forward(self, x):
x = self.conv1(x)
x_conv1 = self.relu(x)
x_conv1 = self.maxpool(x_conv1)
x = self.conv2(x_conv1)
x_conv2 = self.relu(x)
x_conv2 = self.maxpool(x_conv2)
x = self.conv3(x_conv2)
x_conv3 = self.relu(x)
x_conv3 = self.maxpool(x_conv3)
x = self.conv4(x_conv3)
x_conv4 = self.relu(x)
# Predicting the mask of the cards pack (this is generating a 32 x 48 map)
x_mask = self.upsample(x_conv4)
x_mask = torch.cat([x_mask, x_conv2], dim=1)
x_mask = self.up_conv1(x_mask)
x_mask = self.upsample(x_mask)
x_mask = torch.cat([x_mask, x_conv1], dim=1)
x_mask = self.up_conv2(x_mask)
x_mask = self.up_conv3(x_mask)
x_mask = x_mask / self.temperature
# x_mask = self.sigmoid(x_mask) # comment this depending on the loss you use?
x_mask = self.softmax(x_mask.view(*x_mask.size()[:2], -1)).view_as(x_mask)
# Predicting point visibility and number of cards
# combine encoder + detached decoder outputs... will it work?
x_reg = torch.cat([self.flat(self.adaptAvgPool(x_conv4)), self.flat(self.adaptMaxPool(x_mask.detach()))], dim=1)
x_reg = self.fc1(x_reg)
x_reg = self.fc2(x_reg)
# Point visibility (detected: yes/no)
out_visibility = self.sigmoid(self.fc_out_detect(x_reg))
if self.n_bins > 0:
# Number of cards: a probability distribution instead of a single scalar?)
# Based on these sources:
# - https://indatalabs.com/blog/head-pose-estimation-with-cv
# - THIS paper: https://arxiv.org/pdf/1710.00925.pdf
out_ncards = self.fc_out_ncards(x_reg) / self.temperature
else:
# A single number (percentage of cards, between 0 and 1)
out_ncards = self.sigmoid(self.fc_out_ncards(x_reg))
return x_mask, out_visibility, out_ncards