diff --git a/src/layers/feedforward_layer.rs b/src/layers/feedforward_layer.rs index a97efd8..b58afd4 100644 --- a/src/layers/feedforward_layer.rs +++ b/src/layers/feedforward_layer.rs @@ -143,6 +143,7 @@ impl FeedForwardLayer { } /// He initialization function. +/* fn he_initialization(input_size: usize, output_size: usize) -> Array2 { let mut rng = rand::rng(); let scale = (2.0 / input_size as f32).sqrt(); @@ -151,6 +152,10 @@ fn he_initialization(input_size: usize, output_size: usize) -> Array2 { .collect(); Array2::from_shape_vec((input_size, output_size), values).unwrap() } +*/ +fn he_initialization(input_size: usize, output_size: usize) -> Array2 { + Array2::ones((input_size, output_size)) +} /// Initializes bias vectors with zeros. fn bias_initialization(size: usize) -> Array1 { diff --git a/src/main.rs b/src/main.rs index 09603bd..543586f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,16 +33,16 @@ fn main() { let tokenizer = Tokenizer::new(vocab.clone()); let embedding = Embedding::new(6, 12); // Example vocab size and embedding size // Input sentence - let sentence = "hello world model"; + let sentence = "hello world rust"; // Tokenize and embed the input let tokens = tokenizer.tokenize(sentence); let embeddings = embedding.forward(tokens.clone()); + println!("embeddings: {:?}", embeddings); - // Convert embeddings to Array3 (batch_size, seq_length, embed_size) let input_tensor = Array3::from_shape_fn( (BATCH_SIZE, tokens.len(), EMBEDDING_SIZE), - |(batch, seq, _)| embeddings[[seq, batch]], + |(_, seq, embed)| embeddings[[seq, embed]], ); println!("INPUT : {}", input_tensor.clone()); diff --git a/src/model/transformer_model.rs b/src/model/transformer_model.rs index 94042d4..5f9f224 100644 --- a/src/model/transformer_model.rs +++ b/src/model/transformer_model.rs @@ -26,7 +26,7 @@ pub fn transformer_model( // Convert embeddings to Array3 (batch_size, seq_length, embed_size) let input_tensor = Array3::from_shape_fn( (BATCH_SIZE, tokens.len(), EMBEDDING_SIZE), - |(batch, seq, _)| embeddings[[seq, batch]], + |(_, seq, embed)| embeddings[[seq, embed]], ); // Initialize gamma and beta for layer normalization @@ -37,24 +37,30 @@ pub fn transformer_model( let feed_forward_layer = FeedForwardLayer::new(BATCH_SIZE, INPUT_SIZE, OUTPUT_SIZE, DROPOUT_RATE); - // Perform encoding (transformer encoder) - let encoded = encoding( - input_tensor.clone(), - gamma.clone(), - beta.clone(), - EPSILON, - &feed_forward_layer, - ); + // Perform encoding with N stacked layers + let mut encoded = input_tensor.clone(); + for _ in 0..NUM_LAYERS { + encoded = encoding( + encoded, + gamma.clone(), + beta.clone(), + EPSILON, + &feed_forward_layer, + ); + } - // Perform decoding (transformer decoder) - let decoded = decoding( - input_tensor, - encoded.clone(), - gamma, - beta, - EPSILON, - &feed_forward_layer, - ); + // Perform decoding with N stacked layers + let mut decoded = input_tensor.clone(); + for _ in 0..NUM_LAYERS { + decoded = decoding( + decoded, + encoded.clone(), + gamma.clone(), + beta.clone(), + EPSILON, + &feed_forward_layer, + ); + } // Apply final linear transformation let output_projection = Array2::ones((OUTPUT_SIZE, vocab.len())); // All ones weights diff --git a/src/settings.rs b/src/settings.rs index 0d5d3dd..fdfe69e 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -11,6 +11,8 @@ pub const EMBEDDING_SIZE: usize = 12; pub const INPUT_SIZE: usize = 12; pub const OUTPUT_SIZE: usize = 12; +pub const NUM_LAYERS: usize = 6; + pub const HIDDEN_SIZE: usize = 6; pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization