From 84d268ea3a3393c25e4a7955e550437fb22b355a Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 14 Jan 2025 12:28:03 -0800 Subject: [PATCH] Debugging... --- src/dartsort/cluster/gaussian_mixture.py | 41 ++++++++++-------------- src/dartsort/util/sparse_util.py | 10 +++--- tests/test_dartsort.py | 2 +- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index 10a72675..f5cd5d1a 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -454,7 +454,7 @@ def m_step( if self.use_proportions and likelihoods is not None: self.update_proportions(likelihoods) if self.log_proportions is not None: - assert len(self.log_proportions) > unit_ids.max() + self.with_noise_unit + assert len(self.log_proportions) == unit_ids.max() + 1 + self.with_noise_unit fit_full_indices, fit_split_indices = quick_indices( self.rg, @@ -541,21 +541,11 @@ def log_likelihoods( unit_neighb_info.append((j, neighbs, ns_unit)) else: assert previous_logliks is not None - if hasattr(previous_logliks, "row_nnz"): - rnnz = previous_logliks.row_nnz[j] - six, data = csc_sparse_getrow(previous_logliks, j, rnnz) - else: - row = previous_logliks[[j]].tocoo(copy=True) - six = row.coords[1] - ns_unit = row.nnz - data = row.data - ns_unit = len(six) - if "covered_neighbs" in unit.annotations: - covered_neighbs = unit.annotations["covered_neighbs"] - else: - covered_neighbs = full_core_neighborhoods.neighborhood_ids[six] - covered_neighbs = torch.unique(covered_neighbs) - unit_neighb_info.append((j, six, data, ns_unit)) + assert hasattr(previous_logliks, 'row_nnz') + assert 'covered_neighbs' in unit.annotations + ns_unit = previous_logliks.row_nnz[j] + unit_neighb_info.append((j, ns_unit)) + covered_neighbs = unit.annotations['covered_neighbs'] core_overlaps[covered_neighbs] += 1 nnz += ns_unit @@ -574,9 +564,10 @@ def log_likelihoods( @delayed def _ll_job(args): - if len(args) == 4: - j, coo, data, ns = args - return j, coo, data + if len(args) == 2: + j, ns = args + six, data = csc_sparse_getrow(previous_logliks, j, ns) + return j, six, data else: assert len(args) == 3 j, neighbs, ns = args @@ -603,8 +594,7 @@ def _ll_job(args): write_offsets = indptr[:-1].copy() pool = Parallel(self.n_threads, backend="threading", return_as="generator") results = pool(_ll_job(ninfo) for ninfo in unit_neighb_info) - nrows = j + 1 + with_noise_unit - row_nnz = np.zeros(nrows, dtype=int) + row_nnz = np.zeros(max(unit_ids) + 1, dtype=int) if show_progress: results = tqdm( results, @@ -617,17 +607,17 @@ def _ll_job(args): for j, inds, liks in results: if inds is None: continue - row_nnz[j] = len(inds) + row_nnz[j] = len(liks) csc_insert(j, write_offsets, inds, csc_indices, csc_data, liks) if with_noise_unit: liks = self.noise_log_likelihoods(indices=split_indices) data_ixs = write_offsets[split_indices] # assert np.array_equal(data_ixs, ccol_indices[1:] - 1) # just fyi - row_nnz[j + 1] = len(data_ixs) csc_indices[data_ixs] = j + 1 csc_data[data_ixs] = liks + nrows = j + 1 + with_noise_unit shape = (nrows, self.data.n_spikes) log_liks = csc_array((csc_data, csc_indices, indptr), shape=shape) log_liks.has_canonical_format = True @@ -669,7 +659,6 @@ def update_proportions(self, log_liks): ) def reassign(self, log_liks): - n_units = log_liks.shape[0] - self.with_noise_unit spike_ix, assignments, spike_logliks, log_liks_csc = loglik_reassign( log_liks, has_noise_unit=self.with_noise_unit, @@ -690,6 +679,7 @@ def reassign(self, log_liks): (kept0,) = (original >= 0).nonzero(as_tuple=True) # intersection + n_units = max(log_liks.shape[0] - self.with_noise_unit, original.max() + 1) intersection = torch.zeros(n_units, dtype=int) spiketorch.add_at_(intersection, assignments[kept], original[kept]) @@ -773,7 +763,9 @@ def cleanup( if isinstance(log_liks, coo_array): log_liks = coo_sparse_mask_rows(log_liks, keep_ll) elif isinstance(log_liks, csc_array): + row_nnz = log_liks.row_nnz[keep] log_liks = csc_sparse_mask_rows(log_liks, keep_ll, in_place=True) + log_liks.row_nnz = row_nnz else: assert False @@ -2166,6 +2158,7 @@ def merge_units( sym_function=self.merge_sym_function, show_progress=show_progress, ) + print(f"{group_ids.shape=} {distances.shape=}") if debug_info is not None: debug_info["Z"] = Z debug_info["improvements"] = improvements diff --git a/src/dartsort/util/sparse_util.py b/src/dartsort/util/sparse_util.py index 23065548..4c4c045a 100644 --- a/src/dartsort/util/sparse_util.py +++ b/src/dartsort/util/sparse_util.py @@ -187,20 +187,20 @@ def csc_sparse_getrow(csc, row, rowcount): columns_out = np.empty(rowcount, dtype=rowix_dtype) data_out = np.empty(rowcount, dtype=csc.data.dtype) _csc_sparse_getrow( - csc.indices, csc.indptr, csc.data, columns_out, data_out, rowix_dtype.type(row) + csc.indices, csc.indptr, csc.data, columns_out, data_out, rowix_dtype.type(row), rowcount ) return columns_out, data_out sigs = [ - "void(i8[::1], i8[::1], f4[::1], i8[::1], f4[::1], i8)", - "void(i4[::1], i4[::1], f4[::1], i4[::1], f4[::1], i4)", + "void(i8[::1], i8[::1], f4[::1], i8[::1], f4[::1], i8, i8)", + "void(i4[::1], i4[::1], f4[::1], i4[::1], f4[::1], i4, i8)", ] @numba.njit(sigs, error_model="numpy", nogil=True) -def _csc_sparse_getrow(indices, indptr, data, columns_out, data_out, the_row): +def _csc_sparse_getrow(indices, indptr, data, columns_out, data_out, the_row, count): write_ix = 0 column = 0 @@ -218,6 +218,8 @@ def _csc_sparse_getrow(indices, indptr, data, columns_out, data_out, the_row): column_end = indptr[column + 1] columns_out[write_ix] = column write_ix += 1 + if write_ix >= count: + return # @numba.njit(sigs, error_model="numpy", nogil=True) diff --git a/tests/test_dartsort.py b/tests/test_dartsort.py index cf99443d..4c016619 100644 --- a/tests/test_dartsort.py +++ b/tests/test_dartsort.py @@ -88,7 +88,7 @@ def test_fakedata(): ) ), refinement_config=dartsort.RefinementConfig( - min_count=10, channels_strategy="count_fuzzcore" + min_count=10, channels_strategy="count" ), featurization_config=dartsort.FeaturizationConfig(n_residual_snips=512), motion_estimation_config=dartsort.MotionEstimationConfig(