Skip to content

Commit

Permalink
Training Pipeline Implementation
Browse files Browse the repository at this point in the history
Training Pipeline Implementation
  • Loading branch information
JakubSchwenkbeck authored Dec 23, 2024
2 parents 31d3369 + eb667c7 commit 2b0e81d
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 62 deletions.
43 changes: 39 additions & 4 deletions src/data/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
#[allow(dead_code)]
struct Dataset {
inputs: Vec<Vec<usize>>, // Each input is a sequence of token IDs
targets: Vec<Vec<usize>>, // Each target is the corresponding output sequence
use crate::data::generation::{generate_input_target_pairs, generate_staircase_pairs};
use crate::data::io::get_input;
use crate::data::tokenizer::Tokenizer;

pub struct Dataset {
pub(crate) inputs: Vec<Vec<usize>>, // Each input is a sequence of token IDs
pub(crate) targets: Vec<Vec<usize>>, // Each target is the corresponding output sequence
}

pub fn gen_data() -> (Tokenizer, Dataset) {
let raw_text = get_input();

let tokenizer = Tokenizer::new(raw_text.clone());

// Generate input-target pairs
let pairs = generate_input_target_pairs(&tokenizer, raw_text);

let mut all_inputs = Vec::new();
let mut all_targets = Vec::new();

// For each input-target pair, generate staircase pairs and add to the dataset
for (input, target) in pairs {
let staircase_pairs = generate_staircase_pairs(&input, &target);

// Add the staircase pairs to the dataset
for (staircase_input, staircase_target) in staircase_pairs {
all_inputs.push(staircase_input);
all_targets.push(staircase_target);
}
}

// Return tokenizer and the generated dataset
(
tokenizer,
Dataset {
inputs: all_inputs,
targets: all_targets,
},
)
}
35 changes: 31 additions & 4 deletions src/data/generation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::data::io::get_input;
use crate::data::tokenizer::Tokenizer;
use crate::settings::INPUT_SIZE;

fn generate_input_target_pairs(
pub fn generate_input_target_pairs(
tokenizer: &Tokenizer,
sentences: Vec<String>,
) -> Vec<(Vec<usize>, Vec<usize>)> {
Expand All @@ -12,7 +13,7 @@ fn generate_input_target_pairs(
let tokens = tokenizer.tokenize(sentence);

// Prepare input (same as sentence)
let input = tokens.clone();
let input = tokenizer.pad_sequence(tokens.clone(), INPUT_SIZE);

// Prepare target (shifted version of the sentence)
let mut target = tokens.clone();
Expand All @@ -31,7 +32,7 @@ fn generate_input_target_pairs(
if !target.is_empty() {
target.remove(0);
}

let target = tokenizer.pad_sequence(target, INPUT_SIZE);
// Add the input-target pair to the result
pairs.push((input, target));
}
Expand Down Expand Up @@ -81,7 +82,10 @@ pub fn example_gen() {
}
}
}
fn generate_staircase_pairs(input: &[usize], target: &[usize]) -> Vec<(Vec<usize>, Vec<usize>)> {
pub fn generate_staircase_pairs(
input: &[usize],
target: &[usize],
) -> Vec<(Vec<usize>, Vec<usize>)> {
let mut staircase_pairs = Vec::new();

// The number of steps will be the length of the target sequence
Expand All @@ -90,9 +94,32 @@ fn generate_staircase_pairs(input: &[usize], target: &[usize]) -> Vec<(Vec<usize
let staircase_input = input.iter().take(i).cloned().collect::<Vec<usize>>();
let staircase_target = target.iter().take(i).cloned().collect::<Vec<usize>>();

// Pad both input and target sequences to max_length
let staircase_input = pad_sequence_to_length(&staircase_input, INPUT_SIZE);
let staircase_target = pad_sequence_to_length(&staircase_target, INPUT_SIZE);
// Add this pair to the staircase pairs vector
staircase_pairs.push((staircase_input, staircase_target));
}

staircase_pairs
}
fn pad_sequence_to_length(seq: &[usize], max_length: usize) -> Vec<usize> {
let mut padded_seq = seq.to_vec();

// Pad with <PAD> token if the sequence is shorter than max_length
match padded_seq.len().cmp(&max_length) {
std::cmp::Ordering::Less => {
// If the sequence is too short, pad it with <PAD> tokens (0)
padded_seq.resize(max_length, 0);
}
std::cmp::Ordering::Greater => {
// If the sequence is too long, truncate it
padded_seq.truncate(max_length);
}
std::cmp::Ordering::Equal => {
// If the sequence is already the correct length, do nothing
}
}

padded_seq
}
45 changes: 23 additions & 22 deletions src/data/learnable.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::settings::*;
use ndarray::{Array1, Array2};

#[derive(Debug)]
pub struct LearnableWeights {
// Embedding Layer
pub embedding: Array2<f32>, // (vocab_size, embedding_dim)
Expand All @@ -14,8 +14,8 @@ pub struct LearnableWeights {
// Feedforward Network
pub linear1_weights: Array2<f32>, // (embedding_dim, ffn_dim)
pub linear2_weights: Array2<f32>, // (ffn_dim, embedding_dim)
pub bias1: Array1<f32>,
pub bias2: Array1<f32>,
pub bias1: Array1<f32>, // (ffn_dim)
pub bias2: Array1<f32>, // (embedding_dim)

// Layer Normalization
pub layer_norm_scale: Vec<f32>, // (embedding_dim,)
Expand All @@ -28,44 +28,45 @@ pub struct LearnableWeights {
impl LearnableWeights {
pub fn new(
output_size: usize,
hidden_size: usize,
_hidden_size: usize,
vocab_size: usize,
embedding_dim: usize,
attention_dim: usize,
ffn_dim: usize,
) -> Self {
LearnableWeights {
// Embedding Layer
embedding: Array2::ones((vocab_size, embedding_dim)),
embedding: Array2::ones((vocab_size, embedding_dim)), // (vocab_size, embedding_dim)

// Attention Mechanism
query_weights: Array2::ones((embedding_dim, attention_dim)),
key_weights: Array2::ones((embedding_dim, attention_dim)),
value_weights: Array2::ones((embedding_dim, attention_dim)),
output_projection: Array2::ones((attention_dim, embedding_dim)),
query_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim)
key_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim)
value_weights: Array2::ones((embedding_dim, attention_dim)), // (embedding_dim, attention_dim)
output_projection: Array2::ones((attention_dim, embedding_dim)), // (attention_dim, embedding_dim)

// Feedforward Network
linear1_weights: Array2::ones((embedding_dim, ffn_dim)),
linear2_weights: Array2::ones((ffn_dim, embedding_dim)),
bias1: Array1::zeros(hidden_size),
bias2: Array1::zeros(output_size),
linear1_weights: Array2::ones((embedding_dim, ffn_dim)), // (embedding_dim, ffn_dim)
linear2_weights: Array2::ones((ffn_dim, embedding_dim)), // (ffn_dim, embedding_dim)
bias1: Array1::zeros(ffn_dim), // (ffn_dim)
bias2: Array1::zeros(embedding_dim), // (embedding_dim)

// Layer Normalization
layer_norm_scale: vec![1.0; embedding_dim], // Initialize scale to 1
layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0
layer_norm_scale: vec![1.0; embedding_dim], // Initialize scale to 1 (embedding_dim,)
layer_norm_shift: vec![0.0; embedding_dim], // Initialize shift to 0 (embedding_dim,)

// Output Layer
output_projection_vocab: Array2::zeros((embedding_dim, vocab_size)),
output_projection_vocab: Array2::zeros((embedding_dim, output_size)), // (embedding_dim, vocab_size)
}
}
}

pub fn initialize_weights() -> LearnableWeights {
LearnableWeights::new(
OUTPUT_SIZE, // output_size
HIDDEN_SIZE, // hidden_size
D_MODEL, // vocab_size
D_MODEL, // embedding_dim
D_K, // attention_dim (for keys, queries)
D_V, // ffn_dim (could align with embedding_dim or specific)
D_MODEL, // output_size = D_MODEL
FFN_DIM, // hidden_size = FFN_DIM
D_MODEL, // vocab_size (assuming this is same as embedding_dim)
D_MODEL, // embedding_dim = D_MODEL
D_K, // attention_dim = D_K
FFN_DIM, // ffn_dim = FFN_DIM
)
}
15 changes: 14 additions & 1 deletion src/data/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(warnings)]

use crate::settings::INPUT_SIZE;
use regex::Regex;
use std::collections::{HashMap, HashSet};

Expand Down Expand Up @@ -78,7 +79,19 @@ impl Tokenizer {
.join(" ")
}

// Helper function to split sentence into words using an improved regex
pub fn pad_sequence(&self, tokens: Vec<usize>, max_length: usize) -> Vec<usize> {
let mut padded = tokens.clone();

// Pad with the PAD token if the sequence is too short
if tokens.len() < max_length {
padded.resize(max_length, self.vocab[&self.pad_token]);
} else if tokens.len() > max_length {
// Truncate if the sequence is too long
padded.truncate(max_length);
}

padded
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use ndarray::{Array2, Axis};
#[requires(beta.shape().len() == 2 && beta.shape()[0] == 1, "Beta must be a 2-dimensional array with a single row")]
#[requires(epsilon > 0.0, "Epsilon must be positive and non-zero")]
#[ensures(ret.shape() == x.shape(), "The resulting array must have the same shape as the input array")]
#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting array must be finite")]
//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting array must be finite")]
pub fn layer_norm(
x: &Array2<f32>,
gamma: &Array2<f32>,
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
pub mod example;
pub mod settings;

pub mod training {
pub mod loss_function;
pub mod train;
}
pub mod data {
pub mod dataset;
pub mod generation;
Expand All @@ -14,6 +18,7 @@ pub mod model {
pub mod decoder;
pub mod embedding;
pub mod encoder;
pub mod train_transformer;
pub mod transformer_model;
}

Expand Down
4 changes: 3 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use Transformer::data::generation::example_gen;
use Transformer::data::tokenizer::Tokenizer;
use Transformer::example::example;

use Transformer::model::train_transformer::train;
use Transformer::model::transformer_model::transformer_model;
fn main() {
println!("runs successfully!");
Expand Down Expand Up @@ -75,4 +75,6 @@ fn main() {
println!("Predicted Token: {:?}", predicted_token);

example_gen();

train()
}
10 changes: 5 additions & 5 deletions src/math/linear_algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ pub fn dotproduct(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
#[requires(b.shape().len() == 3, "Input tensor b must have 3 dimensions")]
#[requires(a.shape()[0] == b.shape()[0], "Batch sizes must match")]
#[requires(a.shape()[2] == b.shape()[1], "Inner dimensions must align for matrix multiplication")]
#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")]
#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")]
//#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")]
//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")]
pub fn tensor_product(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {
// Check that batch sizes match and if dimension align
assert_eq!(a.shape()[0], b.shape()[0], "Batch sizes must match");
Expand Down Expand Up @@ -83,8 +83,8 @@ pub fn tensor_product(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {
#[requires(x.shape().len() == 3, "Input tensor x must have 3 dimensions")]
#[requires(w.shape().len() == 2, "Weight matrix w must have 2 dimensions")]
#[requires(x.shape()[2] == w.shape()[0], "Input feature size must match the weight matrix's rows")]
#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")]
#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")]
//#[ensures(ret.shape().len() == 3, "The resulting tensor must have 3 dimensions.")]
//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")]
pub fn apply_projection(x: &Array3<f32>, w: &Array2<f32>) -> Array3<f32> {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
Expand Down Expand Up @@ -116,7 +116,7 @@ pub fn apply_projection(x: &Array3<f32>, w: &Array2<f32>) -> Array3<f32> {
/// A 2D tensor of shape (batch_size * seq_length, embed_size).
#[requires(batch.shape().len() == 3, "Input tensor must have 3 dimensions")]
#[ensures(ret.shape().len() == 2, "The resulting tensor must have 2 dimensions.")]
#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")]
//#[ensures(ret.iter().all(|&x| x.is_finite()), "All elements in the resulting tensor must be finite.")]
pub fn flatten_3d_array(batch: Array3<f32>) -> Array2<f32> {
let (batch_size, seq_length, embed_size) = batch.dim();
batch
Expand Down
48 changes: 37 additions & 11 deletions src/model/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

use crate::math::positional_encoding::sinusoidal_pos_encoding;
use ndarray::{s, Array, Array2, ArrayView, ArrayView2, Axis, Ix1};
//use rand::Rng;
use rand::Rng;
use std::cmp::Ordering;
use std::collections::HashMap;

pub struct Embedding {
Expand Down Expand Up @@ -74,20 +74,20 @@ impl Embedding {
let similarities: Vec<f32> = self
.weights
.axis_iter(Axis(0))
.map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(embedding)))
.map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(decoded)))
.filter(|&similarity| !similarity.is_nan()) // Filter out NaNs
.collect();

// Find the index of the maximum similarity
let best_match = similarities
if let Some((best_match, _)) = similarities
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;

// Map index to the corresponding token
if let Some(token) = index_to_token.get(&best_match) {
predicted_tokens.push(token.clone());
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
{
// Map index to the corresponding token
if let Some(token) = index_to_token.get(&best_match) {
predicted_tokens.push(token.clone());
}
}
}

Expand All @@ -98,6 +98,31 @@ impl Embedding {
pub fn norm(vector: ArrayView<f32, Ix1>) -> f32 {
vector.mapv(|x| x * x).sum().sqrt()
}

pub fn predict_index(probabilities: ArrayView2<f32>, vocab: &HashMap<String, usize>) -> Vec<usize> {
// Reverse the vocab to get a mapping from index to token
let _index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();

let mut predicted_tokens = Vec::new();

for probs in probabilities.axis_iter(Axis(0)) {
// Iterate over the rows (sequence tokens)
// Find the index of the maximum probability
let max_index = probs
.iter()
.enumerate()
.filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap()
.0;

predicted_tokens.push(max_index);
}

predicted_tokens
}

pub fn predict_tokens(
probabilities: ArrayView2<f32>,
vocab: &HashMap<String, usize>,
Expand All @@ -114,7 +139,8 @@ pub fn predict_tokens(
let max_index = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.filter(|(_, &prob)| !prob.is_nan()) // Filter out NaNs
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap()
.0;

Expand Down
Loading

0 comments on commit 2b0e81d

Please sign in to comment.