Skip to content

Commit

Permalink
Added named_arrays.plt.pcolormesh() function.
Browse files Browse the repository at this point in the history
  • Loading branch information
byrdie committed Feb 11, 2024
1 parent 6ce416a commit 6737bbb
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 0 deletions.
35 changes: 35 additions & 0 deletions named_arrays/_functions/function_named_array_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable
import numpy as np
import matplotlib
import astropy.units as u
import named_arrays as na
import named_arrays._scalars.scalar_named_array_functions
Expand Down Expand Up @@ -88,3 +89,37 @@ def unit_normalized(
squeeze: bool = True,
) -> u.UnitBase | na.AbstractArray:
return na.unit_normalized(a.outputs, squeeze=squeeze)


@_implements(na.plt.pcolormesh)
def pcolormesh(
*XY: na.AbstractArray,
C: na.AbstractFunctionArray,
components: None | tuple[str, str] = None,
axis_rgb: None | str = None,
ax: None | matplotlib.axes.Axes | na.AbstractArray = None,
cmap: None | str | matplotlib.colors.Colormap = None,
norm: None | str | matplotlib.colors.Normalize = None,
vmin: None | na.ArrayLike = None,
vmax: None | na.ArrayLike = None,
**kwargs,
) -> na.ScalarArray:

if len(XY) != 0:
raise ValueError(

Check warning on line 109 in named_arrays/_functions/function_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_functions/function_named_array_functions.py#L109

Added line #L109 was not covered by tests
"if `C` is an instance of `na.AbstractFunctionArray`, "
"`XY` must not be specified."
)

return na.plt.pcolormesh(
C.inputs,
C=C.outputs,
components=components,
axis_rgb=axis_rgb,
ax=ax,
cmap=cmap,
norm=norm,
vmin=vmin,
vmax=vmax,
**kwargs,
)
47 changes: 47 additions & 0 deletions named_arrays/_functions/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,53 @@ class TestPltScatter(
):
pass

class TestPltPcolormesh(
named_arrays.tests.test_core.AbstractTestAbstractArray.TestNamedArrayFunctions.TestPltPcolormesh,
):
@pytest.mark.parametrize("axis_rgb", [None, "rgb"])
def test_pcolormesh(
self,
array: na.AbstractScalarArray,
axis_rgb: None | str
):
if not isinstance(array.inputs, na.AbstractVectorArray):
return

components = list(array.inputs.components.keys())[:2]

kwargs = dict(
C=array,
axis_rgb=axis_rgb,
components=components,
)

if isinstance(array.outputs, na.AbstractVectorArray):
with pytest.raises(TypeError):
na.plt.pcolormesh(**kwargs)
return
elif isinstance(array.outputs, na.AbstractUncertainScalarArray):
with pytest.raises(TypeError):
na.plt.pcolormesh(**kwargs)
return

if any(isinstance(array.inputs.components[c], na.AbstractUncertainScalarArray) for c in components):
with pytest.raises(TypeError):
na.plt.pcolormesh(**kwargs)
return

Check warning on line 673 in named_arrays/_functions/tests/test_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_functions/tests/test_functions.py#L671-L673

Added lines #L671 - L673 were not covered by tests

if axis_rgb is not None:
with pytest.raises(ValueError):
na.plt.pcolormesh(**kwargs)
return

if array.ndim != 2:
with pytest.raises(ValueError):
na.plt.pcolormesh(**kwargs)
return

Check warning on line 683 in named_arrays/_functions/tests/test_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_functions/tests/test_functions.py#L681-L683

Added lines #L681 - L683 were not covered by tests

result = na.plt.pcolormesh(**kwargs)
assert isinstance(result, na.ScalarArray)

