Skip to content

Commit bf2480f

Browse files
committed
Update to version 0.0.5 - includes SciPy csc, csr, dok, lil matrices support.
1 parent 4fa0cc8 commit bf2480f

File tree

6 files changed

+99
-10
lines changed

6 files changed

+99
-10
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ dist/
99
*.egg-info/
1010
htmlcov/
1111
.idea/
12-
staging/
12+
staging/
13+
.vscode/

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and the versioning is mostly derived from [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [v0.0.5] - 2020-11-04
8+
### Added
9+
- Added support for SciPy csr, csc, lil, dok matrices.
10+
711
## [v0.0.4] - 2020-09-17
812
### Added
913
- Initial public release.
1014

15+
[v0.0.5]: https://github.com/interpretml/slicer/releases/tag/v0.0.5
1116
[v0.0.4]: https://github.com/interpretml/slicer/releases/tag/v0.0.4

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="slicer",
8-
version="0.0.4",
8+
version="0.0.5",
99
author="InterpretML",
1010
author_email="interpret@microsoft.com",
1111
description="A small package for big slicing.",

slicer/slicer_internal.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,15 +405,27 @@ def head_slice(cls, o, index_tup, max_dim):
405405

406406
# Process native array dimensions
407407
cut_index = index_tup[:cut]
408-
is_element = True if isinstance(cut_index[-1], int) else False
408+
is_element = any([True if isinstance(x, int) else False for x in cut_index])
409409
sliced_o = o[cut_index]
410410

411411
return is_element, sliced_o, cut
412412

413413
@classmethod
414414
def tail_slice(cls, o, tail_index, max_dim, flatten=True):
415415
if flatten:
416-
return AtomicSlicer(o, max_dim=max_dim)[tail_index]
416+
# NOTE: If we're dealing with a scipy matrix,
417+
# we have to manually flatten it ourselves
418+
# to keep consistent to the rest of slicer's API.
419+
if _safe_isinstance(o, "scipy.sparse.csc", "csc_matrix"):
420+
return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index]
421+
elif _safe_isinstance(o, "scipy.sparse.csr", "csr_matrix"):
422+
return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index]
423+
elif _safe_isinstance(o, "scipy.sparse.dok", "dok_matrix"):
424+
return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index]
425+
elif _safe_isinstance(o, "scipy.sparse.lil", "lil_matrix"):
426+
return AtomicSlicer(o.toarray().flatten(), max_dim=max_dim)[tail_index]
427+
else:
428+
return AtomicSlicer(o, max_dim=max_dim)[tail_index]
417429
else:
418430
inner = [AtomicSlicer(e, max_dim=max_dim)[tail_index] for e in o]
419431
if _safe_isinstance(o, "numpy", "ndarray"):
@@ -427,6 +439,22 @@ def tail_slice(cls, o, tail_index, max_dim, flatten=True):
427439
return torch.stack(inner)
428440
else:
429441
return torch.tensor(inner)
442+
elif _safe_isinstance(o, "scipy.sparse.csc", "csc_matrix"):
443+
from scipy.sparse import vstack
444+
out = vstack(inner, format='csc')
445+
return out
446+
elif _safe_isinstance(o, "scipy.sparse.csr", "csr_matrix"):
447+
from scipy.sparse import vstack
448+
out = vstack(inner, format='csr')
449+
return out
450+
elif _safe_isinstance(o, "scipy.sparse.dok", "dok_matrix"):
451+
from scipy.sparse import vstack
452+
out = vstack(inner, format='dok')
453+
return out
454+
elif _safe_isinstance(o, "scipy.sparse.lil", "lil_matrix"):
455+
from scipy.sparse import vstack
456+
out = vstack(inner, format='lil')
457+
return out
430458
else:
431459
raise ValueError(f"Cannot handle type {type(o)}.") # pragma: no cover
432460

@@ -519,6 +547,10 @@ class UnifiedDataHandler:
519547
("builtins", "dict"): DictHandler,
520548
("torch", "Tensor"): ArrayHandler,
521549
("numpy", "ndarray"): ArrayHandler,
550+
("scipy.sparse.csc", "csc_matrix"): ArrayHandler,
551+
("scipy.sparse.csr", "csr_matrix"): ArrayHandler,
552+
("scipy.sparse.dok", "dok_matrix"): ArrayHandler,
553+
("scipy.sparse.lil", "lil_matrix"): ArrayHandler,
522554
("pandas.core.frame", "DataFrame"): DataFrameHandler,
523555
("pandas.core.series", "Series"): SeriesHandler,
524556
}

