Skip to content

Commit

Permalink
use rayon to hopefully speed up linfa-logistic
Browse files Browse the repository at this point in the history
  • Loading branch information
droundy committed Jul 9, 2024
1 parent b807674 commit 854cb54
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion algorithms/linfa-logistic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ optional = true
version = "1.0"

[dependencies]
ndarray = { version = "0.15", features = ["approx"] }
ndarray = { version = "0.15", features = ["rayon", "approx"] }
ndarray-stats = "0.5.0"
num-traits = "0.2"
argmin = { version = "0.9.0", default-features = false }
Expand Down
11 changes: 5 additions & 6 deletions algorithms/linfa-logistic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,9 @@ fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>(
/// Computes `exp(n - max) / sum(exp(n- max))`, which is a numerically stable version of softmax
fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) {
let max = v.iter().copied().reduce(F::max).unwrap();
v.mapv_inplace(|n| (n - max).exp());
v.par_mapv_inplace(|n| (n - max).exp());
let sum = v.sum();
v.mapv_inplace(|n| n / sum);
v.par_mapv_inplace(|n| n / sum);
}

/// Computes the logistic loss assuming the training labels $y \in {-1, 1}$
Expand All @@ -479,7 +479,7 @@ fn logistic_loss<F: Float, A: Data<Elem = F>>(
let yz = x.dot(&params.into_shape((params.len(), 1)).unwrap()) + intercept;
let len = yz.len();
let mut yz = yz.into_shape(len).unwrap() * y;
yz.mapv_inplace(log_logistic);
yz.par_mapv_inplace(log_logistic);
-yz.sum() + F::cast(0.5) * alpha * params.dot(&params)
}

Expand All @@ -495,8 +495,7 @@ fn logistic_grad<F: Float, A: Data<Elem = F>>(
let yz = x.dot(&params.into_shape((params.len(), 1)).unwrap()) + intercept;
let len = yz.len();
let mut yz = yz.into_shape(len).unwrap() * y;
yz.mapv_inplace(logistic);
yz -= F::one();
yz.par_mapv_inplace(|v| logistic(v) - F::one());
yz *= y;
if w.len() == n_features + 1 {
let mut grad = Array::zeros(w.len());
Expand Down Expand Up @@ -624,7 +623,7 @@ impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
/// model was fitted.
pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array1<F> {
let mut probs = x.dot(&self.params) + self.intercept;
probs.mapv_inplace(logistic);
probs.par_mapv_inplace(logistic);
probs
}
}
Expand Down

0 comments on commit 854cb54

Please sign in to comment.