diff --git a/Cargo.lock b/Cargo.lock index 4219d18..171de24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -55,6 +55,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "getrandom" version = "0.3.3" @@ -67,6 +73,15 @@ dependencies = [ "wasi", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -91,6 +106,7 @@ version = "0.1.0" dependencies = [ "bincode", "csv", + "itertools", "ndarray", "rand", "rand_distr", diff --git a/Cargo.toml b/Cargo.toml index f88cc2c..54673d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" csv = "1.3" +itertools = "0.14.0" [dev-dependencies] # Add any test-specific dependencies here if needed diff --git a/src/batch_mode/feed_forward.rs b/src/batch_mode/feed_forward.rs new file mode 100644 index 0000000..6f796f8 --- /dev/null +++ b/src/batch_mode/feed_forward.rs @@ -0,0 +1,156 @@ +use ndarray::{Array2, Array3, Axis}; +use rand_distr::{Distribution, Normal}; +use itertools::izip; +use crate::{adam::Adam}; +use crate::BatchLayer; + +pub struct FeedForward { + w1: Array2, + b1: Array2, + w2: Array2, + b2: Array2, + + // Cached values for backward pass + cached_input: Option>, + hidden_pre_activation: Option>, + hidden_post_activation: Option>, + + optimizer_w1: Adam, + optimizer_b1: Adam, + optimizer_w2: Adam, + optimizer_b2: Adam, +} + +impl FeedForward { + /// Initialize a feedforward layer with random weights + pub fn new(embedding_dim: usize, hidden_dim: usize) -> Self { + let mut rng = rand::rng(); + + // Xavier/He initialization for w1: std = sqrt(2 / fan_in) + let std_w1 = (2.0 / embedding_dim as f32).sqrt(); + let normal_w1 = Normal::new(0.0, std_w1).unwrap(); + + // Xavier/He initialization for w2: std = sqrt(2 / fan_in) + let std_w2 = (2.0 / hidden_dim as f32).sqrt(); + let normal_w2 = Normal::new(0.0, std_w2).unwrap(); + + FeedForward { + w1: Array2::from_shape_fn((embedding_dim, hidden_dim), |_| normal_w1.sample(&mut rng)), + b1: Array2::zeros((1, hidden_dim)), // Bias initialized to 0 + w2: Array2::from_shape_fn((hidden_dim, embedding_dim), |_| normal_w2.sample(&mut rng)), + b2: Array2::zeros((1, embedding_dim)), // Bias initialized to 0 + cached_input: None, + hidden_pre_activation: None, + hidden_post_activation: None, + optimizer_w1: Adam::new((embedding_dim, hidden_dim)), + optimizer_b1: Adam::new((1, hidden_dim)), + optimizer_w2: Adam::new((hidden_dim, embedding_dim)), + optimizer_b2: Adam::new((1, embedding_dim)), + } + } +} + +impl BatchLayer for FeedForward { + fn layer_type(&self) -> &str { + "FeedForward" + } + + fn forward(&mut self, input: &Array3) -> Array3 { + let batch_size = input.shape()[0]; + let seq_len = input.shape()[1]; + + // Allocate output, hidden_pre, hidden_post + let mut hidden_pre_activation = Array3::::zeros((batch_size, seq_len, self.b1.shape()[1])); + let mut hidden_post_activation = Array3::::zeros((batch_size, seq_len, self.b1.shape()[1])); + let mut output = Array3::::zeros((batch_size, seq_len, self.b2.shape()[1])); + + + for (((mut out_slice, mut pre_slice), mut post_slice), in_slice) in output + .outer_iter_mut() + .zip(hidden_pre_activation.outer_iter_mut()) + .zip(hidden_post_activation.outer_iter_mut()) + .zip(input.outer_iter()) + { + let hidden_pre_slice = in_slice.dot(&self.w1) + &self.b1; + let hidden_post_slice = hidden_pre_slice.mapv(|x| x.max(0.0)); // ReLU + pre_slice.assign(&hidden_pre_slice); + post_slice.assign(&hidden_post_slice); + out_slice.assign(&(hidden_post_slice.dot(&self.w2) + &self.b2)); + } + + // Cache values + self.cached_input = Some(input.clone()); + self.hidden_pre_activation = Some(hidden_pre_activation); + self.hidden_post_activation = Some(hidden_post_activation); + + output + input // residual connection (no LayerNorm here) + } + + + fn backward(&mut self, grads: &Array3, lr: f32) -> Array3 { + // Unwrap cached values + let input = self.cached_input.as_ref().expect("forward must be run first"); + let batch_size = input.shape()[0]; + let hidden_pre_activation = self.hidden_pre_activation.as_ref().unwrap(); + let hidden_post_activation = self.hidden_post_activation.as_ref().unwrap(); + + // Setup gradient accumulators + let mut grad_input = Array3::::zeros(input.raw_dim()); // [batch, seq_len, input_dim] + let mut grad_w1 = Array2::::zeros(self.w1.raw_dim()); + let mut grad_w2 = Array2::::zeros(self.w2.raw_dim()); + let mut grad_b1 = Array2::::zeros(self.b1.raw_dim()); + let mut grad_b2 = Array2::::zeros(self.b2.raw_dim()); + + // now, we compute the gradients for w1, w2, b1, b2, and update parameters via Adam. + for (i, (in_slice, grad_slice, hidden_pre_slice, hidden_post_slice)) in + izip!( + input.outer_iter(), + grads.outer_iter(), + hidden_pre_activation.outer_iter(), + hidden_post_activation.outer_iter() + ) + .enumerate() { + grad_w2 += &hidden_post_slice.t().dot(&grad_slice); + grad_b2 += &grad_slice.sum_axis(Axis(0)).insert_axis(Axis(0)); // Shape: [1, embedding_dim] + + // Gradient w.r.t. hidden_post_activation + let grad_hidden_post_activation = &grad_slice.dot(&self.w2.t()); + + // Gradient through ReLU + let relu_grad = &hidden_pre_slice.mapv(|x| if x > 0.0 { 1.0 } else { 0.0 }); + let grad_hidden_pre_activation = grad_hidden_post_activation * relu_grad; + + // Gradient w.r.t. W1 and b1 + grad_w1 += &in_slice.t().dot(&grad_hidden_pre_activation); + grad_b1 += &grad_hidden_pre_activation + .sum_axis(Axis(0)) + .insert_axis(Axis(0)); // Shape: [1, hidden_dim] + + // Gradient w.r.t. input (through feed-forward computation) + let grad_input_feedforward = grad_hidden_pre_activation.dot(&self.w1.t()); + + // Add gradient from residual connection (for each tensor) + // Forward: output = W2(ReLU(W1*input + b1)) + b2 + input + // Backward: grad_input = grad_feedforward + grad_residual + grad_input.slice_mut(ndarray::s![i, .., ..]).assign(&(&grad_input_feedforward + &grad_slice)); + } + + grad_input /= batch_size as f32; + grad_w1 /= batch_size as f32; + grad_w2 /= batch_size as f32; + grad_b1 /= batch_size as f32; + grad_b2 /= batch_size as f32; + + // Update parameters via Adam optimizer + self.optimizer_w2.step(&mut self.w2, &grad_w2, lr); + self.optimizer_b2.step(&mut self.b2, &grad_b2, lr); + self.optimizer_w1.step(&mut self.w1, &grad_w1, lr); + self.optimizer_b1.step(&mut self.b1, &grad_b1, lr); + + grad_input + } + + fn parameters(&self) -> usize { + self.b1.len() + self.b2.len() + self.w1.len() + self.w2.len() + } +} diff --git a/src/batch_mode/llm.rs b/src/batch_mode/llm.rs new file mode 100644 index 0000000..a71beff --- /dev/null +++ b/src/batch_mode/llm.rs @@ -0,0 +1,18 @@ +use ndarray::{Array3}; +use crate::{Vocab}; +pub trait Layer { + fn layer_type(&self) -> &str; + + fn forward(&mut self, input: &Array3) -> Array3; + + fn backward(&mut self, grads: &Array3, lr: f32) -> Array3; + + fn parameters(&self) -> usize; + +} + +#[allow(clippy::upper_case_acronyms)] +pub struct LLM { + pub vocab: Vocab, + pub network: Vec>, +} \ No newline at end of file diff --git a/src/batch_mode/mod.rs b/src/batch_mode/mod.rs new file mode 100644 index 0000000..ad1ed52 --- /dev/null +++ b/src/batch_mode/mod.rs @@ -0,0 +1,3 @@ +pub mod llm; +pub mod output_projection; +pub mod feed_forward; \ No newline at end of file diff --git a/src/batch_mode/output_projection.rs b/src/batch_mode/output_projection.rs new file mode 100644 index 0000000..4b06c34 --- /dev/null +++ b/src/batch_mode/output_projection.rs @@ -0,0 +1,84 @@ +use ndarray::{Array2, Array3, Axis}; +use rand_distr::{Distribution, Normal}; + +use crate::adam::Adam; +use crate::BatchLayer; +pub struct OutputProjection { + pub w_out: Array2, // Weight matrix + pub b_out: Array2, // Bias vector + pub optimizer: Adam, + pub cached_input: Option>, +} + +impl OutputProjection { + /// Initialize output layer with random weights and zero bias + pub fn new(embedding_dim: usize, vocab_size: usize) -> Self { + let mut rng = rand::rng(); + // Xavier/He initialization: std = sqrt(2 / fan_in) + let std = (2.0 / embedding_dim as f32).sqrt(); + let normal = Normal::new(0.0, std).unwrap(); + + OutputProjection { + w_out: Array2::from_shape_fn((embedding_dim, vocab_size), |_| normal.sample(&mut rng)), + b_out: Array2::zeros((1, vocab_size)), + optimizer: Adam::new((embedding_dim, vocab_size)), + cached_input: None, + } + } +} + +impl BatchLayer for OutputProjection { + fn layer_type(&self) -> &str { + "OutputProjection" + } + + /// Forward pass for batched input: [batch_size, seq_len, embedding_dim] + fn forward(&mut self, input: &Array3) -> Array3 { + let batch_size = input.shape()[0]; + let seq_len = input.shape()[1]; + + self.cached_input = Some(input.clone()); + + // Allocate output: [batch_size, seq_len, vocab_size] + let mut output = Array3::::zeros((batch_size, seq_len, self.b_out.shape()[1])); + + for (mut out_slice, in_slice) in output.outer_iter_mut().zip(input.outer_iter()) { + // in_slice shape: [seq_len, embedding_dim] + // out_slice shape: [seq_len, vocab_size] + out_slice.assign(&(in_slice.dot(&self.w_out) + &self.b_out)); + } + + output + } + + /// Backward pass for batched input + fn backward(&mut self, grads: &Array3, lr: f32) -> Array3 { + let input = self.cached_input.as_ref().unwrap(); + let batch_size = input.shape()[0]; + + let mut grad_input = Array3::::zeros(input.raw_dim()); + let mut grad_w_out = Array2::::zeros(self.w_out.raw_dim()); + let mut grad_b_out = Array2::::zeros(self.b_out.raw_dim()); + + for (i, (in_slice, grad_slice)) in input.outer_iter().zip(grads.outer_iter()).enumerate() { + // Compute gradients for weights and bias + grad_w_out += &in_slice.t().dot(&grad_slice); + grad_b_out += &grad_slice.mean_axis(Axis(0)).unwrap(); + + // Compute gradient wrt input slice and assign to grad_input + grad_input.slice_mut(ndarray::s![i, .., ..]).assign(&grad_slice.dot(&self.w_out.t())); + } + + grad_w_out /= batch_size as f32; + grad_b_out /= batch_size as f32; + + self.optimizer.step(&mut self.w_out, &grad_w_out, lr); + self.b_out -= &(lr * &grad_b_out); + + grad_input + } + + fn parameters(&self) -> usize { + self.w_out.len() + self.b_out.len() + } +} diff --git a/src/lib.rs b/src/lib.rs index a13d7dc..7471b5e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod adam; +pub mod batch_mode; pub mod dataset_loader; pub mod embeddings; pub mod feed_forward; @@ -14,7 +15,14 @@ pub use embeddings::Embeddings; pub use llm::{LLM, Layer}; pub use vocab::Vocab; +// Re-export batch_mode structs +pub use batch_mode::llm::LLM as BatchLLM; +pub use batch_mode::llm::Layer as BatchLayer; +pub use batch_mode::output_projection::OutputProjection as BatchOutputProjection; +pub use batch_mode::feed_forward::FeedForward as BatchFeedForward; + // Constants pub const MAX_SEQ_LEN: usize = 80; pub const EMBEDDING_DIM: usize = 128; pub const HIDDEN_DIM: usize = 256; +pub const BATCH_SIZE: usize = 4; diff --git a/tests/feed_forward_batch_test.rs b/tests/feed_forward_batch_test.rs new file mode 100644 index 0000000..20f07db --- /dev/null +++ b/tests/feed_forward_batch_test.rs @@ -0,0 +1,56 @@ +use llm::{BATCH_SIZE, EMBEDDING_DIM, HIDDEN_DIM, BatchLayer, BatchFeedForward}; +use ndarray::Array3; +use ndarray::Array2; + +#[test] +fn test_feed_forward_forward() { + // Create feed-forward module + let mut feed_forward = BatchFeedForward::new(EMBEDDING_DIM, HIDDEN_DIM); + + // Create input tensor (batch_size=1, seq_len=3, embedding_dim=EMBEDDING_DIM) + let input = Array3::ones((BATCH_SIZE, 3, EMBEDDING_DIM)); + + // Test forward pass + let output = feed_forward.forward(&input); + + // Check output shape - should be same as input + assert_eq!(output.shape(), input.shape()); +} + +#[test] +fn test_feed_forward_with_different_sequence_lengths() { + // Create feed-forward module + let mut feed_forward = BatchFeedForward::new(EMBEDDING_DIM, HIDDEN_DIM); + + // Test with different sequence lengths + for seq_len in 1..5 { + // Create input tensor + let input = Array3::ones((BATCH_SIZE, seq_len, EMBEDDING_DIM)); + + // Test forward pass + let output = feed_forward.forward(&input); + + // Check output shape + assert_eq!(output.shape(), [BATCH_SIZE, seq_len, EMBEDDING_DIM]); + } +} + +#[test] +fn test_feed_forward_and_backward() { + // Create feed-forward module + let mut feed_forward = BatchFeedForward::new(EMBEDDING_DIM, HIDDEN_DIM); + + // Create input tensor (batch_size=1, seq_len=3, embedding_dim=EMBEDDING_DIM) + let input = Array3::ones((BATCH_SIZE, 3, EMBEDDING_DIM)); + + // Test forward pass + let output = feed_forward.forward(&input); + + let grads = Array3::ones((BATCH_SIZE, 3, EMBEDDING_DIM)); + + // Test backward pass + let grad_input = feed_forward.backward(&grads, 0.01); + + // Make sure backward pass modifies the input + assert_ne!(output, grad_input); +} diff --git a/tests/output_projection_batch_test.rs b/tests/output_projection_batch_test.rs new file mode 100644 index 0000000..56fb803 --- /dev/null +++ b/tests/output_projection_batch_test.rs @@ -0,0 +1,111 @@ +use llm::{BATCH_SIZE, EMBEDDING_DIM, BatchLayer, BatchOutputProjection}; +use ndarray::Array3; + + +#[test] +fn test_output_projection_creation() { + let vocab_size = 10; + let output_proj = BatchOutputProjection::new(EMBEDDING_DIM, vocab_size); + + // Check weight matrix dimensions + assert_eq!(output_proj.w_out.shape(), [EMBEDDING_DIM, vocab_size]); + + // Check bias vector dimensions + assert_eq!(output_proj.b_out.shape(), [1, vocab_size]); + + // Check optimizer dimensions + assert_eq!(output_proj.optimizer.m.shape(), [EMBEDDING_DIM, vocab_size]); + assert_eq!(output_proj.optimizer.v.shape(), [EMBEDDING_DIM, vocab_size]); +} + +#[test] +fn test_output_projection_forward() { + let vocab_size = 10; + let mut output_proj = BatchOutputProjection::new(EMBEDDING_DIM, vocab_size); + + // Create input tensor (batch_size=1, seq_len=3, embedding_dim=EMBEDDING_DIM) + let input = Array3::ones((BATCH_SIZE, 3, EMBEDDING_DIM)); + + // Test forward pass + let output = output_proj.forward(&input); + + // Check output shape - should be [seq_len, vocab_size] + assert_eq!(output.shape(), [BATCH_SIZE, 3, vocab_size]); +} + +#[test] +fn test_output_projection_with_different_sequence_lengths() { + let vocab_size = 10; + let mut output_proj = BatchOutputProjection::new(EMBEDDING_DIM, vocab_size); + + // Test with different sequence lengths + for seq_len in 1..5 { + // Create input tensor + let input = Array3::ones((BATCH_SIZE, seq_len, EMBEDDING_DIM)); + + // Test forward pass + let output = output_proj.forward(&input); + + // Check output shape + assert_eq!(output.shape(), [BATCH_SIZE, seq_len, vocab_size]); + } +} + +#[test] +fn test_output_projection_backward() { + let vocab_size = 10; + let mut output_proj = BatchOutputProjection::new(EMBEDDING_DIM, vocab_size); + + // Create input tensor + let input = Array3::ones((BATCH_SIZE, 3, EMBEDDING_DIM)); + + // Forward pass first (required to cache input) + let _output = output_proj.forward(&input); + + // Create gradient tensor + let grads = Array3::ones((BATCH_SIZE, 3, vocab_size)); + + // Test backward pass + let grad_input = output_proj.backward(&grads, 0.01); + + // Check gradient input shape + assert_eq!(grad_input.shape(), [BATCH_SIZE, 3, EMBEDDING_DIM]); + + // Verify that parameters were updated + let w_out_before = output_proj.w_out.clone(); + let b_out_before = output_proj.b_out.clone(); + + // Run another forward and backward pass + let _output = output_proj.forward(&input); + let _grad_input = output_proj.backward(&grads, 0.01); + + // Check that parameters changed + assert_ne!(output_proj.w_out, w_out_before); + assert_ne!(output_proj.b_out, b_out_before); +} + +#[test] +fn test_output_projection_training() { + let vocab_size = 10; + let mut output_proj = BatchOutputProjection::new(EMBEDDING_DIM, vocab_size); + + // Create input tensor + let input = Array3::ones((BATCH_SIZE, 3, EMBEDDING_DIM)); + + // Run multiple training steps + for _ in 0..5 { + // Forward pass + let _output = output_proj.forward(&input); + + // Create gradient tensor (simulating cross-entropy loss gradients) + let mut grads = Array3::zeros((BATCH_SIZE, 3, vocab_size)); + grads[[0,0,0]] = 1.0; // Set gradient for first token + + // Backward pass + let _grad_input = output_proj.backward(&grads, 0.01); + } + + // Verify that parameters were updated + assert_ne!(output_proj.w_out.sum(), 0.0); + assert_ne!(output_proj.b_out.sum(), 0.0); +}