Skip to content

Commit

Permalink
[fmt] Coordinates (#4856)
Browse files Browse the repository at this point in the history
  • Loading branch information
RMeli authored Dec 24, 2024
1 parent 29deccc commit c08cb79
Show file tree
Hide file tree
Showing 55 changed files with 5,819 additions and 3,670 deletions.
137 changes: 86 additions & 51 deletions package/MDAnalysis/coordinates/CRD.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ class CRDReader(base.SingleFrameReaderBase):
Now returns a ValueError instead of FormatError.
Frames now 0-based instead of 1-based.
"""
format = 'CRD'
units = {'time': None, 'length': 'Angstrom'}

format = "CRD"
units = {"time": None, "length": "Angstrom"}

def _read_first_frame(self):
# EXT:
Expand All @@ -62,37 +63,47 @@ def _read_first_frame(self):
extended = False
natoms = 0
for linenum, line in enumerate(crdfile):
if line.strip().startswith('*') or line.strip() == "":
if line.strip().startswith("*") or line.strip() == "":
continue # ignore TITLE and empty lines
fields = line.split()
if len(fields) <= 2:
# should be the natoms line
natoms = int(fields[0])
extended = (fields[-1] == 'EXT')
extended = fields[-1] == "EXT"
continue
# process coordinates
try:
if extended:
coords_list.append(np.array(line[45:100].split()[0:3], dtype=float))
coords_list.append(
np.array(line[45:100].split()[0:3], dtype=float)
)
else:
coords_list.append(np.array(line[20:50].split()[0:3], dtype=float))
coords_list.append(
np.array(line[20:50].split()[0:3], dtype=float)
)
except Exception:
errmsg = (f"Check CRD format at line {linenum}: "
f"{line.rstrip()}")
errmsg = (
f"Check CRD format at line {linenum}: "
f"{line.rstrip()}"
)
raise ValueError(errmsg) from None

self.n_atoms = len(coords_list)

self.ts = self._Timestep.from_coordinates(np.array(coords_list),
**self._ts_kwargs)
self.ts = self._Timestep.from_coordinates(
np.array(coords_list), **self._ts_kwargs
)
self.ts.frame = 0 # 0-based frame number
# if self.convert_units:
# self.convert_pos_from_native(self.ts._pos) # in-place !

# sanity check
if self.n_atoms != natoms:
raise ValueError("Found %d coordinates in %r but the header claims that there "
"should be %d coordinates." % (self.n_atoms, self.filename, natoms))
raise ValueError(
"Found %d coordinates in %r but the header claims that there "
"should be %d coordinates."
% (self.n_atoms, self.filename, natoms)
)

def Writer(self, filename, **kwargs):
"""Returns a CRDWriter for *filename*.
Expand Down Expand Up @@ -132,21 +143,26 @@ class CRDWriter(base.WriterBase):
Files are now written in `wt` mode, and keep extensions, allowing
for files to be written under compressed formats
"""
format = 'CRD'
units = {'time': None, 'length': 'Angstrom'}

format = "CRD"
units = {"time": None, "length": "Angstrom"}

fmt = {
#crdtype = 'extended'
#fortran_format = '(2I10,2X,A8,2X,A8,3F20.10,2X,A8,2X,A8,F20.10)'
"ATOM_EXT": ("{serial:10d}{totRes:10d} {resname:<8.8s} {name:<8.8s}"
"{pos[0]:20.10f}{pos[1]:20.10f}{pos[2]:20.10f} "
"{chainID:<8.8s} {resSeq:<8d}{tempfactor:20.10f}\n"),
# crdtype = 'extended'
# fortran_format = '(2I10,2X,A8,2X,A8,3F20.10,2X,A8,2X,A8,F20.10)'
"ATOM_EXT": (
"{serial:10d}{totRes:10d} {resname:<8.8s} {name:<8.8s}"
"{pos[0]:20.10f}{pos[1]:20.10f}{pos[2]:20.10f} "
"{chainID:<8.8s} {resSeq:<8d}{tempfactor:20.10f}\n"
),
"NUMATOMS_EXT": "{0:10d} EXT\n",
#crdtype = 'standard'
#fortran_format = '(2I5,1X,A4,1X,A4,3F10.5,1X,A4,1X,A4,F10.5)'
"ATOM": ("{serial:5d}{totRes:5d} {resname:<4.4s} {name:<4.4s}"
"{pos[0]:10.5f}{pos[1]:10.5f}{pos[2]:10.5f} "
"{chainID:<4.4s} {resSeq:<4d}{tempfactor:10.5f}\n"),
# crdtype = 'standard'
# fortran_format = '(2I5,1X,A4,1X,A4,3F10.5,1X,A4,1X,A4,F10.5)'
"ATOM": (
"{serial:5d}{totRes:5d} {resname:<4.4s} {name:<4.4s}"
"{pos[0]:10.5f}{pos[1]:10.5f}{pos[2]:10.5f} "
"{chainID:<4.4s} {resSeq:<4d}{tempfactor:10.5f}\n"
),
"TITLE": "* FRAME {frame} FROM {where}\n",
"NUMATOMS": "{0:5d}\n",
}
Expand All @@ -168,7 +184,7 @@ def __init__(self, filename, **kwargs):
.. versionadded:: 2.2.0
"""

self.filename = util.filename(filename, ext='crd', keep=True)
self.filename = util.filename(filename, ext="crd", keep=True)
self.crd = None

# account for explicit crd format, if requested
Expand Down Expand Up @@ -200,21 +216,22 @@ def write(self, selection, frame=None):
except AttributeError:
frame = 0 # should catch cases when we are analyzing a single PDB (?)


atoms = selection.atoms # make sure to use atoms (Issue 46)
coor = atoms.positions # can write from selection == Universe (Issue 49)
coor = (
atoms.positions
) # can write from selection == Universe (Issue 49)

n_atoms = len(atoms)
# Detect which format string we're using to output (EXT or not)
# *len refers to how to truncate various things,
# depending on output format!
if self.extended or n_atoms > 99999:
at_fmt = self.fmt['ATOM_EXT']
at_fmt = self.fmt["ATOM_EXT"]
serial_len = 10
resid_len = 8
totres_len = 10
else:
at_fmt = self.fmt['ATOM']
at_fmt = self.fmt["ATOM"]
serial_len = 5
resid_len = 4
totres_len = 5
Expand All @@ -223,11 +240,11 @@ def write(self, selection, frame=None):
attrs = {}
missing_topology = []
for attr, default in (
('resnames', itertools.cycle(('UNK',))),
# Resids *must* be an array because we index it later
('resids', np.ones(n_atoms, dtype=int)),
('names', itertools.cycle(('X',))),
('tempfactors', itertools.cycle((0.0,))),
("resnames", itertools.cycle(("UNK",))),
# Resids *must* be an array because we index it later
("resids", np.ones(n_atoms, dtype=int)),
("names", itertools.cycle(("X",))),
("tempfactors", itertools.cycle((0.0,))),
):
try:
attrs[attr] = getattr(atoms, attr)
Expand All @@ -236,48 +253,66 @@ def write(self, selection, frame=None):
missing_topology.append(attr)
# ChainIDs - Try ChainIDs first, fall back to Segids
try:
attrs['chainIDs'] = atoms.chainIDs
attrs["chainIDs"] = atoms.chainIDs
except (NoDataError, AttributeError):
# try looking for segids instead
try:
attrs['chainIDs'] = atoms.segids
attrs["chainIDs"] = atoms.segids
except (NoDataError, AttributeError):
attrs['chainIDs'] = itertools.cycle(('',))
attrs["chainIDs"] = itertools.cycle(("",))
missing_topology.append(attr)
if missing_topology:
warnings.warn(
"Supplied AtomGroup was missing the following attributes: "
"{miss}. These will be written with default values. "
"".format(miss=', '.join(missing_topology)))
"".format(miss=", ".join(missing_topology))
)

