Skip to content

Commit e1becf2

Browse files
committed
[Important] Remove Generic SparseMatrix and make complex sparse solvers take SparseCooMatrix instead. Thus, eliminating the memoization problem
1 parent 261800f commit e1becf2

17 files changed

+174
-1086
lines changed

russell_ode/src/radau5.rs

+10-14
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use crate::StrError;
22
use crate::{OdeSolverTrait, Params, System, Workspace};
33
use russell_lab::math::SQRT_6;
44
use russell_lab::{complex_vec_zip, cpx, format_fortran, vec_copy, Complex64, ComplexVector, Vector};
5-
use russell_sparse::{numerical_jacobian, ComplexCscMatrix, CooMatrix, CscMatrix};
6-
use russell_sparse::{ComplexLinSolver, ComplexSparseMatrix, Genie, LinSolver};
5+
use russell_sparse::{
6+
numerical_jacobian, ComplexCooMatrix, ComplexCscMatrix, ComplexLinSolver, CooMatrix, CscMatrix, Genie, LinSolver,
7+
};
78
use std::thread;
89

910
/// Implements the Radau5 method (Radau IIA) (implicit, order 5, embedded) for ODEs and DAEs
@@ -41,7 +42,7 @@ pub(crate) struct Radau5<'a, A> {
4142
kk_real: CooMatrix,
4243

4344
/// Coefficient matrix (for real system). K_comp = (α + βi) M - J
44-
kk_comp: ComplexSparseMatrix,
45+
kk_comp: ComplexCooMatrix,
4546

