From 554987945c8f0ca847d300023c36c7637e064542 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 5 Dec 2024 21:48:20 +0100 Subject: [PATCH] implemented 3D Tensor Multiplication using slices iterating over each batch in batchsize --- src/math/linear_algebra.rs | 40 ++++++++++++++++++++++++++++++++- src/math/positional_encoding.rs | 12 ++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/math/linear_algebra.rs b/src/math/linear_algebra.rs index 8cf4595..2d82645 100644 --- a/src/math/linear_algebra.rs +++ b/src/math/linear_algebra.rs @@ -1,5 +1,5 @@ use ndarray::linalg::general_mat_mul; -use ndarray::{Array1, Array2}; +use ndarray::{s, Array1, Array2, Array3}; /// Performs matrix multiplication between two 2D arrays. /// @@ -21,3 +21,41 @@ pub fn matmul(a: &Array2, b: &Array2) -> Result, &'static pub fn dotproduct(a: &Array1, b: &Array1) -> f32 { a.dot(b) } + +/// Computes the tensor product (batched matrix multiplication) of two 3D tensors `a` and `b`. +/// +/// # Parameters: +/// - `a`: 3D tensor of shape (batch_size, m, k). +/// - `b`: 3D tensor of shape (batch_size, k, n). +/// +/// # Returns: +/// A 3D tensor of shape (batch_size, m, n) containing the result of the batched matrix multiplication. +/// +/// # Panics: +/// - If the batch sizes of `a` and `b` don't match. +/// - If the inner dimensions (`k` in `a` and `b`) don't align for matrix multiplication. +pub fn tensor_product(a: &Array3, b: &Array3) -> Array3 { + // Check that batch sizes match and if dimension align + assert_eq!(a.shape()[0], b.shape()[0], "Batch sizes must match"); + assert_eq!(a.shape()[2], b.shape()[1], "Inner dimensions must align"); + + let batch_size = a.shape()[0]; + let m = a.shape()[1]; // Number of rows in each matrix in a. + let n = b.shape()[2]; // Number of columns in each matrix in b. + + // Initialize a 3D tensor for the result, filled with zeros. + // Its shape corresponds to (batch_size, m, n). + let mut result = Array3::::zeros((batch_size, m, n)); + + for i in 0..batch_size { + // - `s![i, .., ..]` selects the `i`th matrix (2D slice) in the batch. + + let a_slice = a.slice(s![i, .., ..]); + let b_slice = b.slice(s![i, .., ..]); + let mut result_slice = result.slice_mut(s![i, .., ..]); // Mutable slice of the result matrix for this batch. + + general_mat_mul(1.0, &a_slice, &b_slice, 0.0, &mut result_slice); + } + + result +} diff --git a/src/math/positional_encoding.rs b/src/math/positional_encoding.rs index e96b8ae..190466d 100644 --- a/src/math/positional_encoding.rs +++ b/src/math/positional_encoding.rs @@ -1,3 +1,15 @@ +/// Computes the sinusoidal positional encoding for a given position and dimension. +/// +/// This encoding is used in Transformer models to represent token positions +/// in a sequence. It alternates between sine and cosine based on the dimension index. +/// +/// # Arguments +/// - `pos` - Token position in the sequence (must be >= 0). +/// - `index` - Dimension index. +/// - `embedding_size` - Dimensionality of the embedding space. +/// +/// # Returns +/// The positional encoding value (as `f32`). pub fn sinusoidal_pos_encoding(pos: usize, index: usize, embedding_size: usize) -> f32 { if pos == 0 { return 0.0;