Skip to content

Commit

Permalink
feat: add a label extraction helper
Browse files Browse the repository at this point in the history
  • Loading branch information
stakach committed May 21, 2023
1 parent 49126dc commit f4b2f6e
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 1 deletion.
2 changes: 1 addition & 1 deletion shard.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: tensorflow_lite
version: 1.3.0
version: 1.4.0

development_dependencies:
ameba:
Expand Down
1 change: 1 addition & 0 deletions spec/spec_helper.cr
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
require "spec"
require "http"
require "../src/tensorflow_lite"
require "../src/tensorflow_lite/edge_tpu"
21 changes: 21 additions & 0 deletions spec/utilities_spec.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
require "./spec_helper"

module TensorflowLite
SPEC_TF_L_MODEL = Path.new "./bin/mobilenet_v1.tflite"

describe Utilities::ExtractLabels do
unless File.exists? SPEC_TF_L_MODEL
puts "downloading tensorflow model for spec..."
# details: https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2
HTTP::Client.get("https://storage.googleapis.com/tfhub-lite-models/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2.tflite") do |response|
raise "could not download tf model file" unless response.success?
File.write(SPEC_TF_L_MODEL, response.body_io)
end
end

it "extracts the tensorflow model label map" do
labels = Utilities::ExtractLabels.from(SPEC_TF_L_MODEL)
labels.as(Hash(Int32, String)).size.should eq 90
end
end
end
1 change: 1 addition & 0 deletions src/tensorflow_lite.cr
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ require "./tensorflow_lite/interpreter_options"
require "./tensorflow_lite/tensor"
require "./tensorflow_lite/interpreter"
require "./tensorflow_lite/client"
require "./tensorflow_lite/utilities/*"
71 changes: 71 additions & 0 deletions src/tensorflow_lite/utilities/extract_labels.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
require "compress/zip"
require "file_utils"

module TensorflowLite::Utilities::ExtractLabels
# File type detection
# https://github.com/sindresorhus/file-type/blob/main/core.js
# https://en.wikipedia.org/wiki/ZIP_(file_format)
MAGIC_ZIP = Bytes[0x50, 0x4b, 0x03, 0x04]

# extracts the label names from tensorflow lite model at the path specified
def self.from(input : Path, metadata_file : String = ".txt") : Hash(Int32, String)?
# TODO:: we should update this to search the file more optimally
# and work more memory effciently
file = File.new input
bytes = Bytes.new file.size
file.read_fully bytes
file.close

io = IO::Memory.new(bytes)
found = 0
files = [] of String
read_buffer = Bytes.new(MAGIC_ZIP.bytesize)
remaining = bytes.size

# run through the file looking for possible zip headers
# then extract the zip file contents
while remaining >= MAGIC_ZIP.bytesize
read_pos = io.pos
io.read_fully read_buffer

if read_buffer == MAGIC_ZIP
begin
zip_data = IO::Memory.new(bytes[read_pos..-1])
Compress::Zip::Reader.open(zip_data) do |zip|
zip.each_entry do |entry|
if entry.file?
found += 1
Log.debug { "found file -> #{entry.filename}" }
# File.write(File.join(output_folder, entry.filename), entry.io)

# loading the labels
if entry.filename.ends_with?(metadata_file)
labels = {} of Int32 => String

idx = 0
entry.io.each_line do |line|
labels[idx] = line
idx += 1
end

return labels
else
files << entry.filename
end
end
end
end
break
rescue Compress::Zip::Error
end
end

io.pos = read_pos + 1
remaining = bytes.size - io.pos
end

Log.info { "found #{found} files, no matches: #{files.join(", ")}" }

nil
end
end

0 comments on commit f4b2f6e

Please sign in to comment.