Skip to content

Commit

Permalink
add benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
droundy committed Jul 10, 2024
1 parent 854cb54 commit 5f0fa9e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
6 changes: 6 additions & 0 deletions algorithms/linfa-logistic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ linfa-datasets = { version = "0.7.0", path = "../../datasets", features = [
"winequality",
] }
rmp-serde = "1"
criterion = "0.4.0"
rand = "0.8.5"

[[bench]]
name = "logistic_bench"
harness = false
59 changes: 59 additions & 0 deletions algorithms/linfa-logistic/benches/logistic_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use linfa::prelude::*;
use ndarray::{Array1, Ix1};
use rand::{Rng, SeedableRng};

const MAX_ITERATIONS: u64 = 2;

fn train_model(
dataset: &Dataset<f32, bool, Ix1>,
) -> linfa_logistic::FittedLogisticRegression<f32, bool> {
linfa_logistic::LogisticRegression::default()
.max_iterations(MAX_ITERATIONS)
.fit(dataset)
.unwrap()
}

fn generate_categorical_data(nfeatures: usize, nsamples: usize) -> Dataset<f32, bool, Ix1> {
let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
let mut feature_rows: Vec<Vec<f32>> = Vec::new();
let mut label_rows: Vec<bool> = Vec::new();
for _ in 0..nsamples {
let mut features = Vec::new();
for _ in 0..nfeatures {
let value = if rng.gen() { 1.0 } else { 0.0 };
features.push(value);
}
feature_rows.push(features);
label_rows.push(rng.gen());
}
linfa::Dataset::new(
ndarray::Array2::from_shape_vec(
(nsamples, nfeatures),
feature_rows.into_iter().flatten().collect(),
)
.unwrap(),
Array1::from_shape_vec(label_rows.len(), label_rows).unwrap(),
)
}

fn bench(c: &mut Criterion) {
let mut group = c.benchmark_group("Logistic regression");
group.measurement_time(std::time::Duration::from_secs(10)).sample_size(10);
for nfeatures in [1_000] {
for nsamples in [1_000, 10_000, 100_000, 200_000, 500_000, 1_000_000] {
let input = generate_categorical_data(nfeatures, nsamples);
group.bench_with_input(
BenchmarkId::new("train_model", format!("{:e}x{:e}", nfeatures as f64, nsamples as f64)),
&input,
|b, dataset| {
b.iter(|| train_model(dataset));
},
);
}
}
group.finish();
}

criterion_group!(benches, bench);
criterion_main!(benches);

0 comments on commit 5f0fa9e

Please sign in to comment.