@pytest.mark.xfail
class TestJacobian(
named_arrays.tests.test_core.AbstractTestAbstractArray.TestNamedArrayFunctions.TestJacobian,
Expand Down
58 changes: 58 additions & 0 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,64 @@ def plt_imshow(
return result


@_implements(na.plt.pcolormesh)
def pcolormesh(
*XY: na.AbstractScalarArray,
C: na.AbstractScalarArray,
components: None | tuple[str, str] = None,
axis_rgb: None | str = None,
ax: None | matplotlib.axes.Axes | na.AbstractScalarArray = None,
cmap: None | str | matplotlib.colors.Colormap = None,
norm: None | str | matplotlib.colors.Normalize = None,
vmin: None | float | u.Quantity | na.AbstractScalarArray = None,
vmax: None | float | u.Quantity | na.AbstractScalarArray = None,
**kwargs,
) -> na.ScalarArray:

if components is not None:
raise ValueError(f"`components` should be `None` for scalars, got {components}")

Check warning on line 530 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L530

Added line #L530 was not covered by tests

try:
XY = tuple(scalars._normalize(arg) for arg in XY)
C = scalars._normalize(C)
vmin = scalars._normalize(vmin) if vmin is not None else vmin
vmax = scalars._normalize(vmax) if vmax is not None else vmax
except na.ScalarTypeError: # pragma: nocover
pass

if ax is None:
ax = plt.gca()
ax = na.as_named_array(ax)

if axis_rgb is not None: # pragma: nocover
if axis_rgb not in C.shape:
raise ValueError(f"`{axis_rgb=}` must be a member of `{C.shape=}`")

shape_C = na.shape_broadcasted(*XY, C, ax)
shape = {a: shape_C[a] for a in shape_C if a != axis_rgb}
shape_orthogonal = ax.shape

XY = tuple(arg.broadcast_to(shape) for arg in XY)
C = C.broadcast_to(shape_C)
vmin = vmin.broadcast_to(shape_orthogonal) if vmin is not None else vmin
vmax = vmax.broadcast_to(shape_orthogonal) if vmax is not None else vmax

result = na.ScalarArray.empty(shape_orthogonal, dtype=object)

for index in na.ndindex(shape_orthogonal):
result[index] = ax[index].ndarray.pcolormesh(
*[arg[index].ndarray for arg in XY],
C[index].ndarray,
cmap=cmap,
norm=norm,
vmin=vmin[index].ndarray if vmin is not None else vmin,
vmax=vmax[index].ndarray if vmax is not None else vmax,
**kwargs,
)

return result


@_implements(na.jacobian)
def jacobian(
function: Callable[[na.AbstractScalar], na.AbstractScalar],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,12 @@ class TestPltScatter(
):
pass

@pytest.mark.skip
class TestPltPcolormesh(
named_arrays._scalars.tests.test_scalars.AbstractTestAbstractScalar.TestNamedArrayFunctions.TestPltPcolormesh,
):
pass

@pytest.mark.parametrize(
argnames="function",
argvalues=[
Expand Down
6 changes: 6 additions & 0 deletions named_arrays/_vectors/tests/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,12 @@ class TestPltScatter(
):
pass

@pytest.mark.skip
class TestPltPcolormesh(
named_arrays.tests.test_core.AbstractTestAbstractArray.TestNamedArrayFunctions.TestPltPcolormesh,
):
pass

@pytest.mark.parametrize(
argnames="function",
argvalues=[
Expand Down
47 changes: 47 additions & 0 deletions named_arrays/_vectors/vector_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.axes
import astropy.units as u
import named_arrays as na
from named_arrays._scalars import scalars
import named_arrays._scalars.scalar_named_array_functions
from . import vectors

Expand Down Expand Up @@ -264,6 +265,52 @@ def plt_plot_like(
)


@_implements(na.plt.pcolormesh)
def pcolormesh(
*XY: na.AbstractVectorArray,
C: na.AbstractScalarArray,
components: None | tuple[str, str] = None,
axis_rgb: None | str = None,
ax: None | matplotlib.axes.Axes | na.AbstractScalarArray = None,
cmap: None | str | matplotlib.colors.Colormap = None,
norm: None | str | matplotlib.colors.Normalize = None,
vmin: None | float | u.Quantity | na.AbstractScalarArray = None,
vmax: None | float | u.Quantity | na.AbstractScalarArray = None,
**kwargs,
) -> na.ScalarArray:
try:
C = scalars._normalize(C)
vmin = scalars._normalize(vmin) if vmin is not None else vmin
vmax = scalars._normalize(vmax) if vmax is not None else vmax
except na.ScalarTypeError:
return NotImplemented

try:
prototype = vectors._prototype(*XY)
XY = tuple(vectors._normalize(arg, prototype) for arg in XY)
except na.VectorTypeError:
return NotImplemented

Check warning on line 292 in named_arrays/_vectors/vector_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_vectors/vector_named_array_functions.py#L291-L292

Added lines #L291 - L292 were not covered by tests

if len(XY) != 1:
raise ValueError("if any element of `XY` is a vector, `XY` must have a length of 1")

Check warning on line 295 in named_arrays/_vectors/vector_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_vectors/vector_named_array_functions.py#L295

Added line #L295 was not covered by tests
XY = XY[0]

components_XY = XY.components
components = [components_XY[c] for c in components]

return na.plt.pcolormesh(
*components,
C=C,
axis_rgb=axis_rgb,
ax=ax,
cmap=cmap,
norm=norm,
vmin=vmin,
vmax=vmax,
**kwargs,
)


@_implements(na.jacobian)
def jacobian(
function: Callable[[InputT], OutputT],
Expand Down
119 changes: 119 additions & 0 deletions named_arrays/plt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"fill",
"scatter",
"imshow",
"pcolormesh",
]


Expand Down Expand Up @@ -508,3 +509,121 @@ def imshow(
extent=extent,
**kwargs,
)


def pcolormesh(
*XY: na.AbstractArray,
C: na.AbstractArray,
components: None | tuple[str, str] = None,
axis_rgb: None | str = None,
ax: None | matplotlib.axes.Axes | na.AbstractArray = None,
cmap: None | str | matplotlib.colors.Colormap = None,
norm: None | str | matplotlib.colors.Normalize = None,
vmin: None | na.ArrayLike = None,
vmax: None | na.ArrayLike = None,
**kwargs,
) -> na.AbstractScalar:
"""
A thin wrapper around :func:`matplotlib.pyplot.pcolormesh` for named arrays.
Parameters
----------
XY
The coordinates of the mesh.
If `C` is a scalar, `XY` can either be two scalars or one vector .
If `C` is a function, `XY` is not specified.
If `XY` is not specified as two scalars, the `components` must be given,
see below.
C
The mesh data.
components
If `XY` is not specified as two scalars, this parameter should
be a tuple of two strings, specifying the vector components of `XY`
to use as the horizontal and vertical components of the mesh.
axis_rgb
The optional logical axis along which the RGB color channels are
distributed.
ax
The instances of :class:`matplotlib.axes.Axes` to use.
If :obj:`None`, calls :func:`matplotlib.pyplot.gca` to get the current axes.
If an instance of :class:`named_arrays.ScalarArray`, ``ax.shape`` should be a subset of the broadcasted shape of
``*args``.
cmap
The colormap used to map scalar data to colors.
norm
The normalization method used to scale data into the range [0, 1] before
mapping to colors.
vmin
The minimum value of the data range.
vmax
The maximum value of the data range.
kwargs
Additional keyword arguments accepted by `matplotlib.pyplot.pcolormesh`
Examples
--------
Plot a random 2D mesh
.. jupyter-execute::
import matplotlib.pyplot as plt
import named_arrays as na
# Define the size of the grid
shape = dict(x=16, y=16)
# Define a simple coordinate grid
x = na.linspace(-2, 2, axis="x", num=shape["x"])
y = na.linspace(-1, 1, axis="y", num=shape["y"])
# Define a random 2D array of values to plot
a = na.random.uniform(-1, 1, shape_random=shape)
# Plot the coordinates and values using pcolormesh
fig, ax = plt.subplots(constrained_layout=True)
na.plt.pcolormesh(x, y, C=a, ax=ax);
|
Plot a grid of random 2D meshes
.. jupyter-execute::
import named_arrays as na
# Define the size of the grid
shape = dict(row=2, col=3, x=16, y=16)
# Define a simple coordinate grid
x = na.linspace(-2, 2, axis="x", num=shape["x"])
y = na.linspace(-1, 1, axis="y", num=shape["y"])
# Define a random 2D array of values to plot
a = na.random.uniform(-1, 1, shape_random=shape)
# Plot the coordinates and values using pcolormesh
fig, ax = na.plt.subplots(
axis_rows="row",
nrows=shape["row"],
axis_cols="col",
ncols=shape["col"],
sharex=True,
sharey=True,
constrained_layout=True,
)
na.plt.pcolormesh(x, y, C=a, ax=ax);
"""
return na._named_array_function(
pcolormesh,
*XY,
C=C,
axis_rgb=axis_rgb,
ax=ax,
cmap=cmap,
norm=norm,
vmin=vmin,
vmax=vmax,
components=components,
**kwargs,
)
Loading

0 comments on commit 6737bbb

Please sign in to comment.