Skip to content

Commit

Permalink
Impl validate and mat_vec_mul for CSC and CSR
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmech committed Sep 24, 2023
1 parent 4d8ccc7 commit af35186
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 95 deletions.
179 changes: 131 additions & 48 deletions russell_sparse/src/csc_matrix.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{handle_umfpack_error_code, to_i32, CooMatrix, Symmetry};
use crate::StrError;
use russell_lab::Matrix;
use russell_lab::{Matrix, Vector};

extern "C" {
fn umfpack_coo_to_csc(
Expand Down Expand Up @@ -81,39 +81,39 @@ pub struct CscMatrix {
}

impl CscMatrix {
/// Allocates an empty CscMatrix
/// Validates the dimension of the arrays in the CSC matrix
///
/// This function simply allocates the following arrays:
/// The following conditions must be satisfied:
///
/// * col_pointers: `vec![0; ncol + 1]`
/// * row_indices: `vec![0; nnz]`
/// * values: `vec![0.0; nnz]`
///
/// # Examples
///
/// ```
/// use russell_sparse::prelude::*;
/// use russell_sparse::StrError;
///
/// fn main() {
/// let csc = CscMatrix::new(None, 3, 3, 4);
/// assert_eq!(csc.symmetry, None);
/// assert_eq!(csc.nrow, 3);
/// assert_eq!(csc.ncol, 3);
/// assert_eq!(csc.col_pointers, &[0, 0, 0, 0]);
/// assert_eq!(csc.row_indices, &[0, 0, 0, 0]);
/// assert_eq!(csc.values, &[0.0, 0.0, 0.0, 0.0]);
/// }
/// ```text
/// nrow ≥ 1
/// ncol ≥ 1
/// nnz = col_pointers[ncol] ≥ 1
/// col_pointers.len() == ncol + 1
/// row_indices.len() == nnz
/// values.len() == nnz
/// ```
pub fn new(symmetry: Option<Symmetry>, nrow: usize, ncol: usize, nnz: usize) -> Self {
CscMatrix {
symmetry,
nrow,
ncol,
col_pointers: vec![0; ncol + 1],
row_indices: vec![0; nnz],
values: vec![0.0; nnz],
pub fn validate(&self) -> Result<(), StrError> {
if self.nrow < 1 {
return Err("nrow must be ≥ 1");
}
if self.ncol < 1 {
return Err("ncol must be ≥ 1");
}
if self.col_pointers.len() != self.ncol + 1 {
return Err("col_pointers.len() must be = ncol + 1");
}
let nnz = self.col_pointers[self.ncol];
if nnz < 1 {
return Err("nnz = col_pointers[ncol] must be ≥ 1");
}
if self.row_indices.len() != nnz as usize {
return Err("row_indices.len() must be = nnz");
}
if self.values.len() != nnz as usize {
return Err("values.len() must be = nnz");
}
Ok(())
}

/// Creates a new CscMatrix from a CooMatrix
Expand Down Expand Up @@ -336,6 +336,7 @@ impl CscMatrix {
/// }
/// ```
pub fn to_matrix(&self, a: &mut Matrix) -> Result<(), StrError> {
self.validate()?;
let (m, n) = a.dims();
if m != self.nrow || n != self.ncol {
return Err("wrong matrix dimensions");
Expand All @@ -356,6 +357,46 @@ impl CscMatrix {
}
Ok(())
}

/// Performs the matrix-vector multiplication
///
/// ```text
/// v := α ⋅ a ⋅ u
/// (m) (m,n) (n)
/// ```
///
/// # Input
///
/// * `u` -- Vector with dimension equal to the number of columns of the matrix
///
/// # Output
///
/// * `v` -- Vector with dimension equal to the number of rows of the matrix
pub fn mat_vec_mul(&self, v: &mut Vector, alpha: f64, u: &Vector) -> Result<(), StrError> {
self.validate()?;
if u.dim() != self.ncol {
return Err("u.ndim must equal ncol");
}
if v.dim() != self.nrow {
return Err("v.ndim must equal nrow");
}
let mirror_required = match self.symmetry {
Some(sym) => sym.triangular(),
None => false,
};
v.fill(0.0);
for j in 0..self.ncol {
for p in self.col_pointers[j]..self.col_pointers[j + 1] {
let i = self.row_indices[p as usize] as usize;
let aij = self.values[p as usize];
v[i] += alpha * aij * u[j];
if mirror_required && i != j {
v[j] += alpha * aij * u[i];
}
}
}
Ok(())
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -365,18 +406,7 @@ mod tests {
use super::CscMatrix;
use crate::{CooMatrix, Storage, Symmetry};
use russell_chk::vec_approx_eq;
use russell_lab::Matrix;

#[test]
fn new_works() {
let csc = CscMatrix::new(None, 3, 3, 4);
assert_eq!(csc.symmetry, None);
assert_eq!(csc.nrow, 3);
assert_eq!(csc.ncol, 3);
assert_eq!(csc.col_pointers, &[0, 0, 0, 0]);
assert_eq!(csc.row_indices, &[0, 0, 0, 0]);
assert_eq!(csc.values, &[0.0, 0.0, 0.0, 0.0]);
}
use russell_lab::{Matrix, Vector};

#[test]
fn csc_matrix_first_triplet_with_shuffled_entries() {
Expand Down Expand Up @@ -784,16 +814,39 @@ mod tests {

#[test]
fn to_matrix_fails_on_wrong_dims() {
let sym = Some(Symmetry::General(Storage::Upper));
let csc = CscMatrix::new(sym, 2, 2, 3);
let mut a_2x1 = Matrix::new(3, 1);
let mut a_1x2 = Matrix::new(1, 3);
assert_eq!(csc.to_matrix(&mut a_2x1), Err("wrong matrix dimensions"));
assert_eq!(csc.to_matrix(&mut a_1x2), Err("wrong matrix dimensions"));
// 10.0 20.0 << (1 x 2) matrix
let csc = CscMatrix {
symmetry: None,
nrow: 1,
ncol: 2,
col_pointers: vec![0, 1, 2],
row_indices: vec![0, 0],
values: vec![10.0, 20.0],
};
let mut a_3x1 = Matrix::new(3, 1);
let mut a_1x3 = Matrix::new(1, 3);
assert_eq!(csc.to_matrix(&mut a_3x1), Err("wrong matrix dimensions"));
assert_eq!(csc.to_matrix(&mut a_1x3), Err("wrong matrix dimensions"));
}

#[test]
fn to_matrix_and_as_matrix_work() {
// 10.0 20.0 << (1 x 2) matrix
let csc = CscMatrix {
symmetry: None,
nrow: 1,
ncol: 2,
col_pointers: vec![0, 1, 2],
row_indices: vec![0, 0],
values: vec![10.0, 20.0],
};
let mut a = Matrix::new(1, 2);
csc.to_matrix(&mut a).unwrap();
let correct = "┌ ┐\n\
│ 10 20 │\n\
└ ┘";
assert_eq!(format!("{}", a), correct);

let csc = CscMatrix {
symmetry: None,
nrow: 5,
Expand Down Expand Up @@ -878,4 +931,34 @@ mod tests {
└ ┘";
assert_eq!(format!("{}", a), correct);
}

#[test]
fn mat_vec_mul_works() {
// 5.0, -2.0, 0.0, 1.0,
// 10.0, -4.0, 0.0, 2.0,
// 15.0, -6.0, 0.0, 3.0,
let csc = CscMatrix {
symmetry: None,
nrow: 3,
ncol: 4,
col_pointers: vec![0, 3, 6, 6, 9],
row_indices: vec![
0, 1, 2, // j=0, p=(0),1,2
0, 1, 2, // j=1, p=(3),4,5
// j=2, p=(6)
0, 1, 2, // j=3, p=(6),7,8
], // (9)
values: vec![
5.0, 10.0, 15.0, // j=0, p=(0),1,2
-2.0, -4.0, -6.0, // j=1, p=(3),4,5
// j=2, p=(6)
1.0, 2.0, 3.0, // j=3, p=(6),7,8
], // (9)
};
let u = Vector::from(&[1.0, 3.0, 8.0, 5.0]);
let mut v = Vector::new(csc.nrow);
csc.mat_vec_mul(&mut v, 1.0, &u).unwrap();
let correct = &[4.0, 8.0, 12.0];
vec_approx_eq(v.as_data(), correct, 1e-15);
}
}
Loading

0 comments on commit af35186

Please sign in to comment.