Skip to content

Commit 8121669

Browse files
committed
Format.
1 parent e6c1777 commit 8121669

File tree

10 files changed

+46
-55
lines changed

10 files changed

+46
-55
lines changed

src/bin.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ fn start(args: Args) -> Void {
463463
klib::ml::train::run_training::<Autodiff<NdArray<f32>>>(device, &config, true, true)?;
464464
}
465465
_ => {
466-
return Err(anyhow::Error::msg("Invalid device (must choose either `gpu` [requires `ml_gpu` feature], `wgpu` [requires `ml_gpu` feature] or `cpu`)."));
466+
return Err(anyhow::Error::msg(
467+
"Invalid device (must choose either `gpu` [requires `ml_gpu` feature], `wgpu` [requires `ml_gpu` feature] or `cpu`).",
468+
));
467469
}
468470
}
469471
}

src/core/pitch.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ pub trait HasFrequency {
3232
/// Essentially, mid way between the frequency and the next frequency on either side.
3333
fn frequency_range(&self) -> (f32, f32) {
3434
let frequency = self.frequency();
35-
35+
3636
(frequency * (1.0 - 1.0 / 17.462 / 2.0), frequency * (1.0 + 1.0 / 16.8196 / 2.0))
3737
}
3838

3939
/// Returns the tight frequency range of the type (usually a [`Note`]).
4040
/// Essentially, 1/8 the way between the frequency and the next frequency on either side.
4141
fn tight_frequency_range(&self) -> (f32, f32) {
4242
let frequency = self.frequency();
43-
43+
4444
(frequency * (1.0 - 1.0 / 17.462 / 8.0), frequency * (1.0 + 1.0 / 16.8196 / 8.0))
4545
}
4646
}

src/ml/base/data.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use burn::tensor::{backend::Backend, Data, Tensor};
44

55
use super::{
6-
helpers::{get_deterministic_guess, mel_filter_banks_from, u128_to_binary, note_binned_convolution},
6+
helpers::{get_deterministic_guess, mel_filter_banks_from, note_binned_convolution, u128_to_binary},
77
KordItem, INPUT_SPACE_SIZE, NUM_CLASSES,
88
};
99

src/ml/base/gather.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
33
use std::path::Path;
44

5+
use crate::core::{
6+
base::{Parsable, Void},
7+
note::{HasNoteId, Note},
8+
};
59
use crate::{
610
analyze::base::{get_frequency_space, get_smoothed_frequency_space},
711
ml::base::{KordItem, FREQUENCY_SPACE_SIZE},
812
};
9-
use crate::core::{
10-
base::{Parsable, Void},
11-
note::{HasNoteId, Note},
12-
};
1313

1414
use crate::analyze::mic::get_audio_data_from_microphone;
1515

src/ml/base/helpers.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@ use std::{
88
path::{Path, PathBuf},
99
};
1010

11-
use burn::{tensor::{backend::Backend, Tensor}, module::Module};
11+
use burn::{
12+
module::Module,
13+
tensor::{backend::Backend, Tensor},
14+
};
1215
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
1316

1417
use crate::{
1518
analyze::base::get_notes_from_smoothed_frequency_space,
1619
core::{
1720
base::Res,
1821
helpers::{inv_mel, mel},
19-
note::{HasNoteId, Note, ALL_PITCH_NOTES_WITH_FREQUENCY}, pitch::HasFrequency,
22+
note::{HasNoteId, Note, ALL_PITCH_NOTES_WITH_FREQUENCY},
23+
pitch::HasFrequency,
2024
},
2125
};
2226

