Skip to content

Commit

Permalink
change default solvent hyperparameters according to statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
minhuanli committed Oct 2, 2023
1 parent adba74f commit f65851e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
23 changes: 16 additions & 7 deletions SFC_Torch/Fmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,10 @@ def calc_fsolvent(
self,
solventpct=None,
gridsize=None,
dmin_mask=6.0,
Return=False,
dmin_mask=5.0,
dmin_nonzero=3.0,
exponent=10.0,
Return=False,
):
"""
Calculate the structure factor of solvent mask in a differentiable way
Expand Down Expand Up @@ -519,7 +520,7 @@ def calc_fsolvent(
)
rs_grid = reciprocal_grid(Hp1_array, Fp1_tensor, gridsize)
self.real_grid_mask = rsgrid2realmask(
rs_grid, solvent_percent=solventpct
rs_grid, solvent_percent=solventpct, exponent=exponent,
) # type: ignore
if not self.HKL_array is None:
self.Fmask_HKL = realmask2Fmask(self.real_grid_mask, self.HKL_array)
Expand Down Expand Up @@ -686,6 +687,13 @@ def init_scales(self, requires_grad=True):
if hasattr(self, "Fo"):
self.kmasks, self.kisos = self._init_kmask_kiso(requires_grad=requires_grad)
self.uanisos = self._init_uaniso(requires_grad=requires_grad)
Fmodel = self.calc_ftotal()
Fmodel_mag = torch.abs(Fmodel)
self.r_work, self.r_free = r_factor(
self.Fo[~self.Outlier],
Fmodel_mag[~self.Outlier],
self.free_flag[~self.Outlier],
)
else:
self._set_scales(requires_grad)

Expand Down Expand Up @@ -801,7 +809,7 @@ def get_scales_lbfgs(
ls_lr=0.1,
r_lr=0.1,
initialize=True,
verbose=True,
verbose=False,
):
self._get_scales_lbfgs_LS(ls_steps, ls_lr, verbose, initialize)
self._get_scales_lbfgs_r(r_steps, r_lr, verbose, initialize=False)
Expand Down Expand Up @@ -1005,10 +1013,11 @@ def calc_fsolvent_batch(
self,
solventpct=None,
gridsize=None,
dmin_mask=6,
dmin_mask=5,
dmin_nonzero=3.0,
exponent=10.0,
Return=False,
PARTITION=100,
dmin_nonzero=3.0,
):
"""
Should run after Calc_Fprotein_batch, calculate the solvent mask structure factors in batched manner
Expand Down Expand Up @@ -1058,7 +1067,7 @@ def calc_fsolvent_batch(
Hp1_array, Fp1_tensor_batch[start:end], gridsize, end - start
)
real_grid_mask = rsgrid2realmask(
rs_grid, solvent_percent=solventpct, Batch=True
rs_grid, solvent_percent=solventpct, exponent=exponent, Batch=True
) # type: ignore
Fmask_batch_j = realmask2Fmask(real_grid_mask, HKL_array, end - start)
if j == 0:
Expand Down
14 changes: 7 additions & 7 deletions SFC_Torch/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def reciprocal_grid(Hp1_array, Fp1_tensor, gridsize, batchsize=None):
Return:
Reciprocal space unit cell grid, as a torch.complex64 tensor
"""
grid = torch.zeros(gridsize, device=try_gpu(), dtype=torch.complex64)
tuple_index = tuple(torch.tensor(Hp1_array.T, device=try_gpu(), dtype=int)) # type: ignore
grid = torch.zeros(gridsize, device=Fp1_tensor.device, dtype=torch.complex64)
tuple_index = tuple(torch.tensor(Hp1_array.T, device=Fp1_tensor.device, dtype=int)) # type: ignore
if batchsize is not None:
for i in range(batchsize):
Fp1_tensor_i = Fp1_tensor[i]
Expand All @@ -38,7 +38,7 @@ def reciprocal_grid(Hp1_array, Fp1_tensor, gridsize, batchsize=None):
return grid


def rsgrid2realmask(rs_grid, solvent_percent=0.50, scale=50, Batch=False):
def rsgrid2realmask(rs_grid, solvent_percent=0.50, exponent=50.0, Batch=False):
"""
Convert reciprocal space grid to real space solvent mask grid, in a
fully differentiable way with torch
Expand All @@ -51,12 +51,12 @@ def rsgrid2realmask(rs_grid, solvent_percent=0.50, scale=50, Batch=False):
solvent_percent: 0 - 1 float
The approximate volume percentage of solvent in the system, to generate the cutoff
scale: int/float
exponent: int/float
The scale used in sigmoid function, to make the distribution binary
Return:
-------
tf.float32 tensor
torch.float32 tensor
The solvent mask grid in real space, solvent voxels have value close to 1, while protein voxels have value close to 0.
"""
real_grid = torch.real(torch.fft.fftn(rs_grid, dim=(-3, -2, -1)))
Expand All @@ -65,7 +65,7 @@ def rsgrid2realmask(rs_grid, solvent_percent=0.50, scale=50, Batch=False):
CUTOFF = torch.quantile(real_grid_norm[0], solvent_percent)
else:
CUTOFF = torch.quantile(real_grid_norm, solvent_percent)
real_grid_mask = torch.sigmoid((CUTOFF - real_grid_norm) * 50)
real_grid_mask = torch.sigmoid((CUTOFF - real_grid_norm) * exponent)
return real_grid_mask


Expand All @@ -88,7 +88,7 @@ def realmask2Fmask(real_grid_mask, H_array, batchsize=None):
Solvent mask structural factor corresponding to the HKL list in H_array
"""
Fmask_grid = torch.fft.ifftn(real_grid_mask, dim=(-3, -2, -1), norm="forward")
tuple_index = tuple(torch.tensor(H_array.T, device=try_gpu(), dtype=int)) # type: ignore
tuple_index = tuple(torch.tensor(H_array.T, device=Fmask_grid.device, dtype=int)) # type: ignore
if batchsize is not None:
Fmask = Fmask_grid[(slice(None), *tuple_index)]
else:
Expand Down

0 comments on commit f65851e

Please sign in to comment.