Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added get_xlabel(), get_ylabel(), get_title(), get_xscale(), get_yscale(), get_aspect(), transAxes(), and transData() functions to the named_arrays.plt module. #80

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"ASARRAY_LIKE_FUNCTIONS",
"RANDOM_FUNCTIONS",
"PLT_PLOT_LIKE_FUNCTIONS",
"PLT_AXES_SETTERS",
"PLT_AXES_GETTERS",
"PLT_AXES_ATTRIBUTES",
"HANDLED_FUNCTIONS",
"random",
"jacobian",
Expand Down Expand Up @@ -43,6 +46,18 @@
na.plt.set_yscale,
na.plt.set_aspect,
)
PLT_AXES_GETTERS = (
na.plt.get_xlabel,
na.plt.get_ylabel,
na.plt.get_title,
na.plt.get_xscale,
na.plt.get_yscale,
na.plt.get_aspect,
)
PLT_AXES_ATTRIBUTES = (
na.plt.transAxes,
na.plt.transData,
)
NDFILTER_FUNCTIONS = (
na.ndfilters.mean_filter,
na.ndfilters.trimmed_mean_filter,
Expand Down Expand Up @@ -829,6 +844,7 @@
y = scalars._normalize(y)
s = scalars._normalize(s)
ax = scalars._normalize(ax)
kwargs = {k: scalars._normalize(kwargs[k]) for k in kwargs}
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

Expand All @@ -838,15 +854,17 @@
y = y.broadcast_to(shape)
s = s.broadcast_to(shape)
ax = ax.broadcast_to(shape)
kwargs = {k: kwargs[k].broadcast_to(shape) for k in kwargs}

result = na.ScalarArray.empty(shape, dtype=matplotlib.axes.Axes)

for index in na.ndindex(shape):
kwargs_index = {k: kwargs[k][index].ndarray for k in kwargs}
result[index] = ax[index].ndarray.text(
x=x[index].ndarray,
y=y[index].ndarray,
s=s[index].ndarray,
**kwargs,
**kwargs_index,
)

return result
Expand Down Expand Up @@ -880,6 +898,48 @@
getattr(ax[index].ndarray, method.__name__)(*args_index, **kwargs_index)


def plt_axes_getter(
method: str,
ax: na.AbstractScalarArray,
) -> na.ScalarArray:

try:
ax = scalars._normalize(ax)

Check warning on line 907 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L906-L907

Added lines #L906 - L907 were not covered by tests
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

result = na.ScalarArray.empty(shape=ax.shape, dtype=object)

Check warning on line 911 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L911

Added line #L911 was not covered by tests

for index in na.ndindex(ax.shape):
ax_index = ax[index].ndarray
if ax_index is None:
ax_index = plt.gca()
result[index] = getattr(ax_index, method.__name__)()

Check warning on line 917 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L913-L917

Added lines #L913 - L917 were not covered by tests

return result

Check warning on line 919 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L919

Added line #L919 was not covered by tests


def plt_axes_attribute(
method: str,
ax: na.AbstractScalarArray,
) -> na.ScalarArray:

try:
ax = scalars._normalize(ax)

Check warning on line 928 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L927-L928

Added lines #L927 - L928 were not covered by tests
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

result = na.ScalarArray.empty(shape=ax.shape, dtype=object)

Check warning on line 932 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L932

Added line #L932 was not covered by tests

for index in na.ndindex(ax.shape):
ax_index = ax[index].ndarray
if ax_index is None:
ax_index = plt.gca()
result[index] = getattr(ax_index, method.__name__)

Check warning on line 938 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L934-L938

Added lines #L934 - L938 were not covered by tests

return result

Check warning on line 940 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L940

Added line #L940 was not covered by tests


@_implements(na.jacobian)
def jacobian(
function: Callable[[na.AbstractScalar], na.AbstractScalar],
Expand Down
6 changes: 6 additions & 0 deletions named_arrays/_scalars/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@
if func in scalar_named_array_functions.PLT_AXES_SETTERS:
return scalar_named_array_functions.plt_axes_setter(func, *args, **kwargs)

if func in scalar_named_array_functions.PLT_AXES_GETTERS:
return scalar_named_array_functions.plt_axes_getter(func, *args, **kwargs)

Check warning on line 569 in named_arrays/_scalars/scalars.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalars.py#L569

Added line #L569 was not covered by tests

if func in scalar_named_array_functions.PLT_AXES_ATTRIBUTES:
return scalar_named_array_functions.plt_axes_attribute(func, *args, **kwargs)

Check warning on line 572 in named_arrays/_scalars/scalars.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalars.py#L572

Added line #L572 was not covered by tests

if func in scalar_named_array_functions.NDFILTER_FUNCTIONS:
return scalar_named_array_functions.ndfilter(func, *args, **kwargs)

Expand Down
147 changes: 147 additions & 0 deletions named_arrays/plt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Literal, Any
import matplotlib.axes
import matplotlib.transforms
import matplotlib.animation
import matplotlib.pyplot as plt
import astropy.units as u
Expand All @@ -18,11 +20,19 @@
"text",
"brace_vertical",
"set_xlabel",
"get_xlabel",
"set_ylabel",
"get_ylabel",
"set_title",
"get_title",
"set_xscale",
"get_xscale",
"set_yscale",
"get_yscale",
"set_aspect",
"get_aspect",
"transAxes",
"transData",
]


Expand Down Expand Up @@ -885,6 +895,24 @@
)


