Skip to content

Commit

Permalink
wip: restarted gmres impl
Browse files Browse the repository at this point in the history
  • Loading branch information
wgurecky committed Mar 13, 2024
1 parent f2c0c3f commit caa3a25
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ name = "faer_gmres"
path = 'src/lib.rs'

[dependencies]
thiserror = "1.0"
assert_approx_eq = "1.1.0"
num-traits = "0.2.18"
faer = {version = "0.18.2"}
122 changes: 120 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Basic GMRES implementation from the wiki:
// https://en.wikipedia.org/wiki/Generalized_minimal_residual_method
//
// Includes restarted GMRES implementation for reduced memory requirements.
//
// Uses the Faer library for sparse matricies and sparse solver.
//
// Specifically the givens_rotation, apply_givens_rotation and part of the
Expand Down Expand Up @@ -33,6 +35,18 @@ use faer::prelude::*;
use faer::sparse::*;
use faer::mat;
use num_traits::Float;
use thiserror::Error;

#[derive(Error, Debug)]
pub struct GmresError<T>
where
T: faer::RealField + Float
{
cur_x: Mat<T>,
error: T,
tol: T,
msg: String,
}


/// Calculate the givens rotation matrix
Expand Down Expand Up @@ -100,7 +114,7 @@ pub fn gmres<T>(
x: MatRef<T>,
max_iter: usize,
threshold: T,
) -> Result<(Mat<T>, T, usize), String>
) -> Result<(Mat<T>, T, usize), GmresError<T>>
where
T: faer::RealField + Float
{
Expand Down Expand Up @@ -174,8 +188,59 @@ pub fn gmres<T>(
let h_qr = h_sprs.sp_qr().unwrap();
let y = h_qr.solve(&beta.get(0..k_iters+1, 0..1));

let sol = x.as_ref() + q_sprs * y;
if error <= threshold {
Ok((x.as_ref() + q_sprs * y, error, k_iters))
Ok((sol, error, k_iters))
} else {
Err(GmresError{
cur_x: sol,
error: error,
tol: threshold,
msg: "GMRES did not converge. Error: {:?}. Threshold: {:?}".to_string()}
)
}
}

/// Restarted Generalized minimal residual method
pub fn restarted_gmres<T>(
a: SparseColMatRef<usize, T>,
b: MatRef<T>,
x: MatRef<T>,
max_iter_inner: usize,
max_iter_outer: usize,
threshold: T,
) -> Result<(Mat<T>, T, usize), String>
where
T: faer::RealField + Float
{
let mut res_x = x.to_owned();
let mut error = T::from(1e20).unwrap();
let mut tot_iters = 0;
let mut iters = 0;
for _ko in 0..max_iter_outer {
let res = gmres(
a.as_ref(), b.as_ref(), res_x.as_ref(), max_iter_inner, threshold);
match res {
// done
Ok(res) => {
(res_x, error, iters) = res;
tot_iters += iters;
break;
}
// failed to converge move to next outer iter
// store current solution for next outer iter
Err(res) => {
res_x = res.cur_x;
error = res.error;
tot_iters += max_iter_inner;
}
}
if error <= threshold {
break;
}
}
if error <= threshold {
Ok((res_x, error, tot_iters))
} else {
Err(format!(
"GMRES did not converge. Error: {:?}. Threshold: {:?}",
Expand Down Expand Up @@ -336,6 +401,59 @@ mod test_faer_gmres {
assert_approx_eq!(res_x.read(4, 0), 0.292447, 1e-4);
}


#[test]
fn test_restarted_gmres_4() {
let a: Mat<f32> = faer::mat![
[0.888641, 0.477151, 0.764081, 0.244348, 0.662542],
[0.695741, 0.991383, 0.800932, 0.089616, 0.250400],
[0.149974, 0.584978, 0.937576, 0.870798, 0.990016],
[0.429292, 0.459984, 0.056629, 0.567589, 0.048561],
[0.454428, 0.253192, 0.173598, 0.321640, 0.632031],
];

let mut a_test_triplets = vec![];
for i in 0..a.nrows() {
for j in 0..a.ncols() {
a_test_triplets.push((i, j, a.read(i, j)));
}
}
let a_test = SparseColMat::<usize, f32>::try_new_from_triplets(
5, 5,
&a_test_triplets).unwrap();

// rhs
let b: Mat<f32> = faer::mat![
[0.104594],
[0.437549],
[0.040264],
[0.298842],
[0.254451]
];

// initia sol guess
let x0: Mat<f32> = faer::mat![
[0.0],
[0.0],
[0.0],
[0.0],
[0.0],
];

let (res_x, err, iters) = restarted_gmres(
a_test.as_ref(), b.as_ref(), x0.as_ref(), 3, 30, 1e-6).unwrap();
println!("Result x: {:?}", res_x);
println!("Error x: {:?}", err);
println!("Iters : {:?}", iters);
assert!(err < 1e-4);
assert!(iters < 100);
assert_approx_eq!(res_x.read(0, 0), 0.037919, 1e-4);
assert_approx_eq!(res_x.read(1, 0), 0.888551, 1e-4);
assert_approx_eq!(res_x.read(2, 0), -0.657575, 1e-4);
assert_approx_eq!(res_x.read(3, 0), -0.181680, 1e-4);
assert_approx_eq!(res_x.read(4, 0), 0.292447, 1e-4);
}

#[test]
fn test_arnoldi() {
}
Expand Down

0 comments on commit caa3a25

Please sign in to comment.