-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNNEnsemble.jl
38 lines (32 loc) · 1.49 KB
/
NNEnsemble.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
include("neuralnets.jl")
struct TCNEnsemble
models
optimizers
end
Base.length(m::TCNEnsemble) = length(m.models)
Base.show(io, m::TCNEnsemble) = Base.show(io, "TCNEnsemble(N=$(length(m)))")
function train_ensemble!(
ensemble::TCNEnsemble, dgp, dtY;
epochs=1_000, batchsize=32, passes_per_batch=2, dev=cpu, loss=rmse_conv,
validation_loss=true, validation_frequency=10, validation_size=2_000, verbosity=1,
transform=true
)
Flux.trainmode!(ensemble)
for (i, (m, o)) ∈ enumerate(zip(ensemble.models, ensemble.optimizers))
if validation_loss
_, bm = train_cnn!(m, o, dgp, dtY, epochs=epochs, batchsize=batchsize,
passes_per_batch=passes_per_batch, dev=dev, loss=loss,
validation_loss=validation_loss, validation_frequency=validation_frequency,
validation_size=validation_size, verbosity=verbosity, transform=transform)
ensemble.models[i] = bm
else
train_cnn!(m, o, dgp, dtY, epochs=epochs, batchsize=batchsize,
passes_per_batch=passes_per_batch, dev=dev, loss=loss,
validation_loss=validation_loss, validation_frequency=validation_frequency,
validation_size=validation_size, verbosity=verbosity, transform=transform)
end
end
end
Flux.trainmode!(e::TCNEnsemble) = [Flux.trainmode!(m) for m ∈ e.models]
Flux.testmode!(e::TCNEnsemble) = [Flux.testmode!(m) for m ∈ e.models]
(e::TCNEnsemble)(X) = mean(m(X) for m ∈ e.models)