with util.openany(self.filename, 'wt') as crd:
with util.openany(self.filename, "wt") as crd:
# Write Title
crd.write(self.fmt['TITLE'].format(
frame=frame, where=u.trajectory.filename))
crd.write(
self.fmt["TITLE"].format(
frame=frame, where=u.trajectory.filename
)
)
crd.write("*\n")

# Write NUMATOMS
if self.extended or n_atoms > 99999:
crd.write(self.fmt['NUMATOMS_EXT'].format(n_atoms))
crd.write(self.fmt["NUMATOMS_EXT"].format(n_atoms))
else:
crd.write(self.fmt['NUMATOMS'].format(n_atoms))
crd.write(self.fmt["NUMATOMS"].format(n_atoms))

# Write all atoms

current_resid = 1
resids = attrs['resids']
resids = attrs["resids"]
for i, pos, resname, name, chainID, resid, tempfactor in zip(
range(n_atoms), coor, attrs['resnames'], attrs['names'],
attrs['chainIDs'], attrs['resids'], attrs['tempfactors']):
if not i == 0 and resids[i] != resids[i-1]:
range(n_atoms),
coor,
attrs["resnames"],
attrs["names"],
attrs["chainIDs"],
attrs["resids"],
attrs["tempfactors"],
):
if not i == 0 and resids[i] != resids[i - 1]:
current_resid += 1

