Skip to content

Commit

Permalink
refactor gmres to accept general faer linop as input
Browse files Browse the repository at this point in the history
  • Loading branch information
wgurecky committed Aug 8, 2024
1 parent 23b3a10 commit d6b8430
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,6 @@ impl <T> fmt::Display for GmresError<T>
}
}

// pub trait LinOp<T>
// where
// T: faer::RealField + Float
// {
// fn apply_linop_to_vec(&self, target: MatRef<T>) -> Mat<T>;
// }

#[derive(Clone,Debug)]
pub struct JacobiPreconLinOp<'a, T>
where
Expand Down Expand Up @@ -189,28 +182,31 @@ fn apply_givens_rotation<T>(h: &mut Vec<T>, cs: &mut Vec<T>, sn: &mut Vec<T>, k:
/// * `m`- An optional preconditioner that is applied to the original system such that
/// the new krylov subspace built is [M^{-1}k, M^{-1}Ak, M^{-1}A^2k, ...].
/// If None, no preconditioner is applied.
fn arnoldi<'a, T>(
a: SparseColMatRef<'a, usize, T>,
fn arnoldi<'a, T, Lop: LinOp<T>>(
a: &Lop,
q: &Vec<Mat<T>>,
k: usize,
m: Option<&dyn LinOp<T>>
) -> (Vec<T>, Mat<T>)
where
T: faer::RealField + Float
{
// unused in faer LinOp apply() method
let mut _dummy_podstack: [u8;1] = [0u8;1];

// Krylov vector
let q_col: MatRef<T> = q[k].as_ref();

// let mut qv: Mat<f64> = a * q_col;
// parallel version of above
let mut qv: Mat<T> = faer::Mat::zeros(q_col.nrows(), 1);
linalg::matmul::sparse_dense_matmul(
qv.as_mut(), a.as_ref(), q_col.as_ref(), None, T::from(1.0).unwrap(), faer::get_global_parallelism());
//linalg::matmul::sparse_dense_matmul(
// qv.as_mut(), a.as_ref(), q_col.as_ref(), None, T::from(1.0).unwrap(), faer::get_global_parallelism());
a.apply(qv.as_mut(), q_col.as_ref(), faer::get_global_parallelism(), PodStack::new(&mut _dummy_podstack));

// Apply left preconditioner if supplied
match m {
Some(m) => {
let mut _dummy_podstack: [u8;1] = [0u8;1];
let mut lp_out = faer::Mat::zeros(qv.nrows(), qv.ncols());
m.apply(lp_out.as_mut(), qv.as_ref(), faer::get_global_parallelism(), PodStack::new(&mut _dummy_podstack));
qv = lp_out;
Expand All @@ -233,8 +229,8 @@ fn arnoldi<'a, T>(


/// Generalized minimal residual method
pub fn gmres<'a, T>(
a: SparseColMatRef<'a, usize, T>,
pub fn gmres<'a, T, Lop: LinOp<T>>(
a: Lop,
b: MatRef<T>,
mut x: MatMut<T>,
max_iter: usize,
Expand All @@ -245,11 +241,14 @@ pub fn gmres<'a, T>(
T: faer::RealField + Float
{
// compute initial residual
let mut r = b - a * x.as_ref();
// let mut a_x = a * x.as_ref();
let mut _dummy_podstack: [u8;1] = [0u8;1];
let mut a_x = faer::Mat::zeros(b.nrows(), b.ncols());
a.apply(a_x.as_mut(), x.as_ref(), faer::get_global_parallelism(), PodStack::new(&mut _dummy_podstack));
let mut r = b - a_x;

match &m {
Some(m) => {
let mut _dummy_podstack: [u8;1] = [0u8;1];
let mut lp_out = faer::Mat::zeros(r.nrows(), r.ncols());
(&m).apply(lp_out.as_mut(), r.as_ref(), faer::get_global_parallelism(), PodStack::new(&mut _dummy_podstack));
r = lp_out;
Expand Down Expand Up @@ -277,7 +276,7 @@ pub fn gmres<'a, T>(

let mut k_iters = 0;
for k in 0..max_iter {
let (mut hk, qk) = arnoldi(a, &qs, k, m);
let (mut hk, qk) = arnoldi(&a, &qs, k, m);
apply_givens_rotation(&mut hk, &mut cs, &mut sn, k);
hs.push(hk);
qs.push(qk);
Expand Down Expand Up @@ -328,8 +327,8 @@ pub fn gmres<'a, T>(
}

/// Restarted Generalized minimal residual method
pub fn restarted_gmres<'a, T>(
a: SparseColMatRef<'a, usize, T>,
pub fn restarted_gmres<'a, T, Lop: LinOp<T>>(
a: Lop,
b: MatRef<T>,
mut x: MatMut<T>,
max_iter_inner: usize,
Expand All @@ -345,7 +344,7 @@ pub fn restarted_gmres<'a, T>(
let mut iters = 0;
for _ko in 0..max_iter_outer {
let res = gmres(
a.as_ref(), b.as_ref(), x.rb_mut(),
&a, b.as_ref(), x.rb_mut(),
max_iter_inner, threshold, m);
match res {
// done
Expand Down

0 comments on commit d6b8430

Please sign in to comment.