Skip to content

Commit

Permalink
refactor: calibration for resources and accuracy over same scale range (
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Feb 11, 2024
1 parent e077168 commit 97d9832
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
5 changes: 1 addition & 4 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,7 @@ pub(crate) fn calibrate(
let range = if let Some(scales) = scales {
scales
} else {
match target {
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
}
(10..14).collect::<Vec<crate::Scale>>()
};

let div_rebasing = if let Some(div_rebasing) = div_rebasing {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ mod native_tests {
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
test_dir.close().unwrap();
}

Expand Down
12 changes: 6 additions & 6 deletions tests/python/binding_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
output_path = os.path.join(
Expand Down Expand Up @@ -147,13 +147,13 @@ def test_calibrate():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
output_path = os.path.join(
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_model_compile():
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
compiled_model_path = os.path.join(
Expand All @@ -205,7 +205,7 @@ def test_forward():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
Expand Down

0 comments on commit 97d9832

Please sign in to comment.