8
8
import warnings
9
9
from fnmatch import fnmatch
10
10
from pathlib import Path
11
+ from tempfile import NamedTemporaryFile
11
12
from typing import TYPE_CHECKING , TypeAlias , cast
12
13
13
14
import numpy as np
14
15
from monty .io import zopen
15
16
from monty .json import MSONable
16
17
17
18
from pymatgen .core .structure import Composition , DummySpecies , Element , Lattice , Molecule , Species , Structure
19
+ from pymatgen .io .ase import NO_ASE_ERR , AseAtomsAdaptor
20
+
21
+ if NO_ASE_ERR is None :
22
+ from ase .io .trajectory import Trajectory as AseTrajectory
23
+ else :
24
+ AseTrajectory = None
25
+
18
26
19
27
if TYPE_CHECKING :
20
- from collections .abc import Iterator
28
+ from collections .abc import Iterator , Sequence
21
29
from typing import Any
22
30
23
31
from typing_extensions import Self
24
32
25
33
from pymatgen .util .typing import Matrix3D , PathLike , SitePropsType , Vector3D
26
34
27
-
28
35
__author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen, Evan Spotte-Smith"
29
36
__version__ = "0.1"
30
37
__date__ = "Jun 29, 2022"
@@ -563,8 +570,6 @@ def from_file(cls, filename: str | Path, constant_lattice: bool = True, **kwargs
563
570
Trajectory: containing the structures or molecules in the file.
564
571
"""
565
572
filename = str (Path (filename ).expanduser ().resolve ())
566
- is_mol = False
567
- molecules = []
568
573
structures = []
569
574
570
575
if fnmatch (filename , "*XDATCAR*" ):
@@ -578,31 +583,24 @@ def from_file(cls, filename: str | Path, constant_lattice: bool = True, **kwargs
578
583
structures = Vasprun (filename ).structures
579
584
580
585
elif fnmatch (filename , "*.traj" ):
581
- try :
582
- from ase . io . trajectory import Trajectory as AseTrajectory
583
-
584
- from pymatgen . io . ase import AseAtomsAdaptor
585
-
586
- ase_traj = AseTrajectory ( filename )
587
- # Periodic boundary conditions should be the same for all frames so just check the first
588
- pbc = ase_traj [ 0 ]. pbc
586
+ if NO_ASE_ERR is None :
587
+ return cls . from_ase (
588
+ filename ,
589
+ constant_lattice = constant_lattice ,
590
+ store_frame_properties = True ,
591
+ additional_fields = None ,
592
+ )
593
+ raise ImportError ( "ASE is required to read .traj files. pip install ase" )
589
594
590
- if any (pbc ):
591
- structures = [AseAtomsAdaptor .get_structure (atoms ) for atoms in ase_traj ]
592
- else :
593
- molecules = [AseAtomsAdaptor .get_molecule (atoms ) for atoms in ase_traj ]
594
- is_mol = True
595
+ elif fnmatch (filename , "*.json*" ):
596
+ from monty .serialization import loadfn
595
597
596
- except ImportError as exc :
597
- raise ImportError ("ASE is required to read .traj files. pip install ase" ) from exc
598
+ return loadfn (filename , ** kwargs )
598
599
599
600
else :
600
- supported_file_types = ("XDATCAR" , "vasprun.xml" , "*.traj" )
601
+ supported_file_types = ("XDATCAR" , "vasprun.xml" , "*.traj" , ".json" )
601
602
raise ValueError (f"Expect file to be one of { supported_file_types } ; got { filename } ." )
602
603
603
- if is_mol :
604
- return cls .from_molecules (molecules , ** kwargs )
605
-
606
604
return cls .from_structures (structures , constant_lattice = constant_lattice , ** kwargs )
607
605
608
606
@staticmethod
@@ -734,3 +732,148 @@ def _get_site_props(self, frames: ValidIndex) -> SitePropsType | None:
734
732
return [self .site_properties [idx ] for idx in frames ]
735
733
raise ValueError ("Unexpected frames type." )
736
734
raise ValueError ("Unexpected site_properties type." )
735
+
736
+ @classmethod
737
+ def from_ase (
738
+ cls ,
739
+ trajectory : str | Path | AseTrajectory ,
740
+ constant_lattice : bool | None = None ,
741
+ store_frame_properties : bool = True ,
742
+ property_map : dict [str , str ] | None = None ,
743
+ lattice_match_tol : float = 1.0e-6 ,
744
+ additional_fields : Sequence [str ] | None = ["temperature" , "velocities" ],
745
+ ) -> Trajectory :
746
+ """
747
+ Convert an ASE trajectory to a pymatgen trajectory.
748
+
749
+ Args:
750
+ trajectory (str, .Path, or ASE .Trajectory) : the ASE trajectory, or a file path to it if a str or .Path
751
+ constant_lattice (bool or None) : if a bool, whether the lattice is constant in the .Trajectory.
752
+ If `None`, this is determined on the fly.
753
+ store_frame_properties (bool) : Whether to store pymatgen .Trajectory `frame_properties` as
754
+ ASE calculator properties. Defaults to True
755
+ property_map (dict[str,str]) : A mapping between ASE calculator properties and
756
+ pymatgen .Trajectory `frame_properties` keys. Ex.:
757
+ property_map = {"energy": "e_0_energy"}
758
+ would map `e_0_energy` in the pymatgen .Trajectory `frame_properties`
759
+ to ASE's `get_potential_energy` function.
760
+ See `ase.calculators.calculator.all_properties` for a list of acceptable calculator properties.
761
+ lattice_match_tol (float = 1.0e-6) : tolerance to which lattices are matched if
762
+ `constant_lattice = None`.
763
+ additional_fields (Sequence of str, defaults to ["temperature", "velocities"]) :
764
+ Optional other fields to save in the pymatgen .Trajectory.
765
+ Valid options are "temperature" and "velocities".
766
+
767
+ Returns:
768
+ pymatgen .Trajectory
769
+ """
770
+ if isinstance (trajectory , str | Path ):
771
+ trajectory = AseTrajectory (trajectory , "r" )
772
+
773
+ property_map = property_map or {
774
+ "energy" : "energy" ,
775
+ "forces" : "forces" ,
776
+ "stress" : "stress" ,
777
+ }
778
+ additional_fields = additional_fields or []
779
+
780
+ adaptor = AseAtomsAdaptor ()
781
+
782
+ structures = []
783
+ frame_properties = []
784
+ converter = adaptor .get_structure if (is_pbc := any (trajectory [0 ].pbc )) else adaptor .get_molecule
785
+
786
+ for atoms in trajectory :
787
+ site_properties = {}
788
+ if "velocities" in additional_fields :
789
+ site_properties ["velocities" ] = atoms .get_velocities ()
790
+
791
+ structures .append (converter (atoms , site_properties = site_properties ))
792
+
793
+ if store_frame_properties and atoms .calc :
794
+ props = {v : atoms .calc .get_property (k ) for k , v in property_map .items ()}
795
+ if "temperature" in additional_fields :
796
+ props ["temperature" ] = atoms .get_temperature ()
797
+
798
+ frame_properties .append (props )
799
+
800
+ if constant_lattice is None :
801
+ constant_lattice = all (
802
+ np .all (np .abs (ref_struct .lattice .matrix - structures [j ].lattice .matrix )) < lattice_match_tol
803
+ for i , ref_struct in enumerate (structures )
804
+ for j in range (i + 1 , len (structures ))
805
+ )
806
+
807
+ if is_pbc :
808
+ return cls .from_structures (structures , constant_lattice = constant_lattice , frame_properties = frame_properties )
809
+ return cls .from_molecules (
810
+ structures ,
811
+ constant_lattice = constant_lattice ,
812
+ frame_properties = frame_properties ,
813
+ )
814
+
815
+ def to_ase (
816
+ self ,
817
+ property_map : dict [str , str ] | None = None ,
818
+ ase_traj_file : str | Path | None = None ,
819
+ ) -> AseTrajectory :
820
+ """
821
+ Convert a pymatgen .Trajectory to an ASE .Trajectory.
822
+
823
+ Args:
824
+ trajectory (pymatgen .Trajectory) : trajectory to convert
825
+ property_map (dict[str,str]) : A mapping between ASE calculator properties and
826
+ pymatgen .Trajectory `frame_properties` keys. Ex.:
827
+ property_map = {"energy": "e_0_energy"}
828
+ would map `e_0_energy` in the pymatgen .Trajectory `frame_properties`
829
+ to ASE's `get_potential_energy` function.
830
+ See `ase.calculators.calculator.all_properties` for a list of acceptable calculator properties.
831
+ ase_traj_file (str, Path, or None (default) ) : If not None, the name of
832
+ the file to write the ASE trajectory to.
833
+
834
+ Returns:
835
+ ase .Trajectory
836
+ """
837
+ if NO_ASE_ERR is not None :
838
+ raise ImportError ("ASE is required to write .traj files. pip install ase" )
839
+
840
+ from ase .calculators .calculator import all_properties
841
+ from ase .calculators .singlepoint import SinglePointCalculator
842
+
843
+ property_map = property_map or {
844
+ "energy" : "energy" ,
845
+ "forces" : "forces" ,
846
+ "stress" : "stress" ,
847
+ }
848
+
849
+ if (unrecognized_props := set (property_map ).difference (set (all_properties ))) != set ():
850
+ raise ValueError (f"Unrecognized ASE calculator properties:\n { ', ' .join (unrecognized_props )} " )
851
+
852
+ adaptor = AseAtomsAdaptor ()
853
+
854
+ temp_file = None
855
+ if ase_traj_file is None :
856
+ temp_file = NamedTemporaryFile (delete = False ) # noqa: SIM115
857
+ ase_traj_file = temp_file .name
858
+
859
+ frame_props = self .frame_properties or [{} for _ in range (len (self ))]
860
+ for idx , structure in enumerate (self ):
861
+ atoms = adaptor .get_atoms (structure , msonable = False , velocities = structure .site_properties .get ("velocities" ))
862
+
863
+ props : dict [str , Any ] = {k : frame_props [idx ][v ] for k , v in property_map .items () if v in frame_props [idx ]}
864
+
865
+ # Ensure that `charges` and `magmoms` are not lost from AseAtomsAdaptor
866
+ for k in ("charges" , "magmoms" ):
867
+ if k in atoms .calc .implemented_properties or k in atoms .calc .results :
868
+ props [k ] = atoms .calc .get_property (k )
869
+
870
+ atoms .calc = SinglePointCalculator (atoms = atoms , ** props )
871
+
872
+ with AseTrajectory (ase_traj_file , "a" if idx > 0 else "w" , atoms = atoms ) as _traj_file :
873
+ _traj_file .write ()
874
+
875
+ ase_traj = AseTrajectory (ase_traj_file , "r" )
876
+ if temp_file is not None :
877
+ temp_file .close ()
878
+
879
+ return ase_traj
0 commit comments