Skip to content

Commit 4fa23f4

Browse files
committed
Row subsetting
1 parent 4b64155 commit 4fa23f4

File tree

3 files changed

+302
-123
lines changed

3 files changed

+302
-123
lines changed

src/dartsort/cluster/gaussian_mixture.py

Lines changed: 9 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
from tqdm.auto import tqdm, trange
1616

1717
from ..util import more_operators, noise_util, spiketorch
18+
from ..util.sparse_util import (
19+
csc_insert,
20+
get_csc_storage,
21+
coo_to_torch,
22+
coo_to_scipy,
23+
csc_sparse_mask_rows,
24+
coo_sparse_mask_rows,
25+
)
1826
from .cluster_util import agglomerate, combine_distances, leafsets
1927
from .kmeans import kmeans
2028
from .modes import smoothed_dipscore_at
@@ -736,9 +744,7 @@ def cleanup(
736744
if isinstance(log_liks, coo_array):
737745
log_liks = coo_sparse_mask_rows(log_liks, keep_ll)
738746
elif isinstance(log_liks, csc_array):
739-
keep_ll = np.flatnonzero(keep_ll)
740-
assert keep_ll.max() <= log_liks.shape[0] - self.with_noise_unit
741-
log_liks = log_liks[keep_ll]
747+
log_liks = csc_sparse_mask_rows(log_liks, keep_ll, in_place=True)
742748
else:
743749
assert False
744750

@@ -2689,72 +2695,6 @@ def get_nans(target, storage, name, shape):
26892695
return buffer
26902696

26912697

2692-
def get_coo_storage(ns_total, storage, use_storage):
2693-
if not use_storage:
2694-
# coo_uix = np.empty(ns_total, dtype=int)
2695-
coo_six = np.empty(ns_total, dtype=int)
2696-
coo_data = np.empty(ns_total, dtype=np.float32)
2697-
return coo_six, coo_data
2698-
2699-
if hasattr(storage, "coo_data"):
2700-
if storage.coo_data.size < ns_total:
2701-
# del storage.coo_uix
2702-
del storage.coo_six
2703-
del storage.coo_data
2704-
# storage.coo_uix = np.empty(ns_total, dtype=int)
2705-
storage.coo_six = np.empty(ns_total, dtype=int)
2706-
storage.coo_data = np.empty(ns_total, dtype=np.float32)
2707-
else:
2708-
# storage.coo_uix = np.empty(ns_total, dtype=int)
2709-
storage.coo_six = np.empty(ns_total, dtype=int)
2710-
storage.coo_data = np.empty(ns_total, dtype=np.float32)
2711-
2712-
# return storage.coo_uix, storage.coo_six, storage.coo_data
2713-
return storage.coo_six, storage.coo_data
2714-
2715-
2716-
def get_csc_storage(ns_total, storage, use_storage):
2717-
if not use_storage:
2718-
csc_row_indices = np.empty(ns_total, dtype=int)
2719-
csc_data = np.empty(ns_total, dtype=np.float32)
2720-
return csc_row_indices, csc_data
2721-
2722-
if hasattr(storage, "csc_data"):
2723-
if storage.csc_data.size < ns_total:
2724-
del storage.csc_row_indices
2725-
del storage.csc_data
2726-
storage.csc_row_indices = np.empty(ns_total, dtype=int)
2727-
storage.csc_data = np.empty(ns_total, dtype=np.float32)
2728-
else:
2729-
storage.csc_row_indices = np.empty(ns_total, dtype=int)
2730-
storage.csc_data = np.empty(ns_total, dtype=np.float32)
2731-
2732-
return storage.csc_row_indices, storage.csc_data
2733-
2734-
2735-
def coo_to_torch(coo_array, dtype, transpose=False, is_coalesced=True, copy_data=False):
2736-
coo = (
2737-
torch.from_numpy(coo_array.coords[int(transpose)]),
2738-
torch.from_numpy(coo_array.coords[1 - int(transpose)]),
2739-
)
2740-
s0, s1 = coo_array.shape
2741-
if transpose:
2742-
s0, s1 = s1, s0
2743-
res = torch.sparse_coo_tensor(
2744-
torch.row_stack(coo),
2745-
torch.asarray(coo_array.data, dtype=torch.float, copy=copy_data),
2746-
size=(s0, s1),
2747-
is_coalesced=is_coalesced,
2748-
)
2749-
return res
2750-
2751-
2752-
def coo_to_scipy(coo_tensor):
2753-
data = coo_tensor.values().numpy(force=True)
2754-
coords = coo_tensor.indices().numpy(force=True)
2755-
return coo_array((data, coords), shape=coo_tensor.shape)
2756-
2757-
27582698
def marginal_loglik(
27592699
indices, log_proportions, log_likelihoods, unit_ids=None, reduce="mean"
27602700
):
@@ -2897,60 +2837,6 @@ def hot_argmax_loop(
28972837
scores[i] = mx + np.log(np.exp(dx - mx).sum())
28982838

28992839

2900-
sig = "void(i8, i8[::1], i8[::1], i8[::1], f4[::1], f4[::1])"
2901-
2902-
2903-
@numba.njit(sig, error_model="numpy", nogil=True, parallel=True)
2904-
def csc_insert(row, write_offsets, inds, csc_indices, csc_data, liks):
2905-
"""Insert elements into a CSC sparse array
2906-
2907-
To use this, you need to know the indptr in advance. Then this function
2908-
can help you to insert a row into the array. You have to insert all nz
2909-
elements for that row at once in a single call to this function, and
2910-
rows must be written in order.
2911-
2912-
(However, the columns within each row can be unordered.)
2913-
2914-
write_offsets should be initialized at the indptr -- so, they are
2915-
each column's "write head", indicating how many rows have been written
2916-
for that column so far.
2917-
2918-
Then, this updates the row indices (csc_indices) array with the correct
2919-
row for each column (inds), and adds the data in the right place. The
2920-
"write heads" are incremented, so that when this fn is called for the next
2921-
row things are in the right place.
2922-
2923-
This would be equivalent to:
2924-
data_ixs = write_offsets[inds]
2925-
csc_indices[data_ixs] = row
2926-
csc_data[data_ixs] = liks
2927-
write_offsets[inds] += 1
2928-
"""
2929-
for j in numba.prange(inds.shape[0]):
2930-
ind = inds[j]
2931-
data_ix = write_offsets[ind]
2932-
csc_indices[data_ix] = row
2933-
csc_data[data_ix] = liks[j]
2934-
write_offsets[ind] += 1
2935-
2936-
2937-
def coo_sparse_mask_rows(coo, keep_mask):
2938-
"""Row indexing with a boolean mask."""
2939-
if keep_mask.all():
2940-
return coo
2941-
2942-
kept_label_indices = np.flatnonzero(keep_mask)
2943-
ii, jj = coo.coords
2944-
ixs = np.searchsorted(kept_label_indices, ii)
2945-
ixs.clip(0, kept_label_indices.size - 1, out=ixs)
2946-
valid = np.flatnonzero(kept_label_indices[ixs] == ii)
2947-
coo = coo_array(
2948-
(coo.data[valid], (ixs[valid], jj[valid])),
2949-
shape=(kept_label_indices.size, coo.shape[1]),
2950-
)
2951-
return coo
2952-
2953-
29542840
def bimodalities_dense(
29552841
log_liks,
29562842
labels,

src/dartsort/util/sparse_util.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import numpy as np
2+
import torch
3+
import numba
4+
from scipy.sparse import coo_array, csc_array
5+
6+
7+
def get_coo_storage(ns_total, storage, use_storage):
8+
if not use_storage:
9+
# coo_uix = np.empty(ns_total, dtype=int)
10+
coo_six = np.empty(ns_total, dtype=int)
11+
coo_data = np.empty(ns_total, dtype=np.float32)
12+
return coo_six, coo_data
13+
14+
if hasattr(storage, "coo_data"):
15+
if storage.coo_data.size < ns_total:
16+
# del storage.coo_uix
17+
del storage.coo_six
18+
del storage.coo_data
19+
# storage.coo_uix = np.empty(ns_total, dtype=int)
20+
storage.coo_six = np.empty(ns_total, dtype=int)
21+
storage.coo_data = np.empty(ns_total, dtype=np.float32)
22+
else:
23+
# storage.coo_uix = np.empty(ns_total, dtype=int)
24+
storage.coo_six = np.empty(ns_total, dtype=int)
25+
storage.coo_data = np.empty(ns_total, dtype=np.float32)
26+
27+
# return storage.coo_uix, storage.coo_six, storage.coo_data
28+
return storage.coo_six, storage.coo_data
29+
30+
31+
def coo_to_torch(coo_array, dtype, transpose=False, is_coalesced=True, copy_data=False):
32+
coo = (
33+
torch.from_numpy(coo_array.coords[int(transpose)]),
34+
torch.from_numpy(coo_array.coords[1 - int(transpose)]),
35+
)
36+
s0, s1 = coo_array.shape
37+
if transpose:
38+
s0, s1 = s1, s0
39+
res = torch.sparse_coo_tensor(
40+
torch.row_stack(coo),
41+
torch.asarray(coo_array.data, dtype=torch.float, copy=copy_data),
42+
size=(s0, s1),
43+
is_coalesced=is_coalesced,
44+
)
45+
return res
46+
47+
48+
def coo_to_scipy(coo_tensor):
49+
data = coo_tensor.values().numpy(force=True)
50+
coords = coo_tensor.indices().numpy(force=True)
51+
return coo_array((data, coords), shape=coo_tensor.shape)
52+
53+
54+
def get_csc_storage(ns_total, storage, use_storage):
55+
if not use_storage:
56+
csc_row_indices = np.empty(ns_total, dtype=int)
57+
csc_data = np.empty(ns_total, dtype=np.float32)
58+
return csc_row_indices, csc_data
59+
60+
if hasattr(storage, "csc_data"):
61+
if storage.csc_data.size < ns_total:
62+
del storage.csc_row_indices
63+
del storage.csc_data
64+
storage.csc_row_indices = np.empty(ns_total, dtype=int)
65+
storage.csc_data = np.empty(ns_total, dtype=np.float32)
66+
else:
67+
storage.csc_row_indices = np.empty(ns_total, dtype=int)
68+
storage.csc_data = np.empty(ns_total, dtype=np.float32)
69+
70+
return storage.csc_row_indices, storage.csc_data
71+
72+
73+
sig = "void(i8, i8[::1], i8[::1], i8[::1], f4[::1], f4[::1])"
74+
75+
76+
@numba.njit(sig, error_model="numpy", nogil=True, parallel=True)
77+
def csc_insert(row, write_offsets, inds, csc_indices, csc_data, liks):
78+
"""Insert elements into a CSC sparse array
79+
80+
To use this, you need to know the indptr in advance. Then this function
81+
can help you to insert a row into the array. You have to insert all nz
82+
elements for that row at once in a single call to this function, and
83+
rows must be written in order.
84+
85+
(However, the columns within each row can be unordered.)
86+
87+
write_offsets should be initialized at the indptr -- so, they are
88+
each column's "write head", indicating how many rows have been written
89+
for that column so far.
90+
91+
Then, this updates the row indices (csc_indices) array with the correct
92+
row for each column (inds), and adds the data in the right place. The
93+
"write heads" are incremented, so that when this fn is called for the next
94+
row things are in the right place.
95+
96+
This would be equivalent to:
97+
data_ixs = write_offsets[inds]
98+
csc_indices[data_ixs] = row
99+
csc_data[data_ixs] = liks
100+
write_offsets[inds] += 1
101+
"""
102+
for j in numba.prange(inds.shape[0]):
103+
ind = inds[j]
104+
data_ix = write_offsets[ind]
105+
csc_indices[data_ix] = row
106+
csc_data[data_ix] = liks[j]
107+
write_offsets[ind] += 1
108+
109+
110+
def coo_sparse_mask_rows(coo, keep_mask):
111+
"""Row indexing with a boolean mask."""
112+
if keep_mask.all():
113+
return coo
114+
115+
kept_label_indices = np.flatnonzero(keep_mask)
116+
ii, jj = coo.coords
117+
ixs = np.searchsorted(kept_label_indices, ii)
118+
ixs.clip(0, kept_label_indices.size - 1, out=ixs)
119+
valid = np.flatnonzero(kept_label_indices[ixs] == ii)
120+
coo = coo_array(
121+
(coo.data[valid], (ixs[valid], jj[valid])),
122+
shape=(kept_label_indices.size, coo.shape[1]),
123+
)
124+
return coo
125+
126+
127+
def csc_sparse_mask_rows(csc, keep_mask, in_place=False):
128+
if keep_mask.all():
129+
return csc
130+
131+
if not in_place:
132+
csc = csc.copy()
133+
134+
rowix_dtype = csc.indices.dtype
135+
kept_row_inds = np.flatnonzero(keep_mask).astype(rowix_dtype)
136+
oldrow_to_newrow = np.zeros(len(keep_mask), dtype=rowix_dtype)
137+
oldrow_to_newrow[kept_row_inds] = np.arange(len(kept_row_inds))
138+
nnz = _csc_sparse_mask_rows(
139+
csc.indices, csc.indptr, csc.data, oldrow_to_newrow, keep_mask
140+
)
141+
142+
return csc_array(
143+
(csc.data[:nnz], csc.indices[:nnz], csc.indptr),
144+
shape=(len(kept_row_inds), csc.shape[1]),
145+
)
146+
147+
148+
sigs = [
149+
"i8(i8[::1], i8[::1], f4[::1], i8[::1], bool_[::1])",
150+
"i8(i4[::1], i4[::1], f4[::1], i4[::1], bool_[::1])",
151+
]
152+
153+
154+
@numba.njit(sigs, error_model="numpy", nogil=True)
155+
def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask):
156+
write_ix = 0
157+
158+
column = 0
159+
column_kept_count = 0
160+
column_end = indptr[1]
161+
162+
for read_ix in range(len(indices)):
163+
row = indices[read_ix]
164+
if not keep_mask[row]:
165+
continue
166+
167+
# write data for this sample
168+
indices[write_ix] = oldrow_to_newrow[row]
169+
data[write_ix] = data[read_ix]
170+
write_ix += 1
171+
172+
while read_ix >= column_end:
173+
indptr[column + 1] = write_ix - 1
174+
column += 1
175+
column_end = indptr[column + 1]
176+
177+
while column < len(indptr) - 1:
178+
indptr[column + 1] = write_ix
179+
column += 1
180+
column_end = indptr[column + 1]
181+
182+
return write_ix
183+
184+
185+
# @numba.njit(sigs, error_model="numpy", nogil=True)
186+
# def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask):
187+
# write_ix = 0
188+
189+
# column_start = indptr[0]
190+
# for column in range(len(indptr) - 1):
191+
# column_kept_count = 0
192+
# column_end = indptr[column + 1]
193+
194+
# for read_ix in range(column_start, column_end):
195+
# row = indices[read_ix]
196+
# if not keep_mask[row]:
197+
# continue
198+
199+
# indices[write_ix] = oldrow_to_newrow[row]
200+
# data[write_ix] = data[read_ix]
201+
# column_kept_count += 1
202+
# write_ix += 1
203+
204+
# # indptr[column] is not column_start.
205+
# indptr[column + 1] = indptr[column] + column_kept_count
206+
# column_start = column_end
207+
208+
# return write_ix

0 commit comments

Comments
 (0)