Skip to content

Commit

Permalink
implemented fully functional positional (sinusoidal) embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSchwenkbeck committed Dec 19, 2024
1 parent da46653 commit 86dc59a
Showing 1 changed file with 77 additions and 15 deletions.
92 changes: 77 additions & 15 deletions src/model/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![allow(dead_code)]

use crate::math::positional_encoding::sinusoidal_pos_encoding;
use ndarray::{s, Array2, ArrayView, ArrayView2, Axis, Ix1};
//use rand::Rng;
use std::collections::HashMap;
use ndarray::{s, Array, Array2, ArrayView2, Axis};
use rand::Rng;

pub struct Embedding {
vocab_size: usize,
Expand All @@ -13,18 +14,48 @@ pub struct Embedding {
impl Embedding {
pub fn new(vocab_size: usize, embed_size: usize) -> Self {
// Initialize with random values for simplicity
let mut rng = rand::rng();
let weights =
Array::from_shape_fn((vocab_size, embed_size), |_| rng.random_range(-1.0..1.0));
//let mut rng = rand::rng();
//let mut _weights =Array::from_shape_fn((vocab_size, embed_size), |_| rng.random_range(-1.0..1.0));
let weights = Array2::from_shape_vec(
(6, 12),
vec![
// Embedding for "hello"
50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "world"
0.0, 0.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "rust"
0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "transformer"
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0,
// Embedding for "learning"
0.0, 00.0, 0.0, 00.0, 0.0, 00.0, 0.0, 00.0, 50.0, 50.0, 0.0, 00.0,
// Embedding for "model"
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0,
],
)
.unwrap();

Embedding {
vocab_size,
embed_size,
weights,
}
}
/// Compute the sinusoidal positional encodings for a given sequence length.
pub fn get_positional_encodings(&self, seq_len: usize) -> Array2<f32> {
let mut pos_encodings = Array2::zeros((seq_len, self.embed_size));

for pos in 0..seq_len {
for i in 0..self.embed_size {
pos_encodings[(pos, i)] = sinusoidal_pos_encoding(pos, i, self.embed_size);
}
}
pos_encodings
}

// Inside Embedding::forward
/// Forward pass through the embedding layer, adding positional encodings to token embeddings.
pub fn forward(&self, tokens: Vec<usize>) -> Array2<f32> {
let seq_len = tokens.len();
let mut token_embeddings: Vec<f32> = Vec::new();

// For each token, get the corresponding embedding and append it to the token_embeddings vector
Expand All @@ -33,35 +64,62 @@ impl Embedding {
}

// Create the Array2 from the flattened token_embeddings vector
Array2::from_shape_vec((tokens.len(), self.embed_size), token_embeddings).unwrap()
let token_embeddings =
Array2::from_shape_vec((seq_len, self.embed_size), token_embeddings).unwrap();
let pos_encodings = self.get_positional_encodings(seq_len);
token_embeddings + pos_encodings
}

pub fn predict_tokens(probabilities: ArrayView2<f32>, vocab : &HashMap<String, usize> ) -> Array2<f32> {
pub fn predict_tokens(
probabilities: ArrayView2<f32>,
vocab: &HashMap<String, usize>,
) -> Vec<String> {
// Reverse the vocab to get a mapping from index to token
let index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();

let mut predicted_tokens = Vec::new();

for probs in probabilities.axis_iter(Axis(0)) { // Iterate over the rows (sequence tokens)
let max_index = probs.argmax().unwrap(); // Find the index of the maximum probability
predicted_tokens.push(vocab[max_index]); // Map the index to the vocabulary
for probs in probabilities.axis_iter(Axis(0)) {
// Iterate over the rows (sequence tokens)
// Find the index of the maximum probability
let max_index = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;

// Map the index to the corresponding token
if let Some(token) = index_to_token.get(&max_index) {
predicted_tokens.push(token.clone());
}
}

predicted_tokens
}

pub fn retrieve_tokens(
&self,
decoded_embeddings: Array2<f32>,
weights: &Array2<f32>, // Embedding matrix
vocab: &HashMap<String, usize>, // Token to index mapping
) -> Vec<String> {
// Reverse the vocab to get a mapping from index to token
let index_to_token: HashMap<usize, String> = vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
let index_to_token: HashMap<usize, String> =
vocab.iter().map(|(k, &v)| (v, k.clone())).collect();

let mut predicted_tokens = Vec::new();

for decoded in decoded_embeddings.axis_iter(Axis(0)) {
let decoded_norm = norm(decoded);
if decoded_norm < 1e-6 {
continue; // Skip null rows
}
// Compute cosine similarity with all embeddings in `weights`
let similarities: Vec<f32> = weights
let similarities: Vec<f32> = self
.weights
.axis_iter(Axis(0))
.map(|embedding| embedding.dot(&decoded) / (embedding.norm() * decoded.norm()))
.map(|embedding| embedding.dot(&decoded) / (norm(embedding) * norm(embedding)))
.collect();

// Find the index of the maximum similarity
Expand All @@ -81,3 +139,7 @@ impl Embedding {
predicted_tokens
}
}

pub fn norm(vector: ArrayView<f32, Ix1>) -> f32 {
vector.mapv(|x| x * x).sum().sqrt()
}

0 comments on commit 86dc59a

Please sign in to comment.