Skip to content

Commit

Permalink
WIP row
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 14, 2025
1 parent 1e481f3 commit 81cc9f1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 12 deletions.
56 changes: 44 additions & 12 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
coo_to_scipy,
csc_sparse_mask_rows,
coo_sparse_mask_rows,
csc_sparse_getrow,
)
from .cluster_util import agglomerate, combine_distances, leafsets
from .kmeans import kmeans
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
ppca_rank: int = 0,
ppca_inner_em_iter: int = 25,
ppca_atol: float = 0.05,
ppca_warm_start: bool = True,
n_threads: int = 4,
min_count: int = 50,
n_em_iters: int = 25,
Expand Down Expand Up @@ -174,6 +176,7 @@ def __init__(
ppca_rank=ppca_rank,
ppca_inner_em_iter=ppca_inner_em_iter,
ppca_atol=ppca_atol,
ppca_warm_start=ppca_warm_start,
)
if ppca_in_split:
self.split_unit_args = self.unit_args
Expand Down Expand Up @@ -538,7 +541,11 @@ def log_likelihoods(
unit_neighb_info.append((j, neighbs, ns_unit))
else:
assert previous_logliks is not None
row = previous_logliks[[j]].tocoo(copy=True)
if hasattr(previous_logliks, 'row_nnz'):
rnnz = previous_logliks.row_nnz[j]
row = csc_sparse_getrow(previous_logliks, j, rnnz).tocoo(copy=False)
else:
row = previous_logliks[[j]].tocoo(copy=True)
six = row.coords[1]
ns_unit = row.nnz
if "covered_neighbs" in unit.annotations:
Expand Down Expand Up @@ -594,6 +601,8 @@ def _ll_job(args):
write_offsets = indptr[:-1].copy()
pool = Parallel(self.n_threads, backend="threading", return_as="generator")
results = pool(_ll_job(ninfo) for ninfo in unit_neighb_info)
nrows = j + 1 + with_noise_unit
row_nnz = np.zeros(nrows, dtype=int)
if show_progress:
results = tqdm(
results,
Expand All @@ -606,18 +615,21 @@ def _ll_job(args):
for j, inds, liks in results:
if inds is None:
continue
row_nnz[j] = len(inds)
csc_insert(j, write_offsets, inds, csc_indices, csc_data, liks)

if with_noise_unit:
liks = self.noise_log_likelihoods(indices=split_indices)
data_ixs = write_offsets[split_indices]
# assert np.array_equal(data_ixs, ccol_indices[1:] - 1) # just fyi
row_nnz[j + 1] = len(data_ixs)
csc_indices[data_ixs] = j + 1
csc_data[data_ixs] = liks

shape = (j + 1 + with_noise_unit, self.data.n_spikes)
shape = (nrows, self.data.n_spikes)
log_liks = csc_array((csc_data, csc_indices, indptr), shape=shape)
log_liks.has_canonical_format = True
log_liks.row_nnz = row_nnz

return log_liks

Expand Down Expand Up @@ -708,7 +720,8 @@ def cleanup(
label_ids = label_ids[label_ids >= 0]
big_enough = counts >= min_count

keep = torch.zeros(self.n_units(), dtype=bool)
n_units = max(label_ids.max().item() + 1, len(self._units))
keep = torch.zeros(n_units, dtype=bool)
keep[label_ids] = big_enough
self._stack = None

Expand All @@ -726,7 +739,7 @@ def cleanup(
kept_ids = label_ids[big_enough]
new_ids = torch.arange(kept_ids.numel())
old2new = dict(zip(kept_ids, new_ids))
self._relabel(kept_ids)
self._relabel(kept_ids, split=split)

if self.log_proportions is not None:
lps = self.log_proportions.numpy(force=True)
Expand Down Expand Up @@ -1988,7 +2001,7 @@ def train_extract_buffer(self):
)
return self.storage._train_extract_buffer

def _relabel(self, old_labels, new_labels=None, flat=False):
def _relabel(self, old_labels, new_labels=None, flat=False, split=None):
"""Re-label units
!! This could invalidate self._units and props.
Expand All @@ -2008,18 +2021,33 @@ def _relabel(self, old_labels, new_labels=None, flat=False):
label in old_labels and index new_labels with that, so that
cleanup can call relabel with unit_ids[big_enough].
"""
split_indices = slice(None)
if split is not None:
split_indices = self.data.split_indices[split]

if new_labels is None:
new_labels = torch.arange(len(old_labels))

original = self.labels[split_indices]

if flat:
kept = self.labels >= 0
label_indices = self.labels[kept]
kept = original >= 0
label_indices = original[kept]
else:
kept = torch.isin(self.labels, old_labels)
label_indices = torch.searchsorted(old_labels, self.labels[kept])
label_indices = torch.searchsorted(old_labels, original, right=True) - 1
kept = old_labels[label_indices] == original
label_indices = label_indices[kept]

self.labels[kept] = new_labels.to(self.labels)[label_indices]
self.labels[torch.logical_not(kept)] = -1
unkept = torch.logical_not(kept)
if split is not None:
unkept = split_indices[unkept]
kept = split_indices[kept]

if new_labels is not None:
label_indices = new_labels.to(self.labels)[label_indices]

self.labels[kept] = label_indices
self.labels[unkept] = -1
self._stack = None

def stack_units(
Expand Down Expand Up @@ -2253,6 +2281,7 @@ def __init__(
ppca_atol=0.05,
ppca_rank=0,
scale_mean: float = 0.1,
ppca_warm_start: bool = True,
**annotations,
):
super().__init__()
Expand All @@ -2273,6 +2302,7 @@ def __init__(
self.ppca_inner_em_iter = ppca_inner_em_iter
self.ppca_atol = ppca_atol
self.annotations = annotations
self.ppca_warm_start = ppca_warm_start

@classmethod
def from_features(
Expand All @@ -2294,6 +2324,7 @@ def from_features(
scale_mean: float = 0.1,
core_neighborhoods=None,
core_neighborhood_ids=None,
ppca_warm_start=True,
**annotations,
):
self = cls(
Expand All @@ -2311,6 +2342,7 @@ def from_features(
ppca_rank=ppca_rank,
ppca_inner_em_iter=ppca_inner_em_iter,
ppca_atol=ppca_atol,
ppca_warm_start=ppca_warm_start,
**annotations,
)
self.fit(
Expand Down Expand Up @@ -2396,7 +2428,7 @@ def fit(
active_mean = active_W = None
if hasattr(self, "mean"):
active_mean = self.mean[:, achans]
if hasattr(self, "W"):
if hasattr(self, "W") and self.ppca_warm_start:
active_W = self.W[:, achans]

if je_suis:
Expand Down
48 changes: 48 additions & 0 deletions src/dartsort/util/sparse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,54 @@ def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask):
return write_ix


def csc_sparse_getrow(csc, row, rowcount):
rowix_dtype = csc.indices.dtype
indptr_out = np.empty(rowcount + 1, dtype=rowix_dtype)
data_out = np.empty(rowcount, dtype=rowix_dtype)
indices_out = np.full(rowcount, row, dtype=rowix_dtype)
_csc_sparse_getrow(
csc.indices, csc.indptr, csc.data, indptr_out, data_out, rowix_dtype.type(row)
)

return csc_array(
(data_out, indices_out, indptr_out),
shape=(len(kept_row_inds), csc.shape[1]),
)


sigs = [
"void(i8[::1], i8[::1], f4[::1], i8[::1], i8[::1], i8)",
"void(i4[::1], i4[::1], f4[::1], i4[::1], i4[::1], i4)",
]


@numba.njit(sigs, error_model="numpy", nogil=True)
def _csc_sparse_getrow(indices, indptr, data, indptr_out, data_out, the_row):
write_ix = 0

column = 0
column_end = indptr[1]

for read_ix in range(len(indices)):
row = indices[read_ix]
if row != the_row:
continue

# write data for this sample
data_out[write_ix] = data[read_ix]
write_ix += 1

while read_ix >= column_end:
indptr_out[column + 1] = write_ix - 1
column += 1
column_end = indptr[column + 1]

while column < len(indptr) - 1:
indptr_out[column + 1] = write_ix
column += 1
column_end = indptr[column + 1]


# @numba.njit(sigs, error_model="numpy", nogil=True)
# def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask):
# write_ix = 0
Expand Down

0 comments on commit 81cc9f1

Please sign in to comment.