slicer/test_slicer.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
An unholy balance of use cases and test coverage.
33
"""
44

5+
import pytest
6+
57
from .slicer import AtomicSlicer
68

79
from . import Slicer as S
810
from . import Alias as A
911
from . import Obj as O
1012

11-
import pytest
1213
import pandas as pd
13-
import torch
1414
import numpy as np
15+
import torch
16+
from scipy.sparse import csc_matrix
17+
from scipy.sparse import csr_matrix
18+
from scipy.sparse import dok_matrix
19+
from scipy.sparse import lil_matrix
20+
21+
1522
from .utils_testing import ctr_eq
1623

1724

@@ -214,6 +221,38 @@ def test_slicer_simple_di():
214221
assert ctr_eq(actual, 3)
215222

216223

224+
def test_slicer_sparse():
225+
array = np.array([[1, 0, 4], [0, 0, 5], [2, 3, 6]])
226+
csc_array = csc_matrix(array)
227+
csr_array = csr_matrix(array)
228+
dok_array = dok_matrix(array)
229+
lil_array = lil_matrix(array)
230+
231+
candidates = [csc_array, csr_array, dok_array, lil_array]
232+
for candidate in candidates:
233+
slicer = S(candidate)
234+
actual = slicer[0, 0]
235+
assert ctr_eq(actual.o, 1)
236+
actual = slicer[1, 1]
237+
assert ctr_eq(actual.o, 0)
238+
239+
actual = slicer[0]
240+
expected = np.array([1, 0, 4])
241+
assert ctr_eq(actual.o, expected)
242+
243+
actual = slicer[:, 1]
244+
expected = np.array([0, 0, 3])
245+
assert ctr_eq(actual.o, expected)
246+
247+
actual = slicer[:, :]
248+
expected = np.array([[1, 0, 4], [0, 0, 5], [2, 3, 6]])
249+
assert ctr_eq(actual.o, expected)
250+
251+
actual = slicer[0, :]
252+
expected = np.array([1, 0, 4])
253+
assert ctr_eq(actual.o, expected)
254+
255+
217256
def test_slicer_torch():
218257
import torch
219258

@@ -277,7 +316,7 @@ def test_tracked_dim_arg_smoke():
277316
assert True
278317

279318

280-
def test_atomic_1d():
319+
def test_operations_1d():
281320
elements = [1, 2, 3, 4]
282321
li = elements
283322
tup = tuple(elements)
@@ -305,12 +344,17 @@ def test_atomic_1d():
305344
assert ctr_eq(slicer[0:3:2], elements[0:3:2])
306345

307346

308-
def test_atomic_2d():
347+
def test_operations_2d():
309348
elements = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
310349
li = elements
311350
df = pd.DataFrame(elements, columns=["A", "B", "C"])
312351

313-
containers = [li, df]
352+
sparse_csc = csc_matrix(elements)
353+
sparse_csr = csr_matrix(elements)
354+
sparse_dok = dok_matrix(elements)
355+
sparse_lil = lil_matrix(elements)
356+
357+
containers = [li, df, sparse_csc, sparse_csr, sparse_dok, sparse_lil]
314358
for _, ctr in enumerate(containers):
315359
slicer = AtomicSlicer(ctr)
316360

@@ -337,7 +381,7 @@ def test_atomic_2d():
337381
assert ctr_eq(slicer[..., 0], [elements[i][0] for i, _ in enumerate(elements)])
338382

339383

340-
def test_atomic_3d():
384+
def test_operations_3d():
341385
# 3-dimensional fixed dimension case
342386
elements = [
343387
[[1, 2, 3], [4, 5, 6]],

slicer/utils_testing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@
77
import numpy as np
88
import torch
99
import pandas as pd
10+
from scipy.sparse import csc_matrix
11+
from scipy.sparse import csr_matrix
12+
from scipy.sparse import dok_matrix
13+
from scipy.sparse import lil_matrix
1014

1115

1216
def coerced(o: Any):
17+
if isinstance(o, (csc_matrix, csr_matrix, dok_matrix, lil_matrix)):
18+
o = o.toarray()
19+
1320
to_list_collections = tuple([np.ndarray, torch.Tensor, pd.core.series.Series])
1421
if isinstance(o, (list, tuple)):
1522
return o

0 commit comments

Comments
 (0)