Skip to content

Commit

Permalink
Implement FeedForwardLayer with forward pass, dropout, and He initial…
Browse files Browse the repository at this point in the history
…ization

- Added forward pass with GELU activation
- Integrated dropout during training
- Implemented He initialization for weights
  • Loading branch information
JakubSchwenkbeck committed Dec 11, 2024
1 parent 8cdac62 commit 0e84ccf
Showing 1 changed file with 59 additions and 8 deletions.
67 changes: 59 additions & 8 deletions src/layers/feedforward_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::activation::activation_functions::gelu;
use ndarray::Array2;
use rand::Rng;
use std::ops::Add;

pub struct FeedForwardLayer {
Expand All @@ -19,27 +20,77 @@ impl FeedForwardLayer {
input_size: usize,
output_size: usize,
hidden_dim: usize,
dropout: f32,
dropout_rate: f32,
) -> FeedForwardLayer {
let weights1 = Array2::<f32>::zeros((input_size, hidden_dim));
let bias1 = Array2::<f32>::zeros((input_size, input_size));
let weights2 = Array2::<f32>::zeros((hidden_dim, output_size));
let bias2 = Array2::<f32>::zeros((input_size, input_size));
// He (Kaiming) initialization for weights
let weights1 = he_initialization(input_size, hidden_dim);
let bias1 = bias_initialization(hidden_dim);

let weights2 = he_initialization(hidden_dim, output_size);
let bias2 = bias_initialization(output_size);

FeedForwardLayer {
weights1,
bias1,
weights2,
bias2,
dropout_rate: dropout,
dropout_rate,
}
}

pub fn forward_t(&self, input: &Array2<f32>) -> Array2<f32> {
pub fn forward_t(&self, input: &Array2<f32>, train: bool) -> Array2<f32> {
// First linear layer
let first_dot = input.dot(&self.weights1);
let first_output = first_dot.add(&self.bias1);
let first_activation = gelu(&first_output);

gelu(&first_activation.dot(&self.weights2).add(&self.bias2))
// Dropout (only during training)
let first_activation = if train {
self.apply_dropout(&first_activation)
} else {
first_activation
};

// Second linear layer
first_activation.dot(&self.weights2).add(&self.bias2)
}
fn apply_dropout(&self, input: &Array2<f32>) -> Array2<f32> {
let mut rng = rand::rng();
input.map(|&x| {
if rng.random::<f32>() < self.dropout_rate {
0.0
} else {
x
}
})
}
}

fn he_initialization(input_size: usize, output_size: usize) -> Array2<f32> {
let mut rng = rand::rng();
// He initialization: scale by sqrt(2 / input_size)
let scale = (2.0 / input_size as f32).sqrt();
let values: Vec<f32> = (0..(input_size * output_size))
.map(|_| rng.random_range(-scale..scale))
.collect();

// Create an Array2 from the values vector
Array2::from_shape_vec((input_size, output_size), values).unwrap()
}

fn bias_initialization(size: usize) -> Array2<f32> {
Array2::zeros((size, 1)) // Biases are usually initialized to zero
}
fn test_bias_initialization() {
let size = 5;

let bias = bias_initialization(size);

// Check that the dimensions are correct (size x 1)
assert_eq!(bias.shape(), &[size, 1]);

// Check that all values in the bias array are 0.0
for &value in bias.iter() {
assert_eq!(value, 0.0);
}
}

0 comments on commit 0e84ccf

Please sign in to comment.