diff --git a/src/gemdat/plots/matplotlib/_autocorrelation.py b/src/gemdat/plots/matplotlib/_autocorrelation.py index 1d9770c7..e2aa9f89 100644 --- a/src/gemdat/plots/matplotlib/_autocorrelation.py +++ b/src/gemdat/plots/matplotlib/_autocorrelation.py @@ -11,6 +11,7 @@ def autocorrelation( *, orientations: Orientations, show_traces: bool = True, + show_shaded: bool = True, ) -> plt.Figure: """Plot the autocorrelation function of the unit vectors series. @@ -20,6 +21,8 @@ def autocorrelation( The unit vector trajectories show_traces : bool If True, show traces of individual trajectories + show_shaded : bool + If True, show standard deviation as shaded area Returns ------- @@ -45,11 +48,13 @@ def autocorrelation( label = 'Trajectories' if (i == 0) else None ax.plot(tgrid, ac_i, lw=0.1, c=last_color, label=label) - ax.fill_between(tgrid, - ac_mean - ac_std, - ac_mean + ac_std, - alpha=0.2, - label='Standard deviation') + if show_shaded: + ax.fill_between(tgrid, + ac_mean - ac_std, + ac_mean + ac_std, + alpha=0.2, + label='Standard deviation') + ax.set_xlabel('Time lag (ps)') ax.set_ylabel('Autocorrelation') ax.legend() diff --git a/src/gemdat/plots/matplotlib/_msd_per_element.py b/src/gemdat/plots/matplotlib/_msd_per_element.py index f60d2b3e..2655cfbe 100644 --- a/src/gemdat/plots/matplotlib/_msd_per_element.py +++ b/src/gemdat/plots/matplotlib/_msd_per_element.py @@ -10,6 +10,7 @@ def msd_per_element( *, trajectory: Trajectory, show_traces: bool = True, + show_shaded: bool = True, ) -> plt.Figure: """Plot mean squared displacement per element. @@ -18,7 +19,9 @@ def msd_per_element( trajectory : Trajectory Input trajectory show_traces : bool - If True, show individual traces for each element + If True, show traces of individual trajectories for each element + show_shaded : bool + If True, show standard deviation as shaded area Returns ------- @@ -49,12 +52,13 @@ def msd_per_element( label = f'{sp.symbol} trajectories' if (i == 0) else None ax.plot(t_values, traj, lw=0.1, c=last_color, label=label) - ax.fill_between(t_values, - msd_mean - msd_std, - msd_mean + msd_std, - color=last_color, - alpha=0.2, - label=f'{sp.symbol} std') + if show_shaded: + ax.fill_between(t_values, + msd_mean - msd_std, + msd_mean + msd_std, + color=last_color, + alpha=0.2, + label=f'{sp.symbol} std') ax.legend() ax.set(title='Mean squared displacement per element',