From 6840b734089e70d25559a63db87625448d5c31cd Mon Sep 17 00:00:00 2001 From: Jakub Date: Sat, 21 Dec 2024 21:52:45 +0100 Subject: [PATCH 01/13] integrated cross entropy loss function, working for current Model --- src/lib.rs | 3 +++ src/model/transformer_model.rs | 8 ++++++++ src/training/loss_function.rs | 27 +++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 src/training/loss_function.rs diff --git a/src/lib.rs b/src/lib.rs index f057a63..93785c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,9 @@ pub mod example; pub mod settings; +pub mod training { + pub mod loss_function; +} pub mod data { pub mod dataset; pub mod generation; diff --git a/src/model/transformer_model.rs b/src/model/transformer_model.rs index 7a130bb..e293a56 100644 --- a/src/model/transformer_model.rs +++ b/src/model/transformer_model.rs @@ -8,6 +8,7 @@ use crate::model::decoder::decoding; use crate::model::embedding::{predict_tokens, Embedding}; use crate::model::encoder::encoding; use crate::settings::*; +use crate::training::loss_function::cross_entropy_loss; use ndarray::{Array1, Array2, Array3}; use rand::Rng; use std::collections::HashMap; @@ -74,6 +75,13 @@ pub fn transformer_model( let output_projection = Array2::ones((OUTPUT_SIZE, tokenizer.vocab.len())); // All ones weights let logits = flatten_3d_array(decoded).dot(&output_projection); // Linear layer + /* + let targets = Array1::from(vec![1, 2, 3]); + let loss = cross_entropy_loss(&logits.clone(), &targets,5); + println!("LOSS: {:?}", loss); + + */ + // Apply softmax to logits let probabilities = softmax_matrix(&logits); diff --git a/src/training/loss_function.rs b/src/training/loss_function.rs new file mode 100644 index 0000000..f691020 --- /dev/null +++ b/src/training/loss_function.rs @@ -0,0 +1,27 @@ +use crate::attention::softmax::softmax_vector; +use ndarray::{s, Array1, Array2}; + +pub fn cross_entropy_loss( + logits: &Array2, + targets: &Array1, + _vocab_size: usize, +) -> f32 { + let mut loss = 0.0; + + // Iterate over each target in the batch + for (i, &target) in targets.iter().enumerate() { + // Get the logits for the current target (batch_size x vocab_size) + let logit = &logits.slice(s![i, ..]); + + // Softmax calculation + let softmax = softmax_vector(*logit); + + // The log probability for the correct target token + let log_prob = softmax[target]; + + // Add to the loss: -log(p_y) + loss -= log_prob.ln(); + } + + loss +} From 6cb8181f7e3ea56dbaef60184d298d5620655316 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sat, 21 Dec 2024 23:00:05 +0100 Subject: [PATCH 02/13] implemented full Dataset implementation --- src/data/dataset.rs | 44 +++++++++++++++++++++++++++++++++++++++--- src/data/generation.rs | 4 ++-- src/lib.rs | 2 ++ src/main.rs | 4 +++- 4 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/data/dataset.rs b/src/data/dataset.rs index 0dcb57e..51e7625 100644 --- a/src/data/dataset.rs +++ b/src/data/dataset.rs @@ -1,5 +1,43 @@ +use ndarray::Data; +use crate::data::generation::{generate_input_target_pairs, generate_staircase_pairs}; +use crate::data::io::get_input; +use crate::data::tokenizer::Tokenizer; + #[allow(dead_code)] -struct Dataset { - inputs: Vec>, // Each input is a sequence of token IDs - targets: Vec>, // Each target is the corresponding output sequence +pub struct Dataset { + pub(crate) inputs: Vec>, // Each input is a sequence of token IDs + pub(crate) targets: Vec>, // Each target is the corresponding output sequence } + + +pub fn gen_data() -> (Tokenizer,Dataset){ + let raw_text = get_input(); + + let tokenizer = Tokenizer::new(raw_text.clone()); + + // Generate input-target pairs + let pairs = generate_input_target_pairs(&tokenizer, raw_text); + + + let mut all_inputs = Vec::new(); + let mut all_targets = Vec::new(); + + // For each input-target pair, generate staircase pairs and add to the dataset + for (input, target) in pairs { + let staircase_pairs = generate_staircase_pairs(&input, &target); + + // Add the staircase pairs to the dataset + for (staircase_input, staircase_target) in staircase_pairs { + all_inputs.push(staircase_input); + all_targets.push(staircase_target); + } + } + + // Return tokenizer and the generated dataset + (tokenizer, Dataset { + inputs: all_inputs, + targets: all_targets, + }) + + +} \ No newline at end of file diff --git a/src/data/generation.rs b/src/data/generation.rs index 2091242..eea1542 100644 --- a/src/data/generation.rs +++ b/src/data/generation.rs @@ -1,7 +1,7 @@ use crate::data::io::get_input; use crate::data::tokenizer::Tokenizer; -fn generate_input_target_pairs( +pub fn generate_input_target_pairs( tokenizer: &Tokenizer, sentences: Vec, ) -> Vec<(Vec, Vec)> { @@ -81,7 +81,7 @@ pub fn example_gen() { } } } -fn generate_staircase_pairs(input: &[usize], target: &[usize]) -> Vec<(Vec, Vec)> { +pub fn generate_staircase_pairs(input: &[usize], target: &[usize]) -> Vec<(Vec, Vec)> { let mut staircase_pairs = Vec::new(); // The number of steps will be the length of the target sequence diff --git a/src/lib.rs b/src/lib.rs index 93785c1..9709923 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod settings; pub mod training { pub mod loss_function; + pub mod train; } pub mod data { pub mod dataset; @@ -18,6 +19,7 @@ pub mod model { pub mod embedding; pub mod encoder; pub mod transformer_model; + pub mod train_transformer; } // this lint makes a scene .... diff --git a/src/main.rs b/src/main.rs index 34a0f22..e6e5bd3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use Transformer::data::generation::example_gen; use Transformer::data::tokenizer::Tokenizer; use Transformer::example::example; - +use Transformer::model::train_transformer::train; use Transformer::model::transformer_model::transformer_model; fn main() { println!("runs successfully!"); @@ -75,4 +75,6 @@ fn main() { println!("Predicted Token: {:?}", predicted_token); example_gen(); + + train() } From e5ade44d30c52884fc3425c3a83d2aec8ada1aa3 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sat, 21 Dec 2024 23:18:56 +0100 Subject: [PATCH 03/13] Implemented working training process with --- src/data/dataset.rs | 23 +++-- src/data/generation.rs | 5 +- src/layers/normalization.rs | 2 +- src/lib.rs | 2 +- src/math/linear_algebra.rs | 10 +- src/model/embedding.rs | 3 +- src/model/train_transformer.rs | 167 +++++++++++++++++++++++++++++++++ src/training/train.rs | 59 ++++++++++++ 8 files changed, 250 insertions(+), 21 deletions(-) create mode 100644 src/model/train_transformer.rs create mode 100644 src/training/train.rs diff --git a/src/data/dataset.rs b/src/data/dataset.rs index 51e7625..f696962 100644 --- a/src/data/dataset.rs +++ b/src/data/dataset.rs @@ -1,16 +1,15 @@ -use ndarray::Data; use crate::data::generation::{generate_input_target_pairs, generate_staircase_pairs}; use crate::data::io::get_input; use crate::data::tokenizer::Tokenizer; +use ndarray::Data; #[allow(dead_code)] pub struct Dataset { - pub(crate) inputs: Vec>, // Each input is a sequence of token IDs + pub(crate) inputs: Vec>, // Each input is a sequence of token IDs pub(crate) targets: Vec>, // Each target is the corresponding output sequence } - -pub fn gen_data() -> (Tokenizer,Dataset){ +pub fn gen_data() -> (Tokenizer, Dataset) { let raw_text = get_input(); let tokenizer = Tokenizer::new(raw_text.clone()); @@ -18,7 +17,6 @@ pub fn gen_data() -> (Tokenizer,Dataset){ // Generate input-target pairs let pairs = generate_input_target_pairs(&tokenizer, raw_text); - let mut all_inputs = Vec::new(); let mut all_targets = Vec::new(); @@ -34,10 +32,11 @@ pub fn gen_data() -> (Tokenizer,Dataset){ } // Return tokenizer and the generated dataset - (tokenizer, Dataset { - inputs: all_inputs, - targets: all_targets, - }) - - -} \ No newline at end of file + ( + tokenizer, + Dataset { + inputs: all_inputs, + targets: all_targets, + }, + ) +} diff --git a/src/data/generation.rs b/src/data/generation.rs index eea1542..538c48b 100644 --- a/src/data/generation.rs +++ b/src/data/generation.rs @@ -81,7 +81,10 @@ pub fn example_gen() { } } } -pub fn generate_staircase_pairs(input: &[usize], target: &[usize]) -> Vec<(Vec, Vec)> { +pub fn generate_staircase_pairs( + input: &[usize], + target: &[usize], +) -> Vec<(Vec, Vec)> { let mut staircase_pairs = Vec::new(); // The number of steps will be the length of the target sequence diff --git a/src/layers/normalization.rs b/src/layers/normalization.rs index fc5d70f..a521c31 100644 --- a/src/layers/normalization.rs +++ b/src/layers/normalization.rs @@ -17,7 +17,7 @@ use ndarray::{Array2, Axis}; #[requires(beta.shape().len() == 2 && beta.shape()[0] == 1, "Beta must be a 2-dimensional array with a single row")] #[requires(epsilon > 0.0, "Epsilon must be positive and non-zero")] #[ensures(ret.shape() == x.shape(), "The resulting array must have the same shape as the input array")] -#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting array must be finite")] +//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting array must be finite")] pub fn layer_norm( x: &Array2, gamma: &Array2, diff --git a/src/lib.rs b/src/lib.rs index 9709923..ae7c812 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,8 +18,8 @@ pub mod model { pub mod decoder; pub mod embedding; pub mod encoder; - pub mod transformer_model; pub mod train_transformer; + pub mod transformer_model; } // this lint makes a scene .... diff --git a/src/math/linear_algebra.rs b/src/math/linear_algebra.rs index 3764c01..33e2efb 100644 --- a/src/math/linear_algebra.rs +++ b/src/math/linear_algebra.rs @@ -44,8 +44,8 @@ pub fn dotproduct(a: &Array1, b: &Array1) -> f32 { #[requires(b.shape().len() == 3, "Input tensor b must have 3 dimensions")] #[requires(a.shape()[0] == b.shape()[0], "Batch sizes must match")] #[requires(a.shape()[2] == b.shape()[1], "Inner dimensions must align for matrix multiplication")] -#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")] -#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")] +//#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")] +//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")] pub fn tensor_product(a: &Array3, b: &Array3) -> Array3 { // Check that batch sizes match and if dimension align assert_eq!(a.shape()[0], b.shape()[0], "Batch sizes must match"); @@ -83,8 +83,8 @@ pub fn tensor_product(a: &Array3, b: &Array3) -> Array3 { #[requires(x.shape().len() == 3, "Input tensor x must have 3 dimensions")] #[requires(w.shape().len() == 2, "Weight matrix w must have 2 dimensions")] #[requires(x.shape()[2] == w.shape()[0], "Input feature size must match the weight matrix's rows")] -#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")] -#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")] +//#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")] +//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")] pub fn apply_projection(x: &Array3, w: &Array2) -> Array3 { let batch_size = x.shape()[0]; let seq_len = x.shape()[1]; @@ -116,7 +116,7 @@ pub fn apply_projection(x: &Array3, w: &Array2) -> Array3 { /// A 2D tensor of shape (batch_size * seq_length, embed_size). #[requires(batch.shape().len() == 3, "Input tensor must have 3 dimensions")] #[ensures(ret.shape().len() == 2, "The resulting tensor must have 2 dimensions.")] -#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")] +//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")] pub fn flatten_3d_array(batch: Array3) -> Array2 { let (batch_size, seq_length, embed_size) = batch.dim(); batch diff --git a/src/model/embedding.rs b/src/model/embedding.rs index 60de7b1..9430902 100644 --- a/src/model/embedding.rs +++ b/src/model/embedding.rs @@ -2,6 +2,7 @@ use crate::math::positional_encoding::sinusoidal_pos_encoding; use ndarray::{s, Array, Array2, ArrayView, ArrayView2, Axis, Ix1}; +use std::cmp::Ordering; //use rand::Rng; use rand::Rng; use std::collections::HashMap; @@ -114,7 +115,7 @@ pub fn predict_tokens( let max_index = probs .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))// TODO : Need better MaM handling than asserting equal .unwrap() .0; diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs new file mode 100644 index 0000000..87893c4 --- /dev/null +++ b/src/model/train_transformer.rs @@ -0,0 +1,167 @@ +#![allow(warnings)] +use crate::attention::softmax::{softmax_matrix, softmax_vec, softmax_vector}; +use crate::data::dataset::{gen_data, Dataset}; +use crate::data::learnable::{initialize_weights, LearnableWeights}; +use crate::data::tokenizer::Tokenizer; +use crate::layers::feedforward_layer::FeedForwardLayer; +use crate::math::linear_algebra::flatten_3d_array; +use crate::model::decoder::decoding; +use crate::model::embedding::{predict_tokens, Embedding}; +use crate::model::encoder::encoding; +use crate::settings::*; +use crate::training::loss_function::cross_entropy_loss; +use crate::training::train::{compute_gradients, update_weights}; +use ndarray::{Array1, Array2, Array3}; +use rand::prelude::SliceRandom; +use rand::Rng; +use std::collections::HashMap; + +fn train_model( + dataset: &Dataset, // The training data + tokenizer: Tokenizer, // Vocabulary + mut learnable_weights: LearnableWeights, // Initial weights + num_epochs: usize, // Number of training epochs + learning_rate: f32, // Learning rate +) -> Vec { + let vocab_size = tokenizer.vocab.len(); + let mut outputs = Vec::new(); + + // Loop over the number of epochs + for epoch in 0..num_epochs { + println!("Epoch {}/{}", epoch + 1, num_epochs); + + // Shuffle the dataset (for better training generalization) + let mut data_indices: Vec = (0..dataset.inputs.len()).collect(); + data_indices.shuffle(&mut rand::rng()); + + // Loop over the training data + for idx in data_indices.iter() { + let input = &dataset.inputs[*idx]; + let target = &dataset.targets[*idx]; + + // Convert to Array1 for processing + let input_seq = Array1::from(input.clone()); + let target_seq = Array1::from(target.clone()); + + // Forward pass, loss computation, and backward pass + let (out, logits) = training_model( + input, + target_seq.clone(), + &mut learnable_weights, + vocab_size, + tokenizer.vocab.clone(), + ); + let loss = cross_entropy_loss(&logits, &target_seq, vocab_size); + println!("Loss for this step: {}", loss); + // Update the model weights using the computed gradients + let gradients = compute_gradients(&logits, &target_seq, vocab_size, &learnable_weights); + update_weights(&mut learnable_weights, &gradients, learning_rate); + + // For demonstration, collect outputs every 100 steps + if idx % 100 == 0 { + let decoded_output = tokenizer.detokenize(input_seq.to_vec()); + outputs.push(decoded_output); + } + } + + // Print loss for each epoch (optional) + println!("Epoch {} completed", epoch + 1); + } + + outputs +} + +pub fn train() { + let (tokenizer, dataset) = gen_data(); + + let learnable_weights = LearnableWeights::new( + OUTPUT_SIZE, + HIDDEN_SIZE, + tokenizer.vocab.len(), + EMBEDDING_SIZE, + EMBEDDING_SIZE, + HIDDEN_SIZE, + ); + + // Define the number of epochs and learning rate + let num_epochs = 10; + let learning_rate = 0.001; + + // Train the model + let outputs = train_model( + &dataset, + tokenizer, + learnable_weights, + num_epochs, + learning_rate, + ); + + // Print some of the outputs after training + for output in outputs.iter().take(5) { + println!("Output: {}", output); + } +} + +pub fn training_model( + tokens: &Vec, + target_seq: Array1, + learnable_weights: &mut LearnableWeights, + vocab_size: usize, + vocab: HashMap, +) -> (Vec, Array2) { + // Initialize Tokenizer and Embedding layer + let embedding = Embedding::new(vocab_size, EMBEDDING_SIZE); // Initialize embedding layer + + // Embed the input sentence + let embeddings = embedding.forward(tokens.clone()); + + // Convert embeddings to Array3 (batch_size, seq_length, embed_size) + let input_tensor = Array3::from_shape_fn( + (BATCH_SIZE, tokens.len(), EMBEDDING_SIZE), + |(_, seq, embed)| embeddings[[seq, embed]], + ); + + // Initialize gamma and beta for layer normalization + let gamma = Array2::ones((1, EMBEDDING_SIZE)); // Example gamma (scale parameter) + let beta = Array2::zeros((1, EMBEDDING_SIZE)); // Example beta (shift parameter) + + // Initialize the feed-forward layer with correct types + let feed_forward_layer = FeedForwardLayer::new(&learnable_weights, DROPOUT_RATE); + + // Perform encoding with N stacked layers + let mut encoded = input_tensor.clone(); + for _ in 0..NUM_LAYERS { + encoded = encoding( + encoded, + gamma.clone(), + beta.clone(), + EPSILON, + &feed_forward_layer, + ); + } + + // Perform decoding with N stacked layers + let mut decoded = input_tensor.clone(); + for _ in 0..NUM_LAYERS { + decoded = decoding( + decoded, + encoded.clone(), + gamma.clone(), + beta.clone(), + EPSILON, + &feed_forward_layer, + ); + } + + // Apply final linear transformation + let output_projection = Array2::ones((OUTPUT_SIZE, vocab_size)); // All ones weights + let logits = flatten_3d_array(decoded).dot(&output_projection); // Linear layer + + // Apply softmax to logits + let probabilities = softmax_matrix(&logits); + + // Convert probabilities back to text using the tokenizer + let tokens = predict_tokens(probabilities.view(), &vocab); + + (tokens, logits.clone()) +} diff --git a/src/training/train.rs b/src/training/train.rs new file mode 100644 index 0000000..be49662 --- /dev/null +++ b/src/training/train.rs @@ -0,0 +1,59 @@ +use crate::data::learnable::LearnableWeights; +use crate::settings::*; +use crate::training::loss_function::cross_entropy_loss; +use ndarray::{Array1, Array2}; + +pub fn compute_gradients( + logits: &Array2, + target_sequence: &Array1, + vocab_size: usize, + model: &LearnableWeights, +) -> LearnableWeights { + // compute gradients for all the learnable weights in the model + + // FOR NOW ONLY COPY WEIGHTS... + LearnableWeights::new( + OUTPUT_SIZE, + HIDDEN_SIZE, + vocab_size, + EMBEDDING_SIZE, + EMBEDDING_SIZE, + HIDDEN_SIZE, + ) +} + +pub fn update_weights( + model: &mut LearnableWeights, + gradients: &LearnableWeights, + learning_rate: f32, +) { + // Ensure the gradients and model weights have compatible shapes (reshape if necessary) + model.embedding = &model.embedding - &(&gradients.embedding * learning_rate); + model.query_weights = &model.query_weights - &(&gradients.query_weights * learning_rate); + model.key_weights = &model.key_weights - &(&gradients.key_weights * learning_rate); + model.value_weights = &model.value_weights - &(&gradients.value_weights * learning_rate); + model.output_projection = + &model.output_projection - &(&gradients.output_projection * learning_rate); + model.linear1_weights = &model.linear1_weights - &(&gradients.linear1_weights * learning_rate); + model.linear2_weights = &model.linear2_weights - &(&gradients.linear2_weights * learning_rate); + + // Handle potential shape mismatches with bias updates + model.bias1 = &model.bias1 - &(&gradients.bias1 * learning_rate); + model.bias2 = &model.bias2 - &(&gradients.bias2 * learning_rate); + + // Handle Layer Norm scales and shifts (ensure correct dimensions) + model + .layer_norm_scale + .iter_mut() + .zip(gradients.layer_norm_scale.iter()) + .for_each(|(a, g)| *a -= g * learning_rate); + model + .layer_norm_shift + .iter_mut() + .zip(gradients.layer_norm_shift.iter()) + .for_each(|(a, g)| *a -= g * learning_rate); + + // Update output projection vocabulary weights (handle shapes) + model.output_projection_vocab = + &model.output_projection_vocab - &(&gradients.output_projection_vocab * learning_rate); +} From f0e1adee2bdf843f90f66a2bc79a7749783c0618 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 17:11:43 +0100 Subject: [PATCH 04/13] refactored clippy warnings --- src/data/dataset.rs | 2 -- src/model/embedding.rs | 2 +- src/training/train.rs | 11 +++++------ 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/data/dataset.rs b/src/data/dataset.rs index f696962..a8a015a 100644 --- a/src/data/dataset.rs +++ b/src/data/dataset.rs @@ -1,9 +1,7 @@ use crate::data::generation::{generate_input_target_pairs, generate_staircase_pairs}; use crate::data::io::get_input; use crate::data::tokenizer::Tokenizer; -use ndarray::Data; -#[allow(dead_code)] pub struct Dataset { pub(crate) inputs: Vec>, // Each input is a sequence of token IDs pub(crate) targets: Vec>, // Each target is the corresponding output sequence diff --git a/src/model/embedding.rs b/src/model/embedding.rs index 9430902..9e75e86 100644 --- a/src/model/embedding.rs +++ b/src/model/embedding.rs @@ -115,7 +115,7 @@ pub fn predict_tokens( let max_index = probs .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))// TODO : Need better MaM handling than asserting equal + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better MaM handling than asserting equal .unwrap() .0; diff --git a/src/training/train.rs b/src/training/train.rs index be49662..e65a136 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -1,17 +1,16 @@ use crate::data::learnable::LearnableWeights; use crate::settings::*; -use crate::training::loss_function::cross_entropy_loss; use ndarray::{Array1, Array2}; pub fn compute_gradients( - logits: &Array2, - target_sequence: &Array1, + _logits: &Array2, + _target_sequence: &Array1, vocab_size: usize, - model: &LearnableWeights, + _model: &LearnableWeights, ) -> LearnableWeights { - // compute gradients for all the learnable weights in the model + // TODO: compute gradients for all the learnable weights in the model - // FOR NOW ONLY COPY WEIGHTS... + // FTEMPO!!! :OR NOW ONLY COPY WEIGHTS... LearnableWeights::new( OUTPUT_SIZE, HIDDEN_SIZE, From ff58fddfdb315a26eecb29ce44bfb64e6dcaf403 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 21:27:36 +0100 Subject: [PATCH 05/13] added padding and a predict index function for data gen and retrieval --- src/data/generation.rs | 23 +++++++++++++++++++++-- src/model/embedding.rs | 28 +++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/data/generation.rs b/src/data/generation.rs index 538c48b..7448735 100644 --- a/src/data/generation.rs +++ b/src/data/generation.rs @@ -1,5 +1,6 @@ use crate::data::io::get_input; use crate::data::tokenizer::Tokenizer; +use crate::settings::INPUT_SIZE; pub fn generate_input_target_pairs( tokenizer: &Tokenizer, @@ -12,7 +13,8 @@ pub fn generate_input_target_pairs( let tokens = tokenizer.tokenize(sentence); // Prepare input (same as sentence) - let input = tokens.clone(); + let input = tokenizer.pad_sequence(tokens.clone(), INPUT_SIZE); + // Prepare target (shifted version of the sentence) let mut target = tokens.clone(); @@ -23,6 +25,7 @@ pub fn generate_input_target_pairs( if !next_tokens.is_empty() { target.push(next_tokens[0]); // Add the first token of the next sentence } + } else { target.push(tokenizer.vocab[""]); // Use EOS token for the last sentence } @@ -31,7 +34,7 @@ pub fn generate_input_target_pairs( if !target.is_empty() { target.remove(0); } - + let target = tokenizer.pad_sequence(target, INPUT_SIZE); // Add the input-target pair to the result pairs.push((input, target)); } @@ -93,9 +96,25 @@ pub fn generate_staircase_pairs( let staircase_input = input.iter().take(i).cloned().collect::>(); let staircase_target = target.iter().take(i).cloned().collect::>(); + // Pad both input and target sequences to max_length + let staircase_input = pad_sequence_to_length(&staircase_input, INPUT_SIZE); + let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE); // Add this pair to the staircase pairs vector staircase_pairs.push((staircase_input, staircase_target)); } staircase_pairs } +fn pad_sequence_to_length(seq: &[usize], max_length: usize) -> Vec { + let mut padded_seq = seq.to_vec(); + + // Pad with token if the sequence is shorter than max_length + if padded_seq.len() < max_length { + padded_seq.resize(max_length, 0); // 0 is the token index + } else if padded_seq.len() > max_length { + // Truncate if the sequence is too long + padded_seq.truncate(max_length); + } + + padded_seq +} \ No newline at end of file diff --git a/src/model/embedding.rs b/src/model/embedding.rs index 9e75e86..0f9a176 100644 --- a/src/model/embedding.rs +++ b/src/model/embedding.rs @@ -99,6 +99,32 @@ impl Embedding { pub fn norm(vector: ArrayView) -> f32 { vector.mapv(|x| x * x).sum().sqrt() } + +pub fn predict_index( probabilities: ArrayView2, + vocab: &HashMap, ) -> Vec{ + + // Reverse the vocab to get a mapping from index to token + let _index_to_token: HashMap = + vocab.iter().map(|(k, &v)| (v, k.clone())).collect(); + + let mut predicted_tokens = Vec::new(); + + for probs in probabilities.axis_iter(Axis(0)) { + // Iterate over the rows (sequence tokens) + // Find the index of the maximum probability + let max_index = probs + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal + .unwrap() + .0; + + predicted_tokens.push(max_index.clone()); + + } + + predicted_tokens +} pub fn predict_tokens( probabilities: ArrayView2, vocab: &HashMap, @@ -115,7 +141,7 @@ pub fn predict_tokens( let max_index = probs .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better MaM handling than asserting equal + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal .unwrap() .0; From 91f7e7e4b43fe0ea9a870ff5d7c75b7c3ba61aeb Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 21:27:59 +0100 Subject: [PATCH 06/13] added to sequences --- src/data/tokenizer.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/data/tokenizer.rs b/src/data/tokenizer.rs index 0ffc109..1478a73 100644 --- a/src/data/tokenizer.rs +++ b/src/data/tokenizer.rs @@ -2,6 +2,7 @@ use regex::Regex; use std::collections::{HashMap, HashSet}; +use crate::settings::INPUT_SIZE; #[derive(Debug)] pub struct Tokenizer { @@ -64,7 +65,7 @@ impl Tokenizer { tokens.push(*token); } - tokens.push(self.vocab[&self.eos_token]); // Add EOS token at the end + tokens.push(self.vocab[&self.eos_token]);// Add EOS token at the end tokens } @@ -78,8 +79,19 @@ impl Tokenizer { .join(" ") } - // Helper function to split sentence into words using an improved regex -} + pub fn pad_sequence(&self, tokens: Vec, max_length: usize) -> Vec { + let mut padded = tokens.clone(); + + // Pad with the PAD token if the sequence is too short + if tokens.len() < max_length { + padded.resize(max_length, self.vocab[&self.pad_token]); + } else if tokens.len() > max_length { + // Truncate if the sequence is too long + padded.truncate(max_length); + } + + padded + }} #[cfg(test)] mod tests { From 50a2bc0be62c4eae6b9e48cdd4232e22c2bbbce4 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 21:28:24 +0100 Subject: [PATCH 07/13] implemented gradient for weight calculations --- src/data/learnable.rs | 2 +- src/model/train_transformer.rs | 68 ++++++++++++++++++++--------- src/settings.rs | 2 +- src/training/train.rs | 78 +++++++++++++++++++++++++++------- 4 files changed, 112 insertions(+), 38 deletions(-) diff --git a/src/data/learnable.rs b/src/data/learnable.rs index 7cf7208..db8b148 100644 --- a/src/data/learnable.rs +++ b/src/data/learnable.rs @@ -42,7 +42,7 @@ impl LearnableWeights { query_weights: Array2::ones((embedding_dim, attention_dim)), key_weights: Array2::ones((embedding_dim, attention_dim)), value_weights: Array2::ones((embedding_dim, attention_dim)), - output_projection: Array2::ones((attention_dim, embedding_dim)), + output_projection: Array2::ones((attention_dim, vocab_size)), // Feedforward Network linear1_weights: Array2::ones((embedding_dim, ffn_dim)), diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs index 87893c4..c392abd 100644 --- a/src/model/train_transformer.rs +++ b/src/model/train_transformer.rs @@ -6,7 +6,7 @@ use crate::data::tokenizer::Tokenizer; use crate::layers::feedforward_layer::FeedForwardLayer; use crate::math::linear_algebra::flatten_3d_array; use crate::model::decoder::decoding; -use crate::model::embedding::{predict_tokens, Embedding}; +use crate::model::embedding::{predict_index, predict_tokens, Embedding}; use crate::model::encoder::encoding; use crate::settings::*; use crate::training::loss_function::cross_entropy_loss; @@ -23,27 +23,30 @@ fn train_model( num_epochs: usize, // Number of training epochs learning_rate: f32, // Learning rate ) -> Vec { - let vocab_size = tokenizer.vocab.len(); - let mut outputs = Vec::new(); + let vocab_size = tokenizer.vocab.len(); // Vocabulary size + let mut outputs = Vec::new(); // To store outputs for progress tracking // Loop over the number of epochs for epoch in 0..num_epochs { - println!("Epoch {}/{}", epoch + 1, num_epochs); + println!("\n=== Epoch {}/{} ===", epoch + 1, num_epochs); - // Shuffle the dataset (for better training generalization) + // Shuffle the dataset indices let mut data_indices: Vec = (0..dataset.inputs.len()).collect(); - data_indices.shuffle(&mut rand::rng()); + data_indices.shuffle(&mut rand::thread_rng()); + + let mut total_loss = 0.0; // Accumulate loss for this epoch + let mut num_batches = 0; // Loop over the training data - for idx in data_indices.iter() { - let input = &dataset.inputs[*idx]; - let target = &dataset.targets[*idx]; + for (step, &idx) in data_indices.iter().enumerate() { + let input = &dataset.inputs[idx]; + let target = &dataset.targets[idx]; // Convert to Array1 for processing let input_seq = Array1::from(input.clone()); let target_seq = Array1::from(target.clone()); - // Forward pass, loss computation, and backward pass + // Forward pass: Model prediction let (out, logits) = training_model( input, target_seq.clone(), @@ -51,23 +54,40 @@ fn train_model( vocab_size, tokenizer.vocab.clone(), ); + + // Compute loss let loss = cross_entropy_loss(&logits, &target_seq, vocab_size); - println!("Loss for this step: {}", loss); - // Update the model weights using the computed gradients + total_loss += loss; // Accumulate loss for averaging + num_batches += 1; + + // Backward pass: Compute gradients let gradients = compute_gradients(&logits, &target_seq, vocab_size, &learnable_weights); + + // Update weights update_weights(&mut learnable_weights, &gradients, learning_rate); - // For demonstration, collect outputs every 100 steps - if idx % 100 == 0 { - let decoded_output = tokenizer.detokenize(input_seq.to_vec()); + // Periodically log training progress + if step % 10 == 0 { + let decoded_output = tokenizer.detokenize(out.to_vec()); + println!( + "Step {}: Loss = {:.4}, Output = {:?}, Expected = {:?}", + step, + loss, + decoded_output, + tokenizer.detokenize(target.to_vec()) + ); outputs.push(decoded_output); } } - // Print loss for each epoch (optional) - println!("Epoch {} completed", epoch + 1); + // End of epoch: Print average loss and track improvement + let avg_loss = total_loss / num_batches as f32; + println!("Epoch {} completed with average loss: {:.4}", epoch + 1, avg_loss); + + // For debugging or tracking, we could save weights periodically here. } + println!("\nTraining completed!"); outputs } @@ -108,7 +128,7 @@ pub fn training_model( learnable_weights: &mut LearnableWeights, vocab_size: usize, vocab: HashMap, -) -> (Vec, Array2) { +) -> (Vec, Array2) { // Initialize Tokenizer and Embedding layer let embedding = Embedding::new(vocab_size, EMBEDDING_SIZE); // Initialize embedding layer @@ -154,14 +174,20 @@ pub fn training_model( } // Apply final linear transformation - let output_projection = Array2::ones((OUTPUT_SIZE, vocab_size)); // All ones weights - let logits = flatten_3d_array(decoded).dot(&output_projection); // Linear layer + let output_projection = &learnable_weights.output_projection; // All ones weights + println!("Decoded shape: {:?}", decoded.dim()); + println!("Flattened decoded shape: {:?}", flatten_3d_array(decoded.clone()).dim()); + println!("Output projection shape: {:?}", output_projection.dim()); + println!("Transposed output projection shape: {:?}", output_projection.t().dim()); + + + let logits = flatten_3d_array(decoded).dot(&output_projection.to_owned()); // Linear layer // Apply softmax to logits let probabilities = softmax_matrix(&logits); // Convert probabilities back to text using the tokenizer - let tokens = predict_tokens(probabilities.view(), &vocab); + let tokens = predict_index(probabilities.view(), &vocab); (tokens, logits.clone()) } diff --git a/src/settings.rs b/src/settings.rs index e03e971..288030d 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -8,7 +8,7 @@ pub const NUM_HEADS: usize = 8; // Number of attention heads pub const SEQ_LENGTH: usize = 128; // Sequence length pub const BATCH_SIZE: usize = 1; // Batch size pub const EMBEDDING_SIZE: usize = 12; -pub const INPUT_SIZE: usize = 12; +pub const INPUT_SIZE: usize = 25; pub const OUTPUT_SIZE: usize = 12; pub const NUM_LAYERS: usize = 6; diff --git a/src/training/train.rs b/src/training/train.rs index e65a136..714b740 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -1,26 +1,74 @@ use crate::data::learnable::LearnableWeights; use crate::settings::*; -use ndarray::{Array1, Array2}; +use ndarray::{Array1, Array2, Axis}; +use crate::attention::softmax::softmax_matrix; pub fn compute_gradients( - _logits: &Array2, - _target_sequence: &Array1, + logits: &Array2, + target_sequence: &Array1, vocab_size: usize, - _model: &LearnableWeights, + model: &LearnableWeights, ) -> LearnableWeights { - // TODO: compute gradients for all the learnable weights in the model - - // FTEMPO!!! :OR NOW ONLY COPY WEIGHTS... - LearnableWeights::new( - OUTPUT_SIZE, - HIDDEN_SIZE, - vocab_size, - EMBEDDING_SIZE, - EMBEDDING_SIZE, - HIDDEN_SIZE, - ) + // Compute the softmax probabilities + let probabilities = softmax_matrix(logits); + + // One-hot encode the target sequence (row = token in sequence, col is vocab index) + println!("Target seqeuence: {:?}",target_sequence); + let mut target_one_hot = Array2::::zeros((target_sequence.len(), vocab_size)); + for (i, &target_idx) in target_sequence.iter().enumerate() { + target_one_hot[[i, target_idx]] = 1.0; + } + + // Compute the loss gradient with respect to logits + let d_logits = &probabilities - &target_one_hot; + println!("d_logits shape: {:?}", logits.dim()); + + // Backpropagate the gradient through the output projection layer + let d_output_projection = d_logits.dot(&model.output_projection.t()); + // Compute gradients for output_projection_vocab + let d_output_projection_vocab = d_logits.sum_axis(Axis(0)).insert_axis(Axis(1)); + + // Backpropagate through layer normalization + // Compute gradients for layer normalization parameters + let d_hidden = d_logits.dot(&model.linear2_weights.t()); + let d_layer_norm_scale = d_hidden.mean_axis(Axis(0)).unwrap().to_vec(); + let d_layer_norm_shift = d_hidden.std_axis(Axis(0), 0.0).to_vec(); + + // Compute gradients for the embedding + let d_embedding = d_hidden.dot(&model.embedding.t()); + + // Compute gradients for attention weights (query, key, value) + let d_query_weights = d_hidden.dot(&model.query_weights.t()); + let d_key_weights = d_hidden.dot(&model.key_weights.t()); + let d_value_weights = d_hidden.dot(&model.value_weights.t()); + + // Compute gradients for linear layers + let d_linear1_weights = d_hidden.dot(&model.linear1_weights.t()); + let d_linear2_weights = model.linear2_weights.t().dot(&d_logits); + + // Compute biases + let d_bias1 = d_hidden.sum_axis(Axis(0)); + let d_bias2 = d_logits.sum_axis(Axis(0)); + + // Package gradients into a LearnableWeights structure + LearnableWeights { + embedding: d_embedding, + query_weights: d_query_weights, + key_weights: d_key_weights, + value_weights: d_value_weights, + output_projection: d_output_projection, + linear1_weights: d_linear1_weights, + linear2_weights: d_linear2_weights, + bias1: d_bias1, + bias2: d_bias2, + layer_norm_scale: d_layer_norm_scale, // Converted to Vec + layer_norm_shift: d_layer_norm_shift, // Converted to Vec + output_projection_vocab: d_output_projection_vocab, // Correct shape as Array2 + } } + + pub fn update_weights( model: &mut LearnableWeights, gradients: &LearnableWeights, From 7818c2cbdb1bdceea74e7c64945014546a26a876 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 21:30:15 +0100 Subject: [PATCH 08/13] Adjusted weights to eliminate every shape mismatch --- src/settings.rs | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/settings.rs b/src/settings.rs index 288030d..24a5d8e 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,20 +1,37 @@ -// Numerical constants with down scaled real-application values #![allow(unused)] + +// Numerical constants with down scaled real-application values pub const EPSILON: f32 = 0.0001; -pub const D_MODEL: usize = 12; // Embedding size -pub const D_K: usize = 64; // Key/query dimension -pub const D_V: usize = 64; // Value dimension + +// Embedding size +pub const D_MODEL: usize = 64; // Changed from 12 to a larger, more standard dimension for transformers + +// Attention mechanism dimensions +pub const D_K: usize = 64; // Key/query dimension (same as D_V for simplicity) +pub const D_V: usize = 64; // Value dimension (same as D_K) pub const NUM_HEADS: usize = 8; // Number of attention heads -pub const SEQ_LENGTH: usize = 128; // Sequence length -pub const BATCH_SIZE: usize = 1; // Batch size -pub const EMBEDDING_SIZE: usize = 12; -pub const INPUT_SIZE: usize = 25; -pub const OUTPUT_SIZE: usize = 12; -pub const NUM_LAYERS: usize = 6; +// Sequence and batch size +pub const SEQ_LENGTH: usize = 128; // Sequence length (adjustable depending on your data) +pub const BATCH_SIZE: usize = 32; // Increased batch size for practical usage + +// Embedding size and dimensions +pub const EMBEDDING_SIZE: usize = D_MODEL; // Should match D_MODEL for consistency + +// Input/Output sizes +pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs +pub const OUTPUT_SIZE: usize = D_MODEL; // Should be consistent with D_MODEL for output + +// Number of layers +pub const NUM_LAYERS: usize = 6; // Number of layers (standard for many transformer architectures) + +// Feedforward network dimension (FFN_DIM) +pub const FFN_DIM: usize = 256; // A common size for the feedforward dimension -pub const HIDDEN_SIZE: usize = 6; +// Hidden size (used for biases and other layer parameters) +pub const HIDDEN_SIZE: usize = 256; // Adjusted for a larger hidden layer size, consistent with FFN_DIM +// Dropout rate and learning rate pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate From a1ab397b6bfbdd372fefd1c819ea30c76aac0826 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 22:36:56 +0100 Subject: [PATCH 09/13] new ompute loss function which for now produces wrong shapes --- src/data/learnable.rs | 25 +++---- src/model/train_transformer.rs | 21 +++++- src/settings.rs | 56 ++++++++-------- src/training/loss_function.rs | 18 +++-- src/training/train.rs | 117 ++++++++++++++++----------------- 5 files changed, 130 insertions(+), 107 deletions(-) diff --git a/src/data/learnable.rs b/src/data/learnable.rs index db8b148..3397115 100644 --- a/src/data/learnable.rs +++ b/src/data/learnable.rs @@ -1,5 +1,5 @@ -use crate::settings::*; use ndarray::{Array1, Array2}; +use crate::settings::*; pub struct LearnableWeights { // Embedding Layer @@ -14,8 +14,8 @@ pub struct LearnableWeights { // Feedforward Network pub linear1_weights: Array2, // (embedding_dim, ffn_dim) pub linear2_weights: Array2, // (ffn_dim, embedding_dim) - pub bias1: Array1, - pub bias2: Array1, + pub bias1: Array1, // (ffn_dim) + pub bias2: Array1, // (embedding_dim) // Layer Normalization pub layer_norm_scale: Vec, // (embedding_dim,) @@ -42,13 +42,13 @@ impl LearnableWeights { query_weights: Array2::ones((embedding_dim, attention_dim)), key_weights: Array2::ones((embedding_dim, attention_dim)), value_weights: Array2::ones((embedding_dim, attention_dim)), - output_projection: Array2::ones((attention_dim, vocab_size)), + output_projection: Array2::ones((attention_dim, embedding_dim)), // Feedforward Network linear1_weights: Array2::ones((embedding_dim, ffn_dim)), linear2_weights: Array2::ones((ffn_dim, embedding_dim)), - bias1: Array1::zeros(hidden_size), - bias2: Array1::zeros(output_size), + bias1: Array1::zeros(ffn_dim), + bias2: Array1::zeros(embedding_dim), // Layer Normalization layer_norm_scale: vec![1.0; embedding_dim], // Initialize scale to 1 @@ -59,13 +59,14 @@ impl LearnableWeights { } } } + pub fn initialize_weights() -> LearnableWeights { LearnableWeights::new( - OUTPUT_SIZE, // output_size - HIDDEN_SIZE, // hidden_size - D_MODEL, // vocab_size - D_MODEL, // embedding_dim - D_K, // attention_dim (for keys, queries) - D_V, // ffn_dim (could align with embedding_dim or specific) + D_MODEL, // output_size = D_MODEL + FFN_DIM, // hidden_size = FFN_DIM + D_MODEL, // vocab_size + D_MODEL, // embedding_dim = D_MODEL + D_K, // attention_dim + FFN_DIM, // ffn_dim ) } diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs index c392abd..edf2369 100644 --- a/src/model/train_transformer.rs +++ b/src/model/train_transformer.rs @@ -60,8 +60,27 @@ fn train_model( total_loss += loss; // Accumulate loss for averaging num_batches += 1; + // Prepare inputs, targets, and predictions for gradient computation + let inputs = Array3::from_shape_fn( + (BATCH_SIZE, input.len(), EMBEDDING_SIZE), + |(_, seq, embed)| logits[[seq, embed]], + ); + + let targets = Array2::from_shape_fn( + (target.len(), logits.shape()[1]), + |(seq, embed)| logits[[seq, embed]], + ); + + let predictions = logits.clone(); + // Backward pass: Compute gradients - let gradients = compute_gradients(&logits, &target_seq, vocab_size, &learnable_weights); + let gradients = compute_gradients( + &mut learnable_weights, + &inputs, + &targets, + &predictions, + vocab_size + ); // Update weights update_weights(&mut learnable_weights, &gradients, learning_rate); diff --git a/src/settings.rs b/src/settings.rs index 24a5d8e..8290d5b 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,39 +1,39 @@ -#![allow(unused)] + #![allow(unused)] -// Numerical constants with down scaled real-application values -pub const EPSILON: f32 = 0.0001; + // Numerical constants with downscaled real-application values + pub const EPSILON: f32 = 0.0001; -// Embedding size -pub const D_MODEL: usize = 64; // Changed from 12 to a larger, more standard dimension for transformers + // Embedding size + pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size -// Attention mechanism dimensions -pub const D_K: usize = 64; // Key/query dimension (same as D_V for simplicity) -pub const D_V: usize = 64; // Value dimension (same as D_K) -pub const NUM_HEADS: usize = 8; // Number of attention heads + // Attention mechanism dimensions + pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) + pub const D_V: usize = 32; // Value dimension (same as D_K) + pub const NUM_HEADS: usize = 4; // Reduced the number of attention heads for smaller model -// Sequence and batch size -pub const SEQ_LENGTH: usize = 128; // Sequence length (adjustable depending on your data) -pub const BATCH_SIZE: usize = 32; // Increased batch size for practical usage + // Sequence and batch size + pub const SEQ_LENGTH: usize = 64; // Reduced sequence length + pub const BATCH_SIZE: usize = 1; // Reduced batch size for smaller model training -// Embedding size and dimensions -pub const EMBEDDING_SIZE: usize = D_MODEL; // Should match D_MODEL for consistency + // Embedding size and dimensions + pub const EMBEDDING_SIZE: usize = D_MODEL; // Matches D_MODEL for consistency -// Input/Output sizes -pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs -pub const OUTPUT_SIZE: usize = D_MODEL; // Should be consistent with D_MODEL for output + // Input/Output sizes + pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs + pub const OUTPUT_SIZE: usize = D_MODEL; // Consistent with D_MODEL for output -// Number of layers -pub const NUM_LAYERS: usize = 6; // Number of layers (standard for many transformer architectures) + // Number of layers + pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architecture -// Feedforward network dimension (FFN_DIM) -pub const FFN_DIM: usize = 256; // A common size for the feedforward dimension + // Feedforward network dimension (FFN_DIM) + pub const FFN_DIM: usize = 128; // Smaller FFN dimension -// Hidden size (used for biases and other layer parameters) -pub const HIDDEN_SIZE: usize = 256; // Adjusted for a larger hidden layer size, consistent with FFN_DIM + // Hidden size (used for biases and other layer parameters) + pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM -// Dropout rate and learning rate -pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization -pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate + // Dropout rate and learning rate + pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization + pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate -// Positional encoding parameters -pub const MAX_SEQ_LENGTH: usize = 512; // Maximum sequence length for positional encoding + // Positional encoding parameters + pub const MAX_SEQ_LENGTH: usize = 128; // Maximum sequence length for positional encoding diff --git a/src/training/loss_function.rs b/src/training/loss_function.rs index f691020..6259d78 100644 --- a/src/training/loss_function.rs +++ b/src/training/loss_function.rs @@ -4,22 +4,30 @@ use ndarray::{s, Array1, Array2}; pub fn cross_entropy_loss( logits: &Array2, targets: &Array1, - _vocab_size: usize, + vocab_size: usize, ) -> f32 { let mut loss = 0.0; + // Ensure that the number of targets matches the batch size + assert_eq!(logits.dim().0, targets.len(), "Batch size mismatch"); + // Iterate over each target in the batch for (i, &target) in targets.iter().enumerate() { + // Ensure target index is within valid range + if target >= vocab_size { + panic!("Target index {} is out of bounds for vocab_size {}", target, vocab_size); + } + // Get the logits for the current target (batch_size x vocab_size) - let logit = &logits.slice(s![i, ..]); + let logit = &logits.slice(s![i, ..]); // Get the logits for the i-th sample - // Softmax calculation + // Softmax calculation: convert logits to probabilities let softmax = softmax_vector(*logit); // The log probability for the correct target token - let log_prob = softmax[target]; + let log_prob = softmax[i]; - // Add to the loss: -log(p_y) + // Add to the loss: -log(p_y) for cross-entropy loss -= log_prob.ln(); } diff --git a/src/training/train.rs b/src/training/train.rs index 714b740..568b061 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -1,70 +1,63 @@ use crate::data::learnable::LearnableWeights; use crate::settings::*; -use ndarray::{Array1, Array2, Axis}; +use ndarray::{Array1, Array2, Array3, Axis}; use crate::attention::softmax::softmax_matrix; +use crate::math::linear_algebra::flatten_3d_array; +/// Compute gradients for the transformer model's learnable weights. pub fn compute_gradients( - logits: &Array2, - target_sequence: &Array1, - vocab_size: usize, - model: &LearnableWeights, + weights: &mut LearnableWeights, + inputs: &Array3, + targets: &Array2, + predictions: &Array2, + vocabsize : usize, ) -> LearnableWeights { - // Compute the softmax probabilities - let probabilities = softmax_matrix(logits); - - // One-hot encode the target sequence (row = token in sequence, col is vocab index) - println!("Target seqeuence: {:?}",target_sequence); - let mut target_one_hot = Array2::::zeros((target_sequence.len(), vocab_size)); - for (i, &target_idx) in target_sequence.iter().enumerate() { - target_one_hot[[i, target_idx]] = 1.0; - } - - // Compute the loss gradient with respect to logits - let d_logits = &probabilities - &target_one_hot; - println!("d_logits shape: {:?}", logits.dim()); - - // Backpropagate the gradient through the output projection layer - let d_output_projection = d_logits.dot(&model.output_projection.t()); - // Compute gradients for output_projection_vocab - let d_output_projection_vocab = d_logits.sum_axis(Axis(0)).insert_axis(Axis(1)); - - // Backpropagate through layer normalization - // Compute gradients for layer normalization parameters - let d_hidden = d_logits.dot(&model.linear2_weights.t()); - let d_layer_norm_scale = d_hidden.mean_axis(Axis(0)).unwrap().to_vec(); - let d_layer_norm_shift = d_hidden.std_axis(Axis(0), 0.0).to_vec(); - - // Compute gradients for the embedding - let d_embedding = d_hidden.dot(&model.embedding.t()); - - // Compute gradients for attention weights (query, key, value) - let d_query_weights = d_hidden.dot(&model.query_weights.t()); - let d_key_weights = d_hidden.dot(&model.key_weights.t()); - let d_value_weights = d_hidden.dot(&model.value_weights.t()); - - // Compute gradients for linear layers - let d_linear1_weights = d_hidden.dot(&model.linear1_weights.t()); - let d_linear2_weights = model.linear2_weights.t().dot(&d_logits); - - // Compute biases - let d_bias1 = d_hidden.sum_axis(Axis(0)); - let d_bias2 = d_logits.sum_axis(Axis(0)); - - // Package gradients into a LearnableWeights structure - LearnableWeights { - embedding: d_embedding, - query_weights: d_query_weights, - key_weights: d_key_weights, - value_weights: d_value_weights, - output_projection: d_output_projection, - linear1_weights: d_linear1_weights, - linear2_weights: d_linear2_weights, - bias1: d_bias1, - bias2: d_bias2, - layer_norm_scale: d_layer_norm_scale, // Converted to Vec - layer_norm_shift: d_layer_norm_shift, // Converted to Vec - output_projection_vocab: d_output_projection_vocab, // Correct shape as Array2 - } + let mut gradients = LearnableWeights::new( + OUTPUT_SIZE, + HIDDEN_SIZE, + vocabsize, // Ensure the vocab size is correct + D_MODEL, + D_K, + FFN_DIM, + ); + + // Compute the loss and its derivative + let loss = predictions - targets; + let d_loss = &loss * 2.0 / (BATCH_SIZE as f32); + + // Compute gradients for the output projection weights + gradients.output_projection_vocab = predictions.t().dot(&d_loss); + + // Flattened inputs for further computations + let flattened_inputs = flatten_3d_array(inputs.clone()); + + // Compute gradients for the feedforward network weights + let d_linear2 = d_loss.dot(&weights.linear2_weights.t()); + gradients.linear2_weights = flattened_inputs.t().dot(&d_linear2); + gradients.bias2 = d_linear2.sum_axis(ndarray::Axis(0)); + + let d_linear1 = d_linear2.dot(&weights.linear1_weights.t()); + gradients.linear1_weights = flattened_inputs.t().dot(&d_linear1); + gradients.bias1 = d_linear1.sum_axis(ndarray::Axis(0)); + + // Compute gradients for the attention mechanism weights + let d_attention_output = d_loss.dot(&weights.output_projection.t()); + gradients.output_projection = flattened_inputs.t().dot(&d_attention_output); + let d_value = d_attention_output.dot(&weights.value_weights.t()); + gradients.value_weights = flattened_inputs.t().dot(&d_value); + let d_key = d_attention_output.dot(&weights.key_weights.t()); + gradients.key_weights = flattened_inputs.t().dot(&d_key); + let d_query = d_attention_output.dot(&weights.query_weights.t()); + gradients.query_weights = flattened_inputs.t().dot(&d_query); + + // Compute gradients for the embedding layer + gradients.embedding = inputs.mean_axis(ndarray::Axis(0)).unwrap(); // Ensure shape consistency with model.embedding + + // Compute gradients for layer normalization parameters (scale and shift) + gradients.layer_norm_scale = d_linear1.mean_axis(ndarray::Axis(0)).unwrap().to_vec(); + gradients.layer_norm_shift = d_linear1.sum_axis(ndarray::Axis(0)).to_vec(); + + gradients } @@ -74,6 +67,8 @@ pub fn update_weights( gradients: &LearnableWeights, learning_rate: f32, ) { + + println!("EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}",model.embedding.shape(),gradients.embedding.shape()); // Ensure the gradients and model weights have compatible shapes (reshape if necessary) model.embedding = &model.embedding - &(&gradients.embedding * learning_rate); model.query_weights = &model.query_weights - &(&gradients.query_weights * learning_rate); From 7127cd957691076e6bc685aef106a6dfc4f41297 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 22:40:57 +0100 Subject: [PATCH 10/13] Refactored and linted code, still having mismatches regarding the gradient computations --- src/data/generation.rs | 6 ++-- src/data/learnable.rs | 18 +++++------ src/data/tokenizer.rs | 7 +++-- src/model/embedding.rs | 35 +++++++++++---------- src/model/train_transformer.rs | 29 ++++++++++++------ src/settings.rs | 56 +++++++++++++++++----------------- src/training/loss_function.rs | 11 +++---- src/training/train.rs | 18 +++++------ 8 files changed, 93 insertions(+), 87 deletions(-) diff --git a/src/data/generation.rs b/src/data/generation.rs index 7448735..d8cf239 100644 --- a/src/data/generation.rs +++ b/src/data/generation.rs @@ -15,7 +15,6 @@ pub fn generate_input_target_pairs( // Prepare input (same as sentence) let input = tokenizer.pad_sequence(tokens.clone(), INPUT_SIZE); - // Prepare target (shifted version of the sentence) let mut target = tokens.clone(); if i + 1 < sentences.len() { @@ -25,7 +24,6 @@ pub fn generate_input_target_pairs( if !next_tokens.is_empty() { target.push(next_tokens[0]); // Add the first token of the next sentence } - } else { target.push(tokenizer.vocab[""]); // Use EOS token for the last sentence } @@ -98,7 +96,7 @@ pub fn generate_staircase_pairs( // Pad both input and target sequences to max_length let staircase_input = pad_sequence_to_length(&staircase_input, INPUT_SIZE); - let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE); + let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE); // Add this pair to the staircase pairs vector staircase_pairs.push((staircase_input, staircase_target)); } @@ -117,4 +115,4 @@ fn pad_sequence_to_length(seq: &[usize], max_length: usize) -> Vec { } padded_seq -} \ No newline at end of file +} diff --git a/src/data/learnable.rs b/src/data/learnable.rs index 3397115..5fc6a84 100644 --- a/src/data/learnable.rs +++ b/src/data/learnable.rs @@ -1,5 +1,5 @@ -use ndarray::{Array1, Array2}; use crate::settings::*; +use ndarray::{Array1, Array2}; pub struct LearnableWeights { // Embedding Layer @@ -28,7 +28,7 @@ pub struct LearnableWeights { impl LearnableWeights { pub fn new( output_size: usize, - hidden_size: usize, + _hidden_size: usize, vocab_size: usize, embedding_dim: usize, attention_dim: usize, @@ -55,18 +55,18 @@ impl LearnableWeights { layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0 // Output Layer - output_projection_vocab: Array2::zeros((embedding_dim, vocab_size)), + output_projection_vocab: Array2::zeros((embedding_dim, output_size)), } } } pub fn initialize_weights() -> LearnableWeights { LearnableWeights::new( - D_MODEL, // output_size = D_MODEL - FFN_DIM, // hidden_size = FFN_DIM - D_MODEL, // vocab_size - D_MODEL, // embedding_dim = D_MODEL - D_K, // attention_dim - FFN_DIM, // ffn_dim + D_MODEL, // output_size = D_MODEL + FFN_DIM, // hidden_size = FFN_DIM + D_MODEL, // vocab_size + D_MODEL, // embedding_dim = D_MODEL + D_K, // attention_dim + FFN_DIM, // ffn_dim ) } diff --git a/src/data/tokenizer.rs b/src/data/tokenizer.rs index 1478a73..9a460ba 100644 --- a/src/data/tokenizer.rs +++ b/src/data/tokenizer.rs @@ -1,8 +1,8 @@ #![allow(warnings)] +use crate::settings::INPUT_SIZE; use regex::Regex; use std::collections::{HashMap, HashSet}; -use crate::settings::INPUT_SIZE; #[derive(Debug)] pub struct Tokenizer { @@ -65,7 +65,7 @@ impl Tokenizer { tokens.push(*token); } - tokens.push(self.vocab[&self.eos_token]);// Add EOS token at the end + tokens.push(self.vocab[&self.eos_token]); // Add EOS token at the end tokens } @@ -91,7 +91,8 @@ impl Tokenizer { } padded - }} + } +} #[cfg(test)] mod tests { diff --git a/src/model/embedding.rs b/src/model/embedding.rs index 0f9a176..c897207 100644 --- a/src/model/embedding.rs +++ b/src/model/embedding.rs @@ -2,9 +2,8 @@ use crate::math::positional_encoding::sinusoidal_pos_encoding; use ndarray::{s, Array, Array2, ArrayView, ArrayView2, Axis, Ix1}; -use std::cmp::Ordering; -//use rand::Rng; use rand::Rng; +use std::cmp::Ordering; use std::collections::HashMap; pub struct Embedding { @@ -75,20 +74,20 @@ impl Embedding { let similarities: Vec = self .weights .axis_iter(Axis(0)) - .map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(embedding))) + .map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(decoded))) + .filter(|&similarity| !similarity.is_nan()) // Filter out NaNs .collect(); // Find the index of the maximum similarity - let best_match = similarities + if let Some((best_match, _)) = similarities .iter() .enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) - .unwrap() - .0; - - // Map index to the corresponding token - if let Some(token) = index_to_token.get(&best_match) { - predicted_tokens.push(token.clone()); + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + { + // Map index to the corresponding token + if let Some(token) = index_to_token.get(&best_match) { + predicted_tokens.push(token.clone()); + } } } @@ -100,9 +99,7 @@ pub fn norm(vector: ArrayView) -> f32 { vector.mapv(|x| x * x).sum().sqrt() } -pub fn predict_index( probabilities: ArrayView2, - vocab: &HashMap, ) -> Vec{ - +pub fn predict_index(probabilities: ArrayView2, vocab: &HashMap) -> Vec { // Reverse the vocab to get a mapping from index to token let _index_to_token: HashMap = vocab.iter().map(|(k, &v)| (v, k.clone())).collect(); @@ -115,16 +112,17 @@ pub fn predict_index( probabilities: ArrayView2, let max_index = probs .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal + .filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) .unwrap() .0; - predicted_tokens.push(max_index.clone()); - + predicted_tokens.push(max_index); } predicted_tokens } + pub fn predict_tokens( probabilities: ArrayView2, vocab: &HashMap, @@ -141,7 +139,8 @@ pub fn predict_tokens( let max_index = probs .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal + .filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) .unwrap() .0; diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs index edf2369..39c7f24 100644 --- a/src/model/train_transformer.rs +++ b/src/model/train_transformer.rs @@ -24,7 +24,7 @@ fn train_model( learning_rate: f32, // Learning rate ) -> Vec { let vocab_size = tokenizer.vocab.len(); // Vocabulary size - let mut outputs = Vec::new(); // To store outputs for progress tracking + let mut outputs = Vec::new(); // To store outputs for progress tracking // Loop over the number of epochs for epoch in 0..num_epochs { @@ -66,10 +66,10 @@ fn train_model( |(_, seq, embed)| logits[[seq, embed]], ); - let targets = Array2::from_shape_fn( - (target.len(), logits.shape()[1]), - |(seq, embed)| logits[[seq, embed]], - ); + let targets = + Array2::from_shape_fn((target.len(), logits.shape()[1]), |(seq, embed)| { + logits[[seq, embed]] + }); let predictions = logits.clone(); @@ -79,7 +79,7 @@ fn train_model( &inputs, &targets, &predictions, - vocab_size + vocab_size, ); // Update weights @@ -101,7 +101,11 @@ fn train_model( // End of epoch: Print average loss and track improvement let avg_loss = total_loss / num_batches as f32; - println!("Epoch {} completed with average loss: {:.4}", epoch + 1, avg_loss); + println!( + "Epoch {} completed with average loss: {:.4}", + epoch + 1, + avg_loss + ); // For debugging or tracking, we could save weights periodically here. } @@ -195,10 +199,15 @@ pub fn training_model( // Apply final linear transformation let output_projection = &learnable_weights.output_projection; // All ones weights println!("Decoded shape: {:?}", decoded.dim()); - println!("Flattened decoded shape: {:?}", flatten_3d_array(decoded.clone()).dim()); + println!( + "Flattened decoded shape: {:?}", + flatten_3d_array(decoded.clone()).dim() + ); println!("Output projection shape: {:?}", output_projection.dim()); - println!("Transposed output projection shape: {:?}", output_projection.t().dim()); - + println!( + "Transposed output projection shape: {:?}", + output_projection.t().dim() + ); let logits = flatten_3d_array(decoded).dot(&output_projection.to_owned()); // Linear layer diff --git a/src/settings.rs b/src/settings.rs index 8290d5b..5034f80 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,39 +1,39 @@ - #![allow(unused)] +#![allow(unused)] - // Numerical constants with downscaled real-application values - pub const EPSILON: f32 = 0.0001; +// Numerical constants with downscaled real-application values +pub const EPSILON: f32 = 0.0001; - // Embedding size - pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size +// Embedding size +pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size - // Attention mechanism dimensions - pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) - pub const D_V: usize = 32; // Value dimension (same as D_K) - pub const NUM_HEADS: usize = 4; // Reduced the number of attention heads for smaller model +// Attention mechanism dimensions +pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) +pub const D_V: usize = 32; // Value dimension (same as D_K) +pub const NUM_HEADS: usize = 4; // Reduced the number of attention heads for smaller model - // Sequence and batch size - pub const SEQ_LENGTH: usize = 64; // Reduced sequence length - pub const BATCH_SIZE: usize = 1; // Reduced batch size for smaller model training +// Sequence and batch size +pub const SEQ_LENGTH: usize = 64; // Reduced sequence length +pub const BATCH_SIZE: usize = 1; // Reduced batch size for smaller model training - // Embedding size and dimensions - pub const EMBEDDING_SIZE: usize = D_MODEL; // Matches D_MODEL for consistency +// Embedding size and dimensions +pub const EMBEDDING_SIZE: usize = D_MODEL; // Matches D_MODEL for consistency - // Input/Output sizes - pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs - pub const OUTPUT_SIZE: usize = D_MODEL; // Consistent with D_MODEL for output +// Input/Output sizes +pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs +pub const OUTPUT_SIZE: usize = D_MODEL; // Consistent with D_MODEL for output - // Number of layers - pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architecture +// Number of layers +pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architecture - // Feedforward network dimension (FFN_DIM) - pub const FFN_DIM: usize = 128; // Smaller FFN dimension +// Feedforward network dimension (FFN_DIM) +pub const FFN_DIM: usize = 128; // Smaller FFN dimension - // Hidden size (used for biases and other layer parameters) - pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM +// Hidden size (used for biases and other layer parameters) +pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM - // Dropout rate and learning rate - pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization - pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate +// Dropout rate and learning rate +pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization +pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate - // Positional encoding parameters - pub const MAX_SEQ_LENGTH: usize = 128; // Maximum sequence length for positional encoding +// Positional encoding parameters +pub const MAX_SEQ_LENGTH: usize = 128; // Maximum sequence length for positional encoding diff --git a/src/training/loss_function.rs b/src/training/loss_function.rs index 6259d78..f43d9fd 100644 --- a/src/training/loss_function.rs +++ b/src/training/loss_function.rs @@ -1,11 +1,7 @@ use crate::attention::softmax::softmax_vector; use ndarray::{s, Array1, Array2}; -pub fn cross_entropy_loss( - logits: &Array2, - targets: &Array1, - vocab_size: usize, -) -> f32 { +pub fn cross_entropy_loss(logits: &Array2, targets: &Array1, vocab_size: usize) -> f32 { let mut loss = 0.0; // Ensure that the number of targets matches the batch size @@ -15,7 +11,10 @@ pub fn cross_entropy_loss( for (i, &target) in targets.iter().enumerate() { // Ensure target index is within valid range if target >= vocab_size { - panic!("Target index {} is out of bounds for vocab_size {}", target, vocab_size); + panic!( + "Target index {} is out of bounds for vocab_size {}", + target, vocab_size + ); } // Get the logits for the current target (batch_size x vocab_size) diff --git a/src/training/train.rs b/src/training/train.rs index 568b061..1883f84 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -1,8 +1,7 @@ use crate::data::learnable::LearnableWeights; -use crate::settings::*; -use ndarray::{Array1, Array2, Array3, Axis}; -use crate::attention::softmax::softmax_matrix; use crate::math::linear_algebra::flatten_3d_array; +use crate::settings::*; +use ndarray::{Array2, Array3}; /// Compute gradients for the transformer model's learnable weights. pub fn compute_gradients( @@ -10,12 +9,12 @@ pub fn compute_gradients( inputs: &Array3, targets: &Array2, predictions: &Array2, - vocabsize : usize, + vocabsize: usize, ) -> LearnableWeights { let mut gradients = LearnableWeights::new( OUTPUT_SIZE, HIDDEN_SIZE, - vocabsize, // Ensure the vocab size is correct + vocabsize, // Ensure the vocab size is correct D_MODEL, D_K, FFN_DIM, @@ -60,15 +59,16 @@ pub fn compute_gradients( gradients } - - pub fn update_weights( model: &mut LearnableWeights, gradients: &LearnableWeights, learning_rate: f32, ) { - - println!("EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}",model.embedding.shape(),gradients.embedding.shape()); + println!( + "EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}", + model.embedding.shape(), + gradients.embedding.shape() + ); // Ensure the gradients and model weights have compatible shapes (reshape if necessary) model.embedding = &model.embedding - &(&gradients.embedding * learning_rate); model.query_weights = &model.query_weights - &(&gradients.query_weights * learning_rate); From 8c3348203e5c4accc0decdd06ed7d810a9f4af94 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 22 Dec 2024 22:51:30 +0100 Subject: [PATCH 11/13] implemented gradient computation and adjusted model parameters, having issues with the shapes of the linear weights 1 and 2 --- src/data/generation.rs | 17 ++++++++++++----- src/data/learnable.rs | 22 +++++++++++----------- src/model/train_transformer.rs | 9 ++------- src/settings.rs | 2 +- src/training/train.rs | 27 +++++++++++++++------------ 5 files changed, 41 insertions(+), 36 deletions(-) diff --git a/src/data/generation.rs b/src/data/generation.rs index d8cf239..8e622ba 100644 --- a/src/data/generation.rs +++ b/src/data/generation.rs @@ -107,11 +107,18 @@ fn pad_sequence_to_length(seq: &[usize], max_length: usize) -> Vec { let mut padded_seq = seq.to_vec(); // Pad with token if the sequence is shorter than max_length - if padded_seq.len() < max_length { - padded_seq.resize(max_length, 0); // 0 is the token index - } else if padded_seq.len() > max_length { - // Truncate if the sequence is too long - padded_seq.truncate(max_length); + match padded_seq.len().cmp(&max_length) { + std::cmp::Ordering::Less => { + // If the sequence is too short, pad it with tokens (0) + padded_seq.resize(max_length, 0); + } + std::cmp::Ordering::Greater => { + // If the sequence is too long, truncate it + padded_seq.truncate(max_length); + } + std::cmp::Ordering::Equal => { + // If the sequence is already the correct length, do nothing + } } padded_seq diff --git a/src/data/learnable.rs b/src/data/learnable.rs index 5fc6a84..ee3843d 100644 --- a/src/data/learnable.rs +++ b/src/data/learnable.rs @@ -39,23 +39,23 @@ impl LearnableWeights { embedding: Array2::ones((vocab_size, embedding_dim)), // Attention Mechanism - query_weights: Array2::ones((embedding_dim, attention_dim)), - key_weights: Array2::ones((embedding_dim, attention_dim)), - value_weights: Array2::ones((embedding_dim, attention_dim)), - output_projection: Array2::ones((attention_dim, embedding_dim)), + query_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) + key_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) + value_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) + output_projection: Array2::ones((attention_dim, embedding_dim)), // (attention_dim, embedding_dim) // Feedforward Network - linear1_weights: Array2::ones((embedding_dim, ffn_dim)), - linear2_weights: Array2::ones((ffn_dim, embedding_dim)), - bias1: Array1::zeros(ffn_dim), - bias2: Array1::zeros(embedding_dim), + linear1_weights: Array2::ones((embedding_dim, ffn_dim)), // (embedding_dim, ffn_dim) + linear2_weights: Array2::ones((ffn_dim, embedding_dim)), // (ffn_dim, embedding_dim) + bias1: Array1::zeros(ffn_dim), // (ffn_dim) + bias2: Array1::zeros(embedding_dim), // (embedding_dim) // Layer Normalization - layer_norm_scale: vec![1.0; embedding_dim], // Initialize scale to 1 - layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0 + layer_norm_scale: vec![1.0; embedding_dim], // Initialize scale to 1 (embedding_dim,) + layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0 (embedding_dim,) // Output Layer - output_projection_vocab: Array2::zeros((embedding_dim, output_size)), + output_projection_vocab: Array2::zeros((embedding_dim, output_size)), // (embedding_dim, vocab_size) } } } diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs index 39c7f24..2e4f917 100644 --- a/src/model/train_transformer.rs +++ b/src/model/train_transformer.rs @@ -74,13 +74,8 @@ fn train_model( let predictions = logits.clone(); // Backward pass: Compute gradients - let gradients = compute_gradients( - &mut learnable_weights, - &inputs, - &targets, - &predictions, - vocab_size, - ); + let gradients = + compute_gradients(&mut learnable_weights, &inputs, &targets, &predictions); // Update weights update_weights(&mut learnable_weights, &gradients, learning_rate); diff --git a/src/settings.rs b/src/settings.rs index 5034f80..dbe59bc 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -4,7 +4,7 @@ pub const EPSILON: f32 = 0.0001; // Embedding size -pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size +pub const D_MODEL: usize = 88; // Attention mechanism dimensions pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) diff --git a/src/training/train.rs b/src/training/train.rs index 1883f84..6b79852 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -9,15 +9,14 @@ pub fn compute_gradients( inputs: &Array3, targets: &Array2, predictions: &Array2, - vocabsize: usize, ) -> LearnableWeights { let mut gradients = LearnableWeights::new( - OUTPUT_SIZE, - HIDDEN_SIZE, - vocabsize, // Ensure the vocab size is correct - D_MODEL, - D_K, - FFN_DIM, + D_MODEL, // output_size = D_MODEL + FFN_DIM, // hidden_size = FFN_DIM + D_MODEL, // vocab_size + D_MODEL, // embedding_dim = D_MODEL + D_K, // attention_dim + FFN_DIM, // ffn_dim ); // Compute the loss and its derivative @@ -64,11 +63,6 @@ pub fn update_weights( gradients: &LearnableWeights, learning_rate: f32, ) { - println!( - "EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}", - model.embedding.shape(), - gradients.embedding.shape() - ); // Ensure the gradients and model weights have compatible shapes (reshape if necessary) model.embedding = &model.embedding - &(&gradients.embedding * learning_rate); model.query_weights = &model.query_weights - &(&gradients.query_weights * learning_rate); @@ -76,6 +70,15 @@ pub fn update_weights( model.value_weights = &model.value_weights - &(&gradients.value_weights * learning_rate); model.output_projection = &model.output_projection - &(&gradients.output_projection * learning_rate); + + println!( + "lin1 OLD :{:?}, lin1 NEW: {:?}", + model.linear1_weights.shape(), + gradients.linear1_weights.shape() + ); + + // TODO DEBUG SHAPE issues (old 88,128 and new 88,88 which seems wrong) + model.linear1_weights = &model.linear1_weights - &(&gradients.linear1_weights * learning_rate); model.linear2_weights = &model.linear2_weights - &(&gradients.linear2_weights * learning_rate); From dcb664ced28a1fa6929d8952ff324eeae6e966c3 Mon Sep 17 00:00:00 2001 From: Jakub Date: Mon, 23 Dec 2024 17:19:21 +0100 Subject: [PATCH 12/13] adjusted hidden size to match dimensions for gradient computation --- src/data/learnable.rs | 11 ++++---- src/model/train_transformer.rs | 21 ++++++++++---- src/settings.rs | 4 +-- src/training/train.rs | 50 ++++++++++++++++++++++++---------- 4 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/data/learnable.rs b/src/data/learnable.rs index ee3843d..318c2fc 100644 --- a/src/data/learnable.rs +++ b/src/data/learnable.rs @@ -1,6 +1,5 @@ use crate::settings::*; use ndarray::{Array1, Array2}; - pub struct LearnableWeights { // Embedding Layer pub embedding: Array2, // (vocab_size, embedding_dim) @@ -36,11 +35,11 @@ impl LearnableWeights { ) -> Self { LearnableWeights { // Embedding Layer - embedding: Array2::ones((vocab_size, embedding_dim)), + embedding: Array2::ones((vocab_size, embedding_dim)), // (vocab_size, embedding_dim) // Attention Mechanism query_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) - key_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) + key_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) value_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) output_projection: Array2::ones((attention_dim, embedding_dim)), // (attention_dim, embedding_dim) @@ -64,9 +63,9 @@ pub fn initialize_weights() -> LearnableWeights { LearnableWeights::new( D_MODEL, // output_size = D_MODEL FFN_DIM, // hidden_size = FFN_DIM - D_MODEL, // vocab_size + D_MODEL, // vocab_size (assuming this is same as embedding_dim) D_MODEL, // embedding_dim = D_MODEL - D_K, // attention_dim - FFN_DIM, // ffn_dim + D_K, // attention_dim = D_K + FFN_DIM, // ffn_dim = FFN_DIM ) } diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs index 2e4f917..ea6932e 100644 --- a/src/model/train_transformer.rs +++ b/src/model/train_transformer.rs @@ -101,8 +101,6 @@ fn train_model( epoch + 1, avg_loss ); - - // For debugging or tracking, we could save weights periodically here. } println!("\nTraining completed!"); @@ -159,6 +157,9 @@ pub fn training_model( |(_, seq, embed)| embeddings[[seq, embed]], ); + // Debugging: Print shape of input tensor + println!("Input tensor shape: {:?}", input_tensor.shape()); + // Initialize gamma and beta for layer normalization let gamma = Array2::ones((1, EMBEDDING_SIZE)); // Example gamma (scale parameter) let beta = Array2::zeros((1, EMBEDDING_SIZE)); // Example beta (shift parameter) @@ -178,6 +179,9 @@ pub fn training_model( ); } + // Debugging: Print shape after encoding + println!("Encoded shape: {:?}", encoded.shape()); + // Perform decoding with N stacked layers let mut decoded = input_tensor.clone(); for _ in 0..NUM_LAYERS { @@ -191,23 +195,28 @@ pub fn training_model( ); } + // Debugging: Print shape after decoding + println!("Decoded shape: {:?}", decoded.shape()); + // Apply final linear transformation let output_projection = &learnable_weights.output_projection; // All ones weights - println!("Decoded shape: {:?}", decoded.dim()); + println!("Decoded shape: {:?}", decoded.shape()); println!( "Flattened decoded shape: {:?}", - flatten_3d_array(decoded.clone()).dim() + flatten_3d_array(decoded.clone()).shape() ); - println!("Output projection shape: {:?}", output_projection.dim()); + println!("Output projection shape: {:?}", output_projection.shape()); println!( "Transposed output projection shape: {:?}", - output_projection.t().dim() + output_projection.t().shape() ); let logits = flatten_3d_array(decoded).dot(&output_projection.to_owned()); // Linear layer + println!("Logits shape: {:?}", logits.shape()); // Apply softmax to logits let probabilities = softmax_matrix(&logits); + println!("Softmax probabilities shape: {:?}", probabilities.shape()); // Convert probabilities back to text using the tokenizer let tokens = predict_index(probabilities.view(), &vocab); diff --git a/src/settings.rs b/src/settings.rs index dbe59bc..67fc606 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -4,7 +4,7 @@ pub const EPSILON: f32 = 0.0001; // Embedding size -pub const D_MODEL: usize = 88; +pub const D_MODEL: usize = 88; // Model embedding size, matching the vocab size // Attention mechanism dimensions pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) @@ -29,7 +29,7 @@ pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architectu pub const FFN_DIM: usize = 128; // Smaller FFN dimension // Hidden size (used for biases and other layer parameters) -pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM +pub const HIDDEN_SIZE: usize = 88; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM // Dropout rate and learning rate pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization diff --git a/src/training/train.rs b/src/training/train.rs index 6b79852..5d203b1 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -19,6 +19,9 @@ pub fn compute_gradients( FFN_DIM, // ffn_dim ); + // Ensure correct shapes before proceeding with the gradients computation + println!("input shape: {:?}", inputs.shape()); + // Compute the loss and its derivative let loss = predictions - targets; let d_loss = &loss * 2.0 / (BATCH_SIZE as f32); @@ -27,26 +30,42 @@ pub fn compute_gradients( gradients.output_projection_vocab = predictions.t().dot(&d_loss); // Flattened inputs for further computations - let flattened_inputs = flatten_3d_array(inputs.clone()); + let flattened_inputs = flatten_3d_array(inputs.clone()); // Flatten [1, 88, 88] -> [88, 88] + println!("Shape of flattened_inputs: {:?}", flattened_inputs.shape()); // Compute gradients for the feedforward network weights - let d_linear2 = d_loss.dot(&weights.linear2_weights.t()); - gradients.linear2_weights = flattened_inputs.t().dot(&d_linear2); - gradients.bias2 = d_linear2.sum_axis(ndarray::Axis(0)); + // d_linear2 corresponds to the gradient w.r.t. the second linear layer + let d_linear2 = d_loss.dot(&weights.linear2_weights.t()); // Shape: [88, 128] + gradients.linear2_weights = flattened_inputs.t().dot(&d_linear2); // Shape: [88, 128] + gradients.bias2 = d_linear2.sum_axis(ndarray::Axis(0)); // Sum across sequences to get bias gradient + println!("Shape of d_linear2: {:?}", d_linear2.shape()); + + // d_linear1 corresponds to the gradient w.r.t. the first linear layer + let d_linear1 = d_linear2.dot(&weights.linear1_weights.t()); // Shape: [88, 88] + println!("Shape of d_linear1: {:?}", d_linear1.shape()); + + gradients.linear1_weights = flattened_inputs.t().dot(&d_linear1); // Shape: [88, 128] (for linear1) + gradients.bias1 = d_linear1.sum_axis(ndarray::Axis(0)); // Sum across sequences to get bias gradient - let d_linear1 = d_linear2.dot(&weights.linear1_weights.t()); - gradients.linear1_weights = flattened_inputs.t().dot(&d_linear1); - gradients.bias1 = d_linear1.sum_axis(ndarray::Axis(0)); + println!("Shape of linear1_weights: {:?}", weights.linear1_weights.shape()); + println!( + "Dot product inputs: {:?} and {:?}", + flattened_inputs.shape(), + d_linear1.shape() + ); // Compute gradients for the attention mechanism weights - let d_attention_output = d_loss.dot(&weights.output_projection.t()); - gradients.output_projection = flattened_inputs.t().dot(&d_attention_output); - let d_value = d_attention_output.dot(&weights.value_weights.t()); - gradients.value_weights = flattened_inputs.t().dot(&d_value); - let d_key = d_attention_output.dot(&weights.key_weights.t()); - gradients.key_weights = flattened_inputs.t().dot(&d_key); - let d_query = d_attention_output.dot(&weights.query_weights.t()); - gradients.query_weights = flattened_inputs.t().dot(&d_query); + let d_attention_output = d_loss.dot(&weights.output_projection.t()); // Shape: [88, 88] + gradients.output_projection = flattened_inputs.t().dot(&d_attention_output); // Shape: [88, 88] + + let d_value = d_attention_output.dot(&weights.value_weights.t()); // Shape: [88, 88] + gradients.value_weights = flattened_inputs.t().dot(&d_value); // Shape: [88, 88] + + let d_key = d_attention_output.dot(&weights.key_weights.t()); // Shape: [88, 88] + gradients.key_weights = flattened_inputs.t().dot(&d_key); // Shape: [88, 88] + + let d_query = d_attention_output.dot(&weights.query_weights.t()); // Shape: [88, 88] + gradients.query_weights = flattened_inputs.t().dot(&d_query); // Shape: [88, 88] // Compute gradients for the embedding layer gradients.embedding = inputs.mean_axis(ndarray::Axis(0)).unwrap(); // Ensure shape consistency with model.embedding @@ -58,6 +77,7 @@ pub fn compute_gradients( gradients } + pub fn update_weights( model: &mut LearnableWeights, gradients: &LearnableWeights, From eb667c754abc04a97619e96ec0d88eedb76e719a Mon Sep 17 00:00:00 2001 From: Jakub Date: Mon, 23 Dec 2024 20:53:36 +0100 Subject: [PATCH 13/13] refactored working learning structure --- src/data/learnable.rs | 3 +- src/model/train_transformer.rs | 127 ++++++++++++++------------------- src/settings.rs | 2 +- src/training/train.rs | 22 ------ 4 files changed, 58 insertions(+), 96 deletions(-) diff --git a/src/data/learnable.rs b/src/data/learnable.rs index 318c2fc..c4d9016 100644 --- a/src/data/learnable.rs +++ b/src/data/learnable.rs @@ -1,5 +1,6 @@ use crate::settings::*; use ndarray::{Array1, Array2}; +#[derive(Debug)] pub struct LearnableWeights { // Embedding Layer pub embedding: Array2, // (vocab_size, embedding_dim) @@ -39,7 +40,7 @@ impl LearnableWeights { // Attention Mechanism query_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) - key_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) + key_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) value_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim) output_projection: Array2::ones((attention_dim, embedding_dim)), // (attention_dim, embedding_dim) diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs index ea6932e..61dd9be 100644 --- a/src/model/train_transformer.rs +++ b/src/model/train_transformer.rs @@ -1,24 +1,22 @@ -#![allow(warnings)] -use crate::attention::softmax::{softmax_matrix, softmax_vec, softmax_vector}; +use crate::attention::softmax::softmax_matrix; use crate::data::dataset::{gen_data, Dataset}; -use crate::data::learnable::{initialize_weights, LearnableWeights}; +use crate::data::learnable::LearnableWeights; use crate::data::tokenizer::Tokenizer; use crate::layers::feedforward_layer::FeedForwardLayer; use crate::math::linear_algebra::flatten_3d_array; use crate::model::decoder::decoding; -use crate::model::embedding::{predict_index, predict_tokens, Embedding}; +use crate::model::embedding::{predict_index, Embedding}; use crate::model::encoder::encoding; use crate::settings::*; use crate::training::loss_function::cross_entropy_loss; use crate::training::train::{compute_gradients, update_weights}; use ndarray::{Array1, Array2, Array3}; use rand::prelude::SliceRandom; -use rand::Rng; use std::collections::HashMap; fn train_model( dataset: &Dataset, // The training data - tokenizer: Tokenizer, // Vocabulary + tokenizer: &Tokenizer, // Vocabulary mut learnable_weights: LearnableWeights, // Initial weights num_epochs: usize, // Number of training epochs learning_rate: f32, // Learning rate @@ -32,7 +30,7 @@ fn train_model( // Shuffle the dataset indices let mut data_indices: Vec = (0..dataset.inputs.len()).collect(); - data_indices.shuffle(&mut rand::thread_rng()); + data_indices.shuffle(&mut rand::rng()); let mut total_loss = 0.0; // Accumulate loss for this epoch let mut num_batches = 0; @@ -43,7 +41,7 @@ fn train_model( let target = &dataset.targets[idx]; // Convert to Array1 for processing - let input_seq = Array1::from(input.clone()); + let _input_seq = Array1::from(input.clone()); let target_seq = Array1::from(target.clone()); // Forward pass: Model prediction @@ -60,7 +58,16 @@ fn train_model( total_loss += loss; // Accumulate loss for averaging num_batches += 1; - // Prepare inputs, targets, and predictions for gradient computation + // Log loss and progress every 10 steps + if step % 100 == 0 { + let decoded_output = tokenizer.detokenize(out.to_vec()); + let expected_output = tokenizer.detokenize(target.to_vec()); + println!( + "Step {}: Loss = {:.4}, Output = {:?}, Expected = {:?}", + step, loss, decoded_output, expected_output + ); + outputs.push(decoded_output); + } let inputs = Array3::from_shape_fn( (BATCH_SIZE, input.len(), EMBEDDING_SIZE), |(_, seq, embed)| logits[[seq, embed]], @@ -73,25 +80,23 @@ fn train_model( let predictions = logits.clone(); - // Backward pass: Compute gradients + // Compute gradients let gradients = compute_gradients(&mut learnable_weights, &inputs, &targets, &predictions); // Update weights update_weights(&mut learnable_weights, &gradients, learning_rate); - // Periodically log training progress - if step % 10 == 0 { - let decoded_output = tokenizer.detokenize(out.to_vec()); - println!( - "Step {}: Loss = {:.4}, Output = {:?}, Expected = {:?}", - step, - loss, - decoded_output, - tokenizer.detokenize(target.to_vec()) - ); - outputs.push(decoded_output); - } + // Log gradients for debugging (optional) + println!("Step {}: Computed gradients = {:?}", step, gradients); + + // Update weights + + // Periodically log weight updates (optional) + println!( + "Step {}: Weights updated with learning rate = {:.6}", + step, learning_rate + ); } // End of epoch: Print average loss and track improvement @@ -126,7 +131,7 @@ pub fn train() { // Train the model let outputs = train_model( &dataset, - tokenizer, + &tokenizer, learnable_weights, num_epochs, learning_rate, @@ -139,17 +144,17 @@ pub fn train() { } pub fn training_model( - tokens: &Vec, - target_seq: Array1, - learnable_weights: &mut LearnableWeights, - vocab_size: usize, - vocab: HashMap, + tokens: &[usize], // Input tokens + _target_seq: Array1, // Target sequence + learnable_weights: &mut LearnableWeights, // Learnable weights + vocab_size: usize, // Vocabulary size + vocab: HashMap, // Vocabulary map ) -> (Vec, Array2) { // Initialize Tokenizer and Embedding layer - let embedding = Embedding::new(vocab_size, EMBEDDING_SIZE); // Initialize embedding layer + let embedding = Embedding::new(vocab_size, EMBEDDING_SIZE); // Embed the input sentence - let embeddings = embedding.forward(tokens.clone()); + let embeddings = embedding.forward(tokens.to_vec()); // Convert embeddings to Array3 (batch_size, seq_length, embed_size) let input_tensor = Array3::from_shape_fn( @@ -157,69 +162,47 @@ pub fn training_model( |(_, seq, embed)| embeddings[[seq, embed]], ); - // Debugging: Print shape of input tensor - println!("Input tensor shape: {:?}", input_tensor.shape()); - // Initialize gamma and beta for layer normalization - let gamma = Array2::ones((1, EMBEDDING_SIZE)); // Example gamma (scale parameter) - let beta = Array2::zeros((1, EMBEDDING_SIZE)); // Example beta (shift parameter) + let gamma = Array2::ones((1, EMBEDDING_SIZE)); + let beta = Array2::zeros((1, EMBEDDING_SIZE)); // Initialize the feed-forward layer with correct types - let feed_forward_layer = FeedForwardLayer::new(&learnable_weights, DROPOUT_RATE); + let feed_forward_layer = FeedForwardLayer::new(learnable_weights, DROPOUT_RATE); - // Perform encoding with N stacked layers - let mut encoded = input_tensor.clone(); - for _ in 0..NUM_LAYERS { - encoded = encoding( - encoded, + // Perform encoding with stacked layers + let encoded = (0..NUM_LAYERS).fold(input_tensor.clone(), |acc, _| { + encoding( + acc, gamma.clone(), beta.clone(), EPSILON, &feed_forward_layer, - ); - } - - // Debugging: Print shape after encoding - println!("Encoded shape: {:?}", encoded.shape()); + ) + }); - // Perform decoding with N stacked layers - let mut decoded = input_tensor.clone(); - for _ in 0..NUM_LAYERS { - decoded = decoding( - decoded, + // Perform decoding with stacked layers + let decoded = (0..NUM_LAYERS).fold(input_tensor.clone(), |acc, _| { + decoding( + acc, encoded.clone(), gamma.clone(), beta.clone(), EPSILON, &feed_forward_layer, - ); - } - - // Debugging: Print shape after decoding - println!("Decoded shape: {:?}", decoded.shape()); + ) + }); // Apply final linear transformation - let output_projection = &learnable_weights.output_projection; // All ones weights - println!("Decoded shape: {:?}", decoded.shape()); - println!( - "Flattened decoded shape: {:?}", - flatten_3d_array(decoded.clone()).shape() - ); - println!("Output projection shape: {:?}", output_projection.shape()); - println!( - "Transposed output projection shape: {:?}", - output_projection.t().shape() - ); - - let logits = flatten_3d_array(decoded).dot(&output_projection.to_owned()); // Linear layer - println!("Logits shape: {:?}", logits.shape()); + let logits = flatten_3d_array(decoded).dot(&learnable_weights.output_projection.to_owned()); // Apply softmax to logits let probabilities = softmax_matrix(&logits); - println!("Softmax probabilities shape: {:?}", probabilities.shape()); // Convert probabilities back to text using the tokenizer let tokens = predict_index(probabilities.view(), &vocab); - (tokens, logits.clone()) + // Optionally print logits for debugging + println!("Logits: {:?}", logits); + + (tokens, logits) } diff --git a/src/settings.rs b/src/settings.rs index 67fc606..6af6df3 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -4,7 +4,7 @@ pub const EPSILON: f32 = 0.0001; // Embedding size -pub const D_MODEL: usize = 88; // Model embedding size, matching the vocab size +pub const D_MODEL: usize = 88; // Model embedding size, matching the vocab size // Attention mechanism dimensions pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) diff --git a/src/training/train.rs b/src/training/train.rs index 5d203b1..32fb4a4 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -19,9 +19,6 @@ pub fn compute_gradients( FFN_DIM, // ffn_dim ); - // Ensure correct shapes before proceeding with the gradients computation - println!("input shape: {:?}", inputs.shape()); - // Compute the loss and its derivative let loss = predictions - targets; let d_loss = &loss * 2.0 / (BATCH_SIZE as f32); @@ -31,29 +28,19 @@ pub fn compute_gradients( // Flattened inputs for further computations let flattened_inputs = flatten_3d_array(inputs.clone()); // Flatten [1, 88, 88] -> [88, 88] - println!("Shape of flattened_inputs: {:?}", flattened_inputs.shape()); // Compute gradients for the feedforward network weights // d_linear2 corresponds to the gradient w.r.t. the second linear layer let d_linear2 = d_loss.dot(&weights.linear2_weights.t()); // Shape: [88, 128] gradients.linear2_weights = flattened_inputs.t().dot(&d_linear2); // Shape: [88, 128] gradients.bias2 = d_linear2.sum_axis(ndarray::Axis(0)); // Sum across sequences to get bias gradient - println!("Shape of d_linear2: {:?}", d_linear2.shape()); // d_linear1 corresponds to the gradient w.r.t. the first linear layer let d_linear1 = d_linear2.dot(&weights.linear1_weights.t()); // Shape: [88, 88] - println!("Shape of d_linear1: {:?}", d_linear1.shape()); gradients.linear1_weights = flattened_inputs.t().dot(&d_linear1); // Shape: [88, 128] (for linear1) gradients.bias1 = d_linear1.sum_axis(ndarray::Axis(0)); // Sum across sequences to get bias gradient - println!("Shape of linear1_weights: {:?}", weights.linear1_weights.shape()); - println!( - "Dot product inputs: {:?} and {:?}", - flattened_inputs.shape(), - d_linear1.shape() - ); - // Compute gradients for the attention mechanism weights let d_attention_output = d_loss.dot(&weights.output_projection.t()); // Shape: [88, 88] gradients.output_projection = flattened_inputs.t().dot(&d_attention_output); // Shape: [88, 88] @@ -77,7 +64,6 @@ pub fn compute_gradients( gradients } - pub fn update_weights( model: &mut LearnableWeights, gradients: &LearnableWeights, @@ -91,14 +77,6 @@ pub fn update_weights( model.output_projection = &model.output_projection - &(&gradients.output_projection * learning_rate); - println!( - "lin1 OLD :{:?}, lin1 NEW: {:?}", - model.linear1_weights.shape(), - gradients.linear1_weights.shape() - ); - - // TODO DEBUG SHAPE issues (old 88,128 and new 88,88 which seems wrong) - model.linear1_weights = &model.linear1_weights - &(&gradients.linear1_weights * learning_rate); model.linear2_weights = &model.linear2_weights - &(&gradients.linear2_weights * learning_rate);