diff --git a/carabiner/mpl/utils.py b/carabiner/mpl/utils.py index 643d16c..7dc11a6 100644 --- a/carabiner/mpl/utils.py +++ b/carabiner/mpl/utils.py @@ -2,6 +2,7 @@ from typing import Any, Iterable, Mapping, Tuple, Optional, Union from functools import wraps +import os try: import matplotlib.pyplot as plt @@ -17,13 +18,15 @@ from tqdm.auto import tqdm from ..cast import cast -from ..utils import colorblind_palette as utils_colorblind_palette -rcParams['axes.prop_cycle'] = cycler(color=utils_colorblind_palette()) +from ..utils import colorblind_palette as utils_colorblind_palette, print_err TFigAx = Tuple[figure.Figure, axes.Axes] colorblind_palette = utils_colorblind_palette +# Set default color cycle on import +rcParams['axes.prop_cycle'] = cycler(color=colorblind_palette()) + def grid( nrow: int = 1, ncol: int = 1, @@ -207,4 +210,55 @@ def scattergrid( ) if not dummy_group: ax.legend(**_legend_opts) - return fig, axes \ No newline at end of file + return fig, axes + +def figsaver( + dir: str = ".", + prefix: Optional[str] = None, + dpi: int = 300, + format: str = 'png', +) -> Callable[[Figure, str, int, str, Optional[DataFrame]], None]: + + """Create a function to save figures in a predefined location. + + Parameters + ---------- + dir : str, optional + Directory to save figures. Default: ".". + prefix : str, optional + Prefix for filenames. Default: no prefix. + dpi : int, optional + Resolution of saved figures. Default: 300. + format : str, optional + File format of figures. Default: "png". + + Returns + ------- + Callable + A function taking Figure, name, and optionally a Pandas + DataFrame as arguments. Saves as {dir}/{prefix}{name}.{format}. + If a DataFrame is provided, it as saved as {dir}/{prefix}{name}.csv. + + """ + + def _figsave( + fig: Figure, + name: str, + df: Optional[DataFrame] = None + ) -> None: + """ + + """ + figname = os.path.join(output_dir, f"{prefix}{name}.{format}") + print_err(f"Saving plot at {figname}") + fig.savefig( + figname, + dpi=dpi, + bbox_inches='tight', + ) + if df is not None and isinstance(df, DataFrame): + dataname = os.path.join(output_dir, f"{prefix}{name}.csv" + print_err(f"Saving data at {dataname}") + df.to_csv(dataname, index=False) + return None + return _figave \ No newline at end of file