-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathResNet.R
84 lines (72 loc) · 2.63 KB
/
ResNet.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
ResNet <- torch::nn_module(
name='ResNet',
initialize=function(catFeatures, numFeatures=0, sizeEmbedding, sizeHidden, numLayers,
hiddenFactor, activation=torch::nn_relu,
normalization=torch::nn_batch_norm1d, hiddenDropout=NULL,
residualDropout=NULL, d_out=1) {
self$embedding <- torch::nn_embedding_bag(num_embeddings = catFeatures + 1,
embedding_dim = sizeEmbedding,
padding_idx = 1)
self$first_layer <- torch::nn_linear(sizeEmbedding + numFeatures, sizeHidden)
resHidden <- sizeHidden * hiddenFactor
self$layers <- torch::nn_module_list(lapply(1:numLayers,
function (x) ResLayer(sizeHidden, resHidden,
normalization, activation,
hiddenDropout,
residualDropout)))
self$lastNorm <- normalization(sizeHidden)
self$head <- torch::nn_linear(sizeHidden, d_out)
self$lastAct <- activation()
},
forward=function(x_cat) {
x_cat <- self$embedding(x_cat + 1L)
x_num <- NULL
if (!is.null(x_num)) {
x <- torch::torch_cat(list(x_cat, x_num), dim=2L)
} else {
x <- x_cat
}
x <- self$first_layer(x)
for (i in 1:length(self$layers)) {
x <- self$layers[[i]](x)
}
x <- self$lastNorm(x)
x <- self$lastAct(x)
x <- self$head(x)
x <- x$squeeze(-1)
return(x)
}
)
ResLayer <- torch::nn_module(
name='ResLayer',
initialize=function(sizeHidden, resHidden, normalization,
activation, hiddenDropout=NULL, residualDropout=NULL){
self$norm <- normalization(sizeHidden)
self$linear0 <- torch::nn_linear(sizeHidden, resHidden)
self$linear1 <- torch::nn_linear(resHidden, sizeHidden)
self$activation <- activation
if (!is.null(hiddenDropout)){
self$hiddenDropout <- torch::nn_dropout(p=hiddenDropout)
}
if (!is.null(residualDropout))
{
self$residualDropout <- torch::nn_dropout(p=residualDropout)
}
self$activation <- activation()
},
forward=function(x) {
z <- x
z <- self$norm(z)
z <- self$linear0(z)
z <- self$activation(z)
if (!is.null(self$hiddenDropout)) {
z <- self$hiddenDropout(z)
}
z <- self$linear1(z)
if (!is.null(self$residualDropout)) {
z <- self$residualDropout(z)
}
x <- z + x
return(x)
}
)