Skip to content

Commit faeb203

Browse files
can take the min/max in channels aggregation (#147)
1 parent b7b4b5f commit faeb203

File tree

5 files changed

+71
-15
lines changed

5 files changed

+71
-15
lines changed

docs/api/aggregation.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11
# Aggregation
22

3+
!!! tips "Recommendation"
4+
We recommend using the `sopa.aggregate` function below, which is a wrapper for all types of aggregation. Internally, it uses `aggregate_channels`, `count_transcripts`, and/or `aggregate_bins`, which are also documented below if needed.
5+
36
::: sopa.aggregate
7+
8+
::: sopa.aggregation.aggregate_channels
9+
10+
::: sopa.aggregation.count_transcripts
11+
12+
::: sopa.aggregation.aggregate_bins

sopa/aggregation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .bins import aggregate_bins
2-
from .channels import average_channels
2+
from .channels import average_channels, aggregate_channels
33
from .transcripts import count_transcripts
44
from .aggregation import aggregate, Aggregator
55
from .overlay import overlay_segmentation

sopa/aggregation/aggregation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
get_spatial_element,
1818
get_spatial_image,
1919
)
20-
from . import aggregate_bins, average_channels, count_transcripts
20+
from . import aggregate_bins
21+
from . import aggregate_channels as _aggregate_channels
22+
from . import count_transcripts
2123

2224
log = logging.getLogger(__name__)
2325

