Skip to content

Commit

Permalink
removed debugging prints
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSchwenkbeck committed Dec 19, 2024
1 parent 86dc59a commit 39fdea5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions src/layers/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ pub fn layer_norm(
// Step 1: Calculate mean and variance across the features (axis=1)
let mean = x.mean_axis(Axis(1)).unwrap();
let variance = x.var_axis(Axis(1), 0.0);
println!("Mean: {:?}", mean);
println!("Variance: {:?}", variance);
//println!("Mean: {:?}", mean);
// println!("Variance: {:?}", variance);

let expanded_mean = mean.insert_axis(Axis(1)); // Expands [6] to [6, 1]
let expanded_variance = variance.insert_axis(Axis(1)); // Expands [6] to [6, 1]
println!("EXPMean: {:?}", expanded_mean);
println!("EXPVariance: {:?}", expanded_variance);
// println!("EXPMean: {:?}", expanded_mean);
//println!("EXPVariance: {:?}", expanded_variance);

// Add epsilon to expanded variance
let normalized = (x - &expanded_mean) / (expanded_variance + epsilon).mapv(f32::sqrt);

println!("Normalized {}", normalized);
// println!("Normalized {}", normalized);
// Step 2: Normalize the input
//let normalized = (x - &reshaped_mean) / (reshaped_variance + epsilon).mapv(f32::sqrt);

Expand Down
4 changes: 2 additions & 2 deletions src/model/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn encoding(
dummy_learned_matrices.clone(), // W_O
);

println!("Attention1 :{}", attention_output);
//println!("Attention1 :{}", attention_output);
// Add & Normalize (Residual Connection + Layer Norm)
let attention_residual = attention_output.add(&input); // Residual connection
let reshaped_attention = attention_residual
Expand All @@ -60,7 +60,7 @@ pub fn encoding(
// Feed-Forward Network
let feed_forward_output = feed_forward_layer.forward(attention_norm.clone());

println!("feed_forward_output :{:?}", feed_forward_output);
//println!("feed_forward_output :{:?}", feed_forward_output);
// Add & Normalize (Residual Connection + Layer Norm)
let feed_forward_residual = feed_forward_output.add(&attention_norm); // Residual connection
let reshaped_ff_attention = feed_forward_residual
Expand Down

0 comments on commit 39fdea5

Please sign in to comment.