Skip to content

Commit

Permalink
Merge pull request #75 from nbigaouette/update-dependencies
Browse files Browse the repository at this point in the history
Upgrade dependencies
  • Loading branch information
nbigaouette authored Apr 11, 2021
2 parents 850d8f1 + 9253687 commit e6125ad
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Fix Windows i686 compilation ([#70](https://github.com/nbigaouette/onnxruntime-rs/pull/70))
- Upgrade dependencies ([#75](https://github.com/nbigaouette/onnxruntime-rs/pull/75))

## [0.0.11] - 2021-02-22

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ keywords = ["neuralnetworks", "onnx", "bindings"]

[build-dependencies]
bindgen = { version = "0.55", optional = true }
ureq = "1.5.1"
ureq = "2.1"

# Used on Windows
zip = "0.5"
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,9 @@ where
P: AsRef<Path>,
{
let resp = ureq::get(source_url)
.timeout_connect(1_000) // 1 second
.timeout(std::time::Duration::from_secs(300))
.call();

if resp.error() {
panic!("ERROR: Failed to download {}: {:#?}", source_url, resp);
}
.call()
.unwrap_or_else(|err| panic!("ERROR: Failed to download {}: {:?}", source_url, err));

let len = resp
.header("Content-Length")
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ name = "integration_tests"
required-features = ["model-fetching"]

[dependencies]
onnxruntime-sys = {version = "0.0.11", path = "../onnxruntime-sys"}
onnxruntime-sys = { version = "0.0.11", path = "../onnxruntime-sys" }

lazy_static = "1.4"
ndarray = "0.13"
ndarray = "0.15"
thiserror = "1.0"
tracing = "0.1"

# Enabled with 'model-fetching' feature
ureq = {version = "1.5.1", optional = true}
ureq = { version = "2.1", optional = true }

[dev-dependencies]
image = "0.23"
test-env-log = {version = "0.2", default-features = false, features = ["trace"]}
test-env-log = { version = "0.2", default-features = false, features = ["trace"] }
tracing-subscriber = "0.2"
ureq = "1.5.1"
ureq = "2.1"

[features]
# Fetch model from ONNX Model Zoo (https://github.com/onnx/models)
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ impl AvailableOnnxModel {
);

let resp = ureq::get(url)
.timeout_connect(1_000) // 1 second
.timeout(Duration::from_secs(180)) // 3 minutes
.call();
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::UreqError)?;

assert!(resp.has("Content-Length"));
let len = resp
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ pub enum OrtDownloadError {
/// Generic input/output error
#[error("Error downloading data to file: {0}")]
IoError(#[from] io::Error),
#[cfg(feature = "model-fetching")]
/// Download error by ureq
#[error("Error downloading data to file: {0}")]
UreqError(#[from] Box<ureq::Error>),
/// Error getting content-length from an HTTP GET request
#[error("Error getting content-length")]
ContentLengthError,
Expand Down
13 changes: 9 additions & 4 deletions onnxruntime/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::{
time::Duration,
};

use onnxruntime::error::OrtDownloadError;

mod download {
use super::*;

Expand Down Expand Up @@ -294,16 +296,17 @@ mod download {
}
}

fn get_imagenet_labels() -> Result<Vec<String>, io::Error> {
fn get_imagenet_labels() -> Result<Vec<String>, OrtDownloadError> {
// Download the ImageNet class labels, matching SqueezeNet's classes.
let labels_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("synset.txt");
if !labels_path.exists() {
let url = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt";
println!("Downloading {:?} to {:?}...", url, labels_path);
let resp = ureq::get(url)
.timeout_connect(1_000) // 1 second
.timeout(Duration::from_secs(180)) // 3 minutes
.call();
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::UreqError)?;

assert!(resp.has("Content-Length"));
let len = resp
Expand All @@ -323,5 +326,7 @@ fn get_imagenet_labels() -> Result<Vec<String>, io::Error> {
}
let file = BufReader::new(fs::File::open(labels_path).unwrap());

file.lines().collect()
file.lines()
.map(|line| line.map_err(|io_err| OrtDownloadError::IoError(io_err)))
.collect()
}

0 comments on commit e6125ad

Please sign in to comment.