Skip to content

Commit 3da433f

Browse files
authored
Implement predict_proba for DecisionTreeClassifier (#287)
* Implement predict_proba for DecisionTreeClassifier * Some automated fixes suggested by cargo clippy --fix
1 parent 4523ac7 commit 3da433f

File tree

13 files changed

+166
-61
lines changed

13 files changed

+166
-61
lines changed

src/algorithm/neighbour/fastpair.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ mod tests_fastpair {
212212
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
213213

214214
/// Brute force algorithm, used only for comparison and testing
215-
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
215+
pub fn closest_pair_brute(
216+
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
217+
) -> PairwiseDistance<f64> {
216218
use itertools::Itertools;
217219
let m = fastpair.samples.shape().0;
218220

src/linalg/basic/matrix.rs

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
9191
}
9292
}
9393

94-
impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a, T> {
94+
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
9595
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
9696
writeln!(
9797
f,
@@ -142,7 +142,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
142142
}
143143
}
144144

145-
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &mut T> + 'b> {
145+
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
146146
let column_major = self.column_major;
147147
let stride = self.stride;
148148
let ptr = self.values.as_mut_ptr();
@@ -169,7 +169,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
169169
}
170170
}
171171

172-
impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'a, T> {
172+
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
173173
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174174
writeln!(
175175
f,
@@ -493,7 +493,7 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
493493
impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
494494
impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
495495

496-
impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'a, T> {
496+
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
497497
fn get(&self, pos: (usize, usize)) -> &T {
498498
if self.column_major {
499499
&self.values[pos.0 + pos.1 * self.stride]
@@ -515,7 +515,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
515515
}
516516
}
517517

518-
impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'a, T> {
518+
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
519519
fn get(&self, i: usize) -> &T {
520520
if self.nrows == 1 {
521521
if self.column_major {
@@ -553,11 +553,11 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<
553553
}
554554
}
555555

556-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'a, T> {}
556+
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
557557

558-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'a, T> {}
558+
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
559559

560-
impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'a, T> {
560+
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
561561
fn get(&self, pos: (usize, usize)) -> &T {
562562
if self.column_major {
563563
&self.values[pos.0 + pos.1 * self.stride]
@@ -579,9 +579,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
579579
}
580580
}
581581

582-
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
583-
for DenseMatrixMutView<'a, T>
584-
{
582+
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
585583
fn set(&mut self, pos: (usize, usize), x: T) {
586584
if self.column_major {
587585
self.values[pos.0 + pos.1 * self.stride] = x;
@@ -595,15 +593,16 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
595593
}
596594
}
597595

598-
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'a, T> {}
596+
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
599597

600-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'a, T> {}
598+
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
601599

602600
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
603601

604602
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
605603

606604
#[cfg(test)]
605+
#[warn(clippy::reversed_empty_ranges)]
607606
mod tests {
608607
use super::*;
609608
use approx::relative_eq;

src/linalg/basic/vector.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
119119
}
120120
}
121121

122-
impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T> {
122+
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
123123
fn get(&self, i: usize) -> &T {
124124
&self.ptr[i]
125125
}
@@ -138,7 +138,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T
138138
}
139139
}
140140

141-
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a, T> {
141+
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> {
142142
fn set(&mut self, i: usize, x: T) {
143143
self.ptr[i] = x;
144144
}
@@ -149,10 +149,10 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a
149149
}
150150
}
151151

152-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'a, T> {}
153-
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'a, T> {}
152+
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {}
153+
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {}
154154

155-
impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
155+
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
156156
fn get(&self, i: usize) -> &T {
157157
&self.ptr[i]
158158
}
@@ -171,7 +171,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
171171
}
172172
}
173173

174-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'a, T> {}
174+
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {}
175175

176176
#[cfg(test)]
177177
mod tests {

src/linalg/ndarray/matrix.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>
6868

6969
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
7070

71-
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'a, T, Ix2> {
71+
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> {
7272
fn get(&self, pos: (usize, usize)) -> &T {
7373
&self[[pos.0, pos.1]]
7474
}
@@ -144,11 +144,9 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
144144
impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
145145
impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
146146

147-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'a, T, Ix2> {}
147+
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
148148

149-
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
150-
for ArrayViewMut<'a, T, Ix2>
151-
{
149+
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
152150
fn get(&self, pos: (usize, usize)) -> &T {
153151
&self[[pos.0, pos.1]]
154152
}
@@ -175,9 +173,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
175173
}
176174
}
177175

178-
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
179-
for ArrayViewMut<'a, T, Ix2>
180-
{
176+
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
181177
fn set(&mut self, pos: (usize, usize), x: T) {
182178
self[[pos.0, pos.1]] = x
183179
}
@@ -195,9 +191,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
195191
}
196192
}
197193

198-
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
194+
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
199195

200-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
196+
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
201197

