From a3a62e47a189b9b605c2337df16b9a40e5001f42 Mon Sep 17 00:00:00 2001
From: Roy Smart <roytsmart@gmail.com>
Date: Fri, 8 Nov 2024 18:18:46 -0700
Subject: [PATCH 1/3] Added `named_arrays.plt.rgbmovie()` function to plot 4D
 arrays.

---
 named_arrays/plt.py            | 213 +++++++++++++++++++++++++++++++++
 named_arrays/tests/test_plt.py | 109 +++++++++++++++++
 2 files changed, 322 insertions(+)

diff --git a/named_arrays/plt.py b/named_arrays/plt.py
index 66cc9cb..a0d7494 100644
--- a/named_arrays/plt.py
+++ b/named_arrays/plt.py
@@ -19,6 +19,7 @@
     "pcolormesh",
     "rgbmesh",
     "pcolormovie",
+    "rgbmovie",
     "text",
     "brace_vertical",
     "set_xlabel",
@@ -1083,6 +1084,218 @@ def pcolormovie(
     )
 
 
+def rgbmovie(
+    *TWXY: na.AbstractArray,
+    C: na.AbstractArray,
+    axis_time: str,
+    axis_wavelength: str,
+    ax: None | matplotlib.axes.Axes | na.AbstractArray = None,
+    norm: None | Callable= None,
+    vmin: None | na.ArrayLike = None,
+    vmax: None | na.ArrayLike = None,
+    wavelength_norm: None | Callable = None,
+    wavelength_min: None | float | u.Quantity | na.AbstractScalar = None,
+    wavelength_max: None | float | u.Quantity | na.AbstractScalar = None,
+    kwargs_pcolormesh: None | dict[str, Any] = None,
+    **kwargs_animation,
+) -> tuple[
+    matplotlib.animation.FuncAnimation,
+    na.FunctionArray[na.Cartesian2dVectorArray, na.AbstractScalar]
+]:
+    """
+    A convenience function that calls :func:`pcolormovie` with the outputs
+    from :func:`named_arrays.colorsynth.rgb` and returns an animation
+    instance and a colorbar.
+
+    This allows us to plot 4D cubes, with the third dimension being represented
+    by color, using a :func:`pcolormovie`-like interface.
+
+    Parameters
+    ----------
+    TWXY
+        The coordinates of the mesh to plot.
+        Allowed combinations are:
+        an instance of :class:`named_arrays.AbstractSpectralPositionalVectorArray`,
+        an instance of :class:`named_arrays.AbstractScalar` and an instance of
+        :class:`named_arrays.AbstractSpectralPositionalVectorArray`,
+        two instances of :class:`named_arrays.AbstractScalar` and an instance of
+        :class:`named_arrays.AbstractCartesian2dVectorArray`,
+        or four instances of :class:`named_arrays.AbstractScalar`.
+    C
+        The mesh data.
+    axis_time
+        The logical axis corresponding to the different frames in the animation.
+    axis_wavelength
+        The logical axis representing changing wavelength coordinate.
+    ax
+        The instances of :class:`matplotlib.axes.Axes` to use.
+        If :obj:`None`, calls :func:`matplotlib.pyplot.gca` to get the current axes.
+        If an instance of :class:`named_arrays.ScalarArray`, ``ax.shape`` should be a subset of the broadcasted shape of
+        ``*args``.
+    norm
+        An optional function that transforms the spectral power distribution
+        values before mapping to RGB.
+        Equivalent to the `spd_norm` argument of :func:`named_arrays.colorsynth.rgb`.
+    vmin
+        The value of the spectral power distribution representing minimum
+        intensity.
+        Equivalent to the `spd_min` argument of :func:`named_arrays.colorsynth.rgb`.
+    vmax
+        The value of the spectral power distribution representing maximum
+        intensity.
+        Equivalent to the `spd_max` argument of :func:`named_arrays.colorsynth.rgb`.
+    wavelength_norm
+        An optional function to transform the wavelength values before they
+        are mapped into the human visible color range.
+    wavelength_min
+        The wavelength value that is mapped to the minimum wavelength of the
+        human visible color range, 380 nm.
+    wavelength_max
+        The wavelength value that is mapped to the maximum wavelength of the
+        human visible color range, 700 nm
+    kwargs_pcolormesh
+        Additional keyword arguments accepted by :func:`pcolormesh`.
+    kwargs_animation
+        Additional keyword arguments accepted by
+        :class:`matplotlib.animation.FuncAnimation`.
+
+    Examples
+    --------
+
+    Plot a random, 4D cube.
+
+    .. jupyter-execute::
+
+        import IPython.display
+        import matplotlib.pyplot as plt
+        import astropy.units as u
+        import astropy.visualization
+        import named_arrays as na
+
+        # Define the size of the grid
+        shape = dict(
+            t=3,
+            w=11,
+            x=16,
+            y=16,
+        )
+
+        # Define a simple coordinate grid
+        t = na.linspace(0, 2, axis="t", num=shape["t"]) * u.s
+        w = na.linspace(-1, 1, axis="w", num=shape["w"]) * u.mm
+        x = na.linspace(-2, 2, axis="x", num=shape["x"]) * u.mm
+        y = na.linspace(-1, 1, axis="y", num=shape["y"]) * u.mm
+
+        # Define a random array of values to plot
+        a = na.random.uniform(-1, 1, shape_random=shape)
+
+        # Plot the coordinates and values using rgbmovie()
+        with astropy.visualization.quantity_support():
+            fig, ax = plt.subplots(
+                ncols=2,
+                gridspec_kw=dict(width_ratios=[.9, .1]),
+                constrained_layout=True,
+            )
+            ani, colorbar = na.plt.rgbmovie(
+                t, w, x, y,
+                C=a,
+                axis_time="t",
+                axis_wavelength="w",
+                ax=ax[0],
+            );
+            na.plt.pcolormesh(
+                C=colorbar,
+                axis_rgb="w",
+                ax=ax[1],
+            )
+            ax[1].yaxis.tick_right()
+            ax[1].yaxis.set_label_position("right")
+            plt.close(fig)
+            IPython.display.HTML(ani.to_jshtml())
+
+    """
+
+    if len(TWXY) == 0:
+        if isinstance(C, na.AbstractFunctionArray):
+            TWXY = (C.inputs,)
+            C = C.outputs
+        else:   # pragma: nocover
+            raise TypeError(
+                "if no positional arguments, `C` must be an instance of "
+                f"`na.AbstractFunctionArray`. got {type(C)}."
+            )
+
+    if len(TWXY) == 1:
+        TWXY, = TWXY
+        if isinstance(TWXY, na.AbstractTemporalSpectralPositionalVectorArray):
+            t = TWXY.time
+            w = TWXY.wavelength
+            x = TWXY.position.x
+            y = TWXY.position.y
+        else:   # pragma: nocover
+            raise TypeError(
+                "if one positional argument, it must be an instance of "
+                f"`na.AbstractTemporalSpectralPositionalVectorArray`, "
+                f"got {type(TWXY)}."
+            )
+    elif len(TWXY) == 2:
+        t, WXY = TWXY
+        if isinstance(WXY, na.AbstractSpectralPositionalVectorArray):
+            w = WXY.wavelength
+            x = WXY.position.x
+            y = WXY.position.y
+        else:  # pragma: nocover
+            raise TypeError(
+                "if two positional arguments, "
+                "the second argument must be an instance of "
+                "`na.AbstarctSpectralPositionalVectorArray`, "
+                f"got {type(WXY)}.`"
+            )
+    elif len(TWXY) == 3:
+        t, w, XY = TWXY
+        if isinstance(XY, na.AbstractCartesian2dVectorArray):
+            x = XY.x
+            y = XY.y
+        else:   # pragma: nocover
+            raise TypeError(
+                "if three positional arguments, "
+                "the third argument must be an instance of"
+                f"`na.AbstractCartesian2dVectorArray`, got {type(XY)}`."
+            )
+
+    elif len(TWXY) == 4:
+        t, w, x, y = TWXY
+    else:  # pragma: nocover
+        raise ValueError(
+            f"incorrect number of arguments, expected 0, 1, 3, or 4,"
+            f" got {len(TWXY)}."
+        )
+
+    rgb, colorbar = na.colorsynth.rgb_and_colorbar(
+        spd=C,
+        wavelength=w,
+        axis=axis_wavelength,
+        spd_min=vmin,
+        spd_max=vmax,
+        spd_norm=norm,
+        wavelength_min=wavelength_min,
+        wavelength_max=wavelength_max,
+        wavelength_norm=wavelength_norm,
+    )
+
+    animation = pcolormovie(
+        t, x, y,
+        C=rgb,
+        axis_time=axis_time,
+        axis_rgb=axis_wavelength,
+        ax=ax,
+        kwargs_pcolormesh=kwargs_pcolormesh,
+        kwargs_animation=kwargs_animation,
+    )
+
+    return animation, colorbar
+
+
 def text(
     x: float | u.Quantity | na.AbstractScalar,
     y: float | u.Quantity | na.AbstractScalar,
diff --git a/named_arrays/tests/test_plt.py b/named_arrays/tests/test_plt.py
index c6b3a11..c780e35 100644
--- a/named_arrays/tests/test_plt.py
+++ b/named_arrays/tests/test_plt.py
@@ -116,6 +116,115 @@ def test_pcolormovie(
     assert isinstance(result.to_jshtml(), str)
 
 
+@pytest.mark.parametrize(
+    argnames="T",
+    argvalues=[
+        na.linspace(-1, 1, axis="t", num=_num_t) * u.s,
+    ],
+)
+@pytest.mark.parametrize(
+    argnames="W",
+    argvalues=[
+        na.linspace(-1, 1, axis="w", num=_num_w) * u.mm,
+    ],
+)
+@pytest.mark.parametrize(
+    argnames="X",
+    argvalues=[
+        na.linspace(-2, 2, axis="x", num=_num_x),
+    ],
+)
+@pytest.mark.parametrize(
+    argnames="Y",
+    argvalues=[
+        na.linspace(-1, 1, axis="y", num=_num_y),
+    ],
+)
+@pytest.mark.parametrize(
+    argnames="C",
+    argvalues=[
+        na.random.uniform(
+            low=-1,
+            high=1,
+            shape_random=dict(t=_num_t, w=_num_w, x=_num_x, y=_num_y),
+        ),
+    ],
+)
+def test_rgbmovie(
+    T: na.AbstractScalar,
+    W: na.AbstractScalar,
+    X: na.AbstractScalar,
+    Y: na.AbstractScalar,
+    C: na.AbstractScalar,
+):
+    ani_1, cbar_1 = na.plt.rgbmovie(
+        T,
+        W,
+        X,
+        Y,
+        C=C,
+        axis_time="t",
+        axis_wavelength="w",
+    )
+    ani_2, cbar_2 = na.plt.rgbmovie(
+        T,
+        na.SpectralPositionalVectorArray(
+            wavelength=W,
+            position=na.Cartesian2dVectorArray(X, Y),
+        ),
+        C=C,
+        axis_time="t",
+        axis_wavelength="w",
+    )
+    ani_3, cbar_3 = na.plt.rgbmovie(
+        T,
+        W,
+        na.Cartesian2dVectorArray(X, Y),
+        C=C,
+        axis_time="t",
+        axis_wavelength="w",
+    )
+    ani_4, cbar_4 = na.plt.rgbmovie(
+        na.TemporalSpectralPositionalVectorArray(
+            time=T,
+            wavelength=W,
+            position=na.Cartesian2dVectorArray(X, Y),
+        ),
+        C=C,
+        axis_time="t",
+        axis_wavelength="w",
+    )
+    ani_5, cbar_5 = na.plt.rgbmovie(
+        C=na.FunctionArray(
+            inputs=na.TemporalSpectralPositionalVectorArray(
+                time=T,
+                wavelength=W,
+                position=na.Cartesian2dVectorArray(X, Y),
+            ),
+            outputs=C,
+        ),
+        axis_time="t",
+        axis_wavelength="w",
+    )
+
+    assert isinstance(ani_1, matplotlib.animation.FuncAnimation)
+    assert isinstance(ani_2, matplotlib.animation.FuncAnimation)
+    assert isinstance(ani_3, matplotlib.animation.FuncAnimation)
+    assert isinstance(ani_4, matplotlib.animation.FuncAnimation)
+    assert isinstance(ani_5, matplotlib.animation.FuncAnimation)
+
+    assert isinstance(ani_1.to_jshtml(), str)
+    assert isinstance(ani_2.to_jshtml(), str)
+    assert isinstance(ani_3.to_jshtml(), str)
+    assert isinstance(ani_4.to_jshtml(), str)
+    assert isinstance(ani_5.to_jshtml(), str)
+
+    assert np.all(cbar_1 == cbar_2)
+    assert np.all(cbar_1 == cbar_3)
+    assert np.all(cbar_1 == cbar_4)
+    assert np.all(cbar_1 == cbar_5)
+
+
 @pytest.mark.parametrize(
     argnames="xlabel,ax",
     argvalues=[

From d9f7632815be5ecc59c7c5a4b82492efefa4b4b0 Mon Sep 17 00:00:00 2001
From: Roy Smart <roytsmart@gmail.com>
Date: Fri, 8 Nov 2024 19:28:06 -0700
Subject: [PATCH 2/3] doc fixes

---
 named_arrays/plt.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/named_arrays/plt.py b/named_arrays/plt.py
index a0d7494..e521a90 100644
--- a/named_arrays/plt.py
+++ b/named_arrays/plt.py
@@ -1166,8 +1166,8 @@ def rgbmovie(
 
     .. jupyter-execute::
 
-        import IPython.display
         import matplotlib.pyplot as plt
+        import IPython.display
         import astropy.units as u
         import astropy.visualization
         import named_arrays as na

From 78a5a8b0960a2eeb096d69791cb9d5125a171b6e Mon Sep 17 00:00:00 2001
From: Roy Smart <roytsmart@gmail.com>
Date: Fri, 8 Nov 2024 19:52:20 -0700
Subject: [PATCH 3/3] doc example fix

---
 named_arrays/plt.py | 45 +++++++++++++++++++++++----------------------
 1 file changed, 23 insertions(+), 22 deletions(-)

diff --git a/named_arrays/plt.py b/named_arrays/plt.py
index e521a90..f98e01a 100644
--- a/named_arrays/plt.py
+++ b/named_arrays/plt.py
@@ -1190,28 +1190,29 @@ def rgbmovie(
         a = na.random.uniform(-1, 1, shape_random=shape)
 
         # Plot the coordinates and values using rgbmovie()
-        with astropy.visualization.quantity_support():
-            fig, ax = plt.subplots(
-                ncols=2,
-                gridspec_kw=dict(width_ratios=[.9, .1]),
-                constrained_layout=True,
-            )
-            ani, colorbar = na.plt.rgbmovie(
-                t, w, x, y,
-                C=a,
-                axis_time="t",
-                axis_wavelength="w",
-                ax=ax[0],
-            );
-            na.plt.pcolormesh(
-                C=colorbar,
-                axis_rgb="w",
-                ax=ax[1],
-            )
-            ax[1].yaxis.tick_right()
-            ax[1].yaxis.set_label_position("right")
-            plt.close(fig)
-            IPython.display.HTML(ani.to_jshtml())
+        astropy.visualization.quantity_support()
+        fig, ax = plt.subplots(
+            ncols=2,
+            gridspec_kw=dict(width_ratios=[.9, .1]),
+            constrained_layout=True,
+        )
+        ani, colorbar = na.plt.rgbmovie(
+            t, w, x, y,
+            C=a,
+            axis_time="t",
+            axis_wavelength="w",
+            ax=ax[0],
+        );
+        na.plt.pcolormesh(
+            C=colorbar,
+            axis_rgb="w",
+            ax=ax[1],
+        )
+        ax[1].yaxis.tick_right()
+        ax[1].yaxis.set_label_position("right")
+        ani = ani.to_jshtml()
+        plt.close(fig)
+        IPython.display.HTML(ani)
 
     """