Skip to content

Commit

Permalink
Handle Martini terms.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtarzia committed Nov 21, 2023
1 parent b73a532 commit 78039f4
Showing 1 changed file with 230 additions and 11 deletions.
241 changes: 230 additions & 11 deletions src/cgexplore/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@
CosineAngle,
TargetAngle,
TargetCosineAngle,
TargetMartiniAngle,
find_angles,
)
from .assigned_system import AssignedSystem
from .assigned_system import AssignedSystem, MartiniSystem
from .beads import CgBead, get_cgbead_from_element
from .bonds import Bond, TargetBond
from .bonds import Bond, TargetBond, TargetMartiniBond
from .errors import ForcefieldUnitError
from .nonbonded import Nonbonded, TargetNonbonded
from .torsions import TargetTorsion, Torsion, find_torsions
from .torsions import (
TargetMartiniTorsion,
TargetTorsion,
Torsion,
find_torsions,
)
from .utilities import angle_between, convert_pyramid_angle

logging.basicConfig(
Expand Down Expand Up @@ -64,9 +70,8 @@ def add_torsion_range(self, torsion_range: tuple) -> None:
def add_nonbonded_range(self, nonbonded_range: tuple) -> None:
self._nonbonded_ranges += (nonbonded_range,)

def yield_forcefields(self):
def _get_iterations(self) -> list:
iterations = []

for bond_range in self._bond_ranges:
iterations.append(tuple(bond_range.yield_bonds()))

Expand All @@ -78,6 +83,10 @@ def yield_forcefields(self):

for nonbonded_range in self._nonbonded_ranges:
iterations.append(tuple(nonbonded_range.yield_nonbondeds()))
return iterations

def yield_forcefields(self):
iterations = self._get_iterations()

for i, parameter_set in enumerate(itertools.product(*iterations)):
bond_terms = tuple(
Expand Down Expand Up @@ -148,6 +157,9 @@ def __init__(
self._hrprefix = "ffhr"

def _assign_bond_terms(self, molecule: stk.Molecule) -> tuple:
found = set()
assigned = set()

bonds = list(molecule.get_bonds())
bond_terms = []
for bond in bonds:
Expand All @@ -157,21 +169,31 @@ def _assign_bond_terms(self, molecule: stk.Molecule) -> tuple:
get_cgbead_from_element(i, self.get_bead_set())
for i in atom_estrings
]
cgbead_string = tuple(i.bead_type[0] for i in cgbeads)

cgbead_string = tuple(i.bead_type for i in cgbeads)
found.add(cgbead_string)
found.add(tuple(reversed(cgbead_string)))
for target_term in self._bond_targets:
if (target_term.class1, target_term.class2) not in (
cgbead_string,
tuple(reversed(cgbead_string)),
):
continue
assigned.add(cgbead_string)
assigned.add(tuple(reversed(cgbead_string)))
try:
assert isinstance(target_term.bond_r, openmm.unit.Quantity)
assert isinstance(target_term.bond_k, openmm.unit.Quantity)
except AssertionError:
msg = f"{target_term} in bonds does not have units"
raise ForcefieldUnitError(msg)

if "Martini" in target_term.__class__.__name__:
force = "MartiniDefinedBond"
funct = target_term.funct
else:
force = "HarmonicBondForce"
funct = 0

bond_terms.append(
Bond(
atoms=atoms,
Expand All @@ -182,15 +204,23 @@ def _assign_bond_terms(self, molecule: stk.Molecule) -> tuple:
atom_ids=tuple(i.get_id() for i in atoms),
bond_k=target_term.bond_k,
bond_r=target_term.bond_r,
force="HarmonicBondForce",
force=force,
funct=funct,
)
)

logging.info(
"unassigned bond terms: "
f"{sorted((i for i in found if i not in assigned))}"
)
return tuple(bond_terms)

def _assign_angle_terms(self, molecule: stk.Molecule) -> tuple:
angle_terms = []
pos_mat = molecule.get_position_matrix()

found = set()
assigned = set()
pyramid_angles: dict[str, list] = {}
octahedral_angles: dict[str, list] = {}
for found_angle in find_angles(molecule):
Expand All @@ -205,7 +235,9 @@ def _assign_angle_terms(self, molecule: stk.Molecule) -> tuple:
f"Angle not assigned ({found_angle}; {atom_estrings})."
)

cgbead_string = tuple(i.bead_type[0] for i in cgbeads)
cgbead_string = tuple(i.bead_type for i in cgbeads)
found.add(cgbead_string)
found.add(tuple(reversed(cgbead_string)))
for target_angle in self._angle_targets:
search_string = (
target_angle.class1,
Expand All @@ -218,6 +250,8 @@ def _assign_angle_terms(self, molecule: stk.Molecule) -> tuple:
):
continue

assigned.add(cgbead_string)
assigned.add(tuple(reversed(cgbead_string)))
if isinstance(target_angle, TargetAngle):
try:
assert isinstance(
Expand Down Expand Up @@ -291,6 +325,40 @@ def _assign_angle_terms(self, molecule: stk.Molecule) -> tuple:
)
angle_terms.append(actual_angle)

elif isinstance(target_angle, TargetMartiniAngle):
try:
assert isinstance(
target_angle.angle, openmm.unit.Quantity
)
assert isinstance(
target_angle.angle_k, openmm.unit.Quantity
)
except AssertionError:
msg = (
f"{target_angle} in angles does not have units for"
" parameters"
)
raise ForcefieldUnitError(msg)

central_bead = cgbeads[1]
central_atom = list(found_angle.atoms)[1]
central_name = (
f"{atom_estrings[1]}{central_atom.get_id()+1}"
)
actual_angle = Angle(
atoms=found_angle.atoms,
atom_names=tuple(
f"{i.__class__.__name__}" f"{i.get_id()+1}"
for i in found_angle.atoms
),
atom_ids=found_angle.atom_ids,
angle=target_angle.angle,
angle_k=target_angle.angle_k,
force="MartiniDefinedAngle",
funct=target_angle.funct,
)
angle_terms.append(actual_angle)

# For four coordinate systems, apply standard angle theta to
# neighbouring atoms, then compute pyramid angle for opposing
# interaction.
Expand Down Expand Up @@ -380,13 +448,20 @@ def _assign_angle_terms(self, molecule: stk.Molecule) -> tuple:
force="HarmonicAngleForce",
),
)

