Skip to content

Commit

Permalink
implemented N = 6 stack of encoder and decoders
Browse files Browse the repository at this point in the history
implemented N = 6 stack of encoder and decoders
  • Loading branch information
JakubSchwenkbeck authored Dec 21, 2024
2 parents 55fa876 + 6f11bff commit f74118c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 21 deletions.
5 changes: 5 additions & 0 deletions src/layers/feedforward_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl FeedForwardLayer {
}

/// He initialization function.
/*
fn he_initialization(input_size: usize, output_size: usize) -> Array2<f32> {
let mut rng = rand::rng();
let scale = (2.0 / input_size as f32).sqrt();
Expand All @@ -151,6 +152,10 @@ fn he_initialization(input_size: usize, output_size: usize) -> Array2<f32> {
.collect();
Array2::from_shape_vec((input_size, output_size), values).unwrap()
}
*/
fn he_initialization(input_size: usize, output_size: usize) -> Array2<f32> {
Array2::ones((input_size, output_size))
}

/// Initializes bias vectors with zeros.
fn bias_initialization(size: usize) -> Array1<f32> {
Expand Down
6 changes: 3 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ fn main() {
let tokenizer = Tokenizer::new(vocab.clone());
let embedding = Embedding::new(6, 12); // Example vocab size and embedding size
// Input sentence
let sentence = "hello world model";
let sentence = "hello world rust";

// Tokenize and embed the input
let tokens = tokenizer.tokenize(sentence);
let embeddings = embedding.forward(tokens.clone());
println!("embeddings: {:?}", embeddings);

// Convert embeddings to Array3 (batch_size, seq_length, embed_size)
let input_tensor = Array3::from_shape_fn(
(BATCH_SIZE, tokens.len(), EMBEDDING_SIZE),
|(batch, seq, _)| embeddings[[seq, batch]],
|(_, seq, embed)| embeddings[[seq, embed]],
);

println!("INPUT : {}", input_tensor.clone());
Expand Down
42 changes: 24 additions & 18 deletions src/model/transformer_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn transformer_model(
// Convert embeddings to Array3 (batch_size, seq_length, embed_size)
let input_tensor = Array3::from_shape_fn(
(BATCH_SIZE, tokens.len(), EMBEDDING_SIZE),
|(batch, seq, _)| embeddings[[seq, batch]],
|(_, seq, embed)| embeddings[[seq, embed]],
);

// Initialize gamma and beta for layer normalization
Expand All @@ -37,24 +37,30 @@ pub fn transformer_model(
let feed_forward_layer =
FeedForwardLayer::new(BATCH_SIZE, INPUT_SIZE, OUTPUT_SIZE, DROPOUT_RATE);

// Perform encoding (transformer encoder)
let encoded = encoding(
input_tensor.clone(),
gamma.clone(),
beta.clone(),
EPSILON,
&feed_forward_layer,
);
// Perform encoding with N stacked layers
let mut encoded = input_tensor.clone();
for _ in 0..NUM_LAYERS {
encoded = encoding(
encoded,
gamma.clone(),
beta.clone(),
EPSILON,
&feed_forward_layer,
);
}

// Perform decoding (transformer decoder)
let decoded = decoding(
input_tensor,
encoded.clone(),
gamma,
beta,
EPSILON,
&feed_forward_layer,
);
// Perform decoding with N stacked layers
let mut decoded = input_tensor.clone();
for _ in 0..NUM_LAYERS {
decoded = decoding(
decoded,
encoded.clone(),
gamma.clone(),
beta.clone(),
EPSILON,
&feed_forward_layer,
);
}

// Apply final linear transformation
let output_projection = Array2::ones((OUTPUT_SIZE, vocab.len())); // All ones weights
Expand Down
2 changes: 2 additions & 0 deletions src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub const EMBEDDING_SIZE: usize = 12;
pub const INPUT_SIZE: usize = 12;
pub const OUTPUT_SIZE: usize = 12;

pub const NUM_LAYERS: usize = 6;

pub const HIDDEN_SIZE: usize = 6;

pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization
Expand Down

0 comments on commit f74118c

Please sign in to comment.