Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer Embedding Enhancements #6

Merged
merged 8 commits into from
Dec 19, 2024
10 changes: 6 additions & 4 deletions src/layers/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ pub fn layer_norm(
beta: &Array2<f32>,
epsilon: f32,
) -> Array2<f32> {
// Step 1: Calculate mean and variance across the features (axis=1)
// Calculate mean and variance across the features (axis=1)
let mean = x.mean_axis(Axis(1)).unwrap();
let variance = x.var_axis(Axis(1), 0.0);
//println!("Mean: {:?}", mean);
// println!("Variance: {:?}", variance);

let expanded_mean = mean.insert_axis(Axis(1)); // Expands [6] to [6, 1]
let expanded_variance = variance.insert_axis(Axis(1)); // Expands [6] to [6, 1]
// println!("EXPMean: {:?}", expanded_mean);
//println!("EXPVariance: {:?}", expanded_variance);

// Add epsilon to expanded variance
let normalized = (x - &expanded_mean) / (expanded_variance + epsilon).mapv(f32::sqrt);

// Step 2: Normalize the input
//let normalized = (x - &reshaped_mean) / (reshaped_variance + epsilon).mapv(f32::sqrt);
// println!("Normalized {}", normalized);

// Step 3: Apply scaling (gamma) and shifting (beta)
normalized * gamma + beta
}
43 changes: 24 additions & 19 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use Transformer::data::tokenizer::{example_tokens, Tokenizer};
use Transformer::example::example;
use Transformer::layers::feedforward_layer::FeedForwardLayer;
use Transformer::math::linear_algebra::flatten_3d_array;
use Transformer::model::decoder::decoding;
use Transformer::model::embedding::Embedding;
use Transformer::model::encoder::encoding;
Expand All @@ -17,35 +18,35 @@ fn main() {
example_tokens();

println!(" \n \n \n ENCODER/DECODER \n");
// Example vocabulary
let vocab = vec![
("hello".to_string(), 4),
("world".to_string(), 5),
("my".to_string(), 6),
("name".to_string(), 7),
("is".to_string(), 8),
]
.into_iter()
.collect::<HashMap<String, usize>>();

let vocab = HashMap::from([
("hello".to_string(), 0),
("world".to_string(), 1),
("rust".to_string(), 2),
("transformer".to_string(), 3),
("learning".to_string(), 4),
("model".to_string(), 5),
]);

// Initialize Tokenizer and Embedding layer
let tokenizer = Tokenizer::new(vocab);
let embedding = Embedding::new(10, 64); // Example vocab size and embedding size
// Input sentence
let sentence = "hello world";
let tokenizer = Tokenizer::new(vocab.clone());
let embedding = Embedding::new(6, 12); // Example vocab size and embedding size
// Input sentence
let sentence = "transformer transformer transformer";

// Tokenize and embed the input
let tokens = tokenizer.tokenize(sentence);
let embeddings = embedding.forward(tokens.clone());

// Convert embeddings to Array3 (batch_size, seq_length, embed_size)
let input_tensor = Array3::from_shape_fn((1, tokens.len(), 64), |(batch, seq, _)| {
let input_tensor = Array3::from_shape_fn((1, tokens.len(), 12), |(batch, seq, _)| {
embeddings[[seq, batch]]
});

println!("INPUT : {}", input_tensor.clone());
// Initialize gamma and beta for layer normalization
let gamma = Array2::ones((1, 64)); // Example gamma (scale parameter)
let beta = Array2::zeros((1, 64)); // Example beta (shift parameter)
let gamma = Array2::ones((1, 12)); // Example gamma (scale parameter)
let beta = Array2::zeros((1, 12)); // Example beta (shift parameter)

// Initialize the feed-forward layer with correct types

Expand All @@ -55,15 +56,15 @@ fn main() {
// Perform encoding (transformer layer)
let epsilon = 1e-6; // Small epsilon for numerical stability
let encoded = encoding(
input_tensor,
input_tensor.clone(),
gamma.clone(),
beta.clone(),
epsilon,
&feed_forward_layer,
);
// Perform decoding (transformer layer)
let decoded = decoding(
encoded.clone(),
input_tensor,
encoded.clone(),
gamma,
beta,
Expand All @@ -74,4 +75,8 @@ fn main() {
// Print the encoded and decoded output tensors
println!("Encoded: {:?}", encoded);
println!("Decoded: {:?}", decoded);

let tokens = embedding.retrieve_tokens(flatten_3d_array(decoded), &vocab);

println!("Tokens: {:?}", tokens);
}
8 changes: 8 additions & 0 deletions src/math/linear_algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ pub fn apply_projection(x: &Array3<f32>, w: &Array2<f32>) -> Array3<f32> {

result
}

pub fn flatten_3d_array(batch: Array3<f32>) -> Array2<f32> {
let (batch_size, seq_length, embed_size) = batch.dim();
batch
.to_shape((batch_size * seq_length, embed_size))
.unwrap()
.to_owned()
}
123 changes: 116 additions & 7 deletions src/model/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#![allow(dead_code)]
use ndarray::{s, Array, Array2};
use rand::Rng;

use crate::math::positional_encoding::sinusoidal_pos_encoding;
use ndarray::{s, Array2, ArrayView, ArrayView2, Axis, Ix1};
//use rand::Rng;
use std::collections::HashMap;

pub struct Embedding {
vocab_size: usize,
Expand All @@ -11,18 +14,48 @@ pub struct Embedding {
impl Embedding {
pub fn new(vocab_size: usize, embed_size: usize) -> Self {
// Initialize with random values for simplicity
let mut rng = rand::rng();
let weights =
Array::from_shape_fn((vocab_size, embed_size), |_| rng.random_range(-1.0..1.0));
//let mut rng = rand::rng();
//let mut _weights =Array::from_shape_fn((vocab_size, embed_size), |_| rng.random_range(-1.0..1.0));
let weights = Array2::from_shape_vec(
(6, 12),
vec![
// Embedding for "hello"
50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "world"
0.0, 0.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "rust"
0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "transformer"
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "learning"
0.0, 00.0, 0.0, 00.0, 0.0, 00.0, 0.0, 00.0, 50.0, 50.0, 0.0, 00.0,
// Embedding for "model"
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0,
],
)
.unwrap();

Embedding {
vocab_size,
embed_size,
weights,
}
}
/// Compute the sinusoidal positional encodings for a given sequence length.
pub fn get_positional_encodings(&self, seq_len: usize) -> Array2<f32> {
let mut pos_encodings = Array2::zeros((seq_len, self.embed_size));

for pos in 0..seq_len {
for i in 0..self.embed_size {
pos_encodings[(pos, i)] = sinusoidal_pos_encoding(pos, i, self.embed_size);
}
}
pos_encodings
}

// Inside Embedding::forward
/// Forward pass through the embedding layer, adding positional encodings to token embeddings.
pub fn forward(&self, tokens: Vec<usize>) -> Array2<f32> {
let seq_len = tokens.len();
let mut token_embeddings: Vec<f32> = Vec::new();

// For each token, get the corresponding embedding and append it to the token_embeddings vector
Expand All @@ -31,6 +64,82 @@ impl Embedding {
}

// Create the Array2 from the flattened token_embeddings vector
Array2::from_shape_vec((tokens.len(), self.embed_size), token_embeddings).unwrap()
let token_embeddings =
Array2::from_shape_vec((seq_len, self.embed_size), token_embeddings).unwrap();
let pos_encodings = self.get_positional_encodings(seq_len);
token_embeddings + pos_encodings
}

pub fn predict_tokens(
probabilities: ArrayView2<f32>,
vocab: &HashMap<String, usize>,
) -> Vec<String> {
// Reverse the vocab to get a mapping from index to token
let index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();

let mut predicted_tokens = Vec::new();

for probs in probabilities.axis_iter(Axis(0)) {
// Iterate over the rows (sequence tokens)
// Find the index of the maximum probability
let max_index = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;

// Map the index to the corresponding token
if let Some(token) = index_to_token.get(&max_index) {
predicted_tokens.push(token.clone());
}
}

predicted_tokens
}

pub fn retrieve_tokens(
&self,
decoded_embeddings: Array2<f32>,
vocab: &HashMap<String, usize>, // Token to index mapping
) -> Vec<String> {
// Reverse the vocab to get a mapping from index to token
let index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();

let mut predicted_tokens = Vec::new();

for decoded in decoded_embeddings.axis_iter(Axis(0)) {
let decoded_norm = norm(decoded);
if decoded_norm < 1e-6 {
continue; // Skip null rows
}
// Compute cosine similarity with all embeddings in `weights`
let similarities: Vec<f32> = self
.weights
.axis_iter(Axis(0))
.map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(embedding)))
.collect();

// Find the index of the maximum similarity
let best_match = similarities
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;

// Map index to the corresponding token
if let Some(token) = index_to_token.get(&best_match) {
predicted_tokens.push(token.clone());
}
}

predicted_tokens
}
}

pub fn norm(vector: ArrayView<f32, Ix1>) -> f32 {
vector.mapv(|x| x * x).sum().sqrt()
}
2 changes: 2 additions & 0 deletions src/model/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub fn encoding(
dummy_learned_matrices.clone(), // W_O
);

//println!("Attention1 :{}", attention_output);
// Add & Normalize (Residual Connection + Layer Norm)
let attention_residual = attention_output.add(&input); // Residual connection
let reshaped_attention = attention_residual
Expand All @@ -59,6 +60,7 @@ pub fn encoding(
// Feed-Forward Network
let feed_forward_output = feed_forward_layer.forward(attention_norm.clone());

//println!("feed_forward_output :{:?}", feed_forward_output);
// Add & Normalize (Residual Connection + Layer Norm)
let feed_forward_residual = feed_forward_output.add(&attention_norm); // Residual connection
let reshaped_ff_attention = feed_forward_residual
Expand Down
7 changes: 4 additions & 3 deletions src/settings.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
// Numerical constants with down scaled real-application values
#![allow(unused)]
pub const EPSILON: f32 = 0.0001;
pub const D_MODEL: usize = 512; // Embedding size
pub const D_K: usize = 64; // Key/query dimension
pub const D_V: usize = 64; // Value dimension
pub const NUM_HEADS: usize = 8; // Number of attention heads
pub const SEQ_LENGTH: usize = 128; // Sequence length
pub const BATCH_SIZE: usize = 1; // Batch size

pub const INPUT_SIZE: usize = 64;
pub const OUTPUT_SIZE: usize = 64;
pub const EMBEDDING_SIZE: usize = 6;
pub const INPUT_SIZE: usize = 12;
pub const OUTPUT_SIZE: usize = 12;

pub const HIDDEN_SIZE: usize = 6;

Expand Down
Loading