@@ -157,7 +159,7 @@ def compute_table(
157159
self.filter_cells(self.table.X.sum(axis=1) < min_transcripts)
158160

159161
if aggregate_channels:
160-
mean_intensities = average_channels(
162+
mean_intensities = _aggregate_channels(
161163
self.sdata,
162164
image_key=self.image_key,
163165
shapes_key=self.shapes_key,

sopa/aggregation/channels.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import dask
66
import geopandas as gpd
77
import numpy as np
8+
import numpy.ma as ma
89
import shapely
910
from dask.diagnostics import ProgressBar
1011
from shapely.geometry import Polygon, box
@@ -16,35 +17,47 @@
1617

1718
log = logging.getLogger(__name__)
1819

20+
AVAILABLE_MODES = ["average", "min", "max"]
21+
1922

2023
def average_channels(
24+
sdata: SpatialData, image_key: str = None, shapes_key: str = None, expand_radius_ratio: float = 0
25+
) -> np.ndarray:
26+
log.warning("average_channels is deprecated, use `aggregate_channels` instead")
27+
return aggregate_channels(sdata, image_key, shapes_key, expand_radius_ratio, mode="average")
28+
29+
30+
def aggregate_channels(
2131
sdata: SpatialData,
2232
image_key: str = None,
2333
shapes_key: str = None,
2434
expand_radius_ratio: float = 0,
35+
mode: str = "average",
2536
) -> np.ndarray:
26-
"""Average channel intensities per cell.
37+
"""Aggregate the channel intensities per cell (either `"average"`, or take the `"min"` / `"max"`).
2738
2839
Args:
2940
sdata: A `SpatialData` object
3041
image_key: Key of `sdata` containing the image. If only one `images` element, this does not have to be provided.
3142
shapes_key: Key of `sdata` containing the cell boundaries. If only one `shapes` element, this does not have to be provided.
3243
expand_radius_ratio: Cells polygons will be expanded by `expand_radius_ratio * mean_radius`. This help better aggregate boundary stainings.
44+
mode: Aggregation mode. One of `"average"`, `"min"`, `"max"`. By default, average the intensity inside the cell mask.
3345
3446
Returns:
3547
A numpy `ndarray` of shape `(n_cells, n_channels)`
3648
"""
49+
assert mode in AVAILABLE_MODES, f"Invalid {mode=}. Available modes are {AVAILABLE_MODES}"
50+
3751
image = get_spatial_image(sdata, image_key)
3852

3953
geo_df = get_boundaries(sdata, key=shapes_key)
4054
geo_df = to_intrinsic(sdata, geo_df, image)
4155
geo_df = expand_radius(geo_df, expand_radius_ratio)
4256

43-
log.info(f"Averaging channels intensity over {len(geo_df)} cells with expansion {expand_radius_ratio=}")
44-
return _average_channels_aligned(image, geo_df)
57+
return _aggregate_channels_aligned(image, geo_df, mode)
4558

4659

47-
def _average_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[Polygon]) -> np.ndarray:
60+
def _aggregate_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[Polygon], mode: str) -> np.ndarray:
4861
"""Average channel intensities per cell. The image and cells have to be aligned, i.e. be on the same coordinate system.
4962
5063
Args:
@@ -54,11 +67,17 @@ def _average_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[
5467
Returns:
5568
A numpy `ndarray` of shape `(n_cells, n_channels)`
5669
"""
70+
log.info(f"Aggregating channels intensity over {len(geo_df)} cells with {mode=}")
71+
5772
cells = geo_df if isinstance(geo_df, list) else list(geo_df.geometry)
5873
tree = shapely.STRtree(cells)
5974

60-
intensities = np.zeros((len(cells), len(image.coords["c"])))
75+
n_channels = len(image.coords["c"])
6176
areas = np.zeros(len(cells))
77+
if mode == "min":
78+
aggregation = np.full((len(cells), n_channels), fill_value=np.inf)
79+
else:
80+
aggregation = np.zeros((len(cells), n_channels))
6281

6382
chunk_sizes = image.data.chunks
6483
offsets_y = np.cumsum(np.pad(chunk_sizes[1], (1, 0), "constant"))
@@ -86,9 +105,20 @@ def _average_chunk_inside_cells(chunk, iy, ix):
86105

87106
mask = rasterize(cell, sub_image.shape[1:], bounds)
88107

89-
intensities[index] += np.sum(sub_image * mask, axis=(1, 2))
90108
areas[index] += np.sum(mask)
91109

110+
if mode == "min":
111+
masked_image = ma.masked_array(sub_image, 1 - np.repeat(mask[None], n_channels, axis=0))
112+
aggregation[index] = np.minimum(aggregation[index], masked_image.min(axis=(1, 2)))
113+
elif mode in ["average", "max"]:
114+
func = np.sum if mode == "average" else np.max
115+
values = func(sub_image * mask, axis=(1, 2))
116+
117+
if mode == "average":
118+
aggregation[index] += values
119+
else:
120+
aggregation[index] = np.maximum(aggregation[index], values)
121+
92122
with ProgressBar():
93123
tasks = [
94124
dask.delayed(_average_chunk_inside_cells)(chunk, iy, ix)
@@ -97,4 +127,7 @@ def _average_chunk_inside_cells(chunk, iy, ix):
97127
]
98128
dask.compute(tasks)
99129

100-
return intensities / areas[:, None].clip(1)
130+
if mode == "average":
131+
return aggregation / areas[:, None].clip(1)
132+
else:
133+
return aggregation

tests/test_aggregation.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import xarray as xr
77
from shapely.geometry import Polygon, box
88

9-
from sopa.aggregation.channels import _average_channels_aligned
9+
from sopa.aggregation.channels import _aggregate_channels_aligned
1010
from sopa.aggregation.transcripts import _count_transcripts_aligned
1111

1212
dask.config.set({"dataframe.query-planning": False})
1313
import dask.dataframe as dd # noqa
1414

1515

16-
def test_average_channels_aligned():
16+
def test_aggregate_channels_aligned():
1717
image = np.random.randint(1, 10, size=(3, 8, 16))
1818
arr = da.from_array(image, chunks=(1, 8, 8))
1919
xarr = xr.DataArray(arr, dims=["c", "y", "x"])
@@ -24,11 +24,23 @@ def test_average_channels_aligned():
2424
# One cell is on the first block, one is overlapping on both blocks, and one is on the last block
2525
cells = [box(x, y, x + cell_size - 1, y + cell_size - 1) for x, y in cell_start]
2626

27-
means = _average_channels_aligned(xarr, cells)
27+
mean_intensities = _aggregate_channels_aligned(xarr, cells, "average")
28+
min_intensities = _aggregate_channels_aligned(xarr, cells, "min")
29+
max_intensities = _aggregate_channels_aligned(xarr, cells, "max")
2830

29-
true_means = np.stack([image[:, y : y + cell_size, x : x + cell_size].mean(axis=(1, 2)) for x, y in cell_start])
31+
true_mean_intensities = np.stack(
32+
[image[:, y : y + cell_size, x : x + cell_size].mean(axis=(1, 2)) for x, y in cell_start]
33+
)
34+
true_min_intensities = np.stack(
35+
[image[:, y : y + cell_size, x : x + cell_size].min(axis=(1, 2)) for x, y in cell_start]
36+
)
37+
true_max_intensities = np.stack(
38+
[image[:, y : y + cell_size, x : x + cell_size].max(axis=(1, 2)) for x, y in cell_start]
39+
)
3040

31-
assert (means == true_means).all()
41+
assert (mean_intensities == true_mean_intensities).all()
42+
assert (min_intensities == true_min_intensities).all()
43+
assert (max_intensities == true_max_intensities).all()
3244

3345

3446
def test_count_transcripts():

0 commit comments

Comments
 (0)