202198
#[cfg(test)]
203199
mod tests {

src/linalg/ndarray/vector.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>
4141

4242
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
4343

44-
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a, T, Ix1> {
44+
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> {
4545
fn get(&self, i: usize) -> &T {
4646
&self[i]
4747
}
@@ -60,9 +60,9 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a
6060
}
6161
}
6262

63-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'a, T, Ix1> {}
63+
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {}
6464

65-
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
65+
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
6666
fn get(&self, i: usize) -> &T {
6767
&self[i]
6868
}
@@ -81,7 +81,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut
8181
}
8282
}
8383

84-
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
84+
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
8585
fn set(&mut self, i: usize, x: T) {
8686
self[i] = x;
8787
}
@@ -92,8 +92,8 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<
9292
}
9393
}
9494

95-
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
96-
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
95+
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
96+
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
9797

9898
impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
9999
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {

src/linalg/traits/stats.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
142142
///
143143
/// assert_eq!(a, expected);
144144
/// ```
145-
146145
fn binarize_mut(&mut self, threshold: T) {
147146
let (nrows, ncols) = self.shape();
148147
for row in 0..nrows {

src/linear/logistic_regression.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
258258
}
259259
}
260260

261-
impl<'a, T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
262-
for BinaryObjectiveFunction<'a, T, X>
261+
impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
262+
for BinaryObjectiveFunction<'_, T, X>
263263
{
264264
fn f(&self, w_bias: &[T]) -> T {
265265
let mut f = T::zero();
@@ -313,8 +313,8 @@ struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
313313
_phantom_t: PhantomData<T>,
314314
}
315315

316-
impl<'a, T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
317-
for MultiClassObjectiveFunction<'a, T, X>
316+
impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
317+
for MultiClassObjectiveFunction<'_, T, X>
318318
{
319319
fn f(&self, w_bias: &[T]) -> T {
320320
let mut f = T::zero();

src/naive_bayes/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ mod tests {
147147
#[derive(Debug, PartialEq, Clone)]
148148
struct TestDistribution<'d>(&'d Vec<i32>);
149149

150-
impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> {
150+
impl NBDistribution<i32, i32> for TestDistribution<'_> {
151151
fn prior(&self, _class_index: usize) -> f64 {
152152
1.
153153
}

src/preprocessing/numerical.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,14 @@ where
172172
T: Number + RealNumber,
173173
M: Array2<T>,
174174
{
175-
if let Some(output_matrix) = columns.first().cloned() {
176-
return Some(
177-
columns
178-
.iter()
179-
.skip(1)
180-
.fold(output_matrix, |current_matrix, new_colum| {
181-
current_matrix.h_stack(new_colum)
182-
}),
183-
);
184-
} else {
185-
None
186-
}
175+
columns.first().cloned().map(|output_matrix| {
176+
columns
177+
.iter()
178+
.skip(1)
179+
.fold(output_matrix, |current_matrix, new_colum| {
180+
current_matrix.h_stack(new_colum)
181+
})
182+
})
187183
}
188184

189185
#[cfg(test)]

src/readers/csv.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub struct CSVDefinition<'a> {
3030
/// What seperates the fields in your csv-file?
3131
field_seperator: &'a str,
3232
}
33-
impl<'a> Default for CSVDefinition<'a> {
33+
impl Default for CSVDefinition<'_> {
3434
fn default() -> Self {
3535
Self {
3636
n_rows_header: 1,

src/svm/svc.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
360360
}
361361
}
362362

363-
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
364-
for SVC<'a, TX, TY, X, Y>
363+
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
364+
for SVC<'_, TX, TY, X, Y>
365365
{
366366
fn eq(&self, other: &Self) -> bool {
367367
if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two()
@@ -1110,7 +1110,7 @@ mod tests {
11101110
let svc = SVC::fit(&x, &y, &params).unwrap();
11111111

11121112
// serialization
1113-
let deserialized_svc: SVC<f64, i32, _, _> =
1113+
let deserialized_svc: SVC<'_, f64, i32, _, _> =
11141114
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
11151115

11161116
assert_eq!(svc, deserialized_svc);

src/svm/svr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
281281
}
282282
}
283283

284-
impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
285-
for SVR<'a, T, X, Y>
284+
impl<T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
285+
for SVR<'_, T, X, Y>
286286
{
287287
fn eq(&self, other: &Self) -> bool {
288288
if (self.b - other.b).abs() > T::epsilon() * T::two()
@@ -702,7 +702,7 @@ mod tests {
702702

703703
let svr = SVR::fit(&x, &y, &params).unwrap();
704704

705-
let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
705+
let deserialized_svr: SVR<'_, f64, DenseMatrix<f64>, _> =
706706
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
707707

708708
assert_eq!(svr, deserialized_svr);

0 commit comments

Comments
 (0)