@@ -113,7 +117,7 @@ pub fn note_binned_convolution(spectrum: &[f32]) -> [f32; NUM_CLASSES] {
113117

114118
for (note, _) in ALL_PITCH_NOTES_WITH_FREQUENCY.iter().skip(7).take(90) {
115119
let id_index = note.id_index();
116-
120+
117121
let (low, high) = note.tight_frequency_range();
118122
let low = low.round() as usize;
119123
let high = high.round() as usize;
@@ -137,13 +141,7 @@ pub fn note_binned_convolution(spectrum: &[f32]) -> [f32; NUM_CLASSES] {
137141
pub fn harmonic_convolution(spectrum: &[f32]) -> [f32; FREQUENCY_SPACE_SIZE] {
138142
let mut harmonic_convolution = [0f32; FREQUENCY_SPACE_SIZE];
139143

140-
let (peak, _) = spectrum.iter().enumerate().fold((0usize, 0f32), |(k, max), (j, x)| {
141-
if *x > max {
142-
(j, *x)
143-
} else {
144-
(k, max)
145-
}
146-
});
144+
let (peak, _) = spectrum.iter().enumerate().fold((0usize, 0f32), |(k, max), (j, x)| if *x > max { (j, *x) } else { (k, max) });
147145

148146
for center in (peak / 2)..4000 {
149147
let mut sum = spectrum[center];
@@ -239,4 +237,4 @@ impl<B: Backend> Sigmoid<B> {
239237
//let scaled = input;
240238
scaled.clone().exp().div(scaled.exp().add_scalar(1.0))
241239
}
242-
}
240+
}

src/ml/base/mlp.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! Multilayer Perceptron module.
22
33
use burn::{
4-
nn::{self, LayerNormConfig, LayerNorm},
5-
tensor::{backend::Backend, Tensor}, module::Module,
4+
module::Module,
5+
nn::{self, LayerNorm, LayerNormConfig},
6+
tensor::{backend::Backend, Tensor},
67
};
78

89
/// Multilayer Perceptron module.
@@ -28,12 +29,7 @@ impl<B: Backend> Mlp<B> {
2829
let dropout = nn::DropoutConfig::new(mlp_dropout).init();
2930
let activation = nn::ReLU::new();
3031

31-
Self {
32-
linears,
33-
norm,
34-
dropout,
35-
activation
36-
}
32+
Self { linears, norm, dropout, activation }
3733
}
3834

3935
/// Applies the forward pass on the input tensor.

src/ml/base/model.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33
use core::f32;
44

55
use burn::{
6-
nn::{self, attention::{MultiHeadAttentionConfig, MultiHeadAttention, MhaInput}},
7-
tensor::{backend::Backend, Tensor}, module::Module,
6+
module::Module,
7+
nn::{
8+
self,
9+
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
10+
},
11+
tensor::{backend::Backend, Tensor},
812
};
913

1014
use super::{helpers::Sigmoid, INPUT_SPACE_SIZE, NUM_CLASSES};
1115

1216
#[cfg(feature = "ml_train")]
13-
use crate::ml::train::{
14-
data::KordBatch,
15-
helpers::KordClassificationOutput,
16-
};
17+
use crate::ml::train::{data::KordBatch, helpers::KordClassificationOutput};
1718

1819
/// The Kord model.
19-
///
20+
///
2021
/// This model is a transformer model that uses multi-head attention to classify notes from a frequency space.
2122
#[derive(Module, Debug)]
2223
pub struct KordModel<B: Backend> {
@@ -32,11 +33,7 @@ impl<B: Backend> KordModel<B> {
3233
let output = nn::LinearConfig::new(INPUT_SPACE_SIZE, NUM_CLASSES).init::<B>();
3334
let sigmoid = Sigmoid::new(sigmoid_strength);
3435

35-
Self {
36-
mha,
37-
output,
38-
sigmoid,
39-
}
36+
Self { mha, output, sigmoid }
4037
}
4138

4239
/// Applies the forward pass on the input tensor.
@@ -50,14 +47,14 @@ impl<B: Backend> KordModel<B> {
5047

5148
// Reshape the output to remove the sequence dimension.
5249
let mut x = attn.context.reshape([batch_size, input_size]);
53-
50+
5451
// Perform the final linear layer to map to the output dimensions.
5552
x = self.output.forward(x);
5653

5754
// Apply the sigmoid function to the output to achieve multi-classification.
5855
x = self.sigmoid.forward(x);
5956

60-
x
57+
x
6158
}
6259

6360
/// Applies the forward classification pass on the input tensor.
@@ -90,4 +87,4 @@ impl<B: Backend> KordModel<B> {
9087

9188
KordClassificationOutput { loss, output, targets }
9289
}
93-
}
90+
}

src/ml/infer/execute.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
use burn::{
44
config::Config,
55
module::Module,
6-
tensor::backend::Backend, record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
6+
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
7+
tensor::backend::Backend,
78
};
89
use burn_ndarray::{NdArray, NdArrayDevice};
910
use serde::{de::DeserializeOwned, Serialize};

src/ml/train/execute.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
use std::sync::Arc;
44

55
use burn::{
6+
backend::Autodiff,
67
config::Config,
78
data::dataloader::DataLoaderBuilder,
9+
lr_scheduler::constant::ConstantLr,
810
module::Module,
911
optim::{decay::WeightDecayConfig, AdamConfig},
10-
backend::Autodiff,
12+
record::{BinFileRecorder, FullPrecisionSettings, Recorder},
1113
tensor::backend::{AutodiffBackend, Backend},
12-
train::{metric::LossMetric, LearnerBuilder}, lr_scheduler::constant::ConstantLr, record::{BinFileRecorder, FullPrecisionSettings, Recorder},
14+
train::{metric::LossMetric, LearnerBuilder},
1315
};
1416
use serde::{de::DeserializeOwned, Serialize};
1517

@@ -31,8 +33,8 @@ use super::{
3133
use crate::ml::base::TrainConfig;
3234

3335
/// Run the training.
34-
///
35-
/// Given the [`TrainConfig`], this function will run the training and return the overall accuracy on
36+
///
37+
/// Given the [`TrainConfig`], this function will run the training and return the overall accuracy on
3638
/// the validation / test set.
3739
pub fn run_training<B: AutodiffBackend>(device: B::Device, config: &TrainConfig, print_accuracy_report: bool, save_model: bool) -> Res<f32>
3840
where
@@ -157,7 +159,7 @@ pub fn compute_overall_accuracy<B: Backend>(model_trained: &KordModel<B>, device
157159
}
158160

159161
/// Run hyper parameter tuning.
160-
///
162+
///
161163
///This method sweeps through the hyper parameters and runs training for each combination. The best
162164
/// hyper parameters are then printed at the end.
163165
#[coverage(off)]

src/ml/train/helpers.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ use burn::{
88
},
99
train::{
1010
metric::{
11-
MetricMetadata,
1211
state::{FormatOptions, NumericMetricState},
13-
Adaptor, LossInput, Metric, MetricEntry, Numeric,
12+
Adaptor, LossInput, Metric, MetricEntry, MetricMetadata, Numeric,
1413
},
1514
TrainOutput, TrainStep, ValidStep,
1615
},
@@ -24,11 +23,7 @@ use crate::{
2423
note::{HasNoteId, Note, ALL_PITCH_NOTES},
2524
pitch::HasFrequency,
2625
},
27-
ml::base::{
28-
helpers::load_kord_item,
29-
model::KordModel,
30-
KordItem, FREQUENCY_SPACE_SIZE, NUM_CLASSES,
31-
},
26+
ml::base::{helpers::load_kord_item, model::KordModel, KordItem, FREQUENCY_SPACE_SIZE, NUM_CLASSES},
3227
};
3328

3429
use super::data::KordBatch;

0 commit comments

Comments
 (0)