diff --git a/src/data/generation.rs b/src/data/generation.rs index 7448735..d8cf239 100644 --- a/src/data/generation.rs +++ b/src/data/generation.rs @@ -15,7 +15,6 @@ pub fn generate_input_target_pairs( // Prepare input (same as sentence) let input = tokenizer.pad_sequence(tokens.clone(), INPUT_SIZE); - // Prepare target (shifted version of the sentence) let mut target = tokens.clone(); if i + 1 < sentences.len() { @@ -25,7 +24,6 @@ pub fn generate_input_target_pairs( if !next_tokens.is_empty() { target.push(next_tokens[0]); // Add the first token of the next sentence } - } else { target.push(tokenizer.vocab[""]); // Use EOS token for the last sentence } @@ -98,7 +96,7 @@ pub fn generate_staircase_pairs( // Pad both input and target sequences to max_length let staircase_input = pad_sequence_to_length(&staircase_input, INPUT_SIZE); - let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE); + let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE); // Add this pair to the staircase pairs vector staircase_pairs.push((staircase_input, staircase_target)); } @@ -117,4 +115,4 @@ fn pad_sequence_to_length(seq: &[usize], max_length: usize) -> Vec { } padded_seq -} \ No newline at end of file +} diff --git a/src/data/learnable.rs b/src/data/learnable.rs index 3397115..5fc6a84 100644 --- a/src/data/learnable.rs +++ b/src/data/learnable.rs @@ -1,5 +1,5 @@ -use ndarray::{Array1, Array2}; use crate::settings::*; +use ndarray::{Array1, Array2}; pub struct LearnableWeights { // Embedding Layer @@ -28,7 +28,7 @@ pub struct LearnableWeights { impl LearnableWeights { pub fn new( output_size: usize, - hidden_size: usize, + _hidden_size: usize, vocab_size: usize, embedding_dim: usize, attention_dim: usize, @@ -55,18 +55,18 @@ impl LearnableWeights { layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0 // Output Layer - output_projection_vocab: Array2::zeros((embedding_dim, vocab_size)), + output_projection_vocab: Array2::zeros((embedding_dim, output_size)), } } } pub fn initialize_weights() -> LearnableWeights { LearnableWeights::new( - D_MODEL, // output_size = D_MODEL - FFN_DIM, // hidden_size = FFN_DIM - D_MODEL, // vocab_size - D_MODEL, // embedding_dim = D_MODEL - D_K, // attention_dim - FFN_DIM, // ffn_dim + D_MODEL, // output_size = D_MODEL + FFN_DIM, // hidden_size = FFN_DIM + D_MODEL, // vocab_size + D_MODEL, // embedding_dim = D_MODEL + D_K, // attention_dim + FFN_DIM, // ffn_dim ) } diff --git a/src/data/tokenizer.rs b/src/data/tokenizer.rs index 1478a73..9a460ba 100644 --- a/src/data/tokenizer.rs +++ b/src/data/tokenizer.rs @@ -1,8 +1,8 @@ #![allow(warnings)] +use crate::settings::INPUT_SIZE; use regex::Regex; use std::collections::{HashMap, HashSet}; -use crate::settings::INPUT_SIZE; #[derive(Debug)] pub struct Tokenizer { @@ -65,7 +65,7 @@ impl Tokenizer { tokens.push(*token); } - tokens.push(self.vocab[&self.eos_token]);// Add EOS token at the end + tokens.push(self.vocab[&self.eos_token]); // Add EOS token at the end tokens } @@ -91,7 +91,8 @@ impl Tokenizer { } padded - }} + } +} #[cfg(test)] mod tests { diff --git a/src/model/embedding.rs b/src/model/embedding.rs index 0f9a176..c897207 100644 --- a/src/model/embedding.rs +++ b/src/model/embedding.rs @@ -2,9 +2,8 @@ use crate::math::positional_encoding::sinusoidal_pos_encoding; use ndarray::{s, Array, Array2, ArrayView, ArrayView2, Axis, Ix1}; -use std::cmp::Ordering; -//use rand::Rng; use rand::Rng; +use std::cmp::Ordering; use std::collections::HashMap; pub struct Embedding { @@ -75,20 +74,20 @@ impl Embedding { let similarities: Vec = self .weights .axis_iter(Axis(0)) - .map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(embedding))) + .map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(decoded))) + .filter(|&similarity| !similarity.is_nan()) // Filter out NaNs .collect(); // Find the index of the maximum similarity - let best_match = similarities + if let Some((best_match, _)) = similarities .iter() .enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) - .unwrap() - .0; - - // Map index to the corresponding token - if let Some(token) = index_to_token.get(&best_match) { - predicted_tokens.push(token.clone()); + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + { + // Map index to the corresponding token + if let Some(token) = index_to_token.get(&best_match) { + predicted_tokens.push(token.clone()); + } } } @@ -100,9 +99,7 @@ pub fn norm(vector: ArrayView) -> f32 { vector.mapv(|x| x * x).sum().sqrt() } -pub fn predict_index( probabilities: ArrayView2, - vocab: &HashMap, ) -> Vec{ - +pub fn predict_index(probabilities: ArrayView2, vocab: &HashMap) -> Vec { // Reverse the vocab to get a mapping from index to token let _index_to_token: HashMap = vocab.iter().map(|(k, &v)| (v, k.clone())).collect(); @@ -115,16 +112,17 @@ pub fn predict_index( probabilities: ArrayView2, let max_index = probs .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal + .filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) .unwrap() .0; - predicted_tokens.push(max_index.clone()); - + predicted_tokens.push(max_index); } predicted_tokens } + pub fn predict_tokens( probabilities: ArrayView2, vocab: &HashMap, @@ -141,7 +139,8 @@ pub fn predict_tokens( let max_index = probs .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal + .filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) .unwrap() .0; diff --git a/src/model/train_transformer.rs b/src/model/train_transformer.rs index edf2369..39c7f24 100644 --- a/src/model/train_transformer.rs +++ b/src/model/train_transformer.rs @@ -24,7 +24,7 @@ fn train_model( learning_rate: f32, // Learning rate ) -> Vec { let vocab_size = tokenizer.vocab.len(); // Vocabulary size - let mut outputs = Vec::new(); // To store outputs for progress tracking + let mut outputs = Vec::new(); // To store outputs for progress tracking // Loop over the number of epochs for epoch in 0..num_epochs { @@ -66,10 +66,10 @@ fn train_model( |(_, seq, embed)| logits[[seq, embed]], ); - let targets = Array2::from_shape_fn( - (target.len(), logits.shape()[1]), - |(seq, embed)| logits[[seq, embed]], - ); + let targets = + Array2::from_shape_fn((target.len(), logits.shape()[1]), |(seq, embed)| { + logits[[seq, embed]] + }); let predictions = logits.clone(); @@ -79,7 +79,7 @@ fn train_model( &inputs, &targets, &predictions, - vocab_size + vocab_size, ); // Update weights @@ -101,7 +101,11 @@ fn train_model( // End of epoch: Print average loss and track improvement let avg_loss = total_loss / num_batches as f32; - println!("Epoch {} completed with average loss: {:.4}", epoch + 1, avg_loss); + println!( + "Epoch {} completed with average loss: {:.4}", + epoch + 1, + avg_loss + ); // For debugging or tracking, we could save weights periodically here. } @@ -195,10 +199,15 @@ pub fn training_model( // Apply final linear transformation let output_projection = &learnable_weights.output_projection; // All ones weights println!("Decoded shape: {:?}", decoded.dim()); - println!("Flattened decoded shape: {:?}", flatten_3d_array(decoded.clone()).dim()); + println!( + "Flattened decoded shape: {:?}", + flatten_3d_array(decoded.clone()).dim() + ); println!("Output projection shape: {:?}", output_projection.dim()); - println!("Transposed output projection shape: {:?}", output_projection.t().dim()); - + println!( + "Transposed output projection shape: {:?}", + output_projection.t().dim() + ); let logits = flatten_3d_array(decoded).dot(&output_projection.to_owned()); // Linear layer diff --git a/src/settings.rs b/src/settings.rs index 8290d5b..5034f80 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,39 +1,39 @@ - #![allow(unused)] +#![allow(unused)] - // Numerical constants with downscaled real-application values - pub const EPSILON: f32 = 0.0001; +// Numerical constants with downscaled real-application values +pub const EPSILON: f32 = 0.0001; - // Embedding size - pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size +// Embedding size +pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size - // Attention mechanism dimensions - pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) - pub const D_V: usize = 32; // Value dimension (same as D_K) - pub const NUM_HEADS: usize = 4; // Reduced the number of attention heads for smaller model +// Attention mechanism dimensions +pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity) +pub const D_V: usize = 32; // Value dimension (same as D_K) +pub const NUM_HEADS: usize = 4; // Reduced the number of attention heads for smaller model - // Sequence and batch size - pub const SEQ_LENGTH: usize = 64; // Reduced sequence length - pub const BATCH_SIZE: usize = 1; // Reduced batch size for smaller model training +// Sequence and batch size +pub const SEQ_LENGTH: usize = 64; // Reduced sequence length +pub const BATCH_SIZE: usize = 1; // Reduced batch size for smaller model training - // Embedding size and dimensions - pub const EMBEDDING_SIZE: usize = D_MODEL; // Matches D_MODEL for consistency +// Embedding size and dimensions +pub const EMBEDDING_SIZE: usize = D_MODEL; // Matches D_MODEL for consistency - // Input/Output sizes - pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs - pub const OUTPUT_SIZE: usize = D_MODEL; // Consistent with D_MODEL for output +// Input/Output sizes +pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs +pub const OUTPUT_SIZE: usize = D_MODEL; // Consistent with D_MODEL for output - // Number of layers - pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architecture +// Number of layers +pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architecture - // Feedforward network dimension (FFN_DIM) - pub const FFN_DIM: usize = 128; // Smaller FFN dimension +// Feedforward network dimension (FFN_DIM) +pub const FFN_DIM: usize = 128; // Smaller FFN dimension - // Hidden size (used for biases and other layer parameters) - pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM +// Hidden size (used for biases and other layer parameters) +pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM - // Dropout rate and learning rate - pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization - pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate +// Dropout rate and learning rate +pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization +pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate - // Positional encoding parameters - pub const MAX_SEQ_LENGTH: usize = 128; // Maximum sequence length for positional encoding +// Positional encoding parameters +pub const MAX_SEQ_LENGTH: usize = 128; // Maximum sequence length for positional encoding diff --git a/src/training/loss_function.rs b/src/training/loss_function.rs index 6259d78..f43d9fd 100644 --- a/src/training/loss_function.rs +++ b/src/training/loss_function.rs @@ -1,11 +1,7 @@ use crate::attention::softmax::softmax_vector; use ndarray::{s, Array1, Array2}; -pub fn cross_entropy_loss( - logits: &Array2, - targets: &Array1, - vocab_size: usize, -) -> f32 { +pub fn cross_entropy_loss(logits: &Array2, targets: &Array1, vocab_size: usize) -> f32 { let mut loss = 0.0; // Ensure that the number of targets matches the batch size @@ -15,7 +11,10 @@ pub fn cross_entropy_loss( for (i, &target) in targets.iter().enumerate() { // Ensure target index is within valid range if target >= vocab_size { - panic!("Target index {} is out of bounds for vocab_size {}", target, vocab_size); + panic!( + "Target index {} is out of bounds for vocab_size {}", + target, vocab_size + ); } // Get the logits for the current target (batch_size x vocab_size) diff --git a/src/training/train.rs b/src/training/train.rs index 568b061..1883f84 100644 --- a/src/training/train.rs +++ b/src/training/train.rs @@ -1,8 +1,7 @@ use crate::data::learnable::LearnableWeights; -use crate::settings::*; -use ndarray::{Array1, Array2, Array3, Axis}; -use crate::attention::softmax::softmax_matrix; use crate::math::linear_algebra::flatten_3d_array; +use crate::settings::*; +use ndarray::{Array2, Array3}; /// Compute gradients for the transformer model's learnable weights. pub fn compute_gradients( @@ -10,12 +9,12 @@ pub fn compute_gradients( inputs: &Array3, targets: &Array2, predictions: &Array2, - vocabsize : usize, + vocabsize: usize, ) -> LearnableWeights { let mut gradients = LearnableWeights::new( OUTPUT_SIZE, HIDDEN_SIZE, - vocabsize, // Ensure the vocab size is correct + vocabsize, // Ensure the vocab size is correct D_MODEL, D_K, FFN_DIM, @@ -60,15 +59,16 @@ pub fn compute_gradients( gradients } - - pub fn update_weights( model: &mut LearnableWeights, gradients: &LearnableWeights, learning_rate: f32, ) { - - println!("EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}",model.embedding.shape(),gradients.embedding.shape()); + println!( + "EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}", + model.embedding.shape(), + gradients.embedding.shape() + ); // Ensure the gradients and model weights have compatible shapes (reshape if necessary) model.embedding = &model.embedding - &(&gradients.embedding * learning_rate); model.query_weights = &model.query_weights - &(&gradients.query_weights * learning_rate);