From b1e95cbedd5ad6887df878fa0ff651c831d81f23 Mon Sep 17 00:00:00 2001 From: relf Date: Sun, 4 Feb 2024 21:09:45 +0100 Subject: [PATCH 1/2] Avoid reference on option --- ego/src/egor_solver.rs | 8 ++++---- gp/src/parameters.rs | 4 ++-- gp/src/sparse_parameters.rs | 8 ++++---- moe/src/gp_algorithm.rs | 2 +- moe/src/gp_parameters.rs | 12 ++++++------ moe/src/sgp_algorithm.rs | 2 +- moe/src/sgp_parameters.rs | 12 ++++++------ 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/ego/src/egor_solver.rs b/ego/src/egor_solver.rs index 8984ac41..dcd7a4cb 100644 --- a/ego/src/egor_solver.rs +++ b/ego/src/egor_solver.rs @@ -541,7 +541,7 @@ where yt: &ArrayBase, Ix2>, init: bool, recluster: bool, - clustering: &Option, + clustering: Option<&Clustering>, model_name: &str, ) -> Box { let mut builder = self.surrogate_builder.clone(); @@ -567,9 +567,9 @@ where ); model } else { - let clustering = clustering.as_ref().unwrap().clone(); + let clustering = clustering.unwrap(); let model = builder - .train_on_clusters(&xt.view(), &yt.view(), &clustering) + .train_on_clusters(&xt.view(), &yt.view(), clustering) .expect("GP training failure"); model } @@ -615,7 +615,7 @@ where &yt.slice(s![.., k..k + 1]).to_owned(), init && i == 0, recluster, - &clusterings[k], + clusterings[k].as_ref(), &name, ) }) diff --git a/gp/src/parameters.rs b/gp/src/parameters.rs index 9da31ba8..b1dea91d 100644 --- a/gp/src/parameters.rs +++ b/gp/src/parameters.rs @@ -109,8 +109,8 @@ impl, Corr: CorrelationModel> GpValidParam } /// Get number of components used by PLS - pub fn kpls_dim(&self) -> &Option { - &self.kpls_dim + pub fn kpls_dim(&self) -> Option<&usize> { + self.kpls_dim.as_ref() } /// Get the number of internal optimization restart diff --git a/gp/src/sparse_parameters.rs b/gp/src/sparse_parameters.rs index c58b5c9f..ac6719df 100644 --- a/gp/src/sparse_parameters.rs +++ b/gp/src/sparse_parameters.rs @@ -86,8 +86,8 @@ impl> SgpValidParams { } /// Get the number of components used by PLS - pub fn kpls_dim(&self) -> &Option { - &self.gp_params.kpls_dim + pub fn kpls_dim(&self) -> Option<&usize> { + self.gp_params.kpls_dim.as_ref() } /// Get starting theta value for optimization @@ -121,8 +121,8 @@ impl> SgpValidParams { } /// Get seed - pub fn seed(&self) -> &Option { - &self.seed + pub fn seed(&self) -> Option<&u64> { + self.seed.as_ref() } } diff --git a/moe/src/gp_algorithm.rs b/moe/src/gp_algorithm.rs index d1dce6ee..ea452044 100644 --- a/moe/src/gp_algorithm.rs +++ b/moe/src/gp_algorithm.rs @@ -109,7 +109,7 @@ impl GpMixValidParams { let dataset = Dataset::from(training); let gmx = if self.gmx().is_some() { - *self.gmx().as_ref().unwrap().clone() + self.gmx().unwrap().clone() } else { trace!("GMM training..."); let gmm = GaussianMixtureModel::params(n_clusters) diff --git a/moe/src/gp_parameters.rs b/moe/src/gp_parameters.rs index cfa63d5d..68f6edba 100644 --- a/moe/src/gp_parameters.rs +++ b/moe/src/gp_parameters.rs @@ -100,13 +100,13 @@ impl GpMixValidParams { /// An optional gaussian mixture to be fitted to generate multivariate normal /// in turns used to cluster - pub fn gmm(&self) -> &Option>> { - &self.gmm + pub fn gmm(&self) -> Option> { + self.gmm.as_ref().map(|gmm| *gmm.clone()) } /// An optional multivariate normal used to cluster (take precedence over gmm) - pub fn gmx(&self) -> &Option>> { - &self.gmx + pub fn gmx(&self) -> Option> { + self.gmx.as_ref().map(|gmx| *gmx.clone()) } /// The random generator @@ -262,8 +262,8 @@ impl GpMixParams { kpls_dim: self.0.kpls_dim(), theta_tuning: self.0.theta_tuning().clone(), n_start: self.0.n_start(), - gmm: self.0.gmm().clone(), - gmx: self.0.gmx().clone(), + gmm: self.0.gmm().map(Box::new), + gmx: self.0.gmx().map(Box::new), rng, }) } diff --git a/moe/src/sgp_algorithm.rs b/moe/src/sgp_algorithm.rs index 49936385..bd2c3993 100644 --- a/moe/src/sgp_algorithm.rs +++ b/moe/src/sgp_algorithm.rs @@ -110,7 +110,7 @@ impl SparseGpMixtureValidParams { let dataset = Dataset::from(training); let gmx = if self.gmx().is_some() { - *self.gmx().as_ref().unwrap().clone() + self.gmx().unwrap().clone() } else { debug!("GMM training..."); let gmm = GaussianMixtureModel::params(n_clusters) diff --git a/moe/src/sgp_parameters.rs b/moe/src/sgp_parameters.rs index ea4c7c59..cf43a5b4 100644 --- a/moe/src/sgp_parameters.rs +++ b/moe/src/sgp_parameters.rs @@ -116,13 +116,13 @@ impl SparseGpMixtureValidParams { /// An optional gaussian mixture to be fitted to generate multivariate normal /// in turns used to cluster - pub fn gmm(&self) -> &Option>> { - &self.gmm + pub fn gmm(&self) -> Option> { + self.gmm.as_ref().map(|gmm| *gmm.clone()) } /// An optional multivariate normal used to cluster (take precedence over gmm) - pub fn gmx(&self) -> &Option>> { - &self.gmx + pub fn gmx(&self) -> Option> { + self.gmx.as_ref().map(|gmx| *gmx.clone()) } /// The random generator @@ -288,8 +288,8 @@ impl SparseGpMixtureParams { n_start: self.0.n_start(), sparse_method: self.0.sparse_method(), inducings: self.0.inducings().clone(), - gmm: self.0.gmm().clone(), - gmx: self.0.gmx().clone(), + gmm: self.0.gmm().map(Box::new), + gmx: self.0.gmx().map(Box::new), rng, }) } From f4bd54a8ee3086b267e63c58e5069da02a0c1d5f Mon Sep 17 00:00:00 2001 From: relf Date: Sun, 4 Feb 2024 21:33:46 +0100 Subject: [PATCH 2/2] Remove useless box indirection --- moe/src/clustering.rs | 3 +-- moe/src/gp_parameters.rs | 24 +++++++++++------------- moe/src/sgp_parameters.rs | 20 +++++++++----------- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/moe/src/clustering.rs b/moe/src/clustering.rs index 53eb8c68..070151e3 100644 --- a/moe/src/clustering.rs +++ b/moe/src/clustering.rs @@ -119,7 +119,6 @@ pub fn find_best_number_of_clusters( .ok(); if let Some(gmm) = maybe_gmm { - let gmm = Box::new(gmm); // Cross Validation for (train, valid) in dataset.fold(5).into_iter() { if let Ok(mixture) = GpMixParams::default() @@ -127,7 +126,7 @@ pub fn find_best_number_of_clusters( .regression_spec(regression_spec) .correlation_spec(correlation_spec) .kpls_dim(kpls_dim) - .gmm(Some(gmm.clone())) + .gmm(gmm.clone()) .fit(&train) { let xytrain = diff --git a/moe/src/gp_parameters.rs b/moe/src/gp_parameters.rs index 68f6edba..0dc8d6a3 100644 --- a/moe/src/gp_parameters.rs +++ b/moe/src/gp_parameters.rs @@ -38,9 +38,9 @@ pub struct GpMixValidParams { /// Number of GP hyperparameters optimization restarts n_start: usize, /// Gaussian Mixture model used to cluster - gmm: Option>>, + gmm: Option>, /// GaussianMixture preset - gmx: Option>>, + gmx: Option>, /// Random number generator rng: R, } @@ -100,13 +100,13 @@ impl GpMixValidParams { /// An optional gaussian mixture to be fitted to generate multivariate normal /// in turns used to cluster - pub fn gmm(&self) -> Option> { - self.gmm.as_ref().map(|gmm| *gmm.clone()) + pub fn gmm(&self) -> Option<&GaussianMixtureModel> { + self.gmm.as_ref() } /// An optional multivariate normal used to cluster (take precedence over gmm) - pub fn gmx(&self) -> Option> { - self.gmx.as_ref().map(|gmx| *gmx.clone()) + pub fn gmx(&self) -> Option<&GaussianMixture> { + self.gmx.as_ref() } /// The random generator @@ -236,8 +236,8 @@ impl GpMixParams { #[doc(hidden)] /// Sets the gaussian mixture (used to find the optimal number of clusters) - pub(crate) fn gmm(mut self, gmm: Option>>) -> Self { - self.0.gmm = gmm; + pub(crate) fn gmm(mut self, gmm: GaussianMixtureModel) -> Self { + self.0.gmm = Some(gmm); self } @@ -246,9 +246,7 @@ impl GpMixParams { /// Warning: no consistency check is done on the given initialization data /// *Panic* if multivariate normal init data not sound pub fn gmx(mut self, weights: Array1, means: Array2, covariances: Array3) -> Self { - self.0.gmx = Some(Box::new( - GaussianMixture::new(weights, means, covariances).unwrap(), - )); + self.0.gmx = Some(GaussianMixture::new(weights, means, covariances).unwrap()); self } @@ -262,8 +260,8 @@ impl GpMixParams { kpls_dim: self.0.kpls_dim(), theta_tuning: self.0.theta_tuning().clone(), n_start: self.0.n_start(), - gmm: self.0.gmm().map(Box::new), - gmx: self.0.gmx().map(Box::new), + gmm: self.0.gmm().cloned(), + gmx: self.0.gmx().cloned(), rng, }) } diff --git a/moe/src/sgp_parameters.rs b/moe/src/sgp_parameters.rs index cf43a5b4..6e113b5e 100644 --- a/moe/src/sgp_parameters.rs +++ b/moe/src/sgp_parameters.rs @@ -42,9 +42,9 @@ pub struct SparseGpMixtureValidParams { /// Inducings inducings: Inducings, /// Gaussian Mixture model used to cluster - gmm: Option>>, + gmm: Option>, /// GaussianMixture preset - gmx: Option>>, + gmx: Option>, /// Random number generator rng: R, } @@ -116,13 +116,13 @@ impl SparseGpMixtureValidParams { /// An optional gaussian mixture to be fitted to generate multivariate normal /// in turns used to cluster - pub fn gmm(&self) -> Option> { - self.gmm.as_ref().map(|gmm| *gmm.clone()) + pub fn gmm(&self) -> Option<&GaussianMixtureModel> { + self.gmm.as_ref() } /// An optional multivariate normal used to cluster (take precedence over gmm) - pub fn gmx(&self) -> Option> { - self.gmx.as_ref().map(|gmx| *gmx.clone()) + pub fn gmx(&self) -> Option<&GaussianMixture> { + self.gmx.as_ref() } /// The random generator @@ -270,9 +270,7 @@ impl SparseGpMixtureParams { /// Warning: no consistency check is done on the given initialization data /// *Panic* if multivariate normal init data not sound pub fn gmx(mut self, weights: Array1, means: Array2, covariances: Array3) -> Self { - self.0.gmx = Some(Box::new( - GaussianMixture::new(weights, means, covariances).unwrap(), - )); + self.0.gmx = Some(GaussianMixture::new(weights, means, covariances).unwrap()); self } @@ -288,8 +286,8 @@ impl SparseGpMixtureParams { n_start: self.0.n_start(), sparse_method: self.0.sparse_method(), inducings: self.0.inducings().clone(), - gmm: self.0.gmm().map(Box::new), - gmx: self.0.gmx().map(Box::new), + gmm: self.0.gmm().cloned(), + gmx: self.0.gmx().cloned(), rng, }) }