Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't deserialize mnist pkl data from kaggle #32

Open
inferrna opened this issue Jan 1, 2025 · 1 comment
Open

Can't deserialize mnist pkl data from kaggle #32

inferrna opened this issue Jan 1, 2025 · 1 comment

Comments

@inferrna
Copy link

inferrna commented Jan 1, 2025

https://www.kaggle.com/datasets/fedesoriano/qmnist-the-extended-mnist-dataset-120k-images/data

Possibly related to #26

In python it looks like this

>>> import pickle
>>> d = pickle.load(open("/media/Data/Data/Datasets/qmnist.pkl", "rb"), encoding='bytes')
>>> d.keys()
dict_keys(['data', 'labels'])
>>> d['data'].dtype
dtype('uint8')
>>> d['data'].shape
(120000, 28, 28)
>>> d['labels'].dtype
dtype('int64')
>>> d['labels'].shape
(120000, 1)

So I tried

#[derive(Deserialize)]
struct Qmnist {
    data: Vec<Vec<u8>>,
    labels: Vec<i64>,
}

and

#[derive(Deserialize)]
struct Qmnist {
    data: Vec<[[u8; 28]; 28]>,
    labels: Vec<[i64; 1]>,
}

and many other similar variants. It all failed with errors like invalid type: integer `1`, expected a sequence or invalid type: integer `1`, expected an array of length 28.
I even tried to flatten it all

#[derive(Deserialize)]
struct Qmnist {
    data: Vec<u8>,
    labels: Vec<i64>,
}

But it failed with invalid type: sequence, expected u8

@inferrna
Copy link
Author

inferrna commented Jan 2, 2025

I found a reason — that's because internally it uses NumPy representation. So here my variant of solution:
(I'm using num-traits, but it seemed to be not so so necessary)

#[derive(Deserialize, Debug)]
struct NpArraySpecs {
    _py_ver: i64,
    sym: char,
    a: Option<i64>,
    b: Option<i64>,
    c: Option<i64>,
    d: i64,
    e: i64,
    f: i64,
}

#[derive(Deserialize, Debug)]
struct NpArray<T: FromBytes + FromSlicedBytes> {
    _fmt_id: i64,
    dimensions: Vec<i64>,
    specs: NpArraySpecs,
    b: bool,
    #[serde(deserialize_with = "from_bytes")]
    data: Vec<T>,
}

trait FromSlicedBytes {
    fn from_sliced_le_bytes(data: &[u8]) -> Self;
}

impl FromSlicedBytes for u8 {
    fn from_sliced_le_bytes(data: &[u8]) -> Self {
        data[0]
    }
}
impl FromSlicedBytes for i64 {
    fn from_sliced_le_bytes(data: &[u8]) -> Self {
        i64::from_le_bytes(data.try_into().unwrap())
    }
}

fn from_bytes<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
where
    D: Deserializer<'de>,
    T: FromBytes + FromSlicedBytes,
{
    let bytes: Value = Deserialize::deserialize(deserializer)?;

    let Value::Bytes(bytes) = bytes else {
        return Err(serde::de::Error::custom("Not a bytes"))
    };

    let element_size = size_of::<T>();

    if bytes.len() % element_size != 0 {
        return Err(serde::de::Error::custom("Byte array length is not a multiple of element size"));
    }

    let mut result = Vec::with_capacity(bytes.len() / element_size);
    for chunk in bytes.chunks(element_size) {
        result.push(T::from_sliced_le_bytes(chunk));
    }
    Ok(result)
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant