Skip to content

Commit

Permalink
Implement realnum::rand (#251)
Browse files Browse the repository at this point in the history
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>

* Implement rand. Use the new derive [#default]
* Use custom range
* Use range seed
* Bump version
* Add array length checks for
  • Loading branch information
Mec-iS authored Mar 20, 2023
1 parent 7d059c4 commit f498f96
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "smartcore"
description = "Machine Learning in Rust."
homepage = "https://smartcorelib.org"
version = "0.3.0"
version = "0.3.1"
authors = ["smartcore Developers"]
edition = "2021"
license = "Apache-2.0"
Expand Down
9 changes: 2 additions & 7 deletions src/algorithm/neighbour/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,15 @@ pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch,
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
#[default]
CoverTree,
}

impl Default for KNNAlgorithmName {
fn default() -> Self {
KNNAlgorithmName::CoverTree
}
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
Expand Down
2 changes: 1 addition & 1 deletion src/cluster/dbscan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//!
//! Example:
//!
//! ```
//! ```ignore
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::basic::arrays::Array2;
//! use smartcore::cluster::dbscan::*;
Expand Down
30 changes: 29 additions & 1 deletion src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,12 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
y: &Y,
parameters: RandomForestClassifierParameters,
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
let (_, num_attributes) = x.shape();
let (x_nrows, num_attributes) = x.shape();
let y_ncols = y.shape();
if x_nrows != y_ncols {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}

let mut yi: Vec<usize> = vec![0; y_ncols];
let classes = y.unique();

Expand Down Expand Up @@ -678,6 +682,30 @@ mod tests {
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
}

#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);

let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];

let fail = RandomForestClassifier::fit(
&x_rand,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
);

assert!(fail.is_err());
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down
30 changes: 30 additions & 0 deletions src/ensemble/random_forest_regressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
let (n_rows, num_attributes) = x.shape();

if n_rows != y.shape() {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}

let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
Expand Down Expand Up @@ -595,6 +599,32 @@ mod tests {
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
}

#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);

let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];

let fail = RandomForestRegressor::fit(
&x_rand,
&y,
RandomForestRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
m: Option::None,
keep_samples: false,
seed: 87,
},
);

assert!(fail.is_err());
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down
2 changes: 1 addition & 1 deletion src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub enum FailedError {
DecompositionFailed,
/// Can't solve for x
SolutionFailed,
/// Erro in input
/// Error in input parameters
ParametersError,
}

Expand Down
9 changes: 2 additions & 7 deletions src/linear/logistic_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,14 @@ use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Clone, Eq, PartialEq, Default)]
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
pub enum LogisticRegressionSolverName {
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
#[default]
LBFGS,
}

impl Default for LogisticRegressionSolverName {
fn default() -> Self {
LogisticRegressionSolverName::LBFGS
}
}

/// Logistic Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
Expand Down
9 changes: 2 additions & 7 deletions src/linear/ridge_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,16 @@ use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Clone, Eq, PartialEq, Default)]
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
pub enum RidgeRegressionSolverName {
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
#[default]
Cholesky,
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD,
}

impl Default for RidgeRegressionSolverName {
fn default() -> Self {
RidgeRegressionSolverName::Cholesky
}
}

/// Ridge Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
Expand Down
9 changes: 2 additions & 7 deletions src/neighbors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,15 @@ pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;

/// Weight function that is used to determine estimated value.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum KNNWeightFunction {
/// All k nearest points are weighted equally
#[default]
Uniform,
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
Distance,
}

impl Default for KNNWeightFunction {
fn default() -> Self {
KNNWeightFunction::Uniform
}
}

impl KNNWeightFunction {
fn calc_weights(&self, distances: Vec<f64>) -> std::vec::Vec<f64> {
match *self {
Expand Down
29 changes: 26 additions & 3 deletions src/numbers/realnum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
//! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, ℝ.
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use num_traits::Float;

use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;

/// Defines real number
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
Expand Down Expand Up @@ -63,8 +67,12 @@ impl RealNumber for f64 {
}

fn rand() -> f64 {
// TODO: to be implemented, see issue smartcore#214
1.0
let mut small_rng = get_rng_impl(None);

let mut rngs: Vec<SmallRng> = (0..3)
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
.collect();
rngs[0].gen::<f64>()
}

fn two() -> Self {
Expand Down Expand Up @@ -108,7 +116,12 @@ impl RealNumber for f32 {
}

fn rand() -> f32 {
1.0
let mut small_rng = get_rng_impl(None);

let mut rngs: Vec<SmallRng> = (0..3)
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
.collect();
rngs[0].gen::<f32>()
}

fn two() -> Self {
Expand Down Expand Up @@ -149,4 +162,14 @@ mod tests {
fn f64_from_string() {
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
}

#[test]
fn f64_rand() {
f64::rand();
}

#[test]
fn f32_rand() {
f32::rand();
}
}
26 changes: 18 additions & 8 deletions src/tree/decision_tree_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,24 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
self.classes.as_ref()
}
/// Get depth of tree
fn depth(&self) -> u16 {
pub fn depth(&self) -> u16 {
self.depth
}
}

/// The function to measure the quality of a split.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum SplitCriterion {
/// [Gini index](../decision_tree_classifier/index.html)
#[default]
Gini,
/// [Entropy](../decision_tree_classifier/index.html)
Entropy,
/// [Classification error](../decision_tree_classifier/index.html)
ClassificationError,
}

impl Default for SplitCriterion {
fn default() -> Self {
SplitCriterion::Gini
}
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
Expand Down Expand Up @@ -543,6 +538,10 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
parameters: DecisionTreeClassifierParameters,
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}

let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
Expand Down Expand Up @@ -968,6 +967,17 @@ mod tests {
);
}

#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);

let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];

let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default());

assert!(fail.is_err());
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down
5 changes: 4 additions & 1 deletion src/tree/decision_tree_regressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
//! Example:
//!
//! ```
//! use rand::thread_rng;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::tree::decision_tree_regressor::*;
//!
Expand Down Expand Up @@ -422,6 +421,10 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
parameters: DecisionTreeRegressorParameters,
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}

let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
Expand Down

0 comments on commit f498f96

Please sign in to comment.