Skip to content

Commit

Permalink
Add option to show shaded area for std (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefsmeets authored May 21, 2024
1 parent 5a870b9 commit feb4cf8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
15 changes: 10 additions & 5 deletions src/gemdat/plots/matplotlib/_autocorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def autocorrelation(
*,
orientations: Orientations,
show_traces: bool = True,
show_shaded: bool = True,
) -> plt.Figure:
"""Plot the autocorrelation function of the unit vectors series.
Expand All @@ -20,6 +21,8 @@ def autocorrelation(
The unit vector trajectories
show_traces : bool
If True, show traces of individual trajectories
show_shaded : bool
If True, show standard deviation as shaded area
Returns
-------
Expand All @@ -45,11 +48,13 @@ def autocorrelation(
label = 'Trajectories' if (i == 0) else None
ax.plot(tgrid, ac_i, lw=0.1, c=last_color, label=label)

ax.fill_between(tgrid,
ac_mean - ac_std,
ac_mean + ac_std,
alpha=0.2,
label='Standard deviation')
if show_shaded:
ax.fill_between(tgrid,
ac_mean - ac_std,
ac_mean + ac_std,
alpha=0.2,
label='Standard deviation')

ax.set_xlabel('Time lag (ps)')
ax.set_ylabel('Autocorrelation')
ax.legend()
Expand Down
18 changes: 11 additions & 7 deletions src/gemdat/plots/matplotlib/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def msd_per_element(
*,
trajectory: Trajectory,
show_traces: bool = True,
show_shaded: bool = True,
) -> plt.Figure:
"""Plot mean squared displacement per element.
Expand All @@ -18,7 +19,9 @@ def msd_per_element(
trajectory : Trajectory
Input trajectory
show_traces : bool
If True, show individual traces for each element
If True, show traces of individual trajectories for each element
show_shaded : bool
If True, show standard deviation as shaded area
Returns
-------
Expand Down Expand Up @@ -49,12 +52,13 @@ def msd_per_element(
label = f'{sp.symbol} trajectories' if (i == 0) else None
ax.plot(t_values, traj, lw=0.1, c=last_color, label=label)

ax.fill_between(t_values,
msd_mean - msd_std,
msd_mean + msd_std,
color=last_color,
alpha=0.2,
label=f'{sp.symbol} std')
if show_shaded:
ax.fill_between(t_values,
msd_mean - msd_std,
msd_mean + msd_std,
color=last_color,
alpha=0.2,
label=f'{sp.symbol} std')

ax.legend()
ax.set(title='Mean squared displacement per element',
Expand Down

0 comments on commit feb4cf8

Please sign in to comment.