From 14f64a00945d016629e7721b67c26e9cd4dd071d Mon Sep 17 00:00:00 2001 From: Stephen von Takach Date: Fri, 26 May 2023 00:47:45 +1000 Subject: [PATCH] feat(client): add support for downloading models --- shard.yml | 2 +- spec/tensorflow_lite_spec.cr | 14 ++++++++++++++ src/tensorflow_lite/client.cr | 24 +++++++++++++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/shard.yml b/shard.yml index 8368550..a6d946b 100644 --- a/shard.yml +++ b/shard.yml @@ -1,5 +1,5 @@ name: tensorflow_lite -version: 1.6.1 +version: 1.6.2 development_dependencies: ameba: diff --git a/spec/tensorflow_lite_spec.cr b/spec/tensorflow_lite_spec.cr index 5529f3c..d683ca1 100644 --- a/spec/tensorflow_lite_spec.cr +++ b/spec/tensorflow_lite_spec.cr @@ -96,5 +96,19 @@ module TensorflowLite client.outputs.size.should eq 1 client.output.should eq client.outputs[0] end + + it "downloads models if a URI is provided to the client" do + model = URI.parse "https://raw.githubusercontent.com/google-coral/test_data/master/ssdlite_mobiledet_coco_qat_postprocess.tflite" + labels = URI.parse "https://raw.githubusercontent.com/google-coral/test_data/master/coco_labels.txt" + last_error = "" + + client = Client.new(model, labels: labels) do |error_msg| + last_error = error_msg + end + + last_error.should eq "" + client.outputs.size.should eq 4 + client.labels.as(Hash(Int32, String)).size.should eq 90 + end end end diff --git a/src/tensorflow_lite/client.cr b/src/tensorflow_lite/client.cr index a6c8569..80ceff3 100644 --- a/src/tensorflow_lite/client.cr +++ b/src/tensorflow_lite/client.cr @@ -1,3 +1,4 @@ +require "uri" require "../tensorflow_lite" # provides a simplified way to load and manipulate the tensorflow interpreter @@ -7,7 +8,7 @@ class TensorflowLite::Client include Indexable(Tensor) # Configures the tensorflow interpreter with the options provided - def initialize(model : Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, @labels : Hash(Int32, String)? = nil, &on_error : String -> Nil) + def initialize(model : URI| Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, labels : URI | Hash(Int32, String)? = nil, &on_error : String -> Nil) @labels_fetched = !!@labels @model = case model in String, Path @@ -18,8 +19,29 @@ class TensorflowLite::Client Model.new(model) in Model model + in URI + HTTP::Client.get(model) do |response| + raise "model download failed with #{response.status} (#{response.status_code}) while fetching #{model}" unless response.success? + Model.new response.body_io.getb_to_end + end end + @labels = case labels + in URI + response = HTTP::Client.get(labels) + raise "labels download failed with #{response.status} (#{response.status_code}) while fetching #{labels}" unless response.success? + + labels_keys = {} of Int32 => String + idx = 0 + response.body.each_line do |line| + labels_keys[idx] = line + idx += 1 + end + labels_keys + in Hash(Int32, String)? + labels + end + @options = InterpreterOptions.new @options.on_error(&on_error) if threads