Skip to content

Commit

Permalink
* Fix for #63
Browse files Browse the repository at this point in the history
  • Loading branch information
perara committed Feb 8, 2024
1 parent 4940ac7 commit e355b28
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@

for e in range(60):
start_training = time()
tm.fit(X_train, Y_train)
tm.fit(X_train.astype(np.uint32), Y_train)
stop_training = time()

start_testing = time()
Expand Down
2 changes: 1 addition & 1 deletion tmu/models/classification/vanilla_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def fit(
**kwargs
):
metrics = metrics or []
assert len(X) == len(Y), "X and Y must have the same length"
assert X.shape[0] == len(Y), "X and Y must have the same number of samples"
assert len(X.shape) >= 2, "X must be a 2D array"
assert len(Y.shape) == 1, "Y must be a 1D array"
assert X.dtype == np.uint32, "X must be of type uint32"
Expand Down
27 changes: 24 additions & 3 deletions tmu/util/encoded_data_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional
import xxhash
import numpy as np
from scipy.sparse import issparse


class DataEncoderCache:
def __init__(self, seed: int):
Expand All @@ -9,9 +11,28 @@ def __init__(self, seed: int):
self.array_hash: Optional[str] = None
self.encoded_data: Optional[np.ndarray] = None

def compute_hash(self, arr: np.ndarray) -> str:
"""Compute a hash for a numpy array."""
return xxhash.xxh3_64_hexdigest(arr)
def compute_hash_csr_matrix(self, csr_mat):
# Convert the components of the csr_matrix to bytes
data_bytes = csr_mat.data.tobytes()
indices_bytes = csr_mat.indices.tobytes()
indptr_bytes = csr_mat.indptr.tobytes()

# Concatenate the bytes representations
total_bytes = data_bytes + indices_bytes + indptr_bytes

# Compute the hash on the concatenated bytes
hash_value = xxhash.xxh3_64_hexdigest(total_bytes)

return hash_value

def compute_hash(self, arr):
"""Compute a hash for a numpy array or csr_matrix."""
if issparse(arr):
# It's a sparse matrix, handle specially
return self.compute_hash_csr_matrix(arr)
else:
# It's a dense array, proceed as before
return xxhash.xxh3_64_hexdigest(arr.tobytes())

def get_encoded_data(self, data: np.ndarray, encoder_func) -> np.ndarray:
"""Get encoded data for an array, using cache if available."""
Expand Down

0 comments on commit e355b28

Please sign in to comment.