From 403a21ff69993f6c2c001cbf046196b7bc852f8f Mon Sep 17 00:00:00 2001 From: Jakub Date: Sun, 8 Dec 2024 16:00:46 +0100 Subject: [PATCH] implemented a nice display for attention weight based on sample inputs --- src/main.rs | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/main.rs b/src/main.rs index b776082..5a2f14e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,42 @@ +use ndarray::{array, s, Array2, Array3}; +use Transformer::attention::scaled_dot_attention::scaled_dot_product; +use Transformer::attention::softmax::softmax_3d; + fn main() { println!("runs successfully!"); + + let a: Array3 = array![[ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + [1.0, 1.1, 1.2], + [0.1, 0.2, 0.3], + [1.3, 1.4, 1.5] + ]]; + + let scores = scaled_dot_product(a.clone(), a.clone(), a.clone(), false); + let sm_scores = softmax_3d(&scores); + // Words corresponding to the input + let words = ["the", "cat", "sat", "on", "the", "mat"]; + + display_attention_weights(sm_scores.slice(s![0, .., ..]).to_owned(), &words); +} +fn display_attention_weights(scores: Array2, words: &[&str]) { + println!("Attention Weights (Softmax Scores):\n"); + + // Print column headers + print!("{:<6}", ""); // Empty corner for alignment + for word in words { + print!(" {:<5}", word); + } + println!(); // New line for clarity + + // Iterate through rows and display each with the corresponding word + for (i, row) in scores.outer_iter().enumerate() { + print!("{:<6}", words[i]); // Row label + for &val in row.iter() { + print!("{:<6.3}", val); // Print score with 3 decimal places + } + println!(); // New line after each row + } }