From 8718add7600a2aae17387485e5cba136e81331d6 Mon Sep 17 00:00:00 2001 From: minhuanli Date: Sun, 12 Nov 2023 18:06:08 -0500 Subject: [PATCH] better typing hints, support multiple input constructor, better moduleized --- SFC_Torch/Fmodel.py | 399 +++++++++++++++++++++++-------------------- SFC_Torch/io.py | 11 +- SFC_Torch/utils.py | 13 ++ tests/test_Fmodel.py | 29 +++- tests/test_io.py | 6 +- tests/test_voxel.py | 4 +- 6 files changed, 260 insertions(+), 202 deletions(-) diff --git a/SFC_Torch/Fmodel.py b/SFC_Torch/Fmodel.py index 1f725f1..7485269 100644 --- a/SFC_Torch/Fmodel.py +++ b/SFC_Torch/Fmodel.py @@ -12,6 +12,8 @@ __author__ = "Minhuan Li" __email__ = "minhuanli@g.harvard.edu" +from typing import Optional, List + import gemmi import time import numpy as np @@ -22,8 +24,9 @@ from .mask import reciprocal_grid, rsgrid2realmask, realmask2Fmask from .utils import try_gpu, DWF_aniso, DWF_iso, diff_array, asu2HKL, aniso_scaling from .utils import vdw_rad_tensor, unitcell_grid_center, bin_by_logarithmic +from .utils import r_factor, assert_numpy, assert_tensor from .packingscore import packingscore_voxelgrid_torch -from .utils import r_factor, assert_numpy +from .io import PDBParser class SFcalculator(object): @@ -33,43 +36,47 @@ class SFcalculator(object): def __init__( self, - PDBfile_dir, - mtzfile_dir=None, - dmin=None, - anomalous=False, - wavelength=None, - set_experiment=True, - expcolumns=["FP", "SIGFP"], - freeflag="FreeR_flag", - testset_value=0, - device=try_gpu() - ): + pdbmodel: str | PDBParser, + mtzdata: str | rs.DataSet = None, + n_bins: int = 10, + dmin: Optional[float] = None, + anomalous: bool=False, + wavelength: Optional[float]=None, + set_experiment: bool=True, + expcolumns: List[str]=["FP", "SIGFP"], + freeflag: str="FreeR_flag", + testset_value: int=0, + device: torch.device=try_gpu() + ) -> None: """ Initialize with necessary reusable information, like spacegroup, unit cell info, HKL_list, et.c. Parameters: ----------- - model_dir: path, str - path to the PDB model file, will use its unit cell info, space group info, atom name info, + pdbmodel: str | PDBParser + path to the PDB file or an instance of PDBparser. Will use its unit cell info, space group info, atom name info, atom position info, atoms B-factor info and atoms occupancy info to initialize the instance. - mtz_file_dir: path, str, default None - path to the mtz_file_dir, will use the HKL list in the mtz instead, override dmin with an inference - - dmin: float, default None + mtzdata: str | rs.Dataset + path to the mtz data or instance of rs.Dataset, will use the HKL list in the mtz instead, override dmin with an inference + + n_bins: str, default 10 + Number of resolution bins used in the reciprocal space + + dmin: float | None highest resolution in the map in Angstrom, to generate Miller indices in recirpocal ASU - anomalous: boolean, default False + anomalous: bool, default False Whether or not to include anomalous scattering in the calculation - wavelength: None or float + wavelength: float | None The wavelength of scattering source in A - set_experiment: boolean, default True + set_experiment: bool, default True Whether or not to set Fo, SigF, free_flag and Outlier from the experimental mtz file. It only works when the mtzfile_dir is not None - expcolumns: list of str, default ['FP', 'SIGFP'] + expcolumns: List[str], default ['FP', 'SIGFP'] list of column names used as expeimental data freeflag: str, default "FreeR_flag" @@ -80,113 +87,54 @@ def __init__( device: torch.device """ - structure = gemmi.read_pdb(PDBfile_dir) # gemmi.Structure object - self.unit_cell = structure.cell # gemmi.UnitCell object - self.space_group = gemmi.SpaceGroup( - structure.spacegroup_hm - ) # gemmi.SpaceGroup object - self.operations = self.space_group.operations() # gemmi.GroupOps object + self.wavelength = wavelength self.anomalous = anomalous self.device = device + self.set_pdb(pdbmodel) + # Generate ASU HKL array and Corresponding d*^2 array + if mtzdata is not None: + self.set_mtz(mtzdata, n_bins, expcolumns, set_experiment, freeflag, testset_value) + else: + self.set_withoutmtz(dmin, n_bins) + self._set_atomic_scattering() + self.inspected = False - if anomalous: - # Try to get the wavelength from PDB remarks - try: - line_index = np.argwhere( - ["WAVELENGTH OR RANGE" in i for i in structure.raw_remarks] - ) - pdb_wavelength = eval( - structure.raw_remarks[line_index[0, 0]].split()[-1] - ) - if wavelength is not None: - assert np.isclose(pdb_wavelength, wavelength, atol=0.05) - else: - self.wavelength = pdb_wavelength - except: - print( - "Can't find wavelength record in the PDB file, or it doesn't match your input wavelength!" - ) - - self.R_G_tensor_stack = torch.tensor( + def set_pdb(self, pdbmodel: str | PDBParser): + """ + set pdb topology, symmetry operations, unit_cell properties, and initialize model coordinates + """ + if type(pdbmodel) == str: + self._pdb = PDBParser(pdbmodel) # sfc.PDBparser object + elif type(pdbmodel) == PDBParser: + self._pdb = pdbmodel + else: + raise TypeError("pdbmodel should be PDBparser instance or path str to a pdb file!") + + # set spacegroup related properties + self.space_group = self._pdb.spacegroup # gemmi.SpaceGroup object + self.operations = self.space_group.operations() # gemmi.GroupOps object + self.R_G_tensor_stack = assert_tensor( np.array([np.array(sym_op.rot) / sym_op.DEN for sym_op in self.operations]), + arr_type=torch.float32, device=self.device, - ).type(torch.float32) - self.T_G_tensor_stack = torch.tensor( + ) + self.T_G_tensor_stack = assert_tensor( np.array( [np.array(sym_op.tran) / sym_op.DEN for sym_op in self.operations] ), + arr_type=torch.float32, device=self.device, - ).type(torch.float32) - - # Generate ASU HKL array and Corresponding d*^2 array - if mtzfile_dir: - mtz_reference = rs.read_mtz(mtzfile_dir) - try: - mtz_reference.dropna(axis=0, subset=expcolumns, inplace=True) - except: - raise ValueError(f"{expcolumns} columns not included in the mtz file!") - if anomalous: - # Try to get the wavelength from MTZ file - try: - mtz_wavelength = mtz_reference.dataset(0).wavelength - assert mtz_wavelength > 0.05 - if self.wavelength is not None: - assert np.isclose(mtz_wavelength, self.wavelength, atol=0.05) - else: - self.wavelength = mtz_wavelength - except: - print( - "Can't find wavelength record in the MTZ file, or it doesn't match with other sources" - ) - # HKL array from the reference mtz file, [N,3] - self.HKL_array = mtz_reference.get_hkls() - self.dHKL = self.unit_cell.calculate_d_array(self.HKL_array).astype( - "float32" - ) - self.dmin = self.dHKL.min() - assert ( - mtz_reference.cell == self.unit_cell - ), "Unit cell from mtz file does not match that in PDB file!" - assert mtz_reference.spacegroup.hm == self.space_group.hm, "Space group from mtz file does not match that in PDB file!" # type: ignore - self.Hasu_array = generate_reciprocal_asu( - self.unit_cell, self.space_group, self.dmin, anomalous=anomalous - ) - assert ( - diff_array(self.HKL_array, self.Hasu_array) == set() - ), "HKL_array should be equal or subset of the Hasu_array!" - self.asu2HKL_index = asu2HKL(self.Hasu_array, self.HKL_array) - # d*^2 array according to the HKL list, [N] - self.dr2asu_array = self.unit_cell.calculate_1_d2_array(self.Hasu_array) - self.dr2HKL_array = self.unit_cell.calculate_1_d2_array(self.HKL_array) - # assign reslution bins - self.assign_resolution_bins() - if set_experiment: - self.set_experiment(mtz_reference, expcolumns, freeflag, testset_value) - else: - if not dmin: - raise ValueError( - "high_resolution dmin OR a reference mtz file should be provided!" - ) - else: - self.dmin = dmin - self.Hasu_array = generate_reciprocal_asu( - self.unit_cell, self.space_group, self.dmin - ) - self.dHasu = self.unit_cell.calculate_d_array(self.Hasu_array).astype( - "float32" - ) - self.dr2asu_array = self.unit_cell.calculate_1_d2_array(self.Hasu_array) - self.HKL_array = None - self.assign_resolution_bins() + ) + # set unit cell related properties + self.unit_cell = self._pdb.cell # gemmi.UnitCell object self.orth2frac_tensor = torch.tensor( self.unit_cell.fractionalization_matrix.tolist(), device=self.device ).type(torch.float32) self.frac2orth_tensor = torch.tensor( self.unit_cell.orthogonalization_matrix.tolist(), device=self.device ).type(torch.float32) - self.reciprocal_cell = self.unit_cell.reciprocal() # gemmi.UnitCell object # [ar, br, cr, cos(alpha_r), cos(beta_r), cos(gamma_r)] self.reciprocal_cell_paras = torch.tensor( @@ -201,80 +149,111 @@ def __init__( device=self.device, ).type(torch.float32) - self.atom_name = [] - self.atom_pos_orth = [] - self.atom_pos_frac = [] - self.atom_aniso_uw = [] - self.atom_b_iso = [] - self.atom_occ = [] - model = structure[0] # gemmi.Model object - for cra in model.all(): - # A list of atom name like ['O','C','N','C', ...], [Nc] - self.atom_name.append(cra.atom.element.name) - # A list of atom's Positions in orthogonal space, [Nc,3] - self.atom_pos_orth.append(cra.atom.pos.tolist()) - # A list of atom's Positions in fractional space, [Nc,3] - self.atom_pos_frac.append( - self.unit_cell.fractionalize(cra.atom.pos).tolist() - ) - # A list of anisotropic B Factor in matrix form [[U11,U22,U33,U12,U13,U23],..], [Nc,3,3] - self.atom_aniso_uw.append(cra.atom.aniso.as_mat33().tolist()) - # A list of isotropic B Factor [B1,B2,...], [Nc] - self.atom_b_iso.append(cra.atom.b_iso) - # A list of occupancy [P1,P2,....], [Nc] - self.atom_occ.append(cra.atom.occ) - - self.atom_pos_orth = torch.tensor(self.atom_pos_orth, device=self.device).type( - torch.float32 - ) - self.atom_pos_frac = torch.tensor(self.atom_pos_frac, device=self.device).type( - torch.float32 - ) - self.atom_aniso_uw = torch.tensor(self.atom_aniso_uw, device=self.device).type( - torch.float32 - ) - - self.atom_b_iso = torch.tensor(self.atom_b_iso, device=self.device).type( - torch.float32 - ) - self.atom_occ = torch.tensor(self.atom_occ, device=self.device).type( - torch.float32 - ) + # set molecule related property + # Tensor atom's Positions in orthogonal space, [Nc,3] + self.atom_pos_orth = assert_tensor(self._pdb.atom_pos, device=self.device, arr_type=torch.float32) + # Tensor of anisotropic B Factor in matrix form, [Nc,3,3] + self.atom_aniso_uw = assert_tensor(self._pdb.atom_b_aniso, device=self.device, arr_type=torch.float32) + # Tensor of isotropic B Factor [B1,B2,...], [Nc] + self.atom_b_iso = assert_tensor(self._pdb.atom_b_iso, device=self.device, arr_type=torch.float32) + # Tensor of occupancy [P1,P2,....], [Nc] + self.atom_occ = assert_tensor(self._pdb.atom_occ, device=self.device, arr_type=torch.float32) + self.n_atoms = len(self.atom_name) self.unique_atom = list(set(self.atom_name)) - # A dictionary of atomic structural factor f0_sj of different atom types at different HKL Rupp's Book P280 - # f0_sj = [sum_{i=1}^4 {a_ij*exp(-b_ij* d*^2/4)} ] + c_j - if anomalous: - assert self.wavelength is not None, ValueError( - "If you need anomalous scattering contribution, provide the wavelength info from input, pbd or mtz file!" - ) - - self.full_atomic_sf_asu = {} - for atom_type in self.unique_atom: - element = gemmi.Element(atom_type) - f0 = np.array( - [element.it92.calculate_sf(dr2 / 4.0) for dr2 in self.dr2asu_array] - ) - if anomalous: - fp, fpp = gemmi.cromer_liberman( - z=element.atomic_number, energy=gemmi.hc / self.wavelength + if self.anomalous: + # Try to get the wavelength from PDB remarks + try: + line_index = np.argwhere( + ["WAVELENGTH OR RANGE" in i for i in self._pdb.pdb_header] ) - self.full_atomic_sf_asu[atom_type] = f0 + fp + 1j * fpp - else: - self.full_atomic_sf_asu[atom_type] = f0 + pdb_wavelength = eval( + self._pdb.pdb_header[line_index[0, 0]].split()[-1] + ) + if self.wavelength is not None: + assert np.isclose(pdb_wavelength, self.wavelength, atol=0.05) + else: + self.wavelength = pdb_wavelength + except: + print( + "Can't find wavelength record in the PDB file, or it doesn't match your input wavelength!" + ) + + @property + def atom_pos_frac(self): + """ + Tensor of atom's Positions in fractional space, [Nc,3] + """ + return torch.tensordot(self.atom_pos_orth, self.orth2frac_tensor.T, 1) - if anomalous: - self.fullsf_tensor = torch.tensor( - np.array([self.full_atomic_sf_asu[atom] for atom in self.atom_name]), - device=self.device, - ).type(torch.complex64) + @property + def cra_name(self): + """ + A list of Chain-Residue-Atom name, ['A-0-ALA-CA', ...] + """ + return self._pdb.cra_name + + @property + def atom_name(self): + """ + A list of element name, ['C', 'N', 'H', ...] + """ + return self._pdb.atom_name + + def set_mtz(self, mtzdata, N_bins, expcolumns, set_experiment, freeflag, testset_value): + """ + set mtz file for HKL list, resolution and experimental related properties + """ + + if type(mtzdata) == str: + mtz_reference = rs.read_mtz(mtzdata) + elif type(mtzdata) == rs.DataSet: + mtz_reference = mtzdata else: - self.fullsf_tensor = torch.tensor( - np.array([self.full_atomic_sf_asu[atom] for atom in self.atom_name]), - device=self.device, - ).type(torch.float32) - self.inspected = False + raise TypeError("mtzdata should be rs.Dataset object or path str to a mtz file!") + + try: + mtz_reference.dropna(axis=0, subset=expcolumns, inplace=True) + except: + raise ValueError(f"{expcolumns} columns not included in the mtz file!") + if self.anomalous: + # Try to get the wavelength from MTZ file + try: + mtz_wavelength = mtz_reference.dataset(0).wavelength + assert mtz_wavelength > 0.05 + if self.wavelength is not None: + assert np.isclose(mtz_wavelength, self.wavelength, atol=0.05) + else: + self.wavelength = mtz_wavelength + except: + print( + "Can't find wavelength record in the MTZ file, or it doesn't match with other sources" + ) + # HKL array from the reference mtz file, [N,3] + self.HKL_array = mtz_reference.get_hkls() + self.dHKL = self.unit_cell.calculate_d_array(self.HKL_array).astype( + "float32" + ) + self.dmin = self.dHKL.min() + assert ( + mtz_reference.cell == self.unit_cell + ), "Unit cell from mtz file does not match that in PDB file!" + assert mtz_reference.spacegroup.hm == self.space_group.hm, "Space group from mtz file does not match that in PDB file!" # type: ignore + self.Hasu_array = generate_reciprocal_asu( + self.unit_cell, self.space_group, self.dmin, anomalous=self.anomalous + ) + assert ( + diff_array(self.HKL_array, self.Hasu_array) == set() + ), "HKL_array should be equal or subset of the Hasu_array!" + self.asu2HKL_index = asu2HKL(self.Hasu_array, self.HKL_array) + # d*^2 array according to the HKL list, [N] + self.dr2asu_array = self.unit_cell.calculate_1_d2_array(self.Hasu_array) + self.dr2HKL_array = self.unit_cell.calculate_1_d2_array(self.HKL_array) + # assign reslution bins + self.assign_resolution_bins(bins=N_bins) + if set_experiment: + self.set_experiment(mtz_reference, expcolumns, freeflag, testset_value) def set_experiment(self, exp_mtz, expcolumns=["FP", "SIGFP"], freeflag="FreeR_flag", testset_value=0): """ @@ -319,6 +298,56 @@ def set_experiment(self, exp_mtz, expcolumns=["FP", "SIGFP"], freeflag="FreeR_fl except: self.Outlier = np.zeros(len(self.Fo)).astype(bool) print("No outlier detection, will use all reflections!") + + def set_withoutmtz(self, dmin, n_bins): + if not dmin: + raise ValueError( + "high_resolution dmin OR a reference mtz file should be provided!" + ) + else: + self.dmin = dmin + self.Hasu_array = generate_reciprocal_asu( + self.unit_cell, self.space_group, self.dmin + ) + self.dHasu = self.unit_cell.calculate_d_array(self.Hasu_array).astype( + "float32" + ) + self.dr2asu_array = self.unit_cell.calculate_1_d2_array(self.Hasu_array) + self.HKL_array = None + self.assign_resolution_bins(n_bins) + + def _set_atomic_scattering(self): + # A dictionary of atomic structural factor f0_sj of different atom types at different HKL Rupp's Book P280 + # f0_sj = [sum_{i=1}^4 {a_ij*exp(-b_ij* d*^2/4)} ] + c_j + if self.anomalous: + assert self.wavelength is not None, ValueError( + "If you need anomalous scattering contribution, provide the wavelength info from input, pdb or mtz file!" + ) + + self.full_atomic_sf_asu = {} + for atom_type in self.unique_atom: + element = gemmi.Element(atom_type) + f0 = np.array( + [element.it92.calculate_sf(dr2 / 4.0) for dr2 in self.dr2asu_array] + ) + if self.anomalous: + fp, fpp = gemmi.cromer_liberman( + z=element.atomic_number, energy=gemmi.hc / self.wavelength + ) + self.full_atomic_sf_asu[atom_type] = f0 + fp + 1j * fpp + else: + self.full_atomic_sf_asu[atom_type] = f0 + + if self.anomalous: + self.fullsf_tensor = torch.tensor( + np.array([self.full_atomic_sf_asu[atom] for atom in self.atom_name]), + device=self.device, + ).type(torch.complex64) + else: + self.fullsf_tensor = torch.tensor( + np.array([self.full_atomic_sf_asu[atom] for atom in self.atom_name]), + device=self.device, + ).type(torch.float32) def assign_resolution_bins( @@ -412,7 +441,7 @@ def calc_fprotein( Parameters ---------- atoms_positions_tensor: 2D [N_atoms, 3] tensor or default None - Positions of atoms in the model, in unit of angstrom; If not given, the model stored in attribute `atom_pos_frac` will be used + Positions of atoms in the model, in unit of angstrom; If not given, the model stored in attribute `atom_pos_orth` will be used atoms_biso_tensor: 1D [N_atoms,] tensor or default None Isotropic B factors of each atoms in the model; If not given, the info stored in attribute `atoms_b_iso` will be used @@ -438,9 +467,7 @@ def calc_fprotein( assert ( len(atoms_position_tensor) == self.n_atoms ), "Atoms in atoms_positions_tensor should be consistent with atom names in PDB model!" - self.atom_pos_frac = torch.tensordot( - atoms_position_tensor, self.orth2frac_tensor.T, 1 - ) + self.atom_pos_orth = atoms_position_tensor if not atoms_aniso_uw_tensor is None: assert len(atoms_aniso_uw_tensor) == len( @@ -826,9 +853,9 @@ def get_scales_lbfgs( def get_scales_adam( self, - lr=0.1, + lr=0.01, n_steps=100, - sub_ratio=0.3, + sub_ratio=0.7, initialize=True, verbose=False ): diff --git a/SFC_Torch/io.py b/SFC_Torch/io.py index 63ca64b..8f0ac45 100644 --- a/SFC_Torch/io.py +++ b/SFC_Torch/io.py @@ -41,8 +41,8 @@ def hier2array(structure, as_tensor=False): ) # A list of atom's Positions in orthogonal space, [Nc,3] atom_pos.append(atom.pos.tolist()) - # A list of anisotropic B Factor [[U11,U22,U33,U12,U13,U23],..], [Nc,6] - atom_b_aniso.append(atom.aniso.elements_pdb()) + # A list of anisotropic B Factor matrix, [Nc,3,3] + atom_b_aniso.append(atom.aniso.as_mat33().tolist()) # A list of isotropic B Factor [B1,B2,...], [Nc] atom_b_iso.append(atom.b_iso) # A list of occupancy [P1,P2,....], [Nc] @@ -87,7 +87,8 @@ def array2hier( current_atom = gemmi.Atom() current_atom.name = atomname_i current_atom.element = gemmi.Element(atom_name[i]) - current_atom.aniso = gemmi.SMat33f(*atom_b_aniso[i]) + Ui = atom_b_aniso[i] + current_atom.aniso = gemmi.SMat33f(*[Ui[0,0], Ui[1,1], Ui[2,2], Ui[0,1], Ui[0,2], Ui[1,2]]) current_atom.b_iso = atom_b_iso[i] current_atom.pos = gemmi.Position(*atom_pos[i]) current_atom.occ = atom_occ[i] @@ -205,10 +206,10 @@ def set_baniso(self, baniso): """ Set the Anisotropic B-factors with an array - baniso: array-like, [Nc,6] + baniso: array-like, [Nc,3,3] """ assert len(baniso) == len(self.atom_b_aniso), "Different atom number!" - assert len(baniso[0]) == 6, "Provide 6 baniso parameters per atom!" + assert baniso[0].shape == (3,3), "Provide a 3*3 matrix per atom!" self.atom_b_aniso = baniso def set_occ(self, occ): diff --git a/SFC_Torch/utils.py b/SFC_Torch/utils.py index 967c496..f8baf86 100644 --- a/SFC_Torch/utils.py +++ b/SFC_Torch/utils.py @@ -11,6 +11,7 @@ "vdw_rad_tensor", "nonH_index", "assert_numpy", + "assert_tensor", "bin_by_logarithmic", "aniso_scaling", ] @@ -43,6 +44,18 @@ def assert_numpy(x, arr_type=None): return x +def assert_tensor(x, arr_type=None, device=try_gpu()): + if isinstance(x, np.ndarray): + x = torch.tensor(x, device=device) + if is_list_or_tuple(x): + x = np.array(x) + x = torch.tensor(x, device=device) + assert isinstance(x, torch.Tensor) + if arr_type is not None: + x = x.to(arr_type) + return x + + def r_factor(Fo, Fmodel, free_flag): """ A function to calculate R_work and R_free diff --git a/tests/test_Fmodel.py b/tests/test_Fmodel.py index ef8f442..d00b681 100644 --- a/tests/test_Fmodel.py +++ b/tests/test_Fmodel.py @@ -5,6 +5,7 @@ import torch from scipy.stats import pearsonr +from SFC_Torch.io import PDBParser from SFC_Torch.Fmodel import SFcalculator from SFC_Torch.utils import assert_numpy @@ -13,7 +14,7 @@ def test_constructor_SFcalculator(data_pdb, data_mtz_exp, case): if case == 1: sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=data_mtz_exp, set_experiment=True) + data_pdb, mtzdata=data_mtz_exp, set_experiment=True) sfcalculator.inspect_data() bins_labels = sfcalculator.assign_resolution_bins(return_labels=True) assert sfcalculator.inspected @@ -26,7 +27,7 @@ def test_constructor_SFcalculator(data_pdb, data_mtz_exp, case): assert len(bins_labels) == 10 else: sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=None, dmin=2.5, set_experiment=True) + data_pdb, mtzdata=None, dmin=2.5, set_experiment=True) sfcalculator.inspect_data() assert sfcalculator.inspected assert np.isclose(assert_numpy(sfcalculator.solventpct), 0.1667, 1e-3) @@ -36,11 +37,27 @@ def test_constructor_SFcalculator(data_pdb, data_mtz_exp, case): assert len(sfcalculator.atom_name) == 488 +def test_constructor_SFcalculator_obj(data_pdb, data_mtz_exp): + pdbmodel = PDBParser(data_pdb) + mtzdata = rs.read_mtz(data_mtz_exp) + sfcalculator = SFcalculator( + pdbmodel, mtzdata=mtzdata, set_experiment=True) + sfcalculator.inspect_data() + bins_labels = sfcalculator.assign_resolution_bins(return_labels=True) + assert sfcalculator.inspected + assert np.isclose(assert_numpy(sfcalculator.solventpct), 0.1667, 1e-3) + assert sfcalculator.gridsize == [48, 60, 60] + assert len(sfcalculator.HKL_array) == 3197 + assert len(sfcalculator.Hasu_array) == 3255 + assert len(sfcalculator.bins) == 3197 + assert np.all(np.sort(np.unique(sfcalculator.bins)) == np.arange(0,10)) + assert len(bins_labels) == 10 + @pytest.mark.parametrize("Return", [True, False]) @pytest.mark.parametrize("Anomalous", [True, False]) def test_calc_fall(data_pdb, data_mtz_exp, data_mtz_fmodel_ksol0, data_mtz_fmodel_ksol1, Return, Anomalous): sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=data_mtz_exp, set_experiment=True, anomalous=Anomalous) + data_pdb, mtzdata=data_mtz_exp, set_experiment=True, anomalous=Anomalous) sfcalculator.inspect_data() Fprotein = sfcalculator.calc_fprotein(Return=Return) Fsolvent = sfcalculator.calc_fsolvent( @@ -99,7 +116,7 @@ def test_calc_fall(data_pdb, data_mtz_exp, data_mtz_fmodel_ksol0, data_mtz_fmode def test_calc_ftotal_nodata(data_pdb): sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=None, dmin=2.5, set_experiment=False) + data_pdb, mtzdata=None, dmin=2.5, set_experiment=False) sfcalculator.inspect_data() sfcalculator.calc_fprotein(Return=False) sfcalculator.calc_fsolvent( @@ -113,7 +130,7 @@ def test_calc_ftotal_nodata(data_pdb): @pytest.mark.parametrize("Anomalous", [True, False]) def test_calc_fall_batch(data_pdb, data_mtz_exp, Anomalous, partition_size): sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=data_mtz_exp, set_experiment=True, anomalous=Anomalous) + data_pdb, mtzdata=data_mtz_exp, set_experiment=True, anomalous=Anomalous) sfcalculator.inspect_data() atoms_pos_batch = torch.tile(sfcalculator.atom_pos_orth, [5, 1, 1]) @@ -149,7 +166,7 @@ def test_calc_fall_batch(data_pdb, data_mtz_exp, Anomalous, partition_size): def test_prepare_dataset(data_pdb, data_mtz_exp): sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=data_mtz_exp, set_experiment=True) + data_pdb, mtzdata=data_mtz_exp, set_experiment=True) sfcalculator.inspect_data() sfcalculator.calc_fprotein(Return=False) sfcalculator.calc_fsolvent( diff --git a/tests/test_io.py b/tests/test_io.py index f157638..1e9d50e 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -25,10 +25,10 @@ def test_setdata(data_pdb): # test set_baniso new_baniso = a.atom_b_aniso.copy() - new_baniso[10, 3] = 7.88 + new_baniso[10, 1, 2] = 7.88 a.set_baniso(new_baniso) - assert a.atom_b_aniso[10, 3] == 7.88 - + assert a.atom_b_aniso[10, 1, 2] == 7.88 + # test set_occ new_occ = a.atom_occ.copy() new_occ[10] = 7.88 diff --git a/tests/test_voxel.py b/tests/test_voxel.py index 635ea41..9c75cb5 100644 --- a/tests/test_voxel.py +++ b/tests/test_voxel.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("binary", [True, False]) def test_voxelvalue_torch_p1_sm(data_pdb, data_mtz_exp, binary): sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=data_mtz_exp, set_experiment=True) + data_pdb, mtzdata=data_mtz_exp, set_experiment=True) vdw_rad = vdw_rad_tensor(sfcalculator.atom_name) uc_grid_orth_tensor = unitcell_grid_center(sfcalculator.unit_cell, spacing=4.5, return_tensor=True) @@ -41,7 +41,7 @@ def test_voxelvalue_torch_p1_sm(data_pdb, data_mtz_exp, binary): def test_voxel_1dto3d(data_pdb, data_mtz_exp): sfcalculator = SFcalculator( - data_pdb, mtzfile_dir=data_mtz_exp, set_experiment=True) + data_pdb, mtzdata=data_mtz_exp, set_experiment=True) vdw_rad = vdw_rad_tensor(sfcalculator.atom_name) uc_grid_orth_tensor = unitcell_grid_center(sfcalculator.unit_cell, spacing=4.5, return_tensor=True)