Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework API to avoid awkward &Option #134

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ego/src/egor_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ where
yt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
init: bool,
recluster: bool,
clustering: &Option<Clustering>,
clustering: Option<&Clustering>,
model_name: &str,
) -> Box<dyn ClusteredSurrogate> {
let mut builder = self.surrogate_builder.clone();
Expand All @@ -567,9 +567,9 @@ where
);
model
} else {
let clustering = clustering.as_ref().unwrap().clone();
let clustering = clustering.unwrap();
let model = builder
.train_on_clusters(&xt.view(), &yt.view(), &clustering)
.train_on_clusters(&xt.view(), &yt.view(), clustering)
.expect("GP training failure");
model
}
Expand Down Expand Up @@ -615,7 +615,7 @@ where
&yt.slice(s![.., k..k + 1]).to_owned(),
init && i == 0,
recluster,
&clusterings[k],
clusterings[k].as_ref(),
&name,
)
})
Expand Down
4 changes: 2 additions & 2 deletions gp/src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GpValidParam
}

/// Get number of components used by PLS
pub fn kpls_dim(&self) -> &Option<usize> {
&self.kpls_dim
pub fn kpls_dim(&self) -> Option<&usize> {
self.kpls_dim.as_ref()
}

/// Get the number of internal optimization restart
Expand Down
8 changes: 4 additions & 4 deletions gp/src/sparse_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ impl<F: Float, Corr: CorrelationModel<F>> SgpValidParams<F, Corr> {
}

/// Get the number of components used by PLS
pub fn kpls_dim(&self) -> &Option<usize> {
&self.gp_params.kpls_dim
pub fn kpls_dim(&self) -> Option<&usize> {
self.gp_params.kpls_dim.as_ref()
}

/// Get starting theta value for optimization
Expand Down Expand Up @@ -121,8 +121,8 @@ impl<F: Float, Corr: CorrelationModel<F>> SgpValidParams<F, Corr> {
}

/// Get seed
pub fn seed(&self) -> &Option<u64> {
&self.seed
pub fn seed(&self) -> Option<&u64> {
self.seed.as_ref()
}
}

Expand Down
3 changes: 1 addition & 2 deletions moe/src/clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,14 @@ pub fn find_best_number_of_clusters<R: Rng + Clone>(
.ok();

if let Some(gmm) = maybe_gmm {
let gmm = Box::new(gmm);
// Cross Validation
for (train, valid) in dataset.fold(5).into_iter() {
if let Ok(mixture) = GpMixParams::default()
.n_clusters(n_clusters)
.regression_spec(regression_spec)
.correlation_spec(correlation_spec)
.kpls_dim(kpls_dim)
.gmm(Some(gmm.clone()))
.gmm(gmm.clone())
.fit(&train)
{
let xytrain =
Expand Down
2 changes: 1 addition & 1 deletion moe/src/gp_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl<R: Rng + SeedableRng + Clone> GpMixValidParams<f64, R> {
let dataset = Dataset::from(training);

let gmx = if self.gmx().is_some() {
*self.gmx().as_ref().unwrap().clone()
self.gmx().unwrap().clone()
} else {
trace!("GMM training...");
let gmm = GaussianMixtureModel::params(n_clusters)
Expand Down
24 changes: 11 additions & 13 deletions moe/src/gp_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ pub struct GpMixValidParams<F: Float, R: Rng + Clone> {
/// Number of GP hyperparameters optimization restarts
n_start: usize,
/// Gaussian Mixture model used to cluster
gmm: Option<Box<GaussianMixtureModel<F>>>,
gmm: Option<GaussianMixtureModel<F>>,
/// GaussianMixture preset
gmx: Option<Box<GaussianMixture<F>>>,
gmx: Option<GaussianMixture<F>>,
/// Random number generator
rng: R,
}
Expand Down Expand Up @@ -100,13 +100,13 @@ impl<F: Float, R: Rng + Clone> GpMixValidParams<F, R> {

/// An optional gaussian mixture to be fitted to generate multivariate normal
/// in turns used to cluster
pub fn gmm(&self) -> &Option<Box<GaussianMixtureModel<F>>> {
&self.gmm
pub fn gmm(&self) -> Option<&GaussianMixtureModel<F>> {
self.gmm.as_ref()
}

/// An optional multivariate normal used to cluster (take precedence over gmm)
pub fn gmx(&self) -> &Option<Box<GaussianMixture<F>>> {
&self.gmx
pub fn gmx(&self) -> Option<&GaussianMixture<F>> {
self.gmx.as_ref()
}

/// The random generator
Expand Down Expand Up @@ -236,8 +236,8 @@ impl<F: Float, R: Rng + Clone> GpMixParams<F, R> {

#[doc(hidden)]
/// Sets the gaussian mixture (used to find the optimal number of clusters)
pub(crate) fn gmm(mut self, gmm: Option<Box<GaussianMixtureModel<F>>>) -> Self {
self.0.gmm = gmm;
pub(crate) fn gmm(mut self, gmm: GaussianMixtureModel<F>) -> Self {
self.0.gmm = Some(gmm);
self
}

Expand All @@ -246,9 +246,7 @@ impl<F: Float, R: Rng + Clone> GpMixParams<F, R> {
/// Warning: no consistency check is done on the given initialization data
/// *Panic* if multivariate normal init data not sound
pub fn gmx(mut self, weights: Array1<F>, means: Array2<F>, covariances: Array3<F>) -> Self {
self.0.gmx = Some(Box::new(
GaussianMixture::new(weights, means, covariances).unwrap(),
));
self.0.gmx = Some(GaussianMixture::new(weights, means, covariances).unwrap());
self
}

Expand All @@ -262,8 +260,8 @@ impl<F: Float, R: Rng + Clone> GpMixParams<F, R> {
kpls_dim: self.0.kpls_dim(),
theta_tuning: self.0.theta_tuning().clone(),
n_start: self.0.n_start(),
gmm: self.0.gmm().clone(),
gmx: self.0.gmx().clone(),
gmm: self.0.gmm().cloned(),
gmx: self.0.gmx().cloned(),
rng,
})
}
Expand Down
2 changes: 1 addition & 1 deletion moe/src/sgp_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<R: Rng + SeedableRng + Clone> SparseGpMixtureValidParams<f64, R> {
let dataset = Dataset::from(training);

let gmx = if self.gmx().is_some() {
*self.gmx().as_ref().unwrap().clone()
self.gmx().unwrap().clone()
} else {
debug!("GMM training...");
let gmm = GaussianMixtureModel::params(n_clusters)
Expand Down
20 changes: 9 additions & 11 deletions moe/src/sgp_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ pub struct SparseGpMixtureValidParams<F: Float, R: Rng + Clone> {
/// Inducings
inducings: Inducings<F>,
/// Gaussian Mixture model used to cluster
gmm: Option<Box<GaussianMixtureModel<F>>>,
gmm: Option<GaussianMixtureModel<F>>,
/// GaussianMixture preset
gmx: Option<Box<GaussianMixture<F>>>,
gmx: Option<GaussianMixture<F>>,
/// Random number generator
rng: R,
}
Expand Down Expand Up @@ -116,13 +116,13 @@ impl<F: Float, R: Rng + Clone> SparseGpMixtureValidParams<F, R> {

/// An optional gaussian mixture to be fitted to generate multivariate normal
/// in turns used to cluster
pub fn gmm(&self) -> &Option<Box<GaussianMixtureModel<F>>> {
&self.gmm
pub fn gmm(&self) -> Option<&GaussianMixtureModel<F>> {
self.gmm.as_ref()
}

/// An optional multivariate normal used to cluster (take precedence over gmm)
pub fn gmx(&self) -> &Option<Box<GaussianMixture<F>>> {
&self.gmx
pub fn gmx(&self) -> Option<&GaussianMixture<F>> {
self.gmx.as_ref()
}

/// The random generator
Expand Down Expand Up @@ -270,9 +270,7 @@ impl<F: Float, R: Rng + SeedableRng + Clone> SparseGpMixtureParams<F, R> {
/// Warning: no consistency check is done on the given initialization data
/// *Panic* if multivariate normal init data not sound
pub fn gmx(mut self, weights: Array1<F>, means: Array2<F>, covariances: Array3<F>) -> Self {
self.0.gmx = Some(Box::new(
GaussianMixture::new(weights, means, covariances).unwrap(),
));
self.0.gmx = Some(GaussianMixture::new(weights, means, covariances).unwrap());
self
}

Expand All @@ -288,8 +286,8 @@ impl<F: Float, R: Rng + SeedableRng + Clone> SparseGpMixtureParams<F, R> {
n_start: self.0.n_start(),
sparse_method: self.0.sparse_method(),
inducings: self.0.inducings().clone(),
gmm: self.0.gmm().clone(),
gmx: self.0.gmx().clone(),
gmm: self.0.gmm().cloned(),
gmx: self.0.gmx().cloned(),
rng,
})
}
Expand Down
Loading