4647
/// Linear solver (for real system)
4748
solver_real: LinSolver<'a>,
@@ -148,7 +149,7 @@ impl<'a, A> Radau5<'a, A> {
148149
mass,
149150
jj: CooMatrix::new(ndim, ndim, jac_nnz, sym).unwrap(),
150151
kk_real: CooMatrix::new(ndim, ndim, nnz, sym).unwrap(),
151-
kk_comp: ComplexSparseMatrix::new_coo(ndim, ndim, nnz, sym).unwrap(),
152+
kk_comp: ComplexCooMatrix::new(ndim, ndim, nnz, sym).unwrap(),
152153
solver_real: LinSolver::new(params.newton.genie).unwrap(),
153154
solver_comp: ComplexLinSolver::new(params.newton.genie).unwrap(),
154155
reuse_jacobian: false,
@@ -195,7 +196,7 @@ impl<'a, A> Radau5<'a, A> {
195196
// auxiliary
196197
let jj = &mut self.jj; // J = df/dy
197198
let kk_real = &mut self.kk_real; // K_real = γ M - J
198-
let kk_comp = self.kk_comp.get_coo_mut().unwrap(); // K_comp = (α + βi) M - J
199+
let kk_comp = &mut self.kk_comp; // K_comp = (α + βi) M - J
199200

200201
// Jacobian matrix
201202
if self.reuse_jacobian {
@@ -261,7 +262,7 @@ impl<'a, A> Radau5<'a, A> {
261262
.factorize(&self.kk_real, self.params.newton.lin_sol_params)?;
262263
self.solver_comp
263264
.actual
264-
.factorize(&mut self.kk_comp, self.params.newton.lin_sol_params)
265+
.factorize(&self.kk_comp, self.params.newton.lin_sol_params)
265266
}
266267

267268
/// Factorizes the real and complex systems concurrently
@@ -276,7 +277,7 @@ impl<'a, A> Radau5<'a, A> {
276277
let handle_comp = scope.spawn(|| {
277278
self.solver_comp
278279
.actual
279-
.factorize(&mut self.kk_comp, self.params.newton.lin_sol_params)
280+
.factorize(&self.kk_comp, self.params.newton.lin_sol_params)
280281
.unwrap();
281282
});
282283
let err_real = handle_real.join();
@@ -296,9 +297,7 @@ impl<'a, A> Radau5<'a, A> {
296297
/// Solves the real and complex linear systems
297298
fn solve_lin_sys(&mut self) -> Result<(), StrError> {
298299
self.solver_real.actual.solve(&mut self.dw0, &self.v0, false)?;
299-
self.solver_comp
300-
.actual
301-
.solve(&mut self.dw12, &self.kk_comp, &self.v12, false)?;
300+
self.solver_comp.actual.solve(&mut self.dw12, &self.v12, false)?;
302301
Ok(())
303302
}
304303

@@ -309,10 +308,7 @@ impl<'a, A> Radau5<'a, A> {
309308
self.solver_real.actual.solve(&mut self.dw0, &self.v0, false).unwrap();
310309
});
311310
let handle_comp = scope.spawn(|| {
312-
self.solver_comp
313-
.actual
314-
.solve(&mut self.dw12, &self.kk_comp, &self.v12, false)
315-
.unwrap();
311+
self.solver_comp.actual.solve(&mut self.dw12, &self.v12, false).unwrap();
316312
});
317313
let err_real = handle_real.join();
318314
let err_comp = handle_comp.join();

russell_sparse/examples/complex_system.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ fn solve(genie: Genie) -> Result<(), StrError> {
4242
let nnz = 16; // number of non-zero values, including duplicates
4343

4444
// input matrix in Complex Triplet format
45-
let mut coo = ComplexSparseMatrix::new_coo(ndim, ndim, nnz, Sym::No)?;
45+
let mut coo = ComplexCooMatrix::new(ndim, ndim, nnz, Sym::No)?;
4646

4747
// first column
4848
coo.put(0, 0, cpx!(19.73, 0.00))?;
@@ -88,8 +88,8 @@ fn solve(genie: Genie) -> Result<(), StrError> {
8888

8989
// call factorize and solve
9090
let mut x = ComplexVector::new(ndim);
91-
solver.actual.factorize(&mut coo, Some(params))?;
92-
solver.actual.solve(&mut x, &coo, &b, false)?;
91+
solver.actual.factorize(&coo, Some(params))?;
92+
solver.actual.solve(&mut x, &b, false)?;
9393
println!("x =\n{}", x);
9494

9595
// check

russell_sparse/src/aliases.rs

+1-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{NumCooMatrix, NumCscMatrix, NumCsrMatrix, NumSparseMatrix};
1+
use crate::{NumCooMatrix, NumCscMatrix, NumCsrMatrix};
22
use russell_lab::Complex64;
33

44
/// Defines an alias to NumCooMatrix with f64
@@ -10,9 +10,6 @@ pub type CscMatrix = NumCscMatrix<f64>;
1010
/// Defines an alias to NumCsrMatrix with f64
1111
pub type CsrMatrix = NumCsrMatrix<f64>;
1212

13-
/// Defines an alias to NumSparseMatrix with f64
14-
pub type SparseMatrix = NumSparseMatrix<f64>;
15-
1613
/// Defines an alias to NumCooMatrix with Complex64
1714
pub type ComplexCooMatrix = NumCooMatrix<Complex64>;
1815

@@ -21,6 +18,3 @@ pub type ComplexCscMatrix = NumCscMatrix<Complex64>;
2118

2219
/// Defines an alias to NumCsrMatrix with Complex64
2320
pub type ComplexCsrMatrix = NumCsrMatrix<Complex64>;
24-
25-
/// Defines an alias to NumSparseMatrix with Complex64
26-
pub type ComplexSparseMatrix = NumSparseMatrix<Complex64>;

russell_sparse/src/bin/mem_check.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ fn test_complex_solver(genie: Genie) {
7373
Genie::Mumps => Samples::complex_symmetric_3x3_lower().0,
7474
Genie::Umfpack => Samples::complex_symmetric_3x3_full().0,
7575
};
76-
let mut mat = ComplexSparseMatrix::from_coo(coo);
7776

78-
match solver.actual.factorize(&mut mat, None) {
77+
match solver.actual.factorize(&coo, None) {
7978
Err(e) => {
8079
println!("FAIL(factorize): {}", e);
8180
return;
@@ -86,15 +85,15 @@ fn test_complex_solver(genie: Genie) {
8685
let mut x = ComplexVector::new(3);
8786
let rhs = ComplexVector::from(&[cpx!(-3.0, 3.0), cpx!(2.0, -2.0), cpx!(9.0, 7.0)]);
8887

89-
match solver.actual.solve(&mut x, &mut mat, &rhs, false) {
88+
match solver.actual.solve(&mut x, &rhs, false) {
9089
Err(e) => {
9190
println!("FAIL(solve): {}", e);
9291
return;
9392
}
9493
_ => (),
9594
}
9695

97-
match solver.actual.solve(&mut x, &mat, &rhs, false) {
96+
match solver.actual.solve(&mut x, &rhs, false) {
9897
Err(e) => {
9998
println!("FAIL(solve again): {}", e);
10099
return;

russell_sparse/src/bin/solve_matrix_market.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,8 @@ fn main() -> Result<(), StrError> {
173173
csc.write_matrix_market("/tmp/russell_sparse/solve_matrix_market_complex.smat", true)?;
174174
}
175175

176-
// save the COO matrix as a generic SparseMatrix
177-
let mut mat = ComplexSparseMatrix::from_coo(coo);
178-
179176
// save information about the matrix
180-
let (nrow, ncol, nnz, sym) = mat.get_info();
177+
let (nrow, ncol, nnz, sym) = coo.get_info();
181178
stats.set_matrix_name_from_path(&opt.matrix_market_file);
182179
stats.matrix.nrow = nrow;
183180
stats.matrix.ncol = ncol;
@@ -189,18 +186,18 @@ fn main() -> Result<(), StrError> {
189186
let mut solver = ComplexLinSolver::new(genie)?;
190187

191188
// call factorize
192-
solver.actual.factorize(&mut mat, Some(params))?;
189+
solver.actual.factorize(&coo, Some(params))?;
193190

194191
// allocate vectors
195192
let mut x = ComplexVector::new(nrow);
196193
let rhs = ComplexVector::filled(nrow, cpx!(1.0, 1.0));
197194

198195
// solve linear system
199-
solver.actual.solve(&mut x, &mat, &rhs, opt.verbose)?;
196+
solver.actual.solve(&mut x, &rhs, opt.verbose)?;
200197

201198
// verify the solution
202199
sw.reset();
203-
stats.verify = VerifyLinSys::from_complex(&mat, &x, &rhs)?;
200+
stats.verify = VerifyLinSys::from_complex(&coo, &x, &rhs)?;
204201
stats.time_nanoseconds.verify = sw.stop();
205202

206203
// update stats

russell_sparse/src/complex_lin_solver.rs

+9-19
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#[cfg(feature = "with_mumps")]
22
use super::ComplexSolverMUMPS;
33

4-
use super::{ComplexSparseMatrix, Genie, LinSolParams, StatsLinSol};
4+
use super::{ComplexCooMatrix, ComplexSolverKLU, ComplexSolverUMFPACK, Genie, LinSolParams, StatsLinSol};
55
use crate::StrError;
6-
use crate::{ComplexSolverKLU, ComplexSolverUMFPACK};
76
use russell_lab::ComplexVector;
87

98
/// Defines a unified interface for complex linear system solvers
@@ -24,7 +23,7 @@ pub trait ComplexLinSolTrait: Send {
2423
/// kept the same for the next calls.
2524
/// 3. If the structure of the matrix needs to be changed, the solver must
2625
/// be "dropped" and a new solver allocated.
27-
fn factorize(&mut self, mat: &mut ComplexSparseMatrix, params: Option<LinSolParams>) -> Result<(), StrError>;
26+
fn factorize(&mut self, mat: &ComplexCooMatrix, params: Option<LinSolParams>) -> Result<(), StrError>;
2827

2928
/// Computes the solution of the linear system
3029
///
@@ -46,13 +45,7 @@ pub trait ComplexLinSolTrait: Send {
4645
/// * `verbose` -- shows messages
4746
///
4847
/// **Warning:** the matrix must be same one used in `factorize`.
49-
fn solve(
50-
&mut self,
51-
x: &mut ComplexVector,
52-
mat: &ComplexSparseMatrix,
53-
rhs: &ComplexVector,
54-
verbose: bool,
55-
) -> Result<(), StrError>;
48+
fn solve(&mut self, x: &mut ComplexVector, rhs: &ComplexVector, verbose: bool) -> Result<(), StrError>;
5649

5750
/// Updates the stats structure (should be called after solve)
5851
fn update_stats(&self, stats: &mut StatsLinSol);
@@ -125,14 +118,14 @@ impl<'a> ComplexLinSolver<'a> {
125118
pub fn compute(
126119
genie: Genie,
127120
x: &mut ComplexVector,
128-
mat: &mut ComplexSparseMatrix,
121+
mat: &ComplexCooMatrix,
129122
rhs: &ComplexVector,
130123
params: Option<LinSolParams>,
131124
) -> Result<Self, StrError> {
132125
let mut solver = ComplexLinSolver::new(genie)?;
133126
solver.actual.factorize(mat, params)?;
134127
let verbose = if let Some(p) = params { p.verbose } else { false };
135-
solver.actual.solve(x, mat, rhs, verbose)?;
128+
solver.actual.solve(x, rhs, verbose)?;
136129
Ok(solver)
137130
}
138131
}
@@ -142,7 +135,7 @@ impl<'a> ComplexLinSolver<'a> {
142135
#[cfg(test)]
143136
mod tests {
144137
use super::ComplexLinSolver;
145-
use crate::{ComplexSparseMatrix, Genie, Samples};
138+
use crate::{Genie, Samples};
146139
use russell_lab::{complex_vec_approx_eq, cpx, Complex64, ComplexVector};
147140

148141
#[cfg(feature = "with_mumps")]
@@ -151,10 +144,9 @@ mod tests {
151144
#[test]
152145
fn complex_lin_solver_compute_works_klu() {
153146
let (coo, _, _, _) = Samples::complex_symmetric_3x3_full();
154-
let mut mat = ComplexSparseMatrix::from_coo(coo);
155147
let mut x = ComplexVector::new(3);
156148
let rhs = ComplexVector::from(&[cpx!(-3.0, 3.0), cpx!(2.0, -2.0), cpx!(9.0, 7.0)]);
157-
ComplexLinSolver::compute(Genie::Klu, &mut x, &mut mat, &rhs, None).unwrap();
149+
ComplexLinSolver::compute(Genie::Klu, &mut x, &coo, &rhs, None).unwrap();
158150
let x_correct = &[cpx!(1.0, 1.0), cpx!(2.0, -2.0), cpx!(3.0, 3.0)];
159151
complex_vec_approx_eq(&x, x_correct, 1e-15);
160152
}
@@ -164,21 +156,19 @@ mod tests {
164156
#[cfg(feature = "with_mumps")]
165157
fn complex_lin_solver_compute_works_mumps() {
166158
let (coo, _, _, _) = Samples::complex_symmetric_3x3_lower();
167-
let mut mat = ComplexSparseMatrix::from_coo(coo);
168159
let mut x = ComplexVector::new(3);
169160
let rhs = ComplexVector::from(&[cpx!(-3.0, 3.0), cpx!(2.0, -2.0), cpx!(9.0, 7.0)]);
170-
ComplexLinSolver::compute(Genie::Mumps, &mut x, &mut mat, &rhs, None).unwrap();
161+
ComplexLinSolver::compute(Genie::Mumps, &mut x, &coo, &rhs, None).unwrap();
171162
let x_correct = &[cpx!(1.0, 1.0), cpx!(2.0, -2.0), cpx!(3.0, 3.0)];
172163
complex_vec_approx_eq(&x, x_correct, 1e-15);
173164
}
174165

175166
#[test]
176167
fn complex_lin_solver_compute_works_umfpack() {
177168
let (coo, _, _, _) = Samples::complex_symmetric_3x3_full();
178-
let mut mat = ComplexSparseMatrix::from_coo(coo);
179169
let mut x = ComplexVector::new(3);
180170
let rhs = ComplexVector::from(&[cpx!(-3.0, 3.0), cpx!(2.0, -2.0), cpx!(9.0, 7.0)]);
181-
ComplexLinSolver::compute(Genie::Umfpack, &mut x, &mut mat, &rhs, None).unwrap();
171+
ComplexLinSolver::compute(Genie::Umfpack, &mut x, &coo, &rhs, None).unwrap();
182172
let x_correct = &[cpx!(1.0, 1.0), cpx!(2.0, -2.0), cpx!(3.0, 3.0)];
183173
complex_vec_approx_eq(&x, x_correct, 1e-15);
184174
}

0 commit comments

Comments
 (0)