Skip to content

Commit

Permalink
Modified the behavior of named_arrays.histogram2d to return an unbr…
Browse files Browse the repository at this point in the history
…oadcasted version of the edges. (#99)
  • Loading branch information
byrdie authored Nov 15, 2024
1 parent 3d966b3 commit d12489a
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 73 deletions.
1 change: 1 addition & 0 deletions named_arrays/_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def histogram2d(
The bin specification of the histogram:
* If `bins` is a dictionary, the keys are interpreted as the axis names
and the values are the number of bins along each axis.
This dictionary must have exactly two keys.
* If `bins` is a 2D Cartesian vector, each component of the vector
represents the bin edges in each dimension.
axis
Expand Down
143 changes: 76 additions & 67 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Sequence, Any, Literal
import dataclasses
import numpy as np
import numpy.typing as npt
import matplotlib.axes
Expand Down Expand Up @@ -280,102 +281,110 @@ def histogram2d(
shape_orthogonal = {a: shape[a] for a in shape if a not in axis}

if isinstance(bins, na.AbstractCartesian2dVectorArray):
shape_bins = bins.shape
if bins.x.ndim != 1: # pragma: nocover

axis_x = set(bins.x.shape) - set(shape_orthogonal)
if len(axis_x) != 1: # pragma: nocover
raise ValueError(
f"The x component of `bins` must have only one dimension, "
f"got {bins.x.shape}."
f"if `bins` is a vector, `bins.x` must have exactly one new axis, "
f"got {axis_x}"
)
if bins.y.ndim != 1: # pragma: nocover

axis_y = set(bins.y.shape) - set(shape_orthogonal)
if len(axis_y) != 1: # pragma: nocover
raise ValueError(
f"The y component of `bins` must have only one dimension, "
f"got {bins.y.shape}."
f"if `bins` is a vector, `bins.y` must have exactly one new axis, "
f"got {axis_y}"
)
bins = (
bins.x.ndarray,
bins.y.ndarray,
)
else:
shape_bins = {ax: bins[ax] + 1 for ax in bins}
bins = tuple(bins.values())

if set(shape_bins).issubset(shape_orthogonal): # pragma: nocover
raise ValueError(
f"The histogram axes, {shape_bins}, should not be a subset of "
f"the orthogonal axes, {shape_orthogonal}."
)
if axis_x == axis_y: # pragma: nocover
raise ValueError(
f"if `bins` is a vector, `bins.x` and `bins.y` must be separable "
f"along the new axes, found non-separable axis {axis_x}."
)

axis_x, axis_y = shape_bins
edges = bins

a = na.Cartesian2dVectorArray(x, y)
elif isinstance(bins, dict):

if min is None:
min = a.min(axis)
elif not isinstance(min, na.AbstractCartesian2dVectorArray):
min = na.broadcast_to(min, shape_orthogonal)
min = na.Cartesian2dVectorArray.from_scalar(min)
else:
min = na.broadcast_to(min, shape_orthogonal)
a = na.Cartesian2dVectorArray(x, y)

if max is None:
max = a.max(axis)
elif not isinstance(max, na.AbstractCartesian2dVectorArray):
max = na.broadcast_to(max, shape_orthogonal)
max = na.Cartesian2dVectorArray.from_scalar(max)
else:
max = na.broadcast_to(max, shape_orthogonal)
if min is None:
min = a.min(axis)
elif not isinstance(min, na.AbstractCartesian2dVectorArray):
min = na.Cartesian2dVectorArray(min, min)

shape_hist = {ax: shape_bins[ax] - 1 for ax in shape_bins}
shape_hist = na.broadcast_shapes(shape_orthogonal, shape_hist)
if max is None:
max = a.max(axis)
elif not isinstance(max, na.AbstractCartesian2dVectorArray):
max = na.Cartesian2dVectorArray(max, max)

if len(bins) != 2: # pragma: nocover
raise ValueError(
f"if `bins` is a dictionary, it must have exactly two keys, "
f"got {bins=}."
)

axis_x, axis_y = tuple(bins)

edges = na.Cartesian2dVectorLinearSpace(
start=min,
stop=max,
axis=na.Cartesian2dVectorArray(
x=axis_x,
y=axis_y,
),
num=na.Cartesian2dVectorArray(
x=bins[axis_x] + 1,
y=bins[axis_y] + 1,
),
)

else: # pragma: nocover
return NotImplemented

shape_edges_x = na.broadcast_shapes(shape_orthogonal, edges.x.shape)
shape_edges_y = na.broadcast_shapes(shape_orthogonal, edges.y.shape)

edges_broadcasted = dataclasses.replace(
edges.explicit,
x=edges.x.broadcast_to(shape_edges_x),
y=edges.y.broadcast_to(shape_edges_y),
)

shape_x = na.broadcast_shapes(shape_orthogonal, {axis_x: shape_bins[axis_x]})
shape_y = na.broadcast_shapes(shape_orthogonal, {axis_y: shape_bins[axis_y]})
shape_edges = edges.shape
shape_hist = {
ax: shape_edges[ax] - 1
for ax in shape_edges
if ax not in shape_orthogonal
}
shape_hist = na.broadcast_shapes(shape_orthogonal, shape_hist)

hist = na.ScalarArray.empty(shape_hist)
xedges = na.ScalarArray.empty(shape_x)
yedges = na.ScalarArray.empty(shape_y)

unit_weights = na.unit(weights)
unit_x = na.unit(x)
unit_y = na.unit(y)

hist = hist if unit_weights is None else hist << unit_weights
xedges = xedges if unit_x is None else xedges << unit_x
yedges = yedges if unit_y is None else yedges << unit_y

for i in na.ndindex(shape_orthogonal):
min_i = min[i]
max_i = max[i]
hist_i, xedges_i, yedges_i = np.histogram2d(
edges_x_i = edges_broadcasted.x[i]
edges_y_i = edges_broadcasted.y[i]
hist_i, _, _ = np.histogram2d(
x=x[i].ndarray_aligned(axis).reshape(-1),
y=y[i].ndarray_aligned(axis).reshape(-1),
bins=bins,
range=[
[min_i.x.ndarray, max_i.x.ndarray],
[min_i.y.ndarray, max_i.y.ndarray],
],
bins=(
edges_x_i.ndarray,
edges_y_i.ndarray,
),
density=density,
weights=weights[i].ndarray_aligned(axis).reshape(-1) if weights is not None else weights,
)

hist[i] = na.ScalarArray(
ndarray=hist_i,
axes=tuple(shape_bins),
)
xedges[i] = na.ScalarArray(
ndarray=xedges_i,
axes=axis_x
)
yedges[i] = na.ScalarArray(
ndarray=yedges_i,
axes=axis_y,
axes=edges_x_i.axes + edges_y_i.axes,
)

return na.FunctionArray(
inputs=na.Cartesian2dVectorArray(
x=xedges,
y=yedges,
),
inputs=edges,
outputs=hist,
)

Expand Down
5 changes: 5 additions & 0 deletions named_arrays/_scalars/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ class TestNamedArrayFunctions(
-100,
100,
),
(
dict(hx=11, hy=12),
None,
None,
),
(
na.Cartesian2dVectorArray(
x=na.linspace(-100, 100, axis="hx", num=11),
Expand Down
18 changes: 12 additions & 6 deletions named_arrays/ndfilters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ def mean_filter(
.. jupyter-execute::
:stderr:
import numpy as np
import matplotlib.pyplot as plt
import scipy.datasets
import named_arrays as na
img = na.ScalarArray(scipy.datasets.ascent(), axes=("y", "x"))
x = na.linspace(-5, 5, axis="x", num=201)
y = na.linspace(-5, 5, axis="y", num=201)
img = np.cos(np.square(x)) * np.cos(np.square(y))
img_filtered = na.ndfilters.mean_filter(img, size=dict(x=21, y=21))
Expand Down Expand Up @@ -116,11 +118,13 @@ def trimmed_mean_filter(
.. jupyter-execute::
:stderr:
import numpy as np
import matplotlib.pyplot as plt
import scipy.datasets
import named_arrays as na
img = na.ScalarArray(scipy.datasets.ascent(), axes=("y", "x"))
x = na.linspace(-5, 5, axis="x", num=201)
y = na.linspace(-5, 5, axis="y", num=201)
img = np.cos(np.square(x)) * np.cos(np.square(y))
img_filtered = na.ndfilters.trimmed_mean_filter(
array=img,
Expand Down Expand Up @@ -186,11 +190,13 @@ def variance_filter(
.. jupyter-execute::
:stderr:
import numpy as np
import matplotlib.pyplot as plt
import scipy.datasets
import named_arrays as na
img = na.ScalarArray(scipy.datasets.ascent(), axes=("y", "x"))
x = na.linspace(-5, 5, axis="x", num=201)
y = na.linspace(-5, 5, axis="y", num=201)
img = np.cos(np.square(x)) * np.cos(np.square(y))
img_filtered = na.ndfilters.variance_filter(img, size=dict(x=21, y=21))
Expand Down

0 comments on commit d12489a

Please sign in to comment.