Skip to content

Commit

Permalink
Refactored and linted code, still having mismatches regarding the gra…
Browse files Browse the repository at this point in the history
…dient computations
  • Loading branch information
JakubSchwenkbeck committed Dec 22, 2024
1 parent a1ab397 commit 7127cd9
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 87 deletions.
6 changes: 2 additions & 4 deletions src/data/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub fn generate_input_target_pairs(
// Prepare input (same as sentence)
let input = tokenizer.pad_sequence(tokens.clone(), INPUT_SIZE);


// Prepare target (shifted version of the sentence)
let mut target = tokens.clone();
if i + 1 < sentences.len() {
Expand All @@ -25,7 +24,6 @@ pub fn generate_input_target_pairs(
if !next_tokens.is_empty() {
target.push(next_tokens[0]); // Add the first token of the next sentence
}

} else {
target.push(tokenizer.vocab["<EOS>"]); // Use EOS token for the last sentence
}
Expand Down Expand Up @@ -98,7 +96,7 @@ pub fn generate_staircase_pairs(

// Pad both input and target sequences to max_length
let staircase_input = pad_sequence_to_length(&staircase_input, INPUT_SIZE);
let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE);
let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE);
// Add this pair to the staircase pairs vector
staircase_pairs.push((staircase_input, staircase_target));
}
Expand All @@ -117,4 +115,4 @@ fn pad_sequence_to_length(seq: &[usize], max_length: usize) -> Vec<usize> {
}

padded_seq
}
}
18 changes: 9 additions & 9 deletions src/data/learnable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ndarray::{Array1, Array2};
use crate::settings::*;
use ndarray::{Array1, Array2};

pub struct LearnableWeights {
// Embedding Layer
Expand Down Expand Up @@ -28,7 +28,7 @@ pub struct LearnableWeights {
impl LearnableWeights {
pub fn new(
output_size: usize,
hidden_size: usize,
_hidden_size: usize,
vocab_size: usize,
embedding_dim: usize,
attention_dim: usize,
Expand All @@ -55,18 +55,18 @@ impl LearnableWeights {
layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0

// Output Layer
output_projection_vocab: Array2::zeros((embedding_dim, vocab_size)),
output_projection_vocab: Array2::zeros((embedding_dim, output_size)),
}
}
}

pub fn initialize_weights() -> LearnableWeights {
LearnableWeights::new(
D_MODEL, // output_size = D_MODEL
FFN_DIM, // hidden_size = FFN_DIM
D_MODEL, // vocab_size
D_MODEL, // embedding_dim = D_MODEL
D_K, // attention_dim
FFN_DIM, // ffn_dim
D_MODEL, // output_size = D_MODEL
FFN_DIM, // hidden_size = FFN_DIM
D_MODEL, // vocab_size
D_MODEL, // embedding_dim = D_MODEL
D_K, // attention_dim
FFN_DIM, // ffn_dim
)
}
7 changes: 4 additions & 3 deletions src/data/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#![allow(warnings)]

use crate::settings::INPUT_SIZE;
use regex::Regex;
use std::collections::{HashMap, HashSet};
use crate::settings::INPUT_SIZE;

#[derive(Debug)]
pub struct Tokenizer {
Expand Down Expand Up @@ -65,7 +65,7 @@ impl Tokenizer {
tokens.push(*token);
}

tokens.push(self.vocab[&self.eos_token]);// Add EOS token at the end
tokens.push(self.vocab[&self.eos_token]); // Add EOS token at the end
tokens
}

Expand All @@ -91,7 +91,8 @@ impl Tokenizer {
}

padded
}}
}
}

#[cfg(test)]
mod tests {
Expand Down
35 changes: 17 additions & 18 deletions src/model/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

use crate::math::positional_encoding::sinusoidal_pos_encoding;
use ndarray::{s, Array, Array2, ArrayView, ArrayView2, Axis, Ix1};
use std::cmp::Ordering;
//use rand::Rng;
use rand::Rng;
use std::cmp::Ordering;
use std::collections::HashMap;

pub struct Embedding {
Expand Down Expand Up @@ -75,20 +74,20 @@ impl Embedding {
let similarities: Vec<f32> = self
.weights
.axis_iter(Axis(0))
.map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(embedding)))
.map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(decoded)))
.filter(|&similarity| !similarity.is_nan()) // Filter out NaNs
.collect();

// Find the index of the maximum similarity
let best_match = similarities
if let Some((best_match, _)) = similarities
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;

// Map index to the corresponding token
if let Some(token) = index_to_token.get(&best_match) {
predicted_tokens.push(token.clone());
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
{
// Map index to the corresponding token
if let Some(token) = index_to_token.get(&best_match) {
predicted_tokens.push(token.clone());
}
}
}