logging.info(
"unassigned angle terms: "
f"{sorted((i for i in found if i not in assigned))}"
)
return tuple(angle_terms)

def _assign_torsion_terms(
self,
molecule: stk.Molecule,
) -> tuple:
torsion_terms = []
found = set()
assigned = set()

# Iterate over the different path lengths, and find all torsions
# for that lengths.
Expand All @@ -400,14 +475,18 @@ def _assign_torsion_terms(
get_cgbead_from_element(i, self.get_bead_set())
for i in atom_estrings
]
cgbead_string = tuple(i.bead_type[0] for i in cgbeads)
cgbead_string = tuple(i.bead_type for i in cgbeads)
found.add(cgbead_string)
found.add(tuple(reversed(cgbead_string)))
for target_torsion in self._torsion_targets:
if target_torsion.search_string not in (
cgbead_string,
tuple(reversed(cgbead_string)),
):
continue

assigned.add(cgbead_string)
assigned.add(tuple(reversed(cgbead_string)))
try:
assert isinstance(
target_torsion.phi0, openmm.unit.Quantity
Expand All @@ -421,6 +500,16 @@ def _assign_torsion_terms(
)
raise ForcefieldUnitError(msg)

if "Martini" in target_torsion.__class__.__name__:
force = "MartiniDefinedTorsion"
funct = target_torsion.funct
else:
force = "PeriodicTorsionForce"
funct = 0
print(target_torsion)
print(force)
raise SystemExit

torsion_terms.append(
Torsion(
atom_names=tuple(
Expand All @@ -435,9 +524,15 @@ def _assign_torsion_terms(
phi0=target_torsion.phi0,
torsion_n=target_torsion.torsion_n,
torsion_k=target_torsion.torsion_k,
force="PeriodicTorsionForce",
force=force,
funct=funct,
)
)

logging.info(
"unassigned torsion terms: "
f"{sorted((i for i in found if i not in assigned))}"
)
return tuple(torsion_terms)

def _assign_nonbonded_terms(
Expand Down Expand Up @@ -573,3 +668,127 @@ def __str__(self) -> str:

def __repr__(self) -> str:
return str(self)


class MartiniForceFieldLibrary(ForceFieldLibrary):
def __init__(
self,
bead_library: tuple[CgBead],
vdw_bond_cutoff: int,
prefix: str,
) -> None:
self._bead_library = bead_library
self._vdw_bond_cutoff = vdw_bond_cutoff
self._prefix = prefix
self._bond_ranges: tuple = ()
self._angle_ranges: tuple = ()
self._torsion_ranges: tuple = ()
self._constraints: tuple = ()

def _get_iterations(self) -> list:
iterations = []
for bond_range in self._bond_ranges:
iterations.append(tuple(bond_range.yield_bonds()))

for angle_range in self._angle_ranges:
iterations.append(tuple(angle_range.yield_angles()))

for torsion_range in self._torsion_ranges:
iterations.append(tuple(torsion_range.yield_torsions()))

return iterations

def yield_forcefields(self):
iterations = self._get_iterations()

for i, parameter_set in enumerate(itertools.product(*iterations)):
bond_terms = tuple(
i for i in parameter_set if "Bond" in i.__class__.__name__
)
angle_terms = tuple(
i
for i in parameter_set
if "Angle" in i.__class__.__name__
# and "Pyramid" not in i.__class__.__name__
)
torsion_terms = tuple(
i
for i in parameter_set
if "Torsion" in i.__class__.__name__
# if len(i.search_string) == 4
)
yield MartiniForceField(
identifier=str(i),
prefix=self._prefix,
present_beads=self._bead_library,
bond_targets=bond_terms,
angle_targets=angle_terms,
torsion_targets=torsion_terms,
constraints=self._constraints,
vdw_bond_cutoff=self._vdw_bond_cutoff,
)

def __str__(self) -> str:
return (
f"{self.__class__.__name__}(\n"
f" bead_library={self._bead_library},\n"
f" bond_ranges={self._bond_ranges},\n"
f" angle_ranges={self._angle_ranges},\n"
f" torsion_ranges={self._torsion_ranges},\n"
"\n)"
)

def __repr__(self) -> str:
return str(self)


class MartiniForceField(Forcefield):
def __init__(
self,
identifier: str,
prefix: str,
present_beads: tuple[CgBead, ...],
bond_targets: tuple[TargetBond | TargetMartiniBond, ...],
angle_targets: tuple[TargetAngle | TargetMartiniAngle, ...],
torsion_targets: tuple[TargetTorsion | TargetMartiniTorsion, ...],
constraints: tuple[tuple],
vdw_bond_cutoff: int,
) -> None:
self._identifier = identifier
self._prefix = prefix
self._present_beads = present_beads
self._bond_targets = bond_targets
self._angle_targets = angle_targets
self._torsion_targets = torsion_targets
self._vdw_bond_cutoff = vdw_bond_cutoff
self._constraints = constraints
self._hrprefix = "mffhr"

def assign_terms(
self,
molecule: stk.Molecule,
name: str,
output_dir: pathlib.Path,
) -> MartiniSystem:
assigned_terms = {
"bond": self._assign_bond_terms(molecule),
"angle": self._assign_angle_terms(molecule),
"torsion": self._assign_torsion_terms(molecule),
"nonbonded": (),
"constraints": self._constraints,
}

return MartiniSystem(
molecule=molecule,
force_field_terms=assigned_terms,
system_xml=(
output_dir
/ f"{name}_{self._prefix}_{self._identifier}_syst.xml"
),
topology_itp=(
output_dir
/ f"{name}_{self._prefix}_{self._identifier}_topo.itp"
),
bead_set=self.get_bead_set(),
vdw_bond_cutoff=self._vdw_bond_cutoff,
)

0 comments on commit 78039f4

Please sign in to comment.