From f65851e2e142ae6dfeaaa05c7af7e85733f58e38 Mon Sep 17 00:00:00 2001 From: minhuanli Date: Mon, 2 Oct 2023 10:56:12 -0400 Subject: [PATCH] change default solvent hyperparameters according to statistics --- SFC_Torch/Fmodel.py | 23 ++++++++++++++++------- SFC_Torch/mask.py | 14 +++++++------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/SFC_Torch/Fmodel.py b/SFC_Torch/Fmodel.py index 1f95caa..85f901c 100644 --- a/SFC_Torch/Fmodel.py +++ b/SFC_Torch/Fmodel.py @@ -474,9 +474,10 @@ def calc_fsolvent( self, solventpct=None, gridsize=None, - dmin_mask=6.0, - Return=False, + dmin_mask=5.0, dmin_nonzero=3.0, + exponent=10.0, + Return=False, ): """ Calculate the structure factor of solvent mask in a differentiable way @@ -519,7 +520,7 @@ def calc_fsolvent( ) rs_grid = reciprocal_grid(Hp1_array, Fp1_tensor, gridsize) self.real_grid_mask = rsgrid2realmask( - rs_grid, solvent_percent=solventpct + rs_grid, solvent_percent=solventpct, exponent=exponent, ) # type: ignore if not self.HKL_array is None: self.Fmask_HKL = realmask2Fmask(self.real_grid_mask, self.HKL_array) @@ -686,6 +687,13 @@ def init_scales(self, requires_grad=True): if hasattr(self, "Fo"): self.kmasks, self.kisos = self._init_kmask_kiso(requires_grad=requires_grad) self.uanisos = self._init_uaniso(requires_grad=requires_grad) + Fmodel = self.calc_ftotal() + Fmodel_mag = torch.abs(Fmodel) + self.r_work, self.r_free = r_factor( + self.Fo[~self.Outlier], + Fmodel_mag[~self.Outlier], + self.free_flag[~self.Outlier], + ) else: self._set_scales(requires_grad) @@ -801,7 +809,7 @@ def get_scales_lbfgs( ls_lr=0.1, r_lr=0.1, initialize=True, - verbose=True, + verbose=False, ): self._get_scales_lbfgs_LS(ls_steps, ls_lr, verbose, initialize) self._get_scales_lbfgs_r(r_steps, r_lr, verbose, initialize=False) @@ -1005,10 +1013,11 @@ def calc_fsolvent_batch( self, solventpct=None, gridsize=None, - dmin_mask=6, + dmin_mask=5, + dmin_nonzero=3.0, + exponent=10.0, Return=False, PARTITION=100, - dmin_nonzero=3.0, ): """ Should run after Calc_Fprotein_batch, calculate the solvent mask structure factors in batched manner @@ -1058,7 +1067,7 @@ def calc_fsolvent_batch( Hp1_array, Fp1_tensor_batch[start:end], gridsize, end - start ) real_grid_mask = rsgrid2realmask( - rs_grid, solvent_percent=solventpct, Batch=True + rs_grid, solvent_percent=solventpct, exponent=exponent, Batch=True ) # type: ignore Fmask_batch_j = realmask2Fmask(real_grid_mask, HKL_array, end - start) if j == 0: diff --git a/SFC_Torch/mask.py b/SFC_Torch/mask.py index 07d20f9..1d0a82d 100644 --- a/SFC_Torch/mask.py +++ b/SFC_Torch/mask.py @@ -21,8 +21,8 @@ def reciprocal_grid(Hp1_array, Fp1_tensor, gridsize, batchsize=None): Return: Reciprocal space unit cell grid, as a torch.complex64 tensor """ - grid = torch.zeros(gridsize, device=try_gpu(), dtype=torch.complex64) - tuple_index = tuple(torch.tensor(Hp1_array.T, device=try_gpu(), dtype=int)) # type: ignore + grid = torch.zeros(gridsize, device=Fp1_tensor.device, dtype=torch.complex64) + tuple_index = tuple(torch.tensor(Hp1_array.T, device=Fp1_tensor.device, dtype=int)) # type: ignore if batchsize is not None: for i in range(batchsize): Fp1_tensor_i = Fp1_tensor[i] @@ -38,7 +38,7 @@ def reciprocal_grid(Hp1_array, Fp1_tensor, gridsize, batchsize=None): return grid -def rsgrid2realmask(rs_grid, solvent_percent=0.50, scale=50, Batch=False): +def rsgrid2realmask(rs_grid, solvent_percent=0.50, exponent=50.0, Batch=False): """ Convert reciprocal space grid to real space solvent mask grid, in a fully differentiable way with torch @@ -51,12 +51,12 @@ def rsgrid2realmask(rs_grid, solvent_percent=0.50, scale=50, Batch=False): solvent_percent: 0 - 1 float The approximate volume percentage of solvent in the system, to generate the cutoff - scale: int/float + exponent: int/float The scale used in sigmoid function, to make the distribution binary Return: ------- - tf.float32 tensor + torch.float32 tensor The solvent mask grid in real space, solvent voxels have value close to 1, while protein voxels have value close to 0. """ real_grid = torch.real(torch.fft.fftn(rs_grid, dim=(-3, -2, -1))) @@ -65,7 +65,7 @@ def rsgrid2realmask(rs_grid, solvent_percent=0.50, scale=50, Batch=False): CUTOFF = torch.quantile(real_grid_norm[0], solvent_percent) else: CUTOFF = torch.quantile(real_grid_norm, solvent_percent) - real_grid_mask = torch.sigmoid((CUTOFF - real_grid_norm) * 50) + real_grid_mask = torch.sigmoid((CUTOFF - real_grid_norm) * exponent) return real_grid_mask @@ -88,7 +88,7 @@ def realmask2Fmask(real_grid_mask, H_array, batchsize=None): Solvent mask structural factor corresponding to the HKL list in H_array """ Fmask_grid = torch.fft.ifftn(real_grid_mask, dim=(-3, -2, -1), norm="forward") - tuple_index = tuple(torch.tensor(H_array.T, device=try_gpu(), dtype=int)) # type: ignore + tuple_index = tuple(torch.tensor(H_array.T, device=Fmask_grid.device, dtype=int)) # type: ignore if batchsize is not None: Fmask = Fmask_grid[(slice(None), *tuple_index)] else: