Skip to content

Added get_xlabel(), get_ylabel(), get_title(), get_xscale(), get_yscale(), get_aspect(), transAxes(), and transData() functions to the named_arrays.plt module. #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,9 @@
"ASARRAY_LIKE_FUNCTIONS",
"RANDOM_FUNCTIONS",
"PLT_PLOT_LIKE_FUNCTIONS",
"PLT_AXES_SETTERS",
"PLT_AXES_GETTERS",
"PLT_AXES_ATTRIBUTES",
"HANDLED_FUNCTIONS",
"random",
"jacobian",
@@ -43,6 +46,18 @@
na.plt.set_yscale,
na.plt.set_aspect,
)
PLT_AXES_GETTERS = (
na.plt.get_xlabel,
na.plt.get_ylabel,
na.plt.get_title,
na.plt.get_xscale,
na.plt.get_yscale,
na.plt.get_aspect,
)
PLT_AXES_ATTRIBUTES = (
na.plt.transAxes,
na.plt.transData,
)
NDFILTER_FUNCTIONS = (
na.ndfilters.mean_filter,
na.ndfilters.trimmed_mean_filter,
@@ -829,6 +844,7 @@
y = scalars._normalize(y)
s = scalars._normalize(s)
ax = scalars._normalize(ax)
kwargs = {k: scalars._normalize(kwargs[k]) for k in kwargs}
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

@@ -838,15 +854,17 @@
y = y.broadcast_to(shape)
s = s.broadcast_to(shape)
ax = ax.broadcast_to(shape)
kwargs = {k: kwargs[k].broadcast_to(shape) for k in kwargs}

result = na.ScalarArray.empty(shape, dtype=matplotlib.axes.Axes)

for index in na.ndindex(shape):
kwargs_index = {k: kwargs[k][index].ndarray for k in kwargs}
result[index] = ax[index].ndarray.text(
x=x[index].ndarray,
y=y[index].ndarray,
s=s[index].ndarray,
**kwargs,
**kwargs_index,
)

return result
@@ -880,6 +898,48 @@
getattr(ax[index].ndarray, method.__name__)(*args_index, **kwargs_index)


def plt_axes_getter(
method: str,
ax: na.AbstractScalarArray,
) -> na.ScalarArray:

try:
ax = scalars._normalize(ax)

Check warning on line 907 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L906-L907

Added lines #L906 - L907 were not covered by tests
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

result = na.ScalarArray.empty(shape=ax.shape, dtype=object)

Check warning on line 911 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L911

Added line #L911 was not covered by tests

for index in na.ndindex(ax.shape):
ax_index = ax[index].ndarray
if ax_index is None:
ax_index = plt.gca()
result[index] = getattr(ax_index, method.__name__)()

Check warning on line 917 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L913-L917

Added lines #L913 - L917 were not covered by tests

return result

Check warning on line 919 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L919

Added line #L919 was not covered by tests


def plt_axes_attribute(
method: str,
ax: na.AbstractScalarArray,
) -> na.ScalarArray:

try:
ax = scalars._normalize(ax)

Check warning on line 928 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L927-L928

Added lines #L927 - L928 were not covered by tests
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

result = na.ScalarArray.empty(shape=ax.shape, dtype=object)

Check warning on line 932 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L932

Added line #L932 was not covered by tests

for index in na.ndindex(ax.shape):
ax_index = ax[index].ndarray
if ax_index is None:
ax_index = plt.gca()
result[index] = getattr(ax_index, method.__name__)

Check warning on line 938 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L934-L938

Added lines #L934 - L938 were not covered by tests

return result

Check warning on line 940 in named_arrays/_scalars/scalar_named_array_functions.py

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L940

Added line #L940 was not covered by tests


@_implements(na.jacobian)
def jacobian(
function: Callable[[na.AbstractScalar], na.AbstractScalar],
6 changes: 6 additions & 0 deletions named_arrays/_scalars/scalars.py
Original file line number Diff line number Diff line change
@@ -565,6 +565,12 @@
if func in scalar_named_array_functions.PLT_AXES_SETTERS:
return scalar_named_array_functions.plt_axes_setter(func, *args, **kwargs)

if func in scalar_named_array_functions.PLT_AXES_GETTERS:
return scalar_named_array_functions.plt_axes_getter(func, *args, **kwargs)

Check warning on line 569 in named_arrays/_scalars/scalars.py

Codecov / codecov/patch

named_arrays/_scalars/scalars.py#L569

Added line #L569 was not covered by tests

if func in scalar_named_array_functions.PLT_AXES_ATTRIBUTES:
return scalar_named_array_functions.plt_axes_attribute(func, *args, **kwargs)

Check warning on line 572 in named_arrays/_scalars/scalars.py

Codecov / codecov/patch

named_arrays/_scalars/scalars.py#L572

Added line #L572 was not covered by tests

if func in scalar_named_array_functions.NDFILTER_FUNCTIONS:
return scalar_named_array_functions.ndfilter(func, *args, **kwargs)

147 changes: 147 additions & 0 deletions named_arrays/plt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Literal, Any
import matplotlib.axes
import matplotlib.transforms
import matplotlib.animation
import matplotlib.pyplot as plt
import astropy.units as u
@@ -18,11 +20,19 @@
"text",
"brace_vertical",
"set_xlabel",
"get_xlabel",
"set_ylabel",
"get_ylabel",
"set_title",
"get_title",
"set_xscale",
"get_xscale",
"set_yscale",
"get_yscale",
"set_aspect",
"get_aspect",
"transAxes",
"transData",
]


@@ -885,6 +895,24 @@
)


