Skip to content

Commit

Permalink
Fix matrix shapes for correct broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSchwenkbeck committed Dec 18, 2024
1 parent be6ffe4 commit d5e5e0f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions src/layers/feedforward_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ pub struct FeedForwardLayer {
}
impl FeedForwardLayer {
// init with random values
pub fn new(batch_size: usize,input_size: usize, output_size: usize, dropout_rate: f32) -> FeedForwardLayer {

pub fn new(
_batch_size: usize,
input_size: usize,
output_size: usize,
dropout_rate: f32,
) -> FeedForwardLayer {
let hidden_size = input_size * 4; // Define the hidden layer size

// He (Kaiming) initialization for weights
Expand Down Expand Up @@ -72,7 +76,6 @@ impl FeedForwardLayer {
Ok(valid_reshaped_x) => {
let dot = valid_reshaped_x.dot(&self.weights1);


let add = dot + &self.bias1;

// First linear layer + gelu
Expand Down Expand Up @@ -128,13 +131,10 @@ fn he_initialization(input_size: usize, output_size: usize) -> Array2<f32> {
Array2::from_shape_vec((input_size, output_size), values).unwrap()
}


fn bias_initialization(size: usize) -> Array1<f32> {
Array1::zeros(size)
}



fn test_bias_initialization() {
let size = 5;

Expand Down Expand Up @@ -166,7 +166,7 @@ fn test_feedforward_forward() {
];

// Create a FeedForwardLayer instance
let feed_forward_layer = FeedForwardLayer::new(2,4, 4, 0.1);
let feed_forward_layer = FeedForwardLayer::new(2, 4, 4, 0.1);

// Feed forward through the layer
let feed_forward_output = feed_forward_layer.forward(input.clone());
Expand Down
2 changes: 1 addition & 1 deletion src/model/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ fn test_decoding() {
let beta = array![[0.0, 0.0, 0.0, 0.0]];

// Dummy FeedForwardLayer
let feed_forward_layer = FeedForwardLayer::new(2,4, 4, 0.1);
let feed_forward_layer = FeedForwardLayer::new(2, 4, 4, 0.1);
let epsilon = 1e-6;
let enc_out = encoding(
input.clone(),
Expand Down
2 changes: 1 addition & 1 deletion src/model/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ fn test_encoding() {
let beta = array![[0.0, 0.0, 0.0, 0.0]];

// Dummy FeedForwardLayer
let feed_forward_layer = FeedForwardLayer::new(2,4, 4, 0.1);
let feed_forward_layer = FeedForwardLayer::new(2, 4, 4, 0.1);

// Call the encoding function
let epsilon = 1e-6;
Expand Down

0 comments on commit d5e5e0f

Please sign in to comment.