# Truncate numbers
serial = util.ltruncate_int(i + 1, serial_len)
resid = util.ltruncate_int(resid, resid_len)
current_resid = util.ltruncate_int(current_resid, totres_len)

crd.write(at_fmt.format(
serial=serial, totRes=current_resid, resname=resname,
name=name, pos=pos, chainID=chainID,
resSeq=resid, tempfactor=tempfactor))
crd.write(
at_fmt.format(
serial=serial,
totRes=current_resid,
resname=resname,
name=name,
pos=pos,
chainID=chainID,
resSeq=resid,
tempfactor=tempfactor,
)
)
34 changes: 21 additions & 13 deletions package/MDAnalysis/coordinates/DMS.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
#
# MDAnalysis --- https://www.mdanalysis.org
# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors
Expand Down Expand Up @@ -46,29 +46,30 @@ class DMSReader(base.SingleFrameReaderBase):
.. versionchanged:: 0.11.0
Frames now 0-based instead of 1-based
"""
format = 'DMS'
units = {'time': None, 'length': 'A', 'velocity': 'A/ps'}

format = "DMS"
units = {"time": None, "length": "A", "velocity": "A/ps"}

def get_coordinates(self, cur):
cur.execute('SELECT * FROM particle')
cur.execute("SELECT * FROM particle")
particles = cur.fetchall()
return [(p['x'], p['y'], p['z']) for p in particles]
return [(p["x"], p["y"], p["z"]) for p in particles]

def get_particle_by_columns(self, cur, columns=None):
if columns is None:
columns = ['x', 'y', 'z']
cur.execute('SELECT * FROM particle')
columns = ["x", "y", "z"]
cur.execute("SELECT * FROM particle")
particles = cur.fetchall()
return [tuple([p[c] for c in columns]) for p in particles]

def get_global_cell(self, cur):
cur.execute('SELECT * FROM global_cell')
cur.execute("SELECT * FROM global_cell")
rows = cur.fetchall()
assert len(rows) == 3
x = [row["x"] for row in rows]
y = [row["y"] for row in rows]
z = [row["z"] for row in rows]
return {'x': x, 'y': y, 'z': z}
return {"x": x, "y": y, "z": z}

def _read_first_frame(self):
coords_list = None
Expand All @@ -85,7 +86,9 @@ def dict_factory(cursor, row):
con.row_factory = dict_factory
cur = con.cursor()
coords_list = self.get_coordinates(cur)
velocities_list = self.get_particle_by_columns(cur, columns=['vx', 'vy', 'vz'])
velocities_list = self.get_particle_by_columns(
cur, columns=["vx", "vy", "vz"]
)
unitcell = self.get_global_cell(cur)

if not coords_list:
Expand All @@ -99,15 +102,20 @@ def dict_factory(cursor, row):
self.ts = self._Timestep.from_coordinates(
np.array(coords_list, dtype=np.float32),
velocities=velocities,
**self._ts_kwargs)
**self._ts_kwargs,
)
self.ts.frame = 0 # 0-based frame number

self.ts.dimensions = triclinic_box(unitcell['x'], unitcell['y'], unitcell['z'])
self.ts.dimensions = triclinic_box(
unitcell["x"], unitcell["y"], unitcell["z"]
)

if self.convert_units:
self.convert_pos_from_native(self.ts._pos) # in-place !
if self.ts.dimensions is not None:
self.convert_pos_from_native(self.ts.dimensions[:3]) # in-place !
self.convert_pos_from_native(
self.ts.dimensions[:3]
) # in-place !
if self.ts.has_velocities:
# converts nm/ps to A/ps units
self.convert_velocities_from_native(self.ts._velocities)
Loading

0 comments on commit c08cb79

Please sign in to comment.