def get_xlabel(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_xlabel` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the horizontal axis label from.
"""
return na._named_array_function(

Check warning on line 909 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L909

Added line #L909 was not covered by tests
get_xlabel,
ax=na.as_named_array(ax),
)



def set_ylabel(
ylabel: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
@@ -908,6 +936,23 @@
)


def get_ylabel(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_ylabel` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the vertical axis label from.
"""
return na._named_array_function(

Check warning on line 950 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L950

Added line #L950 was not covered by tests
get_ylabel,
ax=na.as_named_array(ax),
)


def set_title(
label: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
@@ -931,6 +976,23 @@
)


def get_title(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_title` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the title label from.
"""
return na._named_array_function(

Check warning on line 990 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L990

Added line #L990 was not covered by tests
get_title,
ax=na.as_named_array(ax),
)


def set_xscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
@@ -954,6 +1016,23 @@
)


def get_xscale(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_xscale` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the horizontal axis scale from.
"""
return na._named_array_function(

Check warning on line 1030 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L1030

Added line #L1030 was not covered by tests
get_xscale,
ax=na.as_named_array(ax),
)


def set_yscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
@@ -977,6 +1056,23 @@
)


def get_yscale(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_yscale` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the vertical axis scale from.
"""
return na._named_array_function(

Check warning on line 1070 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L1070

Added line #L1070 was not covered by tests
get_yscale,
ax=na.as_named_array(ax),
)


def set_aspect(
aspect: float | str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
@@ -1000,6 +1096,57 @@
)


def get_aspect(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_aspect` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the aspect ratio from.
"""
return na._named_array_function(

Check warning on line 1110 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L1110

Added line #L1110 was not covered by tests
get_aspect,
ax=na.as_named_array(ax),
)


def transAxes(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> matplotlib.transforms.Transform | na.AbstractScalar:
"""
A thin wrapper around :attr:`matplotlib.axes.Axes.transAxes` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the axes transformation from.
"""
return na._named_array_function(

Check warning on line 1127 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L1127

Added line #L1127 was not covered by tests
transAxes,
ax=na.as_named_array(ax),
)


def transData(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> matplotlib.transforms.Transform | na.AbstractScalar:
"""
A thin wrapper around :attr:`matplotlib.axes.Axes.transData` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the axes transformation from.
"""
return na._named_array_function(

Check warning on line 1144 in named_arrays/plt.py

Codecov / codecov/patch

named_arrays/plt.py#L1144

Added line #L1144 was not covered by tests
transData,
ax=na.as_named_array(ax),
)


def brace_vertical(
x: float | u.Quantity | na.AbstractScalar,
width: float | u.Quantity | na.AbstractScalar,
47 changes: 45 additions & 2 deletions named_arrays/tests/test_plt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import numpy as np
import matplotlib.axes
import matplotlib.animation
import named_arrays as na
@@ -58,6 +59,8 @@
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_xlabel(xlabel, ax=ax)
result = na.plt.get_xlabel(ax)
assert np.all(result == xlabel)

Check warning on line 63 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L62-L63

Added lines #L62 - L63 were not covered by tests


@pytest.mark.parametrize(
@@ -72,6 +75,8 @@
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_ylabel(ylabel, ax=ax)
result = na.plt.get_ylabel(ax)
assert np.all(result == ylabel)

Check warning on line 79 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L78-L79

Added lines #L78 - L79 were not covered by tests


@pytest.mark.parametrize(
@@ -86,6 +91,8 @@
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_title(label, ax=ax)
result = na.plt.get_title(ax)
assert np.all(result == label)

Check warning on line 95 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L94-L95

Added lines #L94 - L95 were not covered by tests


@pytest.mark.parametrize(
@@ -100,6 +107,8 @@
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_xscale(value, ax=ax)
result = na.plt.get_xscale(ax)
assert np.all(result == value)

Check warning on line 111 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L110-L111

Added lines #L110 - L111 were not covered by tests


@pytest.mark.parametrize(
@@ -114,13 +123,15 @@
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_yscale(value, ax=ax)
result = na.plt.get_yscale(ax)
assert np.all(result == value)

Check warning on line 127 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L126-L127

Added lines #L126 - L127 were not covered by tests


@pytest.mark.parametrize(
argnames="aspect,ax",
argvalues=[
("equal", None),
("equal", na.plt.subplots(ncols=3)[1]),
(1, None),
(1, na.plt.subplots(ncols=3)[1]),
(2, na.plt.subplots(ncols=3)[1]),
]
)
@@ -129,3 +140,35 @@
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
na.plt.set_aspect(aspect, ax=ax)
result = na.plt.get_aspect(ax)
assert np.all(result == aspect)

Check warning on line 144 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L143-L144

Added lines #L143 - L144 were not covered by tests


@pytest.mark.parametrize(
argnames="ax",
argvalues=[
None,
na.plt.subplots(ncols=3)[1]
]
)
def test_transAxes(
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
result = na.plt.transAxes(ax)
assert isinstance(result, na.AbstractArray)
assert result.shape == na.shape(ax)

Check warning on line 159 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L157-L159

Added lines #L157 - L159 were not covered by tests


@pytest.mark.parametrize(
argnames="ax",
argvalues=[
None,
na.plt.subplots(ncols=3)[1]
]
)
def test_transData(
ax: None | matplotlib.axes.Axes | na.AbstractScalar,
):
result = na.plt.transData(ax)
assert isinstance(result, na.AbstractArray)
assert result.shape == na.shape(ax)

Check warning on line 174 in named_arrays/tests/test_plt.py

Codecov / codecov/patch

named_arrays/tests/test_plt.py#L172-L174

Added lines #L172 - L174 were not covered by tests