Skip to content

feat: speed up scanpy.get.aggregate summation#4041

Draft
ilan-gold wants to merge 11 commits intoig/two_pass_hvg_v3from
ig/numba_aggregate
Draft

feat: speed up scanpy.get.aggregate summation#4041
ilan-gold wants to merge 11 commits intoig/two_pass_hvg_v3from
ig/numba_aggregate

Conversation

@ilan-gold
Copy link
Copy Markdown
Contributor

@ilan-gold ilan-gold commented Apr 8, 2026

A nice testing script:

Details
# /// script
# requires-python = ">=3.12"
# dependencies = [
#   "numba",
#   "fast-array-utils[accel,sparse]",
#   "scipy",
#   "numpy"
# ]
# ///
#
# This script automatically imports the development branch of zarr to check for issues

from __future__ import annotations

import time

import numba
import numpy as np
from fast_array_utils.numba import njit
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, random


@njit
def agg_sum_csr(  # noqa: D103
    indicator: csr_matrix,
    data: csr_matrix,
):
    out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
    for cat_num in numba.prange(indicator.shape[0]):
        start_cat_idx = indicator.indptr[cat_num]
        stop_cat_idx = indicator.indptr[cat_num + 1]
        for row_num in range(start_cat_idx, stop_cat_idx):
            obs_per_cat = indicator.indices[row_num]

            start_obs = data.indptr[obs_per_cat]
            end_obs = data.indptr[obs_per_cat + 1]

            for j in range(start_obs, end_obs):
                col = data.indices[j]
                out[cat_num, col] += float(data.data[j])
    return out


@njit
def agg_sum_csc(
    indicator: csr_matrix,
    data: csc_matrix,
):
    out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

    # Precompute: observation → category
    obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)

    for cat in range(indicator.shape[0]):
        for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
            obs_to_cat[indicator.indices[k]] = cat

    # Now iterate CSC efficiently
    for col in numba.prange(data.shape[1]):
        start = data.indptr[col]
        end = data.indptr[col + 1]

        for j in range(start, end):
            obs = data.indices[j]
            cat = obs_to_cat[obs]

            if cat != -1:
                out[cat, col] += float(data.data[j])

    return out


mat = random(70_000, 50_000, density=0.02, format="csr", rng=np.random.default_rng())
categories = np.random.randint(0, 20, size=mat.shape[0])

rows = categories
cols = np.arange(mat.shape[0])
data = np.ones(mat.shape[0], dtype=int)
membership_matrix = coo_matrix(
    (data, (categories, cols)), shape=(20, mat.shape[0])
).tocsr()

agg_sum_csr(membership_matrix, mat)
agg_sum_csc(membership_matrix, mat.tocsc())

# NOW THE FUNCTION IS COMPILED, RE-TIME IT EXECUTING FROM CACHE
start = time.time()
agg_sum_csr(membership_matrix, mat)
end = time.time()
print("numba csr time = %s" % (end - start))

start = time.time()
agg_sum_csc(membership_matrix, mat.tocsc())
end = time.time()
print("numba csr->csc time = %s" % (end - start))

csc_mat = mat.tocsc()
start = time.time()
agg_sum_csr(membership_matrix, csc_mat.tocsr())
end = time.time()
print("numba csc->csr time = %s" % (end - start))

csc_mat = mat.tocsc()
start = time.time()
agg_sum_csc(membership_matrix, csc_mat)
end = time.time()
print("numba csc time = %s" % (end - start))

start = time.time()
(membership_matrix @ mat).toarray()
end = time.time()
print("mul time = %s" % (end - start))

For me locally, even if I use numba's njit (i.e., parallel=False), this is still faster than multiplication (the current state of things). This means that things will probably also be faster for dask. It is not worth it to convert formats ever.

  • Closes #
  • Tests included or not required because:

@ilan-gold ilan-gold changed the title feat: speed up numba sums feat: speed up scanpy.get.aggregate summation Apr 8, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
2402 1 2401 339
View the top 1 failed test(s) by shortest run time
tests/test_highly_variable_genes.py::test_compare_to_seurat_v3
Stack Traces | 729s run time
#x1B[0m#x1B[37m@needs#x1B[39;49;00m.skmisc#x1B[90m#x1B[39;49;00m
    #x1B[94mdef#x1B[39;49;00m#x1B[90m #x1B[39;49;00m#x1B[92mtest_compare_to_seurat_v3#x1B[39;49;00m():#x1B[90m#x1B[39;49;00m
        #x1B[90m### test without batch#x1B[39;49;00m#x1B[90m#x1B[39;49;00m
        seurat_hvg_info = pd.read_csv(FILE_V3)#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        pbmc = pbmc3k()#x1B[90m#x1B[39;49;00m
        sc.pp.filter_cells(pbmc, min_genes=#x1B[94m200#x1B[39;49;00m)  #x1B[90m# this doesnt do anything btw#x1B[39;49;00m#x1B[90m#x1B[39;49;00m
        sc.pp.filter_genes(pbmc, min_cells=#x1B[94m3#x1B[39;49;00m)#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        pbmc_dense = pbmc.copy()#x1B[90m#x1B[39;49;00m
        pbmc_dense.X = pbmc_dense.X.toarray()#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        sc.pp.highly_variable_genes(pbmc, n_top_genes=#x1B[94m1000#x1B[39;49;00m, flavor=#x1B[33m"#x1B[39;49;00m#x1B[33mseurat_v3#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m)#x1B[90m#x1B[39;49;00m
        sc.pp.highly_variable_genes(pbmc_dense, n_top_genes=#x1B[94m1000#x1B[39;49;00m, flavor=#x1B[33m"#x1B[39;49;00m#x1B[33mseurat_v3#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m)#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        np.testing.assert_allclose(#x1B[90m#x1B[39;49;00m
            seurat_hvg_info[#x1B[33m"#x1B[39;49;00m#x1B[33mvariance#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m],#x1B[90m#x1B[39;49;00m
            pbmc.var[#x1B[33m"#x1B[39;49;00m#x1B[33mvariances#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m],#x1B[90m#x1B[39;49;00m
            rtol=#x1B[94m2e-05#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            atol=#x1B[94m2e-05#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
        )#x1B[90m#x1B[39;49;00m
        np.testing.assert_allclose(#x1B[90m#x1B[39;49;00m
            seurat_hvg_info[#x1B[33m"#x1B[39;49;00m#x1B[33mvariance.standardized#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m],#x1B[90m#x1B[39;49;00m
            pbmc.var[#x1B[33m"#x1B[39;49;00m#x1B[33mvariances_norm#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m],#x1B[90m#x1B[39;49;00m
            rtol=#x1B[94m2e-05#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            atol=#x1B[94m2e-05#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
        )#x1B[90m#x1B[39;49;00m
        np.testing.assert_allclose(#x1B[90m#x1B[39;49;00m
            pbmc_dense.var[#x1B[33m"#x1B[39;49;00m#x1B[33mvariances_norm#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m],#x1B[90m#x1B[39;49;00m
            pbmc.var[#x1B[33m"#x1B[39;49;00m#x1B[33mvariances_norm#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m],#x1B[90m#x1B[39;49;00m
            rtol=#x1B[94m2e-05#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            atol=#x1B[94m2e-05#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
        )#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        #x1B[90m### test with batch#x1B[39;49;00m#x1B[90m#x1B[39;49;00m
        #x1B[90m# introduce a dummy "technical covariate"; this is used in Seurat's SelectIntegrationFeatures#x1B[39;49;00m#x1B[90m#x1B[39;49;00m
        pbmc.obs[#x1B[33m"#x1B[39;49;00m#x1B[33mdummy_tech#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m] = (#x1B[90m#x1B[39;49;00m
            #x1B[33m"#x1B[39;49;00m#x1B[33msource_#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m + pd.array([*#x1B[96mrange#x1B[39;49;00m(#x1B[94m1#x1B[39;49;00m, #x1B[94m6#x1B[39;49;00m), #x1B[94m5#x1B[39;49;00m]).repeat(#x1B[94m500#x1B[39;49;00m).astype(#x1B[33m"#x1B[39;49;00m#x1B[33mstring#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m)#x1B[90m#x1B[39;49;00m
        )[: pbmc.n_obs]#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        seurat_v3_paper = sc.pp.highly_variable_genes(#x1B[90m#x1B[39;49;00m
            pbmc,#x1B[90m#x1B[39;49;00m
            n_top_genes=#x1B[94m2000#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            flavor=#x1B[33m"#x1B[39;49;00m#x1B[33mseurat_v3_paper#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            batch_key=#x1B[33m"#x1B[39;49;00m#x1B[33mdummy_tech#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            inplace=#x1B[94mFalse#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
        )#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        seurat_v3 = sc.pp.highly_variable_genes(#x1B[90m#x1B[39;49;00m
            pbmc,#x1B[90m#x1B[39;49;00m
            n_top_genes=#x1B[94m2000#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            flavor=#x1B[33m"#x1B[39;49;00m#x1B[33mseurat_v3#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            batch_key=#x1B[33m"#x1B[39;49;00m#x1B[33mdummy_tech#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
            inplace=#x1B[94mFalse#x1B[39;49;00m,#x1B[90m#x1B[39;49;00m
        )#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        seurat_hvg_info_batch = pd.read_csv(FILE_V3_BATCH)#x1B[90m#x1B[39;49;00m
        seu = pd.Index(seurat_hvg_info_batch[#x1B[33m"#x1B[39;49;00m#x1B[33mx#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m].to_numpy())#x1B[90m#x1B[39;49;00m
    #x1B[90m#x1B[39;49;00m
        gene_intersection_paper = seu.intersection(#x1B[90m#x1B[39;49;00m
            seurat_v3_paper[seurat_v3_paper[#x1B[33m"#x1B[39;49;00m#x1B[33mhighly_variable#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m]].index#x1B[90m#x1B[39;49;00m
        )#x1B[90m#x1B[39;49;00m
        gene_intersection_impl = seu.intersection(#x1B[90m#x1B[39;49;00m
            seurat_v3[seurat_v3[#x1B[33m"#x1B[39;49;00m#x1B[33mhighly_variable#x1B[39;49;00m#x1B[33m"#x1B[39;49;00m]].index#x1B[90m#x1B[39;49;00m
        )#x1B[90m#x1B[39;49;00m
>       #x1B[94massert#x1B[39;49;00m #x1B[96mlen#x1B[39;49;00m(gene_intersection_paper) / #x1B[94m2000#x1B[39;49;00m > #x1B[94m0.95#x1B[39;49;00m#x1B[90m#x1B[39;49;00m
#x1B[1m#x1B[31mE       AssertionError: assert (1867 / 2000) > 0.95#x1B[0m
#x1B[1m#x1B[31mE        +  where 1867 = len(Index(['LYZ', 'GNLY', 'S100A9', 'FTL', 'FTH1', 'S100A8', 'HLA-DRA', 'CST3',\n       'CD74', 'NKG7',\n       ...\n       'CR1', 'SLC20A2', 'HBEGF', 'IL15', 'DNAJC9', 'LRRC59', 'ALG12', 'DPP4',\n       'CLUAP1', 'HCFC2'],\n      dtype='object', length=1867))#x1B[0m

#x1B[1m#x1B[31mtests/test_highly_variable_genes.py#x1B[0m:492: AssertionError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@scverse-benchmark
Copy link
Copy Markdown

scverse-benchmark bot commented Apr 8, 2026

No changes in benchmarks.

Warning

Some benchmarks failed

Comparison: https://github.com/scverse/scanpy/compare/8ad893d22c112bc6f149660ca6b5c711668d76c0..6fe8b08f76c4a9ae5d45bfd14d3fd16a1f8acb74
Last changed: Wed, 8 Apr 2026 23:55:52 +0000

More details: https://github.com/scverse/scanpy/pull/4041/checks?check_run_id=70488844488

@ilan-gold ilan-gold force-pushed the ig/numba_aggregate branch from f4fa89b to febc107 Compare April 8, 2026 18:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant