Skip to content

Commit 2bcc511

Browse files
committed
Row with known nnz
1 parent 81cc9f1 commit 2bcc511

File tree

3 files changed

+37
-24
lines changed

3 files changed

+37
-24
lines changed

src/dartsort/cluster/gaussian_mixture.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -541,19 +541,21 @@ def log_likelihoods(
541541
unit_neighb_info.append((j, neighbs, ns_unit))
542542
else:
543543
assert previous_logliks is not None
544-
if hasattr(previous_logliks, 'row_nnz'):
544+
if hasattr(previous_logliks, "row_nnz"):
545545
rnnz = previous_logliks.row_nnz[j]
546-
row = csc_sparse_getrow(previous_logliks, j, rnnz).tocoo(copy=False)
546+
six, data = csc_sparse_getrow(previous_logliks, j, rnnz)
547547
else:
548548
row = previous_logliks[[j]].tocoo(copy=True)
549-
six = row.coords[1]
550-
ns_unit = row.nnz
549+
six = row.coords[1]
550+
ns_unit = row.nnz
551+
data = row.data
552+
ns_unit = len(six)
551553
if "covered_neighbs" in unit.annotations:
552554
covered_neighbs = unit.annotations["covered_neighbs"]
553555
else:
554556
covered_neighbs = full_core_neighborhoods.neighborhood_ids[six]
555557
covered_neighbs = torch.unique(covered_neighbs)
556-
unit_neighb_info.append((j, six, row.data, ns_unit))
558+
unit_neighb_info.append((j, six, data, ns_unit))
557559
core_overlaps[covered_neighbs] += 1
558560
nnz += ns_unit
559561

src/dartsort/util/sparse_util.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -184,27 +184,23 @@ def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask):
184184

185185
def csc_sparse_getrow(csc, row, rowcount):
186186
rowix_dtype = csc.indices.dtype
187-
indptr_out = np.empty(rowcount + 1, dtype=rowix_dtype)
188-
data_out = np.empty(rowcount, dtype=rowix_dtype)
189-
indices_out = np.full(rowcount, row, dtype=rowix_dtype)
187+
columns_out = np.empty(rowcount, dtype=rowix_dtype)
188+
data_out = np.empty(rowcount, dtype=csc.data.dtype)
190189
_csc_sparse_getrow(
191-
csc.indices, csc.indptr, csc.data, indptr_out, data_out, rowix_dtype.type(row)
190+
csc.indices, csc.indptr, csc.data, columns_out, data_out, rowix_dtype.type(row)
192191
)
193192

194-
return csc_array(
195-
(data_out, indices_out, indptr_out),
196-
shape=(len(kept_row_inds), csc.shape[1]),
197-
)
193+
return columns_out, data_out
198194

199195

200196
sigs = [
201-
"void(i8[::1], i8[::1], f4[::1], i8[::1], i8[::1], i8)",
202-
"void(i4[::1], i4[::1], f4[::1], i4[::1], i4[::1], i4)",
197+
"void(i8[::1], i8[::1], f4[::1], i8[::1], f4[::1], i8)",
198+
"void(i4[::1], i4[::1], f4[::1], i4[::1], f4[::1], i4)",
203199
]
204200

205201

206202
@numba.njit(sigs, error_model="numpy", nogil=True)
207-
def _csc_sparse_getrow(indices, indptr, data, indptr_out, data_out, the_row):
203+
def _csc_sparse_getrow(indices, indptr, data, columns_out, data_out, the_row):
208204
write_ix = 0
209205

210206
column = 0
@@ -217,17 +213,11 @@ def _csc_sparse_getrow(indices, indptr, data, indptr_out, data_out, the_row):
217213

218214
# write data for this sample
219215
data_out[write_ix] = data[read_ix]
220-
write_ix += 1
221-
222216
while read_ix >= column_end:
223-
indptr_out[column + 1] = write_ix - 1
224217
column += 1
225218
column_end = indptr[column + 1]
226-
227-
while column < len(indptr) - 1:
228-
indptr_out[column + 1] = write_ix
229-
column += 1
230-
column_end = indptr[column + 1]
219+
columns_out[write_ix] = column
220+
write_ix += 1
231221

232222

233223
# @numba.njit(sigs, error_model="numpy", nogil=True)

tests/test_sparse.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,27 @@ def test_csc_mask():
8080
assert np.array_equal(x1.data, x0.data)
8181

8282

83+
def test_csc_getrow():
84+
rg = np.random.default_rng(10)
85+
ij = rg.integers(low=((0, 0),), high=(shape,), size=(nnz, 2))
86+
ij = np.unique(ij, axis=0)
87+
assert (np.diff(ij[:, 0]) >= 0).all()
88+
assert not (np.diff(ij[:, 1]) >= 0).all()
89+
vals = rg.normal(size=len(ij)).astype(np.float32)
90+
91+
x = coo_array((vals, ij.T), shape).tocsc()
92+
93+
for row in range(x.shape[0]):
94+
x0 = x[[row]]
95+
columns, data = sparse_util.csc_sparse_getrow(x, row, x0.nnz)
96+
97+
assert len(columns) == len(data) == x0.nnz
98+
x0coo = x0.tocoo()
99+
assert np.array_equal(columns, x0coo.coords[1])
100+
assert np.array_equal(data, x0coo.data)
101+
102+
83103
if __name__ == "__main__":
104+
test_csc_getrow()
84105
test_csc_insert()
85106
test_csc_mask()

0 commit comments

Comments
 (0)