Skip to content

Commit

Permalink
Add figsaver for consistent figure saving
Browse files Browse the repository at this point in the history
  • Loading branch information
eachanjohnson committed Oct 5, 2024
1 parent c839f0a commit 1529926
Showing 1 changed file with 57 additions and 3 deletions.
60 changes: 57 additions & 3 deletions carabiner/mpl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any, Iterable, Mapping, Tuple, Optional, Union
from functools import wraps
import os

try:
import matplotlib.pyplot as plt
Expand All @@ -17,13 +18,15 @@
from tqdm.auto import tqdm

from ..cast import cast
from ..utils import colorblind_palette as utils_colorblind_palette
rcParams['axes.prop_cycle'] = cycler(color=utils_colorblind_palette())
from ..utils import colorblind_palette as utils_colorblind_palette, print_err

TFigAx = Tuple[figure.Figure, axes.Axes]

colorblind_palette = utils_colorblind_palette

# Set default color cycle on import
rcParams['axes.prop_cycle'] = cycler(color=colorblind_palette())

def grid(
nrow: int = 1,
ncol: int = 1,
Expand Down Expand Up @@ -207,4 +210,55 @@ def scattergrid(
)
if not dummy_group:
ax.legend(**_legend_opts)
return fig, axes
return fig, axes

def figsaver(
dir: str = ".",
prefix: Optional[str] = None,
dpi: int = 300,
format: str = 'png',
) -> Callable[[Figure, str, int, str, Optional[DataFrame]], None]:

"""Create a function to save figures in a predefined location.
Parameters
----------
dir : str, optional
Directory to save figures. Default: ".".
prefix : str, optional
Prefix for filenames. Default: no prefix.
dpi : int, optional
Resolution of saved figures. Default: 300.
format : str, optional
File format of figures. Default: "png".
Returns
-------
Callable
A function taking Figure, name, and optionally a Pandas
DataFrame as arguments. Saves as {dir}/{prefix}{name}.{format}.
If a DataFrame is provided, it as saved as {dir}/{prefix}{name}.csv.
"""

def _figsave(
fig: Figure,
name: str,
df: Optional[DataFrame] = None
) -> None:
"""
"""
figname = os.path.join(output_dir, f"{prefix}{name}.{format}")
print_err(f"Saving plot at {figname}")
fig.savefig(
figname,
dpi=dpi,
bbox_inches='tight',
)
if df is not None and isinstance(df, DataFrame):
dataname = os.path.join(output_dir, f"{prefix}{name}.csv"
print_err(f"Saving data at {dataname}")
df.to_csv(dataname, index=False)
return None
return _figave

0 comments on commit 1529926

Please sign in to comment.