Skip to content

Commit

Permalink
A more robust to io set spacegroup and unitcell
Browse files Browse the repository at this point in the history
  • Loading branch information
minhuanli committed Apr 18, 2024
1 parent 91a91f7 commit 1e64352
Showing 1 changed file with 72 additions and 44 deletions.
116 changes: 72 additions & 44 deletions SFC_Torch/Fmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,74 @@ def __init__(
self.anomalous = anomalous
self.device = device
self.init_pdb(pdbmodel)
# Generate ASU HKL array and Corresponding d*^2 array
if mtzdata is not None:
self.init_mtz(mtzdata, n_bins, expcolumns, set_experiment, freeflag, testset_value)
else:
self.init_withoutmtz(dmin, n_bins)
self._init_spacegroup()
self._init_cell()
self.init_atomic_scattering()
self.inspected = False

def init_pdb(self, pdbmodel: str | PDBParser):
"""
set pdb topology, symmetry operations, unit_cell properties, and initialize model coordinates
"""
if type(pdbmodel) == str:
if isinstance(pdbmodel, str):
self._pdb = PDBParser(pdbmodel) # sfc.PDBparser object
elif type(pdbmodel) == PDBParser:
elif isinstance(pdbmodel, PDBParser):
self._pdb = pdbmodel
else:
raise TypeError("pdbmodel should be PDBparser instance or path str to a pdb file!")

# set spacegroup related properties
self.space_group = self._pdb.spacegroup # gemmi.SpaceGroup object
# set molecule related property
# Tensor atom's Positions in orthogonal space, [Nc,3]
self._atom_pos_orth = assert_tensor(self._pdb.atom_pos, device=self.device, arr_type=torch.float32)
# Tensor of anisotropic B Factor in matrix form, [Nc,3,3]
self._atom_aniso_uw = assert_tensor(self._pdb.atom_b_aniso, device=self.device, arr_type=torch.float32)
# Tensor of isotropic B Factor [B1,B2,...], [Nc]
self._atom_b_iso = assert_tensor(self._pdb.atom_b_iso, device=self.device, arr_type=torch.float32)
# Tensor of occupancy [P1,P2,....], [Nc]
self._atom_occ = assert_tensor(self._pdb.atom_occ, device=self.device, arr_type=torch.float32)

if self.anomalous:
# Try to get the wavelength from PDB remarks
try:
line_index = np.argwhere(
["WAVELENGTH OR RANGE" in i for i in self._pdb.pdb_header]
)
pdb_wavelength = eval(
self._pdb.pdb_header[line_index[0, 0]].split()[-1]
)
if self.wavelength is not None:
assert np.isclose(pdb_wavelength, self.wavelength, atol=0.05)
else:
self.wavelength = pdb_wavelength
except:
print(
"Can't find wavelength record in the PDB file, or it doesn't match your input wavelength!"
)

@property
def space_group(self):
return self._pdb.spacegroup

@space_group.setter
def space_group(self, spacegroup):
self._pdb.set_spacegroup(spacegroup)
self._init_spacegroup()

@property
def unit_cell(self):
return self._pdb.cell

@unit_cell.setter
def unit_cell(self, cell):
self._pdb.set_unitcell(cell)
self._init_cell()

def _init_spacegroup(self):
# Set up spacegroup related property
self.operations = self.space_group.operations() # gemmi.GroupOps object
self.R_G_tensor_stack = assert_tensor(
np.array([np.array(sym_op.rot) / sym_op.DEN for sym_op in self.operations]),
Expand All @@ -123,9 +170,9 @@ def init_pdb(self, pdbmodel: str | PDBParser):
arr_type=torch.float32,
device=self.device,
)

# set unit cell related properties
self.unit_cell = self._pdb.cell # gemmi.UnitCell object
def _init_cell(self):
# Set up unit cell related property
self.orth2frac_tensor = torch.tensor(
self.unit_cell.fractionalization_matrix.tolist(), device=self.device
).type(torch.float32)
Expand All @@ -145,34 +192,7 @@ def init_pdb(self, pdbmodel: str | PDBParser):
],
device=self.device,
).type(torch.float32)

