Skip to content

Commit

Permalink
feat(client): add support for downloading models
Browse files Browse the repository at this point in the history
  • Loading branch information
stakach committed May 25, 2023
1 parent 5ca522a commit 14f64a0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
2 changes: 1 addition & 1 deletion shard.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: tensorflow_lite
version: 1.6.1
version: 1.6.2

development_dependencies:
ameba:
Expand Down
14 changes: 14 additions & 0 deletions spec/tensorflow_lite_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,19 @@ module TensorflowLite
client.outputs.size.should eq 1
client.output.should eq client.outputs[0]
end

it "downloads models if a URI is provided to the client" do
model = URI.parse "https://raw.githubusercontent.com/google-coral/test_data/master/ssdlite_mobiledet_coco_qat_postprocess.tflite"
labels = URI.parse "https://raw.githubusercontent.com/google-coral/test_data/master/coco_labels.txt"
last_error = ""

client = Client.new(model, labels: labels) do |error_msg|
last_error = error_msg
end

last_error.should eq ""
client.outputs.size.should eq 4
client.labels.as(Hash(Int32, String)).size.should eq 90
end
end
end
24 changes: 23 additions & 1 deletion src/tensorflow_lite/client.cr
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
require "uri"
require "../tensorflow_lite"

# provides a simplified way to load and manipulate the tensorflow interpreter
Expand All @@ -7,7 +8,7 @@ class TensorflowLite::Client
include Indexable(Tensor)

# Configures the tensorflow interpreter with the options provided
def initialize(model : Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, @labels : Hash(Int32, String)? = nil, &on_error : String -> Nil)
def initialize(model : URI| Bytes | Path | Model | String, delegate : Delegate? = nil, threads : Int? = nil, labels : URI | Hash(Int32, String)? = nil, &on_error : String -> Nil)
@labels_fetched = !!@labels
@model = case model
in String, Path
Expand All @@ -18,8 +19,29 @@ class TensorflowLite::Client
Model.new(model)
in Model
model
in URI
HTTP::Client.get(model) do |response|
raise "model download failed with #{response.status} (#{response.status_code}) while fetching #{model}" unless response.success?
Model.new response.body_io.getb_to_end
end
end

@labels = case labels
in URI
response = HTTP::Client.get(labels)
raise "labels download failed with #{response.status} (#{response.status_code}) while fetching #{labels}" unless response.success?

labels_keys = {} of Int32 => String
idx = 0
response.body.each_line do |line|
labels_keys[idx] = line
idx += 1
end
labels_keys
in Hash(Int32, String)?
labels
end

@options = InterpreterOptions.new
@options.on_error(&on_error)
if threads
Expand Down

0 comments on commit 14f64a0

Please sign in to comment.