Skip to content

Commit

Permalink
after multihead attention and projection from 3D to 2D, implemented a…
Browse files Browse the repository at this point in the history
…n example for Attention weights which displays the strength of this Transformer key tech
  • Loading branch information
JakubSchwenkbeck committed Dec 8, 2024
1 parent 6fca8d9 commit 57765e1
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 39 deletions.
17 changes: 15 additions & 2 deletions src/attention/scaled_dot_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ use crate::attention::softmax::softmax_3d;
use crate::math::linear_algebra::{matmul, tensor_product};
use ndarray::{Array3, Axis, ShapeError};

/// Computes the scaled dot-product attention.
///
/// # Arguments
/// - `q`: The query tensor (e.g., [batch, seq_len, dim]).
/// - `k`: The key tensor (e.g., [batch, seq_len, dim]).
/// - `v`: The value tensor (e.g., [batch, seq_len, dim]).
/// - `mask`: Whether to apply a mask to the attention scores (useful for causal or padding masks).
///
/// # Returns
/// The resulting attention output tensor (e.g., [batch, seq_len, dim]).
pub fn scaled_dot_product_attention(
q: Array3<f32>, // Query
k: Array3<f32>, // Key
Expand All @@ -13,7 +23,7 @@ pub fn scaled_dot_product_attention(
let sm_scores = softmax_3d(&scores);
tensor_product(&sm_scores, &v)
}

/// Performs the scaled dot product of query, key, and value tensors and optionally use mask.
pub fn scaled_dot_product(
q: Array3<f32>, // Shape: (B, L_Q, d_k)
k: Array3<f32>, // Shape: (B, L_K, d_k)
Expand All @@ -32,7 +42,8 @@ pub fn scaled_dot_product(
scores /= d_k.sqrt();
if mask {
let mask = Array3::from_shape_fn((batch_size, L_Q, L_K), |(b, i, j)| {
if i > j {
if i == j {
// for Training this should be i > j
0.0
} else {
f32::NEG_INFINITY
Expand All @@ -56,6 +67,8 @@ pub fn scaled_dot_product(
}
scores
}

/// Performs the dot product of query, key, and value tensors.
pub fn query_key_product(
q: Array3<f32>, // Shape: (B, L_Q, d_k)
k: Array3<f32>, // Shape: (B, L_K, d_k)
Expand Down
38 changes: 38 additions & 0 deletions src/example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use crate::attention::scaled_dot_attention::scaled_dot_product;
use crate::attention::softmax::softmax_3d;
use ndarray::{array, s, Array2};

pub fn example() {
let words = vec!["The", "cat", "sat", "on", "the", "mat"];
let q = array![[
[1.0, 0.8, 0.6, 0.4, 0.2, 0.1],
[0.8, 1.0, 0.9, 0.7, 0.3, 0.2],
[0.6, 0.9, 1.0, 0.8, 0.5, 0.3],
[0.4, 0.7, 0.8, 1.0, 0.7, 0.6],
[0.2, 0.3, 0.5, 0.7, 1.0, 0.9],
[0.1, 0.2, 0.3, 0.6, 0.9, 1.0]
]];

let scores = scaled_dot_product(q.clone(), q.clone(), q.clone(), true);
let sm_scores = softmax_3d(&scores);
display_attention_weights(sm_scores.slice(s![0, .., ..]).to_owned(), &words);
}
fn display_attention_weights(scores: Array2<f32>, words: &[&str]) {
println!("Attention Weights (Softmax Scores):\n");

// Print column headers
print!("{:<6}", ""); // Empty corner for alignment
for word in words {
print!(" {:<5}", word);
}
println!(); // New line for clarity

// Iterate through rows and display each with the corresponding word
for (i, row) in scores.outer_iter().enumerate() {
print!("{:<6}", words[i]); // Row label
for &val in row.iter() {
print!("{:<6.3}", val); // Print score with 3 decimal places
}
println!(); // New line after each row
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(non_snake_case)]

pub mod example;
mod settings;

// this lint makes a scene ....
Expand Down
39 changes: 2 additions & 37 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,7 @@
use ndarray::{array, s, Array2, Array3};
use Transformer::attention::scaled_dot_attention::scaled_dot_product;
use Transformer::attention::softmax::softmax_3d;
use Transformer::example::example;

fn main() {
println!("runs successfully!");

let a: Array3<f32> = array![[
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9],
[1.0, 1.1, 1.2],
[0.1, 0.2, 0.3],
[1.3, 1.4, 1.5]
]];

let scores = scaled_dot_product(a.clone(), a.clone(), a.clone(), true);
let sm_scores = softmax_3d(&scores);
// Words corresponding to the input
let words = ["the", "cat", "sat", "on", "the", "mat"];

display_attention_weights(sm_scores.slice(s![0, .., ..]).to_owned(), &words);
}
fn display_attention_weights(scores: Array2<f32>, words: &[&str]) {
println!("Attention Weights (Softmax Scores):\n");

// Print column headers
print!("{:<6}", ""); // Empty corner for alignment
for word in words {
print!(" {:<5}", word);
}
println!(); // New line for clarity

// Iterate through rows and display each with the corresponding word
for (i, row) in scores.outer_iter().enumerate() {
print!("{:<6}", words[i]); // Row label
for &val in row.iter() {
print!("{:<6.3}", val); // Print score with 3 decimal places
}
println!(); // New line after each row
}
example();
}
8 changes: 8 additions & 0 deletions src/math/linear_algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ pub fn tensor_product(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {

result
}
/// Applies a linear projection to a 3D tensor using a weight matrix.
///
/// # Arguments
/// - `x`: The input 3D tensor (e.g., [batch, seq_len, input_dim]).
/// - `w`: The weight matrix for the projection (e.g., [input_dim, output_dim]).
///
/// # Returns
/// A new 3D tensor with the projection applied (e.g., [batch, seq_len, output_dim]).
pub fn apply_projection(x: &Array3<f32>, w: &Array2<f32>) -> Array3<f32> {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
Expand Down

0 comments on commit 57765e1

Please sign in to comment.