Skip to content

Commit

Permalink
Debugging...
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 14, 2025
1 parent 2bcc511 commit 84d268e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 29 deletions.
41 changes: 17 additions & 24 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def m_step(
if self.use_proportions and likelihoods is not None:
self.update_proportions(likelihoods)
if self.log_proportions is not None:
assert len(self.log_proportions) > unit_ids.max() + self.with_noise_unit
assert len(self.log_proportions) == unit_ids.max() + 1 + self.with_noise_unit

fit_full_indices, fit_split_indices = quick_indices(
self.rg,
Expand Down Expand Up @@ -541,21 +541,11 @@ def log_likelihoods(
unit_neighb_info.append((j, neighbs, ns_unit))
else:
assert previous_logliks is not None
if hasattr(previous_logliks, "row_nnz"):
rnnz = previous_logliks.row_nnz[j]
six, data = csc_sparse_getrow(previous_logliks, j, rnnz)
else:
row = previous_logliks[[j]].tocoo(copy=True)
six = row.coords[1]
ns_unit = row.nnz
data = row.data
ns_unit = len(six)
if "covered_neighbs" in unit.annotations:
covered_neighbs = unit.annotations["covered_neighbs"]
else:
covered_neighbs = full_core_neighborhoods.neighborhood_ids[six]
covered_neighbs = torch.unique(covered_neighbs)
unit_neighb_info.append((j, six, data, ns_unit))
assert hasattr(previous_logliks, 'row_nnz')
assert 'covered_neighbs' in unit.annotations
ns_unit = previous_logliks.row_nnz[j]
unit_neighb_info.append((j, ns_unit))
covered_neighbs = unit.annotations['covered_neighbs']
core_overlaps[covered_neighbs] += 1
nnz += ns_unit

Expand All @@ -574,9 +564,10 @@ def log_likelihoods(

@delayed
def _ll_job(args):
if len(args) == 4:
j, coo, data, ns = args
return j, coo, data
if len(args) == 2:
j, ns = args
six, data = csc_sparse_getrow(previous_logliks, j, ns)
return j, six, data
else:
assert len(args) == 3
j, neighbs, ns = args
Expand All @@ -603,8 +594,7 @@ def _ll_job(args):
write_offsets = indptr[:-1].copy()
pool = Parallel(self.n_threads, backend="threading", return_as="generator")
results = pool(_ll_job(ninfo) for ninfo in unit_neighb_info)
nrows = j + 1 + with_noise_unit
row_nnz = np.zeros(nrows, dtype=int)
row_nnz = np.zeros(max(unit_ids) + 1, dtype=int)
if show_progress:
results = tqdm(
results,
Expand All @@ -617,17 +607,17 @@ def _ll_job(args):
for j, inds, liks in results:
if inds is None:
continue
row_nnz[j] = len(inds)
row_nnz[j] = len(liks)
csc_insert(j, write_offsets, inds, csc_indices, csc_data, liks)

if with_noise_unit:
liks = self.noise_log_likelihoods(indices=split_indices)
data_ixs = write_offsets[split_indices]
# assert np.array_equal(data_ixs, ccol_indices[1:] - 1) # just fyi
row_nnz[j + 1] = len(data_ixs)
csc_indices[data_ixs] = j + 1
csc_data[data_ixs] = liks

nrows = j + 1 + with_noise_unit
shape = (nrows, self.data.n_spikes)
log_liks = csc_array((csc_data, csc_indices, indptr), shape=shape)
log_liks.has_canonical_format = True
Expand Down Expand Up @@ -669,7 +659,6 @@ def update_proportions(self, log_liks):
)

def reassign(self, log_liks):
n_units = log_liks.shape[0] - self.with_noise_unit
spike_ix, assignments, spike_logliks, log_liks_csc = loglik_reassign(
log_liks,
has_noise_unit=self.with_noise_unit,
Expand All @@ -690,6 +679,7 @@ def reassign(self, log_liks):
(kept0,) = (original >= 0).nonzero(as_tuple=True)

# intersection
n_units = max(log_liks.shape[0] - self.with_noise_unit, original.max() + 1)
intersection = torch.zeros(n_units, dtype=int)
spiketorch.add_at_(intersection, assignments[kept], original[kept])

Expand Down Expand Up @@ -773,7 +763,9 @@ def cleanup(
if isinstance(log_liks, coo_array):
log_liks = coo_sparse_mask_rows(log_liks, keep_ll)
elif isinstance(log_liks, csc_array):
row_nnz = log_liks.row_nnz[keep]
log_liks = csc_sparse_mask_rows(log_liks, keep_ll, in_place=True)
log_liks.row_nnz = row_nnz
else:
assert False

Expand Down Expand Up @@ -2166,6 +2158,7 @@ def merge_units(
sym_function=self.merge_sym_function,
show_progress=show_progress,
)
print(f"{group_ids.shape=} {distances.shape=}")
if debug_info is not None:
debug_info["Z"] = Z
debug_info["improvements"] = improvements
Expand Down
10 changes: 6 additions & 4 deletions src/dartsort/util/sparse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,20 @@ def csc_sparse_getrow(csc, row, rowcount):
columns_out = np.empty(rowcount, dtype=rowix_dtype)
data_out = np.empty(rowcount, dtype=csc.data.dtype)
_csc_sparse_getrow(
csc.indices, csc.indptr, csc.data, columns_out, data_out, rowix_dtype.type(row)
csc.indices, csc.indptr, csc.data, columns_out, data_out, rowix_dtype.type(row), rowcount
)

return columns_out, data_out


sigs = [
"void(i8[::1], i8[::1], f4[::1], i8[::1], f4[::1], i8)",
"void(i4[::1], i4[::1], f4[::1], i4[::1], f4[::1], i4)",
"void(i8[::1], i8[::1], f4[::1], i8[::1], f4[::1], i8, i8)",
"void(i4[::1], i4[::1], f4[::1], i4[::1], f4[::1], i4, i8)",
]


@numba.njit(sigs, error_model="numpy", nogil=True)
def _csc_sparse_getrow(indices, indptr, data, columns_out, data_out, the_row):
def _csc_sparse_getrow(indices, indptr, data, columns_out, data_out, the_row, count):
write_ix = 0

column = 0
Expand All @@ -218,6 +218,8 @@ def _csc_sparse_getrow(indices, indptr, data, columns_out, data_out, the_row):
column_end = indptr[column + 1]
columns_out[write_ix] = column
write_ix += 1
if write_ix >= count:
return


# @numba.njit(sigs, error_model="numpy", nogil=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dartsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_fakedata():
)
),
refinement_config=dartsort.RefinementConfig(
min_count=10, channels_strategy="count_fuzzcore"
min_count=10, channels_strategy="count"
),
featurization_config=dartsort.FeaturizationConfig(n_residual_snips=512),
motion_estimation_config=dartsort.MotionEstimationConfig(
Expand Down

0 comments on commit 84d268e

Please sign in to comment.