Skip to content

Commit

Permalink
implemented 3D Tensor Multiplication using slices iterating over each…
Browse files Browse the repository at this point in the history
… batch in batchsize
  • Loading branch information
JakubSchwenkbeck committed Dec 5, 2024
1 parent b774014 commit 5549879
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
40 changes: 39 additions & 1 deletion src/math/linear_algebra.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ndarray::linalg::general_mat_mul;
use ndarray::{Array1, Array2};
use ndarray::{s, Array1, Array2, Array3};

/// Performs matrix multiplication between two 2D arrays.
///
Expand All @@ -21,3 +21,41 @@ pub fn matmul(a: &Array2<f32>, b: &Array2<f32>) -> Result<Array2<f32>, &'static
pub fn dotproduct(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
a.dot(b)
}

/// Computes the tensor product (batched matrix multiplication) of two 3D tensors `a` and `b`.
///
/// # Parameters:
/// - `a`: 3D tensor of shape (batch_size, m, k).
/// - `b`: 3D tensor of shape (batch_size, k, n).
///
/// # Returns:
/// A 3D tensor of shape (batch_size, m, n) containing the result of the batched matrix multiplication.
///
/// # Panics:
/// - If the batch sizes of `a` and `b` don't match.
/// - If the inner dimensions (`k` in `a` and `b`) don't align for matrix multiplication.
pub fn tensor_product(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {
// Check that batch sizes match and if dimension align
assert_eq!(a.shape()[0], b.shape()[0], "Batch sizes must match");
assert_eq!(a.shape()[2], b.shape()[1], "Inner dimensions must align");

let batch_size = a.shape()[0];
let m = a.shape()[1]; // Number of rows in each matrix in a.
let n = b.shape()[2]; // Number of columns in each matrix in b.

// Initialize a 3D tensor for the result, filled with zeros.
// Its shape corresponds to (batch_size, m, n).
let mut result = Array3::<f32>::zeros((batch_size, m, n));

for i in 0..batch_size {
// - `s![i, .., ..]` selects the `i`th matrix (2D slice) in the batch.

let a_slice = a.slice(s![i, .., ..]);
let b_slice = b.slice(s![i, .., ..]);
let mut result_slice = result.slice_mut(s![i, .., ..]); // Mutable slice of the result matrix for this batch.

general_mat_mul(1.0, &a_slice, &b_slice, 0.0, &mut result_slice);
}

result
}
12 changes: 12 additions & 0 deletions src/math/positional_encoding.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
/// Computes the sinusoidal positional encoding for a given position and dimension.
///
/// This encoding is used in Transformer models to represent token positions
/// in a sequence. It alternates between sine and cosine based on the dimension index.
///
/// # Arguments
/// - `pos` - Token position in the sequence (must be >= 0).
/// - `index` - Dimension index.
/// - `embedding_size` - Dimensionality of the embedding space.
///
/// # Returns
/// The positional encoding value (as `f32`).
pub fn sinusoidal_pos_encoding(pos: usize, index: usize, embedding_size: usize) -> f32 {
if pos == 0 {
return 0.0;
Expand Down

0 comments on commit 5549879

Please sign in to comment.