Skip to content

Commit 25edcbd

Browse files
authored
implement bincode for tensor (#209)
* implement bincode for tensor * encode shape and strides as array
1 parent 2bc0585 commit 25edcbd

File tree

6 files changed

+95
-1
lines changed

6 files changed

+95
-1
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ kornia = { path = "crates/kornia", version = "0.1.8-rc.1" }
3636
# dev dependencies for workspace
3737
argh = "0.1"
3838
approx = "0.5"
39+
bincode = { version = "2.0.0-rc.3", features = ["serde"] }
3940
criterion = "0.5"
4041
env_logger = "0.11"
4142
faer = "0.20.1"

crates/kornia-tensor/Cargo.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,13 @@ version.workspace = true
1212

1313
[dependencies]
1414
num-traits = { workspace = true }
15-
serde = { workspace = true }
15+
serde = { workspace = true, optional = true }
16+
bincode = { workspace = true, optional = true }
1617
thiserror = { workspace = true }
18+
19+
[features]
20+
serde = ["dep:serde"]
21+
bincode = ["dep:bincode"]
22+
23+
[dev-dependencies]
24+
serde_json = "1"

crates/kornia-tensor/src/bincode.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use crate::{
2+
allocator::{CpuAllocator, TensorAllocator},
3+
storage::TensorStorage,
4+
Tensor,
5+
};
6+
7+
impl<T, const N: usize, A: TensorAllocator + 'static> bincode::enc::Encode for Tensor<T, N, A>
8+
where
9+
T: bincode::enc::Encode + 'static,
10+
{
11+
fn encode<E: bincode::enc::Encoder>(
12+
&self,
13+
encoder: &mut E,
14+
) -> Result<(), bincode::error::EncodeError> {
15+
bincode::Encode::encode(&self.shape, encoder)?;
16+
bincode::Encode::encode(&self.strides, encoder)?;
17+
bincode::Encode::encode(&self.storage.as_slice(), encoder)?;
18+
Ok(())
19+
}
20+
}
21+
22+
impl<T, const N: usize> bincode::de::Decode for Tensor<T, N, CpuAllocator>
23+
where
24+
T: bincode::de::Decode + 'static,
25+
{
26+
fn decode<D: bincode::de::Decoder>(
27+
decoder: &mut D,
28+
) -> Result<Self, bincode::error::DecodeError> {
29+
let shape = bincode::Decode::decode(decoder)?;
30+
let strides = bincode::Decode::decode(decoder)?;
31+
let data = bincode::Decode::decode(decoder)?;
32+
Ok(Self {
33+
shape,
34+
strides,
35+
storage: TensorStorage::from_vec(data, CpuAllocator),
36+
})
37+
}
38+
}
39+
40+
#[cfg(test)]
41+
mod tests {
42+
use super::*;
43+
44+
#[test]
45+
fn test_bincode() -> Result<(), Box<dyn std::error::Error>> {
46+
let tensor = Tensor::<u8, 2, CpuAllocator>::from_shape_vec(
47+
[2, 3],
48+
vec![1, 2, 3, 4, 5, 6],
49+
CpuAllocator,
50+
)?;
51+
let mut serialized = vec![0u8; 100];
52+
let config = bincode::config::standard();
53+
let length = bincode::encode_into_slice(&tensor, &mut serialized, config)?;
54+
let deserialized: (Tensor<u8, 2, CpuAllocator>, usize) =
55+
bincode::decode_from_slice(&serialized[..length], config)?;
56+
assert_eq!(tensor.as_slice(), deserialized.0.as_slice());
57+
Ok(())
58+
}
59+
}

crates/kornia-tensor/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44
/// allocator module containing the memory management utilities.
55
pub mod allocator;
66

7+
/// bincode module containing the serialization and deserialization utilities.
8+
#[cfg(feature = "bincode")]
9+
pub mod bincode;
10+
711
/// tensor module containing the tensor and storage implementations.
812
pub mod tensor;
913

1014
/// serde module containing the serialization and deserialization utilities.
15+
#[cfg(feature = "serde")]
1116
pub mod serde;
1217

1318
/// storage module containing the storage implementations.

crates/kornia-tensor/src/serde.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,19 @@ where
5959
})
6060
}
6161
}
62+
63+
#[cfg(test)]
64+
mod tests {
65+
use super::*;
66+
use crate::allocator::CpuAllocator;
67+
68+
#[test]
69+
fn test_serde() -> Result<(), Box<dyn std::error::Error>> {
70+
let data = vec![1, 2, 3, 4, 5, 6];
71+
let tensor = Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 3], data, CpuAllocator)?;
72+
let serialized = serde_json::to_string(&tensor)?;
73+
let deserialized: Tensor<u8, 2, CpuAllocator> = serde_json::from_str(&serialized)?;
74+
assert_eq!(tensor.as_slice(), deserialized.as_slice());
75+
Ok(())
76+
}
77+
}

justfile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ clean:
3434
test name='':
3535
@cargo test {{ name }}
3636

37+
# Test the code with all features
38+
test-all:
39+
@cargo test --all-features
40+
41+
3742
# ------------------------------------------------------------------------------
3843
# Recipes for the kornia-py project
3944
# ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)