Skip to content

Commit

Permalink
Merge pull request #7 from JakubSchwenkbeck/Transformer/Model
Browse files Browse the repository at this point in the history
Implement Full Transformer Model in Rust
  • Loading branch information
JakubSchwenkbeck authored Dec 19, 2024
2 parents cfcc03c + b368284 commit 23f9866
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 31 deletions.
4 changes: 4 additions & 0 deletions src/attention/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ pub fn softmax_vector(vec: ArrayView1<f32>) -> Array1<f32> {
let sum: f32 = exp_vec.sum();
exp_vec / sum
}
pub fn softmax_vec(vec: Vec<f32>) -> Array1<f32> {
let array = Array1::from(vec); // Convert Vec<f32> to Array1<f32>
softmax_vector(array.view())
}

pub fn softmax_matrix(mat: &Array2<f32>) -> Array2<f32> {
convert_to_array2(mat.map_axis(Axis(1), softmax_vector))
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod model {
pub mod decoder;
pub mod embedding;
pub mod encoder;
pub mod transformer_model;
}

// this lint makes a scene ....
Expand Down
7 changes: 6 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use Transformer::math::linear_algebra::flatten_3d_array;
use Transformer::model::decoder::decoding;
use Transformer::model::embedding::Embedding;
use Transformer::model::encoder::encoding;
use Transformer::model::transformer_model::transformer_model;
use Transformer::settings::{BATCH_SIZE, DROPOUT_RATE, INPUT_SIZE, OUTPUT_SIZE};

fn main() {
Expand All @@ -32,7 +33,7 @@ fn main() {
let tokenizer = Tokenizer::new(vocab.clone());
let embedding = Embedding::new(6, 12); // Example vocab size and embedding size
// Input sentence
let sentence = "transformer transformer transformer";
let sentence = "hello world model";

// Tokenize and embed the input
let tokens = tokenizer.tokenize(sentence);
Expand Down Expand Up @@ -79,4 +80,8 @@ fn main() {
let tokens = embedding.retrieve_tokens(flatten_3d_array(decoded), &vocab);

println!("Tokens: {:?}", tokens);

let predicted_token = transformer_model(sentence, &vocab);

println!("Predicted Token: {:?}", predicted_token);
}
57 changes: 28 additions & 29 deletions src/model/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,35 +70,6 @@ impl Embedding {
token_embeddings + pos_encodings
}

pub fn predict_tokens(
probabilities: ArrayView2<f32>,
vocab: &HashMap<String, usize>,
) -> Vec<String> {
// Reverse the vocab to get a mapping from index to token
let index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();

let mut predicted_tokens = Vec::new();

for probs in probabilities.axis_iter(Axis(0)) {
// Iterate over the rows (sequence tokens)
// Find the index of the maximum probability
let max_index = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;

// Map the index to the corresponding token
if let Some(token) = index_to_token.get(&max_index) {
predicted_tokens.push(token.clone());
}
}

predicted_tokens
}

pub fn retrieve_tokens(
&self,
decoded_embeddings: Array2<f32>,
Expand Down Expand Up @@ -143,3 +114,31 @@ impl Embedding {
pub fn norm(vector: ArrayView<f32, Ix1>) -> f32 {
vector.mapv(|x| x * x).sum().sqrt()
}
pub fn predict_tokens(
probabilities: ArrayView2<f32>,
vocab: &HashMap<String, usize>,
) -> Vec<String> {
// Reverse the vocab to get a mapping from index to token
let index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();

let mut predicted_tokens = Vec::new();

for probs in probabilities.axis_iter(Axis(0)) {
// Iterate over the rows (sequence tokens)
// Find the index of the maximum probability
let max_index = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;

// Map the index to the corresponding token
if let Some(token) = index_to_token.get(&max_index) {
predicted_tokens.push(token.clone());
}
}

predicted_tokens
}
68 changes: 68 additions & 0 deletions src/model/transformer_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#![allow(warnings)]
use crate::attention::softmax::{softmax_matrix, softmax_vec, softmax_vector};
use crate::data::tokenizer::Tokenizer;
use crate::layers::feedforward_layer::FeedForwardLayer;
use crate::math::linear_algebra::flatten_3d_array;
use crate::model::decoder::decoding;
use crate::model::embedding::{predict_tokens, Embedding};
use crate::model::encoder::encoding;
use crate::settings::*;
use ndarray::{Array1, Array2, Array3};
use rand::Rng;
use std::collections::HashMap;

pub fn transformer_model(
sentence: &str, // Input sentence
vocab: &HashMap<String, usize>, // Vocabulary
) -> Vec<String> {
// Initialize Tokenizer and Embedding layer
let tokenizer = Tokenizer::new(vocab.clone());
let embedding = Embedding::new(vocab.len(), EMBEDDING_SIZE); // Initialize embedding layer

// Tokenize and embed the input sentence
let tokens = tokenizer.tokenize(sentence);
let embeddings = embedding.forward(tokens.clone());

// Convert embeddings to Array3 (batch_size, seq_length, embed_size)
let input_tensor = Array3::from_shape_fn(
(BATCH_SIZE, tokens.len(), EMBEDDING_SIZE),
|(batch, seq, _)| embeddings[[seq, batch]],
);

// Initialize gamma and beta for layer normalization
let gamma = Array2::ones((1, EMBEDDING_SIZE)); // Example gamma (scale parameter)
let beta = Array2::zeros((1, EMBEDDING_SIZE)); // Example beta (shift parameter)

// Initialize the feed-forward layer with correct types
let feed_forward_layer =
FeedForwardLayer::new(BATCH_SIZE, INPUT_SIZE, OUTPUT_SIZE, DROPOUT_RATE);

// Perform encoding (transformer encoder)
let encoded = encoding(
input_tensor.clone(),
gamma.clone(),
beta.clone(),
EPSILON,
&feed_forward_layer,
);

// Perform decoding (transformer decoder)
let decoded = decoding(
input_tensor,
encoded.clone(),
gamma,
beta,
EPSILON,
&feed_forward_layer,
);

// Apply final linear transformation
let output_projection = Array2::ones((OUTPUT_SIZE, vocab.len())); // All ones weights
let logits = flatten_3d_array(decoded).dot(&output_projection); // Linear layer

// Apply softmax to logits
let probabilities = softmax_matrix(&logits);

// Convert probabilities back to text using the tokenizer
predict_tokens(probabilities.view(), vocab)
}
2 changes: 1 addition & 1 deletion src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub const D_V: usize = 64; // Value dimension
pub const NUM_HEADS: usize = 8; // Number of attention heads
pub const SEQ_LENGTH: usize = 128; // Sequence length
pub const BATCH_SIZE: usize = 1; // Batch size
pub const EMBEDDING_SIZE: usize = 6;
pub const EMBEDDING_SIZE: usize = 12;
pub const INPUT_SIZE: usize = 12;
pub const OUTPUT_SIZE: usize = 12;

Expand Down

0 comments on commit 23f9866

Please sign in to comment.