From db2052192d8725a0670254c80c2699e7116a8f21 Mon Sep 17 00:00:00 2001 From: "Andrew S. Rosen" Date: Sun, 14 Jan 2024 07:51:16 -0800 Subject: [PATCH] Allow for writing of `Structure.site_properties` as `_atom_site_` flags in `CifWriter` (#3550) * fix mypy, fix ruff, tweak test_cif_writer_site_properties --------- Co-authored-by: Janosh Riebesell --- .pre-commit-config.yaml | 2 +- pymatgen/io/abinit/abitimer.py | 4 +-- pymatgen/io/abinit/netcdf.py | 2 +- pymatgen/io/cif.py | 62 ++++++++++++++++++++-------------- tests/io/test_cif.py | 11 ++++++ 5 files changed, 52 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ddc36a1769f..2c46bcadc8f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.11 + rev: v0.1.13 hooks: - id: ruff args: [--fix, --unsafe-fixes] diff --git a/pymatgen/io/abinit/abitimer.py b/pymatgen/io/abinit/abitimer.py index ffcdb43c9d3..301174153f9 100644 --- a/pymatgen/io/abinit/abitimer.py +++ b/pymatgen/io/abinit/abitimer.py @@ -888,11 +888,11 @@ def scatter_hist(self, ax: plt.Axes = None, **kwargs): # axHistx.axis["bottom"].major_ticklabels.set_visible(False) axHistx.set_yticks([0, 50, 100]) for tl in axHistx.get_xticklabels(): - tl.set_visible(False) # noqa: FBT003 + tl.set_visible(False) # axHisty.axis["left"].major_ticklabels.set_visible(False) for tl in axHisty.get_yticklabels(): - tl.set_visible(False) # noqa: FBT003 + tl.set_visible(False) axHisty.set_xticks([0, 50, 100]) # plt.draw() diff --git a/pymatgen/io/abinit/netcdf.py b/pymatgen/io/abinit/netcdf.py index d1936fb5f99..499b6db4045 100644 --- a/pymatgen/io/abinit/netcdf.py +++ b/pymatgen/io/abinit/netcdf.py @@ -91,7 +91,7 @@ def __init__(self, path): # Slicing a ncvar returns a MaskedArrray and this is really annoying # because it can lead to unexpected behavior in e.g. calls to np.matmul! # See also https://github.com/Unidata/netcdf4-python/issues/785 - self.rootgrp.set_auto_mask(False) # noqa: FBT003 + self.rootgrp.set_auto_mask(False) def __enter__(self): """Activated when used in the with statement.""" diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 56d40e4299d..26a8c809a6f 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -7,7 +7,7 @@ import re import textwrap import warnings -from collections import deque +from collections import defaultdict, deque from datetime import datetime from functools import partial from inspect import getfullargspec as getargspec @@ -1313,13 +1313,14 @@ class CifWriter: def __init__( self, - struct, - symprec=None, - write_magmoms=False, - significant_figures=8, - angle_tolerance=5.0, - refine_struct=True, - ): + struct: Structure, + symprec: float | None = None, + write_magmoms: bool = False, + significant_figures: int = 8, + angle_tolerance: float = 5, + refine_struct: bool = True, + write_site_properties: bool = False, + ) -> None: """ Args: struct (Structure): structure to write @@ -1335,6 +1336,8 @@ def __init__( is not None. refine_struct: Used only if symprec is not None. If True, get_refined_structure is invoked to convert input structure from primitive to conventional. + write_site_properties (bool): Whether to write the Structure.site_properties + to the CIF as _atom_site_{property name}. Defaults to False. """ if write_magmoms and symprec: warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.") @@ -1342,7 +1345,7 @@ def __init__( format_str = f"{{:.{significant_figures}f}}" - block = {} + block: dict[str, Any] = {} loops = [] spacegroup = ("P 1", 1) if symprec is not None: @@ -1367,7 +1370,7 @@ def __init__( block["_chemical_formula_sum"] = no_oxi_comp.formula block["_cell_volume"] = format_str.format(lattice.volume) - _reduced_comp, fu = no_oxi_comp.get_reduced_composition_and_factor() + _, fu = no_oxi_comp.get_reduced_composition_and_factor() block["_cell_formula_units_Z"] = str(int(fu)) if symprec is None: @@ -1388,12 +1391,12 @@ def __init__( loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) try: - symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)} - block["_atom_type_symbol"] = list(symbol_to_oxinum) - block["_atom_type_oxidation_number"] = symbol_to_oxinum.values() + symbol_to_oxi_num = {str(el): float(el.oxi_state or 0) for el in sorted(comp.elements)} + block["_atom_type_symbol"] = list(symbol_to_oxi_num) + block["_atom_type_oxidation_number"] = symbol_to_oxi_num.values() loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) except (TypeError, AttributeError): - symbol_to_oxinum = {el.symbol: 0 for el in sorted(comp.elements)} + symbol_to_oxi_num = {el.symbol: 0 for el in sorted(comp.elements)} atom_site_type_symbol = [] atom_site_symmetry_multiplicity = [] @@ -1406,6 +1409,7 @@ def __init__( atom_site_moment_crystalaxis_x = [] atom_site_moment_crystalaxis_y = [] atom_site_moment_crystalaxis_z = [] + atom_site_properties: dict[str, list] = defaultdict(list) count = 0 if symprec is None: for site in struct: @@ -1437,6 +1441,10 @@ def __init__( atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) + if write_site_properties: + for key, val in site.properties.items(): + atom_site_properties[key].append(format_str.format(val)) + count += 1 else: # The following just presents a deterministic ordering. @@ -1475,17 +1483,21 @@ def __init__( block["_atom_site_fract_y"] = atom_site_fract_y block["_atom_site_fract_z"] = atom_site_fract_z block["_atom_site_occupancy"] = atom_site_occupancy - loops.append( - [ - "_atom_site_type_symbol", - "_atom_site_label", - "_atom_site_symmetry_multiplicity", - "_atom_site_fract_x", - "_atom_site_fract_y", - "_atom_site_fract_z", - "_atom_site_occupancy", - ] - ) + loop_labels = [ + "_atom_site_type_symbol", + "_atom_site_label", + "_atom_site_symmetry_multiplicity", + "_atom_site_fract_x", + "_atom_site_fract_y", + "_atom_site_fract_z", + "_atom_site_occupancy", + ] + if write_site_properties: + for key, vals in atom_site_properties.items(): + block[f"_atom_site_{key}"] = vals + loop_labels += [f"_atom_site_{key}"] + loops.append(loop_labels) + if write_magmoms: block["_atom_site_moment_label"] = atom_site_moment_label block["_atom_site_moment_crystalaxis_x"] = atom_site_moment_crystalaxis_x diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 399c4e595d5..5a4c037b2dc 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -870,6 +870,17 @@ def test_cif_writer_write_file(self): assert len(read_structs) == 2 assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] + def test_cif_writer_site_properties(self): + struct = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + struct.add_site_property(label := "hello", [1.0] * (len(struct) - 1) + [-1.0]) + out_path = f"{self.tmp_path}/test2.cif" + CifWriter(struct, write_site_properties=True).write_file(out_path) + with open(out_path) as file: + cif_str = file.read() + assert f"_atom_site_occupancy\n _atom_site_{label}\n" in cif_str + assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0" in cif_str + assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0" in cif_str + class TestMagCif(PymatgenTest): def setUp(self):