Skip to content

Commit

Permalink
feat(tensor): add a data extraction helper
Browse files Browse the repository at this point in the history
also adds a check to prevent seg faults
  • Loading branch information
stakach committed May 27, 2023
1 parent 4cdf46b commit c911cd3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion shard.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: tensorflow_lite
version: 2.0.0
version: 2.1.0

development_dependencies:
ameba:
Expand Down
29 changes: 29 additions & 0 deletions src/tensorflow_lite/tensor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct TensorflowLite::Tensor
# buffer that makes up the tensor input
def raw_data : Bytes
data_ptr = LibTensorflowLite.tensor_data(self)
raise "no tensor data allocated" if data_ptr.null?
Slice.new(data_ptr.as(Pointer(UInt8)), bytesize)
end

Expand Down Expand Up @@ -182,4 +183,32 @@ struct TensorflowLite::Tensor
def as_i64
to_type(Int64)
end

# returns a slice of the data in the correct type
def as_type
case type
when .float32?
as_f32
when .float64?
as_f64
when .u_int8?
as_u8
when .int8?
as_i8
when .u_int16?
as_u16
when .int16?
as_i16
when .u_int32?
as_u32
when .int32?
as_i32
when .u_int64?
as_u64
when .int64?
as_i64
else
raise NotImplementedError.new("no method for casting to type: #{type}")
end
end
end

0 comments on commit c911cd3

Please sign in to comment.