diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index 72bed64c..3ffce68d 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -22,6 +22,7 @@ coo_to_scipy, csc_sparse_mask_rows, coo_sparse_mask_rows, + csc_sparse_getrow, ) from .cluster_util import agglomerate, combine_distances, leafsets from .kmeans import kmeans @@ -68,6 +69,7 @@ def __init__( ppca_rank: int = 0, ppca_inner_em_iter: int = 25, ppca_atol: float = 0.05, + ppca_warm_start: bool = True, n_threads: int = 4, min_count: int = 50, n_em_iters: int = 25, @@ -174,6 +176,7 @@ def __init__( ppca_rank=ppca_rank, ppca_inner_em_iter=ppca_inner_em_iter, ppca_atol=ppca_atol, + ppca_warm_start=ppca_warm_start, ) if ppca_in_split: self.split_unit_args = self.unit_args @@ -538,7 +541,11 @@ def log_likelihoods( unit_neighb_info.append((j, neighbs, ns_unit)) else: assert previous_logliks is not None - row = previous_logliks[[j]].tocoo(copy=True) + if hasattr(previous_logliks, 'row_nnz'): + rnnz = previous_logliks.row_nnz[j] + row = csc_sparse_getrow(previous_logliks, j, rnnz).tocoo(copy=False) + else: + row = previous_logliks[[j]].tocoo(copy=True) six = row.coords[1] ns_unit = row.nnz if "covered_neighbs" in unit.annotations: @@ -594,6 +601,8 @@ def _ll_job(args): write_offsets = indptr[:-1].copy() pool = Parallel(self.n_threads, backend="threading", return_as="generator") results = pool(_ll_job(ninfo) for ninfo in unit_neighb_info) + nrows = j + 1 + with_noise_unit + row_nnz = np.zeros(nrows, dtype=int) if show_progress: results = tqdm( results, @@ -606,18 +615,21 @@ def _ll_job(args): for j, inds, liks in results: if inds is None: continue + row_nnz[j] = len(inds) csc_insert(j, write_offsets, inds, csc_indices, csc_data, liks) if with_noise_unit: liks = self.noise_log_likelihoods(indices=split_indices) data_ixs = write_offsets[split_indices] # assert np.array_equal(data_ixs, ccol_indices[1:] - 1) # just fyi + row_nnz[j + 1] = len(data_ixs) csc_indices[data_ixs] = j + 1 csc_data[data_ixs] = liks - shape = (j + 1 + with_noise_unit, self.data.n_spikes) + shape = (nrows, self.data.n_spikes) log_liks = csc_array((csc_data, csc_indices, indptr), shape=shape) log_liks.has_canonical_format = True + log_liks.row_nnz = row_nnz return log_liks @@ -708,7 +720,8 @@ def cleanup( label_ids = label_ids[label_ids >= 0] big_enough = counts >= min_count - keep = torch.zeros(self.n_units(), dtype=bool) + n_units = max(label_ids.max().item() + 1, len(self._units)) + keep = torch.zeros(n_units, dtype=bool) keep[label_ids] = big_enough self._stack = None @@ -726,7 +739,7 @@ def cleanup( kept_ids = label_ids[big_enough] new_ids = torch.arange(kept_ids.numel()) old2new = dict(zip(kept_ids, new_ids)) - self._relabel(kept_ids) + self._relabel(kept_ids, split=split) if self.log_proportions is not None: lps = self.log_proportions.numpy(force=True) @@ -1988,7 +2001,7 @@ def train_extract_buffer(self): ) return self.storage._train_extract_buffer - def _relabel(self, old_labels, new_labels=None, flat=False): + def _relabel(self, old_labels, new_labels=None, flat=False, split=None): """Re-label units !! This could invalidate self._units and props. @@ -2008,18 +2021,33 @@ def _relabel(self, old_labels, new_labels=None, flat=False): label in old_labels and index new_labels with that, so that cleanup can call relabel with unit_ids[big_enough]. """ + split_indices = slice(None) + if split is not None: + split_indices = self.data.split_indices[split] + if new_labels is None: new_labels = torch.arange(len(old_labels)) + original = self.labels[split_indices] + if flat: - kept = self.labels >= 0 - label_indices = self.labels[kept] + kept = original >= 0 + label_indices = original[kept] else: - kept = torch.isin(self.labels, old_labels) - label_indices = torch.searchsorted(old_labels, self.labels[kept]) + label_indices = torch.searchsorted(old_labels, original, right=True) - 1 + kept = old_labels[label_indices] == original + label_indices = label_indices[kept] - self.labels[kept] = new_labels.to(self.labels)[label_indices] - self.labels[torch.logical_not(kept)] = -1 + unkept = torch.logical_not(kept) + if split is not None: + unkept = split_indices[unkept] + kept = split_indices[kept] + + if new_labels is not None: + label_indices = new_labels.to(self.labels)[label_indices] + + self.labels[kept] = label_indices + self.labels[unkept] = -1 self._stack = None def stack_units( @@ -2253,6 +2281,7 @@ def __init__( ppca_atol=0.05, ppca_rank=0, scale_mean: float = 0.1, + ppca_warm_start: bool = True, **annotations, ): super().__init__() @@ -2273,6 +2302,7 @@ def __init__( self.ppca_inner_em_iter = ppca_inner_em_iter self.ppca_atol = ppca_atol self.annotations = annotations + self.ppca_warm_start = ppca_warm_start @classmethod def from_features( @@ -2294,6 +2324,7 @@ def from_features( scale_mean: float = 0.1, core_neighborhoods=None, core_neighborhood_ids=None, + ppca_warm_start=True, **annotations, ): self = cls( @@ -2311,6 +2342,7 @@ def from_features( ppca_rank=ppca_rank, ppca_inner_em_iter=ppca_inner_em_iter, ppca_atol=ppca_atol, + ppca_warm_start=ppca_warm_start, **annotations, ) self.fit( @@ -2396,7 +2428,7 @@ def fit( active_mean = active_W = None if hasattr(self, "mean"): active_mean = self.mean[:, achans] - if hasattr(self, "W"): + if hasattr(self, "W") and self.ppca_warm_start: active_W = self.W[:, achans] if je_suis: diff --git a/src/dartsort/util/sparse_util.py b/src/dartsort/util/sparse_util.py index d8f2204c..1920cd30 100644 --- a/src/dartsort/util/sparse_util.py +++ b/src/dartsort/util/sparse_util.py @@ -182,6 +182,54 @@ def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask): return write_ix +def csc_sparse_getrow(csc, row, rowcount): + rowix_dtype = csc.indices.dtype + indptr_out = np.empty(rowcount + 1, dtype=rowix_dtype) + data_out = np.empty(rowcount, dtype=rowix_dtype) + indices_out = np.full(rowcount, row, dtype=rowix_dtype) + _csc_sparse_getrow( + csc.indices, csc.indptr, csc.data, indptr_out, data_out, rowix_dtype.type(row) + ) + + return csc_array( + (data_out, indices_out, indptr_out), + shape=(len(kept_row_inds), csc.shape[1]), + ) + + +sigs = [ + "void(i8[::1], i8[::1], f4[::1], i8[::1], i8[::1], i8)", + "void(i4[::1], i4[::1], f4[::1], i4[::1], i4[::1], i4)", +] + + +@numba.njit(sigs, error_model="numpy", nogil=True) +def _csc_sparse_getrow(indices, indptr, data, indptr_out, data_out, the_row): + write_ix = 0 + + column = 0 + column_end = indptr[1] + + for read_ix in range(len(indices)): + row = indices[read_ix] + if row != the_row: + continue + + # write data for this sample + data_out[write_ix] = data[read_ix] + write_ix += 1 + + while read_ix >= column_end: + indptr_out[column + 1] = write_ix - 1 + column += 1 + column_end = indptr[column + 1] + + while column < len(indptr) - 1: + indptr_out[column + 1] = write_ix + column += 1 + column_end = indptr[column + 1] + + # @numba.njit(sigs, error_model="numpy", nogil=True) # def _csc_sparse_mask_rows(indices, indptr, data, oldrow_to_newrow, keep_mask): # write_ix = 0