diff --git a/README.md b/README.md index 7e543a2..1c7ce41 100644 --- a/README.md +++ b/README.md @@ -42,19 +42,34 @@ import numpy as np import cosmoplots import matplotlib as mpl +# Setup mpl.style.use("cosmoplots.default") -a = np.exp(np.linspace(-3, 5, 100)) +a = np.exp(np.linspace(-3, 1, 100)) + +# Plotting +fig = plt.figure() +ax1 = plt.gca() +ax1.set_xlabel("X Axis") +ax1.set_ylabel("Y Axis") +base = 2 # Default is 10, but 2 works equally well +# Do plotting ... +ax1.semilogx(a) +# It is recommended to call the change_log_axis_base function after doing all the +# plotting. By default, it will try to infer the scaling used for the axis and only +# adjust accordingly. +cosmoplots.change_log_axis_base(ax1, base=base) +# Plotting fig = plt.figure() -ax = plt.gca() -ax.set_xlabel("X Axis") -ax.set_ylabel("Y Axis") +ax2 = plt.gca() +ax2.set_xlabel("X Axis") +ax2.set_ylabel("Y Axis") base = 2 # Default is 10, but 2 works equally well -cosmoplots.change_log_axis_base(ax, "x", base=base) +cosmoplots.change_log_axis_base(ax2, "x", base=base) # Do plotting ... # If you use "plot", the change_log_axis_base can be called at the top (along with add_axes # etc.), but using loglog, semilogx, semilogy will re-set, and the change_log_axis_base # function must be called again. -ax.plot(a) +ax2.plot(a) plt.show() ``` diff --git a/cosmoplots/axes.py b/cosmoplots/axes.py index 005456d..ba51c12 100644 --- a/cosmoplots/axes.py +++ b/cosmoplots/axes.py @@ -1,16 +1,42 @@ """Module for modifying the axis properties of plots.""" -import warnings - +from typing import List, Tuple, Union import matplotlib.pyplot as plt import numpy as np from matplotlib import ticker -def change_log_axis_base(axes: plt.Axes, which: str, base: float = 10) -> plt.Axes: +def _convert_scale_name(scale: str, axis: str) -> str: + """Convert the scale name to a more readable format.""" + # All possible scale names: + # ['asinh', 'function', 'functionlog', 'linear', 'log', 'logit', 'mercator', 'symlog'] + return f"{axis}axis" if scale == "log" else "none" + + +def _check_axes_scales(axes: plt.Axes) -> Tuple[List[str], str]: + xscale, yscale = axes.get_xscale(), axes.get_yscale() + xs, ys = _convert_scale_name(xscale, "x"), _convert_scale_name(yscale, "y") + if xs == "xaxis" and ys == "yaxis": + scales = [xs, ys] + pltype = "loglog" + elif xs == "xaxis": + scales = [xs] + pltype = "semilogx" + elif ys == "yaxis": + scales = [ys] + pltype = "semilogy" + else: + scales = [] + pltype = "linear" + return scales, pltype + + +def change_log_axis_base( + axes: plt.Axes, which: Union[str, None] = None, base: float = 10 +) -> plt.Axes: """Change the tick formatter to not use powers 0 and 1 in logarithmic plots. - Change the logarithmic axes `10^0 -> 1` and `10^1 -> 10` (or the given base), i.e. + Change the logarithmic axis `10^0 -> 1` and `10^1 -> 10` (or the given base), i.e. without power, otherwise use the base to some power. For more robust and less error prone results, the plotting type is also re-set with the same base ('loglog', 'semilogx' and 'semilogy'). @@ -19,28 +45,21 @@ def change_log_axis_base(axes: plt.Axes, which: str, base: float = 10) -> plt.Ax Parameters ---------- - axes: plt.Axes + axes : plt.Axes An axes object - which: str - Whether to update both x and y axis, or just one of them. - base: float + which : str | None, optional + Whether to update both x and y axis, or just one of them ("both", "x" or "y"). + If no value is given, it defaults to None and the function will try to infer the + axis from the current plotting type. If the axis are already linear, the + function will return the axes object without any changes. Defaults to None. + base : float The base of the logarithm. Defaults to base 10 (same as loglog, etc.) Returns ------- plt.Axes The updated axes object. - - Raises - ------ - ValueError - If the axes given in `which` is not `x`, `y` or `both`. """ - warnings.warn( - "The 'change_log_axis_base' function is deprecated and will be removed in the" - " next major version release of cosmoplots, v1.0.0. Instead, use the `mplstyle`" - " files to set the figure dimensions: \nplt.style.use('cosmoplots.default')" - ) if which == "both": axs, pltype = ["xaxis", "yaxis"], "loglog" elif which == "x": @@ -48,9 +67,10 @@ def change_log_axis_base(axes: plt.Axes, which: str, base: float = 10) -> plt.Ax elif which == "y": axs, pltype = ["yaxis"], "semilogy" else: - raise ValueError( - "No valid axis found. 'which' must be either of 'both', 'x' or 'y'." - ) + axs, pltype = _check_axes_scales(axes) + if not axs and pltype == "linear": + # If both the axes are already linear, just return the axes object silently + return axes getattr(axes, pltype)(base=base) for ax in axs: f = getattr(axes, ax) diff --git a/poetry.lock b/poetry.lock index 1960c1e..b53cd37 100644 --- a/poetry.lock +++ b/poetry.lock @@ -223,8 +223,8 @@ files = [ [package.extras] all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] graphite = ["lz4 (>=1.7.4.2)"] -interpolatable = ["munkres", "scipy"] -lxml = ["lxml (>=4.0,<5)"] +interpolatable = ["munkres", "pycairo", "scipy"] +lxml = ["lxml (>=4.0)"] pathops = ["skia-pathops (>=0.5.0)"] plot = ["matplotlib"] repacker = ["uharfbuzz (>=0.23.0)"] @@ -577,7 +577,11 @@ files = [ [package.extras] docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +typing = ["typing-extensions"] +xmp = ["defusedxml"] [[package]] name = "pluggy" @@ -632,13 +636,13 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no [[package]] name = "python-dateutil" -version = "2.8.2" +version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] [package.dependencies]