Expand All @@ -100,9 +99,7 @@ pub fn norm(vector: ArrayView<f32, Ix1>) -> f32 {
vector.mapv(|x| x * x).sum().sqrt()
}

pub fn predict_index( probabilities: ArrayView2<f32>,
vocab: &HashMap<String, usize>, ) -> Vec<usize>{

pub fn predict_index(probabilities: ArrayView2<f32>, vocab: &HashMap<String, usize>) -> Vec<usize> {
// Reverse the vocab to get a mapping from index to token
let _index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
Expand All @@ -115,16 +112,17 @@ pub fn predict_index( probabilities: ArrayView2<f32>,
let max_index = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal
.filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap()
.0;

predicted_tokens.push(max_index.clone());

predicted_tokens.push(max_index);
}

predicted_tokens
}

pub fn predict_tokens(
probabilities: ArrayView2<f32>,
vocab: &HashMap<String, usize>,
Expand All @@ -141,7 +139,8 @@ pub fn predict_tokens(
let max_index = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) // TODO : Need better NaN handling than asserting equal
.filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap()
.0;

Expand Down
29 changes: 19 additions & 10 deletions src/model/train_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn train_model(
learning_rate: f32, // Learning rate
) -> Vec<String> {
let vocab_size = tokenizer.vocab.len(); // Vocabulary size
let mut outputs = Vec::new(); // To store outputs for progress tracking
let mut outputs = Vec::new(); // To store outputs for progress tracking

// Loop over the number of epochs
for epoch in 0..num_epochs {
Expand Down Expand Up @@ -66,10 +66,10 @@ fn train_model(
|(_, seq, embed)| logits[[seq, embed]],
);

let targets = Array2::from_shape_fn(
(target.len(), logits.shape()[1]),
|(seq, embed)| logits[[seq, embed]],
);
let targets =
Array2::from_shape_fn((target.len(), logits.shape()[1]), |(seq, embed)| {
logits[[seq, embed]]
});

let predictions = logits.clone();

Expand All @@ -79,7 +79,7 @@ fn train_model(
&inputs,
&targets,
&predictions,
vocab_size
vocab_size,
);

// Update weights
Expand All @@ -101,7 +101,11 @@ fn train_model(

// End of epoch: Print average loss and track improvement
let avg_loss = total_loss / num_batches as f32;
println!("Epoch {} completed with average loss: {:.4}", epoch + 1, avg_loss);
println!(
"Epoch {} completed with average loss: {:.4}",
epoch + 1,
avg_loss
);

// For debugging or tracking, we could save weights periodically here.
}
Expand Down Expand Up @@ -195,10 +199,15 @@ pub fn training_model(
// Apply final linear transformation
let output_projection = &learnable_weights.output_projection; // All ones weights
println!("Decoded shape: {:?}", decoded.dim());
println!("Flattened decoded shape: {:?}", flatten_3d_array(decoded.clone()).dim());
println!(
"Flattened decoded shape: {:?}",
flatten_3d_array(decoded.clone()).dim()
);
println!("Output projection shape: {:?}", output_projection.dim());
println!("Transposed output projection shape: {:?}", output_projection.t().dim());

println!(
"Transposed output projection shape: {:?}",
output_projection.t().dim()
);

let logits = flatten_3d_array(decoded).dot(&output_projection.to_owned()); // Linear layer

Expand Down
56 changes: 28 additions & 28 deletions src/settings.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
#![allow(unused)]
#![allow(unused)]

// Numerical constants with downscaled real-application values
pub const EPSILON: f32 = 0.0001;
// Numerical constants with downscaled real-application values
pub const EPSILON: f32 = 0.0001;

// Embedding size
pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size
// Embedding size
pub const D_MODEL: usize = 32; // Reduced to 32 for a smaller model size

// Attention mechanism dimensions
pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity)
pub const D_V: usize = 32; // Value dimension (same as D_K)
pub const NUM_HEADS: usize = 4; // Reduced the number of attention heads for smaller model
// Attention mechanism dimensions
pub const D_K: usize = 32; // Key/query dimension (same as D_V for simplicity)
pub const D_V: usize = 32; // Value dimension (same as D_K)
pub const NUM_HEADS: usize = 4; // Reduced the number of attention heads for smaller model

// Sequence and batch size
pub const SEQ_LENGTH: usize = 64; // Reduced sequence length
pub const BATCH_SIZE: usize = 1; // Reduced batch size for smaller model training
// Sequence and batch size
pub const SEQ_LENGTH: usize = 64; // Reduced sequence length
pub const BATCH_SIZE: usize = 1; // Reduced batch size for smaller model training

// Embedding size and dimensions
pub const EMBEDDING_SIZE: usize = D_MODEL; // Matches D_MODEL for consistency
// Embedding size and dimensions
pub const EMBEDDING_SIZE: usize = D_MODEL; // Matches D_MODEL for consistency

// Input/Output sizes
pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs
pub const OUTPUT_SIZE: usize = D_MODEL; // Consistent with D_MODEL for output
// Input/Output sizes
pub const INPUT_SIZE: usize = D_MODEL; // Typically equals D_MODEL for transformer inputs
pub const OUTPUT_SIZE: usize = D_MODEL; // Consistent with D_MODEL for output

// Number of layers
pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architecture
// Number of layers
pub const NUM_LAYERS: usize = 4; // Reduced to 4 layers for a smaller architecture

// Feedforward network dimension (FFN_DIM)
pub const FFN_DIM: usize = 128; // Smaller FFN dimension
// Feedforward network dimension (FFN_DIM)
pub const FFN_DIM: usize = 128; // Smaller FFN dimension

// Hidden size (used for biases and other layer parameters)
pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM
// Hidden size (used for biases and other layer parameters)
pub const HIDDEN_SIZE: usize = 128; // Adjusted for a smaller hidden layer size, consistent with FFN_DIM

// Dropout rate and learning rate
pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization
pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate
// Dropout rate and learning rate
pub const DROPOUT_RATE: f32 = 0.1; // Dropout rate for regularization
pub const LEARNING_RATE: f32 = 1e-4; // Optimizer learning rate

// Positional encoding parameters
pub const MAX_SEQ_LENGTH: usize = 128; // Maximum sequence length for positional encoding
// Positional encoding parameters
pub const MAX_SEQ_LENGTH: usize = 128; // Maximum sequence length for positional encoding
11 changes: 5 additions & 6 deletions src/training/loss_function.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use crate::attention::softmax::softmax_vector;
use ndarray::{s, Array1, Array2};

pub fn cross_entropy_loss(
logits: &Array2<f32>,
targets: &Array1<usize>,
vocab_size: usize,
) -> f32 {
pub fn cross_entropy_loss(logits: &Array2<f32>, targets: &Array1<usize>, vocab_size: usize) -> f32 {
let mut loss = 0.0;

// Ensure that the number of targets matches the batch size
Expand All @@ -15,7 +11,10 @@ pub fn cross_entropy_loss(
for (i, &target) in targets.iter().enumerate() {
// Ensure target index is within valid range
if target >= vocab_size {
panic!("Target index {} is out of bounds for vocab_size {}", target, vocab_size);
panic!(
"Target index {} is out of bounds for vocab_size {}",
target, vocab_size
);
}

// Get the logits for the current target (batch_size x vocab_size)
Expand Down
18 changes: 9 additions & 9 deletions src/training/train.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
use crate::data::learnable::LearnableWeights;
use crate::settings::*;
use ndarray::{Array1, Array2, Array3, Axis};
use crate::attention::softmax::softmax_matrix;
use crate::math::linear_algebra::flatten_3d_array;
use crate::settings::*;
use ndarray::{Array2, Array3};

/// Compute gradients for the transformer model's learnable weights.
pub fn compute_gradients(
weights: &mut LearnableWeights,
inputs: &Array3<f32>,
targets: &Array2<f32>,
predictions: &Array2<f32>,
vocabsize : usize,
vocabsize: usize,
) -> LearnableWeights {
let mut gradients = LearnableWeights::new(
OUTPUT_SIZE,
HIDDEN_SIZE,
vocabsize, // Ensure the vocab size is correct
vocabsize, // Ensure the vocab size is correct
D_MODEL,
D_K,
FFN_DIM,
Expand Down Expand Up @@ -60,15 +59,16 @@ pub fn compute_gradients(
gradients
}



pub fn update_weights(
model: &mut LearnableWeights,
gradients: &LearnableWeights,
learning_rate: f32,
) {

println!("EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}",model.embedding.shape(),gradients.embedding.shape());
println!(
"EMBEDDING OLD :{:?}, EMBEDDING NEW: {:?}",
model.embedding.shape(),
gradients.embedding.shape()
);
// Ensure the gradients and model weights have compatible shapes (reshape if necessary)
model.embedding = &model.embedding - &(&gradients.embedding * learning_rate);
model.query_weights = &model.query_weights - &(&gradients.query_weights * learning_rate);
Expand Down

0 comments on commit 7127cd9

Please sign in to comment.