-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTCN.jl
60 lines (57 loc) · 2.51 KB
/
TCN.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -----------------------------------------
# Temporal Convolutional Network in Flux.jl
# Author: Jonathan Chassot, May 17, 2022
# -----------------------------------------
# Reference:
# Shaojie Bai, J. Zico Kolter, Vladlen Koltun. (2018)
# An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling
# https://arxiv.org/abs/1803.01271
# -----------------------------------------
# Note that this gist uses batch normalization instead of weight normalization as in the original paper
# -----------------------------------------
# Temporal blocks which compose the layers of the TCN
function TemporalBlock(
chan_in::Int, chan_out::Int;
dilation::Int, kernel_size::Int,
residual::Bool = true, pad = SamePad()
)
# Causal convolutions
causal_conv = Chain(
Conv((1, kernel_size), chan_in => chan_out, dilation = dilation,
pad = pad),
BatchNorm(chan_out, relu),
Conv((1, kernel_size), chan_out => chan_out, dilation = dilation,
pad = pad),
BatchNorm(chan_out, relu),
)
residual || return causal_conv
# Skip connection (residual net)
residual_conv = Conv((1, 1), chan_in => chan_out)
Chain(
Parallel(+, causal_conv, residual_conv),
x -> relu.(x)
)
end
# Temporal Convolutional Network with `length(channels) - 1` layers
# e.g., `TCN([1, 8, 8, 1], kernel_size = 3)` constructs a TCN with 3 TemporalBlock layers:
# 1.) 1 => 8, dilation = 2⁰ = 1
# 2.) 8 => 8, dilation = 2¹ = 2
# 3.) 8 => 1, dilation = 2² = 4
# each of them with `kernel_size = 3`
function TCN(
channels::AbstractVector{Int};
kernel_size::Int, dilation_factor::Int = 2,
residual::Bool = true, pad = SamePad()
)
Chain([TemporalBlock(chan_in, chan_out, dilation = dilation_factor ^ (i - 1),
kernel_size = kernel_size, residual = residual,
pad = pad)
for (i, (chan_in, chan_out)) ∈ enumerate(zip(channels[1:end-1], channels[2:end]))]...)
end
# Computes the receptive field size for a specified dilation, kernel size, and number of layers
receptive_field_size(dilation::Int, kernel_size::Int, layers::Int) =
1 + (kernel_size - 1) * (dilation ^ layers - 1) / (dilation - 1)
# Minimum number of layers necessary to achieve a specified receptive field size
# (take ceil(Int, necessary_layers(...)) for final number of layers)
necessary_layers(dilation::Int, kernel_size::Int, receptive_field::Int) =
log(dilation, (receptive_field - 1) * (dilation - 1) / (kernel_size - 1)) + 1