# set molecule related property
# Tensor atom's Positions in orthogonal space, [Nc,3]
self._atom_pos_orth = assert_tensor(self._pdb.atom_pos, device=self.device, arr_type=torch.float32)
# Tensor of anisotropic B Factor in matrix form, [Nc,3,3]
self._atom_aniso_uw = assert_tensor(self._pdb.atom_b_aniso, device=self.device, arr_type=torch.float32)
# Tensor of isotropic B Factor [B1,B2,...], [Nc]
self._atom_b_iso = assert_tensor(self._pdb.atom_b_iso, device=self.device, arr_type=torch.float32)
# Tensor of occupancy [P1,P2,....], [Nc]
self._atom_occ = assert_tensor(self._pdb.atom_occ, device=self.device, arr_type=torch.float32)

if self.anomalous:
# Try to get the wavelength from PDB remarks
try:
line_index = np.argwhere(
["WAVELENGTH OR RANGE" in i for i in self._pdb.pdb_header]
)
pdb_wavelength = eval(
self._pdb.pdb_header[line_index[0, 0]].split()[-1]
)
if self.wavelength is not None:
assert np.isclose(pdb_wavelength, self.wavelength, atol=0.05)
else:
self.wavelength = pdb_wavelength
except:
print(
"Can't find wavelength record in the PDB file, or it doesn't match your input wavelength!"
)

@property
def atom_pos_orth(self):
return self._atom_pos_orth
Expand Down Expand Up @@ -258,6 +278,7 @@ def init_mtz(self, mtzdata, N_bins, expcolumns, set_experiment, freeflag, testse
mtz_reference.dropna(axis=0, subset=expcolumns, inplace=True)
except:
raise ValueError(f"{expcolumns} columns not included in the mtz file!")

if self.anomalous:
# Try to get the wavelength from MTZ file
try:
Expand All @@ -271,20 +292,26 @@ def init_mtz(self, mtzdata, N_bins, expcolumns, set_experiment, freeflag, testse
print(
"Can't find wavelength record in the MTZ file, or it doesn't match with other sources"
)

if (mtz_reference.cell == self._pdb.cell):
pass
else:
print("Unit cell from mtz file does not match that in PDB file! Using the cell info from MTZ file!")
self._pdb.set_unitcell(mtz_reference.cell)

if (mtz_reference.spacegroup.hm == self._pdb.spacegroup.hm):
pass
else:
print("Space group from mtz file does not match that in PDB file! Using the spacegroup from MTZ file!") # type: ignore
self._pdb.set_spacegroup(mtz_reference.spacegroup)

# HKL array from the reference mtz file, [N,3]
self.HKL_array = mtz_reference.get_hkls()
self.dHKL = self.unit_cell.calculate_d_array(self.HKL_array).astype(
"float32"
)
self.dmin = self.dHKL.min()
try:
assert (
mtz_reference.cell == self.unit_cell
)
except:
print("Unit cell from mtz file does not match that in PDB file! Using the cell info from MTZ file!")
self.unit_cell = mtz_reference.cell
assert mtz_reference.spacegroup.hm == self.space_group.hm, "Space group from mtz file does not match that in PDB file!" # type: ignore

self.Hasu_array = generate_reciprocal_asu(
self.unit_cell, self.space_group, self.dmin, anomalous=self.anomalous
)
Expand Down Expand Up @@ -357,6 +384,7 @@ def init_withoutmtz(self, dmin, n_bins):
)
else:
self.dmin = dmin
# Generate ASU HKL array and Corresponding d*^2 array
self.Hasu_array = generate_reciprocal_asu(
self.unit_cell, self.space_group, self.dmin
)
Expand Down

0 comments on commit 1e64352

Please sign in to comment.