Skip to content

Commit

Permalink
Add option to switch between plotting backends (#326)
Browse files Browse the repository at this point in the history
* Implement backend switching decorator

* Switch plot methods to use plot_backend decorator

* Fix circular imports

* Update matplotlib tests

* Update plotly tests
  • Loading branch information
stefsmeets authored Jun 5, 2024
1 parent 4972d1a commit f18ae50
Show file tree
Hide file tree
Showing 32 changed files with 250 additions and 145 deletions.
44 changes: 44 additions & 0 deletions src/gemdat/_plot_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from types import ModuleType

from gemdat import plots as plots_default
from gemdat.plots import matplotlib as plots_matplotlib
from gemdat.plots import plotly as plots_plotly


def plot_backend(func):
"""Decorator to switch plotting backend."""

def wrap(*args, backend: str | None = None, **kwargs):
module: ModuleType

if backend is None:
module = plots_default
elif backend in ('mpl', 'matplotlib'):
module = plots_matplotlib
elif backend == 'plotly':
module = plots_plotly
else:
raise ValueError(f'No such backend: {backend}')

result = func(*args, module=module, **kwargs)

return result

wrap.__doc__ = func.__doc__
wrap.__doc__ += """
Parameters
---------
backend : str
Choose plotting backend. Options: matplotlib, mpl, plotly
Defaults to plotly unless the plot is only available in matplotlib.
Returns
-------
fig : plotly.graph_objects.Figure or matplotlib.figure.Figure depending on backend.
Output figure
"""

return wrap
36 changes: 16 additions & 20 deletions src/gemdat/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pymatgen.core.units import FloatWithUnit
from scipy.constants import Boltzmann, angstrom, elementary_charge

from ._plot_backend import plot_backend
from .caching import weak_lru_cache
from .collective import Collective
from .simulation_metrics import SimulationMetrics
Expand Down Expand Up @@ -353,32 +354,27 @@ def rates(self, n_parts: int = 10) -> pd.DataFrame:

return df

def plot_jumps_vs_distance(self, **kwargs):
@plot_backend
def plot_jumps_vs_distance(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_vs_distance][] for more information."""
from gemdat import plots
return module.jumps_vs_distance(jumps=self, **kwargs)

return plots.jumps_vs_distance(jumps=self, **kwargs)

def plot_jumps_vs_time(self, **kwargs):
@plot_backend
def plot_jumps_vs_time(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_vs_time][] for more information."""
from gemdat import plots

return plots.jumps_vs_time(jumps=self, **kwargs)
return module.jumps_vs_time(jumps=self, **kwargs)

def plot_collective_jumps(self, **kwargs):
@plot_backend
def plot_collective_jumps(self, *, module, **kwargs):
"""See [gemdat.plots.collective_jumps][] for more information."""
from gemdat import plots

return plots.collective_jumps(jumps=self, **kwargs)
return module.collective_jumps(jumps=self, **kwargs)

def plot_jumps_3d(self, **kwargs):
@plot_backend
def plot_jumps_3d(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_3d][] for more information."""
from gemdat import plots
return module.jumps_3d(jumps=self, **kwargs)

return plots.jumps_3d(jumps=self, **kwargs)

def plot_jumps_3d_animation(self, **kwargs):
@plot_backend
def plot_jumps_3d_animation(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_3d_animation][] for more information."""
from gemdat import plots

return plots.jumps_3d_animation(jumps=self, **kwargs)
return module.jumps_3d_animation(jumps=self, **kwargs)
28 changes: 15 additions & 13 deletions src/gemdat/orientations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

from dataclasses import InitVar, dataclass, field, replace
from typing import TYPE_CHECKING

import numpy as np
from pymatgen.symmetry.groups import PointGroup

from gemdat.trajectory import Trajectory
from gemdat.utils import cartesian_to_spherical, fft_autocorrelation

from ._plot_backend import plot_backend

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


@dataclass
class Orientations:
Expand Down Expand Up @@ -259,23 +264,20 @@ def autocorrelation(self):
"""Compute the autocorrelation of the orientation vectors using FFT."""
return fft_autocorrelation(self.vectors)

def plot_rectilinear(self, **kwargs):
@plot_backend
def plot_rectilinear(self, *, module, **kwargs):
"""See [gemdat.plots.rectilinear][] for more info."""
from gemdat import plots

return plots.rectilinear(orientations=self, **kwargs)
return module.rectilinear(orientations=self, **kwargs)

def plot_bond_length_distribution(self, **kwargs):
@plot_backend
def plot_bond_length_distribution(self, *, module, **kwargs):
"""See [gemdat.plots.bond_length_distribution][] for more info."""
from gemdat import plots
return module.bond_length_distribution(orientations=self, **kwargs)

return plots.bond_length_distribution(orientations=self, **kwargs)

def plot_autocorrelation(self, **kwargs):
@plot_backend
def plot_autocorrelation(self, *, module, **kwargs):
"""See [gemdat.plots.unit_vector_autocorrelation][] for more info."""
from gemdat import plots

return plots.autocorrelation(orientations=self, **kwargs)
return module.autocorrelation(orientations=self, **kwargs)


def calculate_spherical_areas(shape: tuple[int, int], radius: float = 1) -> np.ndarray:
Expand Down
15 changes: 7 additions & 8 deletions src/gemdat/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from gemdat.volume import FreeEnergyVolume

from ._plot_backend import plot_backend
from .utils import nearest_structure_reference

if TYPE_CHECKING:
Expand Down Expand Up @@ -142,17 +143,15 @@ def stop_site(self) -> tuple[int, int, int]:
"""Return stop site."""
return self.sites[-1]

def plot_energy_along_path(self, **kwargs):
@plot_backend
def plot_energy_along_path(self, module, **kwargs):
"""See [gemdat.plots.energy_along_path][] for more info."""
from gemdat import plots
return module.energy_along_path(path=self, **kwargs)

return plots.energy_along_path(path=self, **kwargs)

def plot_path_on_grid(self, **kwargs):
@plot_backend
def plot_path_on_grid(self, module, **kwargs):
"""See [gemdat.plots.path_on_grid][] for more info."""
from gemdat import plots

return plots.path_on_grid(path=self, **kwargs)
return module.path_on_grid(path=self, **kwargs)


def free_energy_graph(
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_autocorrelation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.orientations import Orientations
if TYPE_CHECKING:
from gemdat.orientations import Orientations


def autocorrelation(
Expand Down
7 changes: 5 additions & 2 deletions src/gemdat/plots/matplotlib/_bond_length_distribution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import matplotlib.pyplot as plt
from typing import TYPE_CHECKING

from gemdat.orientations import Orientations
import matplotlib.pyplot as plt

from .._shared import _fit_skewnorm_to_hist, _orientations_to_histogram

if TYPE_CHECKING:
from gemdat.orientations import Orientations


def bond_length_distribution(*, orientations: Orientations, bins: int = 50) -> plt.Figure:
"""Plot the bond length probability distribution.
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_histogram.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def displacement_histogram(trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_per_atom.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
6 changes: 5 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_per_element.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from gemdat.plots._shared import _mean_displacements_per_element
from gemdat.trajectory import Trajectory

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_energy_along_path.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
from pymatgen.core import Structure

from gemdat.path import Pathway
if TYPE_CHECKING:
from gemdat.path import Pathway


def energy_along_path(
Expand Down
6 changes: 5 additions & 1 deletion src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.simulation_metrics import SimulationMetrics
from gemdat.trajectory import Trajectory

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_msd_per_element.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def msd_per_element(
Expand Down
10 changes: 6 additions & 4 deletions src/gemdat/plots/matplotlib/_rectilinear.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.orientations import (
Orientations,
calculate_spherical_areas,
)
if TYPE_CHECKING:
from gemdat.orientations import Orientations


def rectilinear(
Expand All @@ -32,6 +32,8 @@ def rectilinear(
fig : matplotlib.figure.Figure
Output figure
"""
from gemdat.orientations import calculate_spherical_areas

az, el, _ = orientations.vectors_spherical.T
az = az.flatten()
el = el.flatten()
Expand Down
6 changes: 5 additions & 1 deletion src/gemdat/plots/matplotlib/_vibrational_amplitudes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

from gemdat.simulation_metrics import SimulationMetrics
from gemdat.trajectory import Trajectory

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/plotly/_autocorrelation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import plotly.graph_objects as go

from gemdat.orientations import Orientations
if TYPE_CHECKING:
from gemdat.orientations import Orientations


def autocorrelation(
Expand Down
7 changes: 5 additions & 2 deletions src/gemdat/plots/plotly/_bond_length_distribution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import plotly.express as px
import plotly.graph_objects as go

from gemdat.orientations import Orientations

from .._shared import _fit_skewnorm_to_hist, _orientations_to_histogram

if TYPE_CHECKING:
from gemdat.orientations import Orientations


def bond_length_distribution(*, orientations: Orientations, bins: int = 50) -> go.Figure:
"""Plot the bond length probability distribution.
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/plotly/_displacement_histogram.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def _trajectory_to_dataframe(trajectory: Trajectory) -> pd.DataFrame:
Expand Down
Loading

0 comments on commit f18ae50

Please sign in to comment.