Skip to content

Commit

Permalink
Added named_arrays.ndfilters.variance_filter() function. (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
byrdie authored Sep 17, 2024
1 parent a49de33 commit 3878330
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
1 change: 1 addition & 0 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
NDFILTER_FUNCTIONS = (
na.ndfilters.mean_filter,
na.ndfilters.trimmed_mean_filter,
na.ndfilters.variance_filter,
)
HANDLED_FUNCTIONS = dict()

Expand Down
66 changes: 66 additions & 0 deletions named_arrays/ndfilters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
__all__ = [
"mean_filter",
"trimmed_mean_filter",
"variance_filter",
]

ArrayT = TypeVar("ArrayT", bound="na.AbstractArray")
Expand Down Expand Up @@ -157,3 +158,68 @@ def trimmed_mean_filter(
mode=mode,
proportion=proportion,
)


def variance_filter(
array: ArrayT,
size: dict[str, int],
where: WhereT = True,
) -> ArrayT | WhereT:
"""
A thin wrapper around :func:`ndfilters.variance_filter` for named arrays.
Parameters
----------
array
The input array to be filtered.
size
The shape of the kernel over which the variance will be calculated.
where
A boolean mask used to select which elements of the input array are to
be filtered.
Examples
--------
Filter a sample image.
.. jupyter-execute::
:stderr:
import matplotlib.pyplot as plt
import scipy.datasets
import named_arrays as na
img = na.ScalarArray(scipy.datasets.ascent(), axes=("y", "x"))
img_filtered = na.ndfilters.variance_filter(img, size=dict(x=21, y=21))
fig, axs = plt.subplots(
ncols=2,
sharex=True,
sharey=True,
constrained_layout=True,
)
axs[0].set_title("original image");
na.plt.imshow(
X=img,
axis_x="x",
axis_y="y",
ax=axs[0],
cmap="gray",
);
axs[1].set_title("filtered image");
na.plt.imshow(
X=img_filtered,
axis_x="x",
axis_y="y",
ax=axs[1],
cmap="gray",
);
"""
return na._named_array_function(
func=variance_filter,
array=array,
size=size,
where=where,
)
6 changes: 4 additions & 2 deletions named_arrays/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,7 @@ def callback(i, x, f, c):
argvalues=[
na.ndfilters.mean_filter,
na.ndfilters.trimmed_mean_filter,
na.ndfilters.variance_filter,
]
)
class TestNdfilter:
Expand All @@ -1489,7 +1490,7 @@ def test_ndfilter(
array: na.AbstractArray,
):

size = dict(y=1)
size = dict(y=3)

kwargs = dict(
array=array,
Expand All @@ -1503,7 +1504,8 @@ def test_ndfilter(

result = function(**kwargs)

assert np.all(result == array)
assert result.type_abstract == array.type_abstract
assert result.shape == array.shape

@pytest.mark.parametrize("axis", [None, "y"])
class TestColorsynth:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"matplotlib",
"scipy",
'astropy',
"ndfilters==0.2.0",
"ndfilters==0.3.0",
"colorsynth==0.1.3",
]
dynamic = ["version"]
Expand Down

0 comments on commit 3878330

Please sign in to comment.