From a3a62e47a189b9b605c2337df16b9a40e5001f42 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 8 Nov 2024 18:18:46 -0700 Subject: [PATCH 1/3] Added `named_arrays.plt.rgbmovie()` function to plot 4D arrays. --- named_arrays/plt.py | 213 +++++++++++++++++++++++++++++++++ named_arrays/tests/test_plt.py | 109 +++++++++++++++++ 2 files changed, 322 insertions(+) diff --git a/named_arrays/plt.py b/named_arrays/plt.py index 66cc9cb..a0d7494 100644 --- a/named_arrays/plt.py +++ b/named_arrays/plt.py @@ -19,6 +19,7 @@ "pcolormesh", "rgbmesh", "pcolormovie", + "rgbmovie", "text", "brace_vertical", "set_xlabel", @@ -1083,6 +1084,218 @@ def pcolormovie( ) +def rgbmovie( + *TWXY: na.AbstractArray, + C: na.AbstractArray, + axis_time: str, + axis_wavelength: str, + ax: None | matplotlib.axes.Axes | na.AbstractArray = None, + norm: None | Callable= None, + vmin: None | na.ArrayLike = None, + vmax: None | na.ArrayLike = None, + wavelength_norm: None | Callable = None, + wavelength_min: None | float | u.Quantity | na.AbstractScalar = None, + wavelength_max: None | float | u.Quantity | na.AbstractScalar = None, + kwargs_pcolormesh: None | dict[str, Any] = None, + **kwargs_animation, +) -> tuple[ + matplotlib.animation.FuncAnimation, + na.FunctionArray[na.Cartesian2dVectorArray, na.AbstractScalar] +]: + """ + A convenience function that calls :func:`pcolormovie` with the outputs + from :func:`named_arrays.colorsynth.rgb` and returns an animation + instance and a colorbar. + + This allows us to plot 4D cubes, with the third dimension being represented + by color, using a :func:`pcolormovie`-like interface. + + Parameters + ---------- + TWXY + The coordinates of the mesh to plot. + Allowed combinations are: + an instance of :class:`named_arrays.AbstractSpectralPositionalVectorArray`, + an instance of :class:`named_arrays.AbstractScalar` and an instance of + :class:`named_arrays.AbstractSpectralPositionalVectorArray`, + two instances of :class:`named_arrays.AbstractScalar` and an instance of + :class:`named_arrays.AbstractCartesian2dVectorArray`, + or four instances of :class:`named_arrays.AbstractScalar`. + C + The mesh data. + axis_time + The logical axis corresponding to the different frames in the animation. + axis_wavelength + The logical axis representing changing wavelength coordinate. + ax + The instances of :class:`matplotlib.axes.Axes` to use. + If :obj:`None`, calls :func:`matplotlib.pyplot.gca` to get the current axes. + If an instance of :class:`named_arrays.ScalarArray`, ``ax.shape`` should be a subset of the broadcasted shape of + ``*args``. + norm + An optional function that transforms the spectral power distribution + values before mapping to RGB. + Equivalent to the `spd_norm` argument of :func:`named_arrays.colorsynth.rgb`. + vmin + The value of the spectral power distribution representing minimum + intensity. + Equivalent to the `spd_min` argument of :func:`named_arrays.colorsynth.rgb`. + vmax + The value of the spectral power distribution representing maximum + intensity. + Equivalent to the `spd_max` argument of :func:`named_arrays.colorsynth.rgb`. + wavelength_norm + An optional function to transform the wavelength values before they + are mapped into the human visible color range. + wavelength_min + The wavelength value that is mapped to the minimum wavelength of the + human visible color range, 380 nm. + wavelength_max + The wavelength value that is mapped to the maximum wavelength of the + human visible color range, 700 nm + kwargs_pcolormesh + Additional keyword arguments accepted by :func:`pcolormesh`. + kwargs_animation + Additional keyword arguments accepted by + :class:`matplotlib.animation.FuncAnimation`. + + Examples + -------- + + Plot a random, 4D cube. + + .. jupyter-execute:: + + import IPython.display + import matplotlib.pyplot as plt + import astropy.units as u + import astropy.visualization + import named_arrays as na + + # Define the size of the grid + shape = dict( + t=3, + w=11, + x=16, + y=16, + ) + + # Define a simple coordinate grid + t = na.linspace(0, 2, axis="t", num=shape["t"]) * u.s + w = na.linspace(-1, 1, axis="w", num=shape["w"]) * u.mm + x = na.linspace(-2, 2, axis="x", num=shape["x"]) * u.mm + y = na.linspace(-1, 1, axis="y", num=shape["y"]) * u.mm + + # Define a random array of values to plot + a = na.random.uniform(-1, 1, shape_random=shape) + + # Plot the coordinates and values using rgbmovie() + with astropy.visualization.quantity_support(): + fig, ax = plt.subplots( + ncols=2, + gridspec_kw=dict(width_ratios=[.9, .1]), + constrained_layout=True, + ) + ani, colorbar = na.plt.rgbmovie( + t, w, x, y, + C=a, + axis_time="t", + axis_wavelength="w", + ax=ax[0], + ); + na.plt.pcolormesh( + C=colorbar, + axis_rgb="w", + ax=ax[1], + ) + ax[1].yaxis.tick_right() + ax[1].yaxis.set_label_position("right") + plt.close(fig) + IPython.display.HTML(ani.to_jshtml()) + + """ + + if len(TWXY) == 0: + if isinstance(C, na.AbstractFunctionArray): + TWXY = (C.inputs,) + C = C.outputs + else: # pragma: nocover + raise TypeError( + "if no positional arguments, `C` must be an instance of " + f"`na.AbstractFunctionArray`. got {type(C)}." + ) + + if len(TWXY) == 1: + TWXY, = TWXY + if isinstance(TWXY, na.AbstractTemporalSpectralPositionalVectorArray): + t = TWXY.time + w = TWXY.wavelength + x = TWXY.position.x + y = TWXY.position.y + else: # pragma: nocover + raise TypeError( + "if one positional argument, it must be an instance of " + f"`na.AbstractTemporalSpectralPositionalVectorArray`, " + f"got {type(TWXY)}." + ) + elif len(TWXY) == 2: + t, WXY = TWXY + if isinstance(WXY, na.AbstractSpectralPositionalVectorArray): + w = WXY.wavelength + x = WXY.position.x + y = WXY.position.y + else: # pragma: nocover + raise TypeError( + "if two positional arguments, " + "the second argument must be an instance of " + "`na.AbstarctSpectralPositionalVectorArray`, " + f"got {type(WXY)}.`" + ) + elif len(TWXY) == 3: + t, w, XY = TWXY + if isinstance(XY, na.AbstractCartesian2dVectorArray): + x = XY.x + y = XY.y + else: # pragma: nocover + raise TypeError( + "if three positional arguments, " + "the third argument must be an instance of" + f"`na.AbstractCartesian2dVectorArray`, got {type(XY)}`." + ) + + elif len(TWXY) == 4: + t, w, x, y = TWXY + else: # pragma: nocover + raise ValueError( + f"incorrect number of arguments, expected 0, 1, 3, or 4," + f" got {len(TWXY)}." + ) + + rgb, colorbar = na.colorsynth.rgb_and_colorbar( + spd=C, + wavelength=w, + axis=axis_wavelength, + spd_min=vmin, + spd_max=vmax, + spd_norm=norm, + wavelength_min=wavelength_min, + wavelength_max=wavelength_max, + wavelength_norm=wavelength_norm, + ) + + animation = pcolormovie( + t, x, y, + C=rgb, + axis_time=axis_time, + axis_rgb=axis_wavelength, + ax=ax, + kwargs_pcolormesh=kwargs_pcolormesh, + kwargs_animation=kwargs_animation, + ) + + return animation, colorbar + + def text( x: float | u.Quantity | na.AbstractScalar, y: float | u.Quantity | na.AbstractScalar, diff --git a/named_arrays/tests/test_plt.py b/named_arrays/tests/test_plt.py index c6b3a11..c780e35 100644 --- a/named_arrays/tests/test_plt.py +++ b/named_arrays/tests/test_plt.py @@ -116,6 +116,115 @@ def test_pcolormovie( assert isinstance(result.to_jshtml(), str) +@pytest.mark.parametrize( + argnames="T", + argvalues=[ + na.linspace(-1, 1, axis="t", num=_num_t) * u.s, + ], +) +@pytest.mark.parametrize( + argnames="W", + argvalues=[ + na.linspace(-1, 1, axis="w", num=_num_w) * u.mm, + ], +) +@pytest.mark.parametrize( + argnames="X", + argvalues=[ + na.linspace(-2, 2, axis="x", num=_num_x), + ], +) +@pytest.mark.parametrize( + argnames="Y", + argvalues=[ + na.linspace(-1, 1, axis="y", num=_num_y), + ], +) +@pytest.mark.parametrize( + argnames="C", + argvalues=[ + na.random.uniform( + low=-1, + high=1, + shape_random=dict(t=_num_t, w=_num_w, x=_num_x, y=_num_y), + ), + ], +) +def test_rgbmovie( + T: na.AbstractScalar, + W: na.AbstractScalar, + X: na.AbstractScalar, + Y: na.AbstractScalar, + C: na.AbstractScalar, +): + ani_1, cbar_1 = na.plt.rgbmovie( + T, + W, + X, + Y, + C=C, + axis_time="t", + axis_wavelength="w", + ) + ani_2, cbar_2 = na.plt.rgbmovie( + T, + na.SpectralPositionalVectorArray( + wavelength=W, + position=na.Cartesian2dVectorArray(X, Y), + ), + C=C, + axis_time="t", + axis_wavelength="w", + ) + ani_3, cbar_3 = na.plt.rgbmovie( + T, + W, + na.Cartesian2dVectorArray(X, Y), + C=C, + axis_time="t", + axis_wavelength="w", + ) + ani_4, cbar_4 = na.plt.rgbmovie( + na.TemporalSpectralPositionalVectorArray( + time=T, + wavelength=W, + position=na.Cartesian2dVectorArray(X, Y), + ), + C=C, + axis_time="t", + axis_wavelength="w", + ) + ani_5, cbar_5 = na.plt.rgbmovie( + C=na.FunctionArray( + inputs=na.TemporalSpectralPositionalVectorArray( + time=T, + wavelength=W, + position=na.Cartesian2dVectorArray(X, Y), + ), + outputs=C, + ), + axis_time="t", + axis_wavelength="w", + ) + + assert isinstance(ani_1, matplotlib.animation.FuncAnimation) + assert isinstance(ani_2, matplotlib.animation.FuncAnimation) + assert isinstance(ani_3, matplotlib.animation.FuncAnimation) + assert isinstance(ani_4, matplotlib.animation.FuncAnimation) + assert isinstance(ani_5, matplotlib.animation.FuncAnimation) + + assert isinstance(ani_1.to_jshtml(), str) + assert isinstance(ani_2.to_jshtml(), str) + assert isinstance(ani_3.to_jshtml(), str) + assert isinstance(ani_4.to_jshtml(), str) + assert isinstance(ani_5.to_jshtml(), str) + + assert np.all(cbar_1 == cbar_2) + assert np.all(cbar_1 == cbar_3) + assert np.all(cbar_1 == cbar_4) + assert np.all(cbar_1 == cbar_5) + + @pytest.mark.parametrize( argnames="xlabel,ax", argvalues=[ From d9f7632815be5ecc59c7c5a4b82492efefa4b4b0 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 8 Nov 2024 19:28:06 -0700 Subject: [PATCH 2/3] doc fixes --- named_arrays/plt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/named_arrays/plt.py b/named_arrays/plt.py index a0d7494..e521a90 100644 --- a/named_arrays/plt.py +++ b/named_arrays/plt.py @@ -1166,8 +1166,8 @@ def rgbmovie( .. jupyter-execute:: - import IPython.display import matplotlib.pyplot as plt + import IPython.display import astropy.units as u import astropy.visualization import named_arrays as na From 78a5a8b0960a2eeb096d69791cb9d5125a171b6e Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Fri, 8 Nov 2024 19:52:20 -0700 Subject: [PATCH 3/3] doc example fix --- named_arrays/plt.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/named_arrays/plt.py b/named_arrays/plt.py index e521a90..f98e01a 100644 --- a/named_arrays/plt.py +++ b/named_arrays/plt.py @@ -1190,28 +1190,29 @@ def rgbmovie( a = na.random.uniform(-1, 1, shape_random=shape) # Plot the coordinates and values using rgbmovie() - with astropy.visualization.quantity_support(): - fig, ax = plt.subplots( - ncols=2, - gridspec_kw=dict(width_ratios=[.9, .1]), - constrained_layout=True, - ) - ani, colorbar = na.plt.rgbmovie( - t, w, x, y, - C=a, - axis_time="t", - axis_wavelength="w", - ax=ax[0], - ); - na.plt.pcolormesh( - C=colorbar, - axis_rgb="w", - ax=ax[1], - ) - ax[1].yaxis.tick_right() - ax[1].yaxis.set_label_position("right") - plt.close(fig) - IPython.display.HTML(ani.to_jshtml()) + astropy.visualization.quantity_support() + fig, ax = plt.subplots( + ncols=2, + gridspec_kw=dict(width_ratios=[.9, .1]), + constrained_layout=True, + ) + ani, colorbar = na.plt.rgbmovie( + t, w, x, y, + C=a, + axis_time="t", + axis_wavelength="w", + ax=ax[0], + ); + na.plt.pcolormesh( + C=colorbar, + axis_rgb="w", + ax=ax[1], + ) + ax[1].yaxis.tick_right() + ax[1].yaxis.set_label_position("right") + ani = ani.to_jshtml() + plt.close(fig) + IPython.display.HTML(ani) """