def get_xlabel(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_xlabel` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the horizontal axis label from.
"""
return na._named_array_function(

Check warning on line 909 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L909

Added line #L909 was not covered by tests
get_xlabel,
ax=na.as_named_array(ax),
)



def set_ylabel(
ylabel: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
Expand All @@ -908,6 +936,23 @@
)


def get_ylabel(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_ylabel` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the vertical axis label from.
"""
return na._named_array_function(

Check warning on line 950 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L950

Added line #L950 was not covered by tests
get_ylabel,
ax=na.as_named_array(ax),
)


def set_title(
label: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
Expand All @@ -931,6 +976,23 @@
)


def get_title(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_title` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the title label from.
"""
return na._named_array_function(

Check warning on line 990 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L990

Added line #L990 was not covered by tests
get_title,
ax=na.as_named_array(ax),
)


def set_xscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
Expand All @@ -954,6 +1016,23 @@
)


def get_xscale(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_xscale` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the horizontal axis scale from.
"""
return na._named_array_function(

Check warning on line 1030 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L1030

Added line #L1030 was not covered by tests
get_xscale,
ax=na.as_named_array(ax),
)


def set_yscale(
value: str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
Expand All @@ -977,6 +1056,23 @@
)


def get_yscale(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_yscale` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the vertical axis scale from.
"""
return na._named_array_function(

Check warning on line 1070 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L1070

Added line #L1070 was not covered by tests
get_yscale,
ax=na.as_named_array(ax),
)


def set_aspect(
aspect: float | str | na.AbstractScalar,
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
Expand All @@ -1000,6 +1096,57 @@
)


def get_aspect(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> str | na.AbstractScalar:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.get_aspect` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the aspect ratio from.
"""
return na._named_array_function(

Check warning on line 1110 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L1110

Added line #L1110 was not covered by tests
get_aspect,
ax=na.as_named_array(ax),
)


def transAxes(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> matplotlib.transforms.Transform | na.AbstractScalar:
"""
A thin wrapper around :attr:`matplotlib.axes.Axes.transAxes` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the axes transformation from.
"""
return na._named_array_function(

Check warning on line 1127 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L1127

Added line #L1127 was not covered by tests
transAxes,
ax=na.as_named_array(ax),
)


def transData(
ax: None | matplotlib.axes.Axes | na.AbstractScalar = None,
) -> matplotlib.transforms.Transform | na.AbstractScalar:
"""
A thin wrapper around :attr:`matplotlib.axes.Axes.transData` for named arrays.

Parameters
----------
ax
The matplotlib axes instance(s) to get the axes transformation from.
"""
return na._named_array_function(

Check warning on line 1144 in named_arrays/plt.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/plt.py#L1144

Added line #L1144 was not covered by tests
transData,
ax=na.as_named_array(ax),
)


def brace_vertical(
x: float | u.Quantity | na.AbstractScalar,
width: float | u.Quantity | na.AbstractScalar,
Expand Down
Loading
Loading