diff --git a/src/attention/scaled_dot_attention.rs b/src/attention/scaled_dot_attention.rs index aad6d4c..50275e5 100644 --- a/src/attention/scaled_dot_attention.rs +++ b/src/attention/scaled_dot_attention.rs @@ -3,6 +3,16 @@ use crate::attention::softmax::softmax_3d; use crate::math::linear_algebra::{matmul, tensor_product}; use ndarray::{Array3, Axis, ShapeError}; +/// Computes the scaled dot-product attention. +/// +/// # Arguments +/// - `q`: The query tensor (e.g., [batch, seq_len, dim]). +/// - `k`: The key tensor (e.g., [batch, seq_len, dim]). +/// - `v`: The value tensor (e.g., [batch, seq_len, dim]). +/// - `mask`: Whether to apply a mask to the attention scores (useful for causal or padding masks). +/// +/// # Returns +/// The resulting attention output tensor (e.g., [batch, seq_len, dim]). pub fn scaled_dot_product_attention( q: Array3, // Query k: Array3, // Key @@ -13,7 +23,7 @@ pub fn scaled_dot_product_attention( let sm_scores = softmax_3d(&scores); tensor_product(&sm_scores, &v) } - +/// Performs the scaled dot product of query, key, and value tensors and optionally use mask. pub fn scaled_dot_product( q: Array3, // Shape: (B, L_Q, d_k) k: Array3, // Shape: (B, L_K, d_k) @@ -32,7 +42,8 @@ pub fn scaled_dot_product( scores /= d_k.sqrt(); if mask { let mask = Array3::from_shape_fn((batch_size, L_Q, L_K), |(b, i, j)| { - if i > j { + if i == j { + // for Training this should be i > j 0.0 } else { f32::NEG_INFINITY @@ -56,6 +67,8 @@ pub fn scaled_dot_product( } scores } + +/// Performs the dot product of query, key, and value tensors. pub fn query_key_product( q: Array3, // Shape: (B, L_Q, d_k) k: Array3, // Shape: (B, L_K, d_k) diff --git a/src/example.rs b/src/example.rs new file mode 100644 index 0000000..828f090 --- /dev/null +++ b/src/example.rs @@ -0,0 +1,38 @@ +use crate::attention::scaled_dot_attention::scaled_dot_product; +use crate::attention::softmax::softmax_3d; +use ndarray::{array, s, Array2}; + +pub fn example() { + let words = vec!["The", "cat", "sat", "on", "the", "mat"]; + let q = array![[ + [1.0, 0.8, 0.6, 0.4, 0.2, 0.1], + [0.8, 1.0, 0.9, 0.7, 0.3, 0.2], + [0.6, 0.9, 1.0, 0.8, 0.5, 0.3], + [0.4, 0.7, 0.8, 1.0, 0.7, 0.6], + [0.2, 0.3, 0.5, 0.7, 1.0, 0.9], + [0.1, 0.2, 0.3, 0.6, 0.9, 1.0] + ]]; + + let scores = scaled_dot_product(q.clone(), q.clone(), q.clone(), true); + let sm_scores = softmax_3d(&scores); + display_attention_weights(sm_scores.slice(s![0, .., ..]).to_owned(), &words); +} +fn display_attention_weights(scores: Array2, words: &[&str]) { + println!("Attention Weights (Softmax Scores):\n"); + + // Print column headers + print!("{:<6}", ""); // Empty corner for alignment + for word in words { + print!(" {:<5}", word); + } + println!(); // New line for clarity + + // Iterate through rows and display each with the corresponding word + for (i, row) in scores.outer_iter().enumerate() { + print!("{:<6}", words[i]); // Row label + for &val in row.iter() { + print!("{:<6.3}", val); // Print score with 3 decimal places + } + println!(); // New line after each row + } +} diff --git a/src/lib.rs b/src/lib.rs index 78f48ae..58611ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![allow(non_snake_case)] +pub mod example; mod settings; // this lint makes a scene .... diff --git a/src/main.rs b/src/main.rs index e8cf130..0774c10 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,42 +1,7 @@ -use ndarray::{array, s, Array2, Array3}; -use Transformer::attention::scaled_dot_attention::scaled_dot_product; -use Transformer::attention::softmax::softmax_3d; +use Transformer::example::example; fn main() { println!("runs successfully!"); - let a: Array3 = array![[ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9], - [1.0, 1.1, 1.2], - [0.1, 0.2, 0.3], - [1.3, 1.4, 1.5] - ]]; - - let scores = scaled_dot_product(a.clone(), a.clone(), a.clone(), true); - let sm_scores = softmax_3d(&scores); - // Words corresponding to the input - let words = ["the", "cat", "sat", "on", "the", "mat"]; - - display_attention_weights(sm_scores.slice(s![0, .., ..]).to_owned(), &words); -} -fn display_attention_weights(scores: Array2, words: &[&str]) { - println!("Attention Weights (Softmax Scores):\n"); - - // Print column headers - print!("{:<6}", ""); // Empty corner for alignment - for word in words { - print!(" {:<5}", word); - } - println!(); // New line for clarity - - // Iterate through rows and display each with the corresponding word - for (i, row) in scores.outer_iter().enumerate() { - print!("{:<6}", words[i]); // Row label - for &val in row.iter() { - print!("{:<6.3}", val); // Print score with 3 decimal places - } - println!(); // New line after each row - } + example(); } diff --git a/src/math/linear_algebra.rs b/src/math/linear_algebra.rs index bcf4800..82567fe 100644 --- a/src/math/linear_algebra.rs +++ b/src/math/linear_algebra.rs @@ -59,6 +59,14 @@ pub fn tensor_product(a: &Array3, b: &Array3) -> Array3 { result } +/// Applies a linear projection to a 3D tensor using a weight matrix. +/// +/// # Arguments +/// - `x`: The input 3D tensor (e.g., [batch, seq_len, input_dim]). +/// - `w`: The weight matrix for the projection (e.g., [input_dim, output_dim]). +/// +/// # Returns +/// A new 3D tensor with the projection applied (e.g., [batch, seq_len, output_dim]). pub fn apply_projection(x: &Array3, w: &Array2) -> Array3 { let batch_size = x.shape()[0]; let seq_len = x.shape()[1];