Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to create a graph from jumps #325

Merged
merged 13 commits into from
Jun 6, 2024
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ jumps.jump_diffusivity(dimensions=3)
To calculate different metrics, such as tracer diffusivity:

```python
from gemdat import SimulationMetrics
from gemdat import TrajectoryMetrics

metrics = SimulationMetrics(diff_trajectory)
metrics = TrajectoryMetrics(diff_trajectory)

metrics.tracer_diffusivity(dimensions=3)
metrics.haven_ratio(dimensions=3)
Expand Down
2 changes: 1 addition & 1 deletion docs/api/gemdat.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
- [gemdat.read_cif][gemdat.io.read_cif]
- [gemdat.load_known_material][gemdat.io.load_known_material]
- [gemdat.SimulationMetrics][gemdat.simulation_metrics.SimulationMetrics]
- [gemdat.TrajectoryMetrics][gemdat.metrics.TrajectoryMetrics]
- [gemdat.Transitions][gemdat.transitions.Transitions]
- [gemdat.Jumps][gemdat.jumps.Jumps]
- [gemdat.Trajectory][gemdat.trajectory.Trajectory]
Expand Down
2 changes: 1 addition & 1 deletion docs/api/gemdat_simulation_metrics.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
::: gemdat.simulation_metrics
::: gemdat.metrics
options:
show_root_heading: false
show_root_toc_entry: false
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ nav:
- gemdat.io: api/gemdat_io.md
- gemdat.plots: api/gemdat_plots.md
- gemdat.rdf: api/gemdat_rdf.md
- gemdat.simulation_metrics: api/gemdat_simulation_metrics.md
- gemdat.metrics: api/gemdat_metrics.md
- gemdat.trajectory: api/gemdat_trajectory.md
- gemdat.transitions: api/gemdat_transitions.md
- gemdat.jumps: api/gemdat_jumps.md
Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from .io import load_known_material, read_cif
from .jumps import Jumps
from .metrics import TrajectoryMetrics
from .orientations import Orientations
from .rdf import radial_distribution
from .shape import ShapeAnalyzer
from .simulation_metrics import SimulationMetrics
from .trajectory import Trajectory
from .transitions import Transitions
from .volume import Volume, trajectory_to_volume
Expand All @@ -18,7 +18,7 @@
'radial_distribution',
'read_cif',
'ShapeAnalyzer',
'SimulationMetrics',
'TrajectoryMetrics',
'Trajectory',
'trajectory_to_volume',
'Transitions',
Expand Down
123 changes: 104 additions & 19 deletions src/gemdat/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from math import ceil
from typing import TYPE_CHECKING, Callable

import networkx as nx
import numpy as np
import pandas as pd
from pymatgen.core.units import FloatWithUnit
Expand All @@ -13,7 +14,7 @@
from ._plot_backend import plot_backend
from .caching import weak_lru_cache
from .collective import Collective
from .simulation_metrics import SimulationMetrics
from .metrics import TrajectoryMetrics
from .transitions import Transitions, _calculate_transitions_matrix

if TYPE_CHECKING:
Expand Down Expand Up @@ -223,7 +224,7 @@ def collective(self, max_dist: float = 1) -> Collective:
sites = self.transitions.sites

time_step = trajectory.time_step
attempt_freq, _ = SimulationMetrics(trajectory).attempt_frequency()
attempt_freq, _ = TrajectoryMetrics(trajectory).attempt_frequency()

max_steps = ceil(1.0 / (attempt_freq * time_step))

Expand All @@ -237,7 +238,7 @@ def collective(self, max_dist: float = 1) -> Collective:

@weak_lru_cache()
def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
"""Calculate activation energies for jumps (UNITS?).
"""Calculate activation energies for jumps in eV.

Parameters
----------
Expand All @@ -251,7 +252,7 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
between site pairs.
"""
trajectory = self.trajectory
attempt_freq, _ = SimulationMetrics(trajectory).attempt_frequency()
attempt_freq, _ = TrajectoryMetrics(trajectory).attempt_frequency()

dct = {}

Expand All @@ -260,13 +261,13 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
atom_locations_parts = [
part.atom_locations() for part in self.transitions.split(n_parts)
]
jumps_counter_parts = [part.jumps_counter() for part in self.split(n_parts)]
counter_parts = [part.counter() for part in self.split(n_parts)]
n_floating = self.n_floating

for site_pair in self.site_pairs:
site_start, site_stop = site_pair

n_jumps = np.array([part[site_pair] for part in jumps_counter_parts])
n_jumps = np.array([part[site_pair] for part in counter_parts])

part_time = trajectory.total_time / n_parts

Expand All @@ -292,22 +293,106 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:

return df

def jumps_counter(self) -> Counter:
"""Calculate number of jumps between sites.
@weak_lru_cache()
def counter(self) -> Counter[tuple[str, str]]:
"""Count number of jumps between sites.

Returns
-------
jumps : dict[tuple[str, str], int]
Dictionary with number of jumps per sites combination
counter : Counter[tuple[str, str]]
Dictionary with site pairs as keys and corresponding
number of jumps as dictionary values
"""
labels = self.sites.labels
jumps = Counter(
[
(labels[i], labels[j])
for _, (i, j) in self.data[['start site', 'destination site']].iterrows()
]
)
return jumps
counter: Counter[tuple[str, str]] = Counter()
for (i, j), val in self._counter().items():
counter[labels[i], labels[j]] += val
return counter

@weak_lru_cache()
def _counter(self) -> Counter[tuple[int, int]]:
"""Count number of jumps between sites. Keys are site indices.

Returns
-------
counter : Counter[tuple[int, int]]
Dictionary with site pairs as keys and corresponding
number of jumps as dictionary values
"""
counter = Counter(zip(self.data['start site'], self.data['destination site']))
return counter

def activation_energy_between_sites(self, start: str, stop: str) -> float:
"""Returns activation energy between two sites.

Uses `Jumps.to_graph()` in the background. For a large number of operations,
it is more efficient to query the graph directly.

Parameters
----------
start : str
Label of the start site
stop : str
Label of the stop site

Returns
-------
e_act : float
Activation energy in eV
"""
G = self.to_graph()
edge_data = G.get_edge_data(start, stop)
if not edge_data:
raise IndexError(f'No jumps between ({start}) and ({stop})')
return edge_data['e_act']

@weak_lru_cache()
def to_graph(
self, min_e_act: float | None = None, max_e_act: float | None = None
) -> nx.DiGraph:
"""Create a graph from jumps data.

The edges are weighted by the activation energy. The nodes are indices that
correspond to `Jumps.sites`.

Parameters
----------
min_e_act : float
Reject edges with activation energy below this threshold
max_e_act : float
Reject edges with activation energy above this threshold

Returns
-------
G : nx.DiGraph
A networkx DiGraph object.
"""
min_e_act = min_e_act if min_e_act else float('-inf')
max_e_act = max_e_act if max_e_act else float('inf')

atom_percentage = [site.species.num_atoms for site in self.transitions.occupancy()]

attempt_freq, _ = self.trajectory.metrics().attempt_frequency()
temperature = self.trajectory.metadata['temperature']
kBT = Boltzmann * temperature

G = nx.DiGraph()

for i, site in enumerate(self.sites):
G.add_node(i, label=site.label)

for (start, stop), n_jumps in self._counter().items():
time_perc = atom_percentage[start] * self.trajectory.total_time

eff_rate = n_jumps / time_perc

e_act = -np.log(eff_rate / attempt_freq) * kBT
e_act /= elementary_charge

if min_e_act <= e_act <= max_e_act:
G.add_edge(start, stop, e_act=e_act)

return G

def split(self, n_parts: int) -> list[Jumps]:
"""Split the jumps into parts.
Expand Down Expand Up @@ -336,12 +421,12 @@ def rates(self, n_parts: int = 10) -> pd.DataFrame:
"""
dct = {}

parts = [part.jumps_counter() for part in self.split(n_parts)]
parts = [part.counter() for part in self.split(n_parts)]
part_time = self.trajectory.total_time / n_parts

for site_pair in self.site_pairs:
n_jumps = [part[site_pair] for part in parts]

part_time = self.trajectory.total_time / n_parts
denom = self.n_floating * part_time

jump_freq_mean = np.mean(n_jumps) / denom
Expand Down
8 changes: 4 additions & 4 deletions src/gemdat/simulation_metrics.py → src/gemdat/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from trajectory import Trajectory


class SimulationMetrics:
class TrajectoryMetrics:
"""Class for calculating different metrics and properties from a molecular
dynamics simulation."""

Expand Down Expand Up @@ -115,7 +115,7 @@ def tracer_diffusivity_center_of_mass(
"""
center_of_mass = self.trajectory.center_of_mass()

metrics = SimulationMetrics(center_of_mass)
metrics = TrajectoryMetrics(center_of_mass)

return metrics.tracer_diffusivity(dimensions=dimensions)

Expand Down Expand Up @@ -230,7 +230,7 @@ def amplitudes(self) -> np.ndarray:
return np.asarray(amplitudes)


class SimulationMetricsStd:
class TrajectoryMetricsStd:
"""Class for calculating different metrics and properties from a molecular
dynamics simulation.

Expand All @@ -246,7 +246,7 @@ def __init__(self, trajectories: list[Trajectory]):
trajectories: list[Trajectory]
Input trajectories
"""
self.metrics = [SimulationMetrics(trajectory) for trajectory in trajectories]
self.metrics = [TrajectoryMetrics(trajectory) for trajectory in trajectories]

def speed(self) -> tuple[np.ndarray, np.ndarray]:
"""Calculate mean speed and standard deviations.
Expand Down
4 changes: 1 addition & 3 deletions src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import matplotlib.pyplot as plt
import numpy as np

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -24,7 +22,7 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure:
fig : matplotlib.figure.Figure
Output figure
"""
metrics = SimulationMetrics(trajectory)
metrics = trajectory.metrics()
speed = metrics.speed()

length = speed.shape[1]
Expand Down
4 changes: 1 addition & 3 deletions src/gemdat/plots/matplotlib/_vibrational_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import numpy as np
from scipy import stats

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -25,7 +23,7 @@ def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
fig : matplotlib.figure.Figure
Output figure
"""
metrics = SimulationMetrics(trajectory)
metrics = trajectory.metrics()

fig, ax = plt.subplots()
ax.hist(metrics.amplitudes(), bins=100, density=True)
Expand Down
4 changes: 1 addition & 3 deletions src/gemdat/plots/plotly/_frequency_vs_occurence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import numpy as np
import plotly.graph_objects as go

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -24,7 +22,7 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
fig : plotly.graph_objects.Figure.Figure
Output figure
"""
metrics = SimulationMetrics(trajectory)
metrics = trajectory.metrics()
speed = metrics.speed()

length = speed.shape[1]
Expand Down
6 changes: 2 additions & 4 deletions src/gemdat/plots/plotly/_vibrational_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import plotly.graph_objects as go
from scipy import stats

from gemdat.simulation_metrics import SimulationMetrics

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory

Expand All @@ -33,8 +31,8 @@ def vibrational_amplitudes(
"""

trajectories = trajectory.split(n_parts)
single_metrics = SimulationMetrics(trajectory)
metrics = [SimulationMetrics(trajectory).amplitudes() for trajectory in trajectories]
single_metrics = trajectory.metrics()
metrics = [trajectory.metrics().amplitudes() for trajectory in trajectories]

max_amp = max(max(metric) for metric in metrics)
min_amp = min(min(metric) for metric in metrics)
Expand Down
7 changes: 7 additions & 0 deletions src/gemdat/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if TYPE_CHECKING:
from pymatgen.core import Structure

from .metrics import TrajectoryMetrics
from .transitions import Transitions
from .volume import Volume

Expand Down Expand Up @@ -522,6 +523,12 @@ def transitions_between_sites(
site_inner_fraction=site_inner_fraction,
)

def metrics(self) -> TrajectoryMetrics:
"""See [gemdat.TrajectoryMetrics][] for more info."""
from .metrics import TrajectoryMetrics

return TrajectoryMetrics(trajectory=self)

@plot_backend
def plot_displacement_per_atom(self, *, module, **kwargs):
"""See [gemdat.plots.displacement_per_atom][] for more info."""
Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pymatgen.core import Structure

from .caching import weak_lru_cache
from .simulation_metrics import SimulationMetrics
from .metrics import TrajectoryMetrics
from .utils import bfill, ffill, integer_remap

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -108,7 +108,7 @@ def from_trajectory(
diff_trajectory = trajectory.filter(floating_specie)

if site_radius is None:
vibration_amplitude = SimulationMetrics(diff_trajectory).vibration_amplitude()
vibration_amplitude = TrajectoryMetrics(diff_trajectory).vibration_amplitude()

site_radius = _compute_site_radius(
trajectory=trajectory,
Expand Down
Loading
Loading