Skip to content

Commit

Permalink
Impl linear_fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmech committed Jul 31, 2023
1 parent 6318341 commit 67b7f4d
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
2 changes: 2 additions & 0 deletions russell_lab/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ mod constants;
mod enums;
mod formatters;
mod generators;
mod linear_fitting;
pub mod math;
mod matrix;
mod matvec;
Expand All @@ -54,6 +55,7 @@ use crate::constants::*;
pub use crate::enums::*;
pub use crate::formatters::*;
pub use crate::generators::*;
pub use crate::linear_fitting::*;
pub use crate::matrix::*;
pub use crate::matvec::*;
pub use crate::read_table::*;
Expand Down
109 changes: 109 additions & 0 deletions russell_lab/src/linear_fitting.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use crate::{StrError, Vector};

/// Calculates the parameters of a linear model using least squares fitting
///
/// # Input
///
/// `x` -- the X-data vector with dimension n
/// `y` -- the Y-data vector with dimension n
/// `pass_through_zero` -- compute the parameters such that the line passes through zero (c = 0)
///
/// # Output
///
/// * `(c, m)` -- the y(x=0)=c intersect and the slope m
///
/// NOTE: this function returns `(0.0, f64::INFINITY)` in two situations:
///
/// * If `pass_through_zero == True` and `sum(X) == 0`
/// * If `pass_through_zero == False` and the line is vertical (null denominator)
pub fn linear_fitting(x: &Vector, y: &Vector, pass_through_zero: bool) -> Result<(f64, f64), StrError> {
// dimension
let nn = x.dim();
if y.dim() != nn {
return Err("vectors must have the same dimension");
}

// sums
let mut sum_x = 0.0;
let mut sum_y = 0.0;
let mut sum_xy = 0.0;
let mut sum_xx = 0.0;
for i in 0..nn {
sum_x += x[i];
sum_y += y[i];
sum_xy += x[i] * y[i];
sum_xx += x[i] * x[i];
}

// calculate parameters
let c;
let m;
let n = nn as f64;
if pass_through_zero {
if sum_xx == 0.0 {
return Ok((0.0, f64::INFINITY));
}
c = 0.0;
m = sum_xy / sum_xx;
} else {
let den = sum_x * sum_x - n * sum_xx;
println!("den = {}", den);
if den == 0.0 {
return Ok((0.0, f64::INFINITY));
}
c = (sum_x * sum_xy - sum_xx * sum_y) / den;
m = (sum_x * sum_y - n * sum_xy) / den;
}

// results
Ok((c, m))
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

#[cfg(test)]
mod tests {
use super::linear_fitting;
use crate::Vector;
use russell_chk::approx_eq;

#[test]
fn linear_fitting_handles_errors() {
let x = Vector::from(&[1.0, 2.0]);
let y = Vector::from(&[6.0, 5.0, 7.0, 10.0]);
assert_eq!(
linear_fitting(&x, &y, false).err(),
Some("vectors must have the same dimension")
);
}

#[test]
fn linear_fitting_works() {
let x = Vector::from(&[1.0, 2.0, 3.0, 4.0]);
let y = Vector::from(&[6.0, 5.0, 7.0, 10.0]);

let (c, m) = linear_fitting(&x, &y, false).unwrap();
assert_eq!(c, 3.5);
assert_eq!(m, 1.4);

let (c, m) = linear_fitting(&x, &y, true).unwrap();
assert_eq!(c, 0.0);
approx_eq(m, 2.566666666666667, 1e-16);
}

#[test]
fn linear_fitting_handles_division_by_zero() {
let x = Vector::from(&[1.0, 1.0, 1.0, 1.0]);
let y = Vector::from(&[1.0, 2.0, 3.0, 4.0]);

let (c, m) = linear_fitting(&x, &y, false).unwrap();
assert_eq!(c, 0.0);
assert_eq!(m, f64::INFINITY);

let x = Vector::from(&[0.0, 0.0, 0.0, 0.0]);
let y = Vector::from(&[1.0, 2.0, 3.0, 4.0]);
let (c, m) = linear_fitting(&x, &y, true).unwrap();
assert_eq!(c, 0.0);
assert_eq!(m, f64::INFINITY);
}
}

0 comments on commit 67b7f4d

Please sign in to comment.