Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
quietlychris authored Mar 30, 2024
2 parents 6b9c2a4 + a29abe8 commit 23407e3
Show file tree
Hide file tree
Showing 17 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checking.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- 1.67.0
- 1.70.0
- stable
- nightly
os:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/codequality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
toolchain:
- 1.67.0
- 1.70.0
- stable

steps:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- 1.67.0
- 1.70.0
- stable
os:
- ubuntu-latest
Expand All @@ -35,7 +35,7 @@ jobs:
fail-fast: false
matrix:
toolchain:
- 1.67.0
- 1.70.0
- stable
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion algorithms/linfa-bayes/src/gaussian_nb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ where
let nclass = xclass.nrows();

// We compute the update of the gaussian mean and variance
let mut class_info = model
let class_info = model
.class_info
.entry(class)
.or_insert_with(GaussianClassInfo::default);
Expand Down
2 changes: 1 addition & 1 deletion algorithms/linfa-bayes/src/multinomial_nb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ where
let nclass = xclass.nrows();

// We compute the feature log probabilities and feature counts on the slice corresponding to the current class
let mut class_info = model
let class_info = model
.class_info
.entry(class)
.or_insert_with(MultinomialClassInfo::default);
Expand Down
14 changes: 7 additions & 7 deletions algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::gaussian_mixture::errors::{GmmError, Result};
use crate::gaussian_mixture::errors::GmmError;
use crate::gaussian_mixture::hyperparams::{
GmmCovarType, GmmInitMethod, GmmParams, GmmValidParams,
};
Expand Down Expand Up @@ -126,7 +126,7 @@ impl<F: Float> GaussianMixtureModel<F> {
hyperparameters: &GmmValidParams<F, R>,
dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
mut rng: R,
) -> Result<GaussianMixtureModel<F>> {
) -> Result<GaussianMixtureModel<F>, GmmError> {
let observations = dataset.records().view();
let n_samples = observations.nrows();

Expand Down Expand Up @@ -216,7 +216,7 @@ impl<F: Float> GaussianMixtureModel<F> {
resp: &Array2<F>,
_covar_type: &GmmCovarType,
reg_covar: F,
) -> Result<(Array1<F>, Array2<F>, Array3<F>)> {
) -> Result<(Array1<F>, Array2<F>, Array3<F>), GmmError> {
let nk = resp.sum_axis(Axis(0));
if nk.min()? < &(F::cast(10.) * F::epsilon()) {
return Err(GmmError::EmptyCluster(format!(
Expand Down Expand Up @@ -255,7 +255,7 @@ impl<F: Float> GaussianMixtureModel<F> {

fn compute_precisions_cholesky_full<D: Data<Elem = F>>(
covariances: &ArrayBase<D, Ix3>,
) -> Result<Array3<F>> {
) -> Result<Array3<F>, GmmError> {
let n_clusters = covariances.shape()[0];
let n_features = covariances.shape()[1];
let mut precisions_chol = Array::zeros((n_clusters, n_features, n_features));
Expand Down Expand Up @@ -290,7 +290,7 @@ impl<F: Float> GaussianMixtureModel<F> {
fn e_step<D: Data<Elem = F>>(
&self,
observations: &ArrayBase<D, Ix2>,
) -> Result<(F, Array2<F>)> {
) -> Result<(F, Array2<F>), GmmError> {
let (log_prob_norm, log_resp) = self.estimate_log_prob_resp(observations);
let log_mean = log_prob_norm.mean().unwrap();
Ok((log_mean, log_resp))
Expand All @@ -301,7 +301,7 @@ impl<F: Float> GaussianMixtureModel<F> {
reg_covar: F,
observations: &ArrayBase<D, Ix2>,
log_resp: &Array2<F>,
) -> Result<()> {
) -> Result<(), GmmError> {
let n_samples = observations.nrows();
let (weights, means, covariances) = Self::estimate_gaussian_parameters(
observations,
Expand Down Expand Up @@ -407,7 +407,7 @@ impl<F: Float, R: Rng + Clone, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, G
{
type Object = GaussianMixtureModel<F>;

fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, GmmError> {
let observations = dataset.records().view();
let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;

Expand Down
1 change: 0 additions & 1 deletion algorithms/linfa-clustering/src/gaussian_mixture/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use linfa_linalg::LinalgError;
#[cfg(feature = "blas")]
use ndarray_linalg::error::LinalgError;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, GmmError>;

/// An error when modeling a GMM algorithm
#[derive(Error, Debug)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::gaussian_mixture::errors::{GmmError, Result};
use crate::gaussian_mixture::errors::GmmError;
use ndarray_rand::rand::{Rng, SeedableRng};
use rand_xoshiro::Xoshiro256Plus;
#[cfg(feature = "serde")]
Expand Down Expand Up @@ -170,7 +170,7 @@ impl<F: Float, R: Rng> ParamGuard for GmmParams<F, R> {
type Checked = GmmValidParams<F, R>;
type Error = GmmError;

fn check_ref(&self) -> Result<&Self::Checked> {
fn check_ref(&self) -> Result<&Self::Checked, GmmError> {
if self.0.n_clusters == 0 {
Err(GmmError::InvalidValue(
"`n_clusters` cannot be 0!".to_string(),
Expand All @@ -194,7 +194,7 @@ impl<F: Float, R: Rng> ParamGuard for GmmParams<F, R> {
}
}

fn check(self) -> Result<Self::Checked> {
fn check(self) -> Result<Self::Checked, GmmError> {
self.check_ref()?;
Ok(self.0)
}
Expand Down
1 change: 0 additions & 1 deletion algorithms/linfa-clustering/src/optics/errors.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use thiserror::Error;
pub type Result<T> = std::result::Result<T, OpticsError>;

/// An error when performing OPTICS Analysis
#[derive(Error, Debug)]
Expand Down
6 changes: 3 additions & 3 deletions algorithms/linfa-clustering/src/optics/hyperparams.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::optics::errors::{OpticsError, Result};
use crate::optics::errors::OpticsError;
use linfa::{param_guard::TransformGuard, Float, ParamGuard};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
Expand Down Expand Up @@ -91,7 +91,7 @@ impl<F: Float, D, N> ParamGuard for OpticsParams<F, D, N> {
type Checked = OpticsValidParams<F, D, N>;
type Error = OpticsError;

fn check_ref(&self) -> Result<&Self::Checked> {
fn check_ref(&self) -> Result<&Self::Checked, OpticsError> {
if self.0.tolerance <= F::zero() {
Err(OpticsError::InvalidValue(
"`tolerance` must be greater than 0!".to_string(),
Expand All @@ -106,7 +106,7 @@ impl<F: Float, D, N> ParamGuard for OpticsParams<F, D, N> {
}
}

fn check(self) -> Result<Self::Checked> {
fn check(self) -> Result<Self::Checked, OpticsError> {
self.check_ref()?;
Ok(self.0)
}
Expand Down
4 changes: 3 additions & 1 deletion src/benchmarks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#[cfg(feature = "benchmarks")]
pub mod config {
use criterion::{measurement::WallTime, BenchmarkGroup, Criterion};
#[cfg(not(target_os = "windows"))]
use criterion::Criterion;
use criterion::{measurement::WallTime, BenchmarkGroup};
#[cfg(not(target_os = "windows"))]
use pprof::criterion::{Output, PProfProfiler};
use std::time::Duration;
Expand Down

0 comments on commit 23407e3

Please sign in to comment.