RecurrentLayers.jl extends Flux.jl recurrent layers offering by providing implementations of bleeding edge recurrent layers not commonly available in base deep learning libraries. It is designed for a seamless integration with the larger Flux ecosystem, enabling researchers and practitioners to leverage the latest developments in recurrent neural networks.
Currently available layers and work in progress in the short term:
- Minimal gated unit (MGU) arxiv
- Light gated recurrent unit (LiGRU) arxiv
- Independently recurrent neural networks (IndRNN) arxiv
- Recurrent addictive networks (RAN) arxiv
- Recurrent highway network (RHN) arixv
- Light recurrent unit (LightRU) pub
- Neural architecture search unit (NAS) arxiv
- Minimal gated recurrent unit (minGRU) and minimal long short term memory (minLSTM) arxiv
RecurrentLayers.jl is not yet registered. You can install it directly from the GitHub repository:
using Pkg
Pkg.add(url="https://github.com/MartinuzziFrancesco/RecurrentLayers.jl")
The workflow is identical to any recurrent Flux layer:
using RecurrentLayers
using Flux
using MLUtils: DataLoader
using Statistics
using Random
# Create dataset
function create_data(input_size, seq_length::Int, num_samples::Int)
data = randn(input_size, seq_length, num_samples) #(input_size, seq_length, num_samples)
labels = sum(data, dims=(1, 2)) .>= 0
labels = Int.(labels)
labels = dropdims(labels, dims=(1))
return data, labels
end
function create_dataset(input_size, seq_length, n_train::Int, n_test::Int, batch_size)
train_data, train_labels = create_data(input_size, seq_length, n_train)
train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true)
test_data, test_labels = create_data(input_size, seq_length, n_test)
test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false)
return train_loader, test_loader
end
struct RecurrentModel{H,C,D}
h0::H
rnn::C
dense::D
end
Flux.@layer RecurrentModel trainable=(rnn, dense)
function RecurrentModel(input_size::Int, hidden_size::Int)
return RecurrentModel(
zeros(Float32, hidden_size),
MGU(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
end
function (model::RecurrentModel)(inp)
state = model.rnn(inp, model.h0)
state = state[:, end, :]
output = model.dense(state)
return output
end
function criterion(model, batch_data, batch_labels)
y_pred = model(batch_data)
loss = Flux.binarycrossentropy(y_pred, batch_labels)
return loss
end
function train_recurrent!(epoch, train_loader, opt, model, criterion)
total_loss = 0.0
for (batch_data, batch_labels) in train_loader
# Compute gradients and update parameters
grads = gradient(() -> criterion(model, batch_data, batch_labels), Flux.params(model))
Flux.Optimise.update!(opt, Flux.params(model), grads)
# Accumulate loss
total_loss += criterion(model, batch_data, batch_labels)
end
avg_loss = total_loss / length(train_loader)
println("Epoch $epoch/$num_epochs, Loss: $(round(avg_loss, digits=4))")
end
function test_recurrent(test_loader, model)
# Evaluation
correct = 0
total = 0
for (batch_data, batch_labels) in test_loader
# Forward pass
predicted = model(batch_data)
# Decode predictions: convert probabilities to class labels (0 or 1)
predicted_labels = vec(predicted .>= 0.5) # Threshold at 0.5 for binary classification
# Compare predicted labels to actual labels
correct += sum(predicted_labels .== vec(batch_labels))
total += length(batch_labels)
end
accuracy = correct / total
println("Accuracy: ", accuracy * 100, "%")
end
function main(;
input_size = 1, # Each element in the sequence is a scalar
hidden_size = 64, # Size of the hidden state
seq_length = 10, # Length of each sequence
batch_size = 16, # Batch size
num_epochs = 50, # Number of epochs for training
n_train = 1000, # Number of samples in train dataset
n_test = 200 # Number of samples in test dataset)
)
model = RecurrentModel(input_size, hidden_size)
# Generate test data
train_loader, test_loader = create_dataset(input_size, seq_length, n_train, n_test, batch_size)
# Define the optimizer
opt = Adam(0.001)
for epoch in 1:num_epochs
train_recurrent!(epoch, train_loader, opt, model, criterion)
end
test_recurrent(test_loader, model)
end
main()
This project is licensed under the MIT License, except for nas_cell.jl
, which is licensed under the Apache License, Version 2.0.
nas_cell.jl
is a reimplementation of the NASCell from TensorFlow and is licensed under the Apache License 2.0. See the file header andLICENSE-APACHE
for details.- All other files are licensed under the MIT License. See
LICENSE-MIT
for details.