Skip to content

Commit d7a670b

Browse files
committed
Update predicted
1 parent 6f063da commit d7a670b

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/calc.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,10 @@ pub fn logistic_regression_irls(xs: MatRef<'_, f64>, ys: &[f64]) -> LogisticMode
670670
let r2 = R2Simd::new(ys, &mu).calculate();
671671
let adj_r2 = calculate_adj_r2(r2, ys.len(), xs.ncols());
672672

673+
if should_disable_predicted() {
674+
mu = Vec::new();
675+
}
676+
673677
LogisticModel {
674678
slopes,
675679
intercept,
@@ -769,15 +773,20 @@ pub fn logistic_regression_newton_raphson(xs: MatRef<'_, f64>, ys: &[f64]) -> Lo
769773
}
770774
beta.copy_from_slice(beta_new.try_as_slice().unwrap());
771775
}
772-
let predicted = (&x * faer::col::from_slice(beta.as_slice()))
773-
.try_as_slice()
774-
.unwrap()
775-
.iter()
776-
.map(|x| logistic(*x))
777-
.collect();
778776
let r2 = R2Simd::new(ys, &mu).calculate();
779777
let adj_r2 = calculate_adj_r2(r2, ys.len(), xs.ncols());
780778

779+
let predicted = if should_disable_predicted() {
780+
Vec::new()
781+
} else {
782+
(&x * faer::col::from_slice(beta.as_slice()))
783+
.try_as_slice()
784+
.unwrap()
785+
.iter()
786+
.map(|x| logistic(*x))
787+
.collect()
788+
};
789+
781790
LogisticModel {
782791
predicted,
783792
intercept: beta[x.ncols() - 1],

0 commit comments

Comments
 (0)