Skip to content

Commit

Permalink
Optimize Fprotein batch memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
minhuanli committed Oct 12, 2023
1 parent 09c9493 commit 829f3df
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
26 changes: 15 additions & 11 deletions SFC_Torch/Fmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,8 @@ def calc_fprotein_batch(self, atoms_position_batch, Return=False, PARTITION=20):
PARTITION: Int, default 20
To reduce the memory cost during the computation, we divide the batch into several partitions and loops through them.
Larger PARTITION will require larger GPU memory. Default 20 will take around 4GB, if N_atoms~1600 and N_HKLs~13000.
Larger PARTITION will require larger GPU memory. Default 20 will require around 21GB, if N_atoms~2400 and N_HKLs~24000.
Memory scales linearly with PARITION, N_atoms, and N_HKLs.
But larger PARTITION will give a smaller wall time, so this is a trade-off.
"""
# Read and tensor-fy necessary information
Expand Down Expand Up @@ -1253,7 +1254,7 @@ def F_protein_batch(
# F_calc = sum_Gsum_j{ [f0_sj*DWF*exp(2*pi*i*(h,k,l)*(R_G*(x1,x2,x3)+T_G))]} fractional postion, Rupp's Book P279
# G is symmetry operations of the spacegroup and j is the atoms
# DWF is the Debye-Waller Factor, has isotropic and anisotropic version, based on the PDB file input, Rupp's Book P641
HKL_tensor = torch.tensor(HKL_array, device=try_gpu()).type(torch.float32)
HKL_tensor = torch.tensor(HKL_array).to(fullsf_tensor)
batchsize = atom_pos_frac_batch.shape[0]

oc_sf = fullsf_tensor * atom_occ[..., None] # [N_atom, N_HKLs]
Expand Down Expand Up @@ -1282,25 +1283,28 @@ def F_protein_batch(
start = j * PARTITION
end = min((j + 1) * PARTITION, batchsize)
for i in range(N_ops): # Loop through symmetry operations to reduce memory cost
# [N_atoms, N_HKLs]
dwf_aniso = DWF_aniso(
atom_aniso_uw, orth2frac_tensor, sym_oped_hkl[:, i, :]
)
# [N_atoms, N_HKLs]
dwf_all = torch.where(
mask_vec[:, None], dwf_iso, dwf_aniso
)
# Shape [PARTITION, N_atoms, N_HKLs]
phase_ij = (
exp_phase_ij = dwf_all * torch.exp(1j * (
2
* torch.pi
* torch.tensordot(
sym_oped_pos_frac[start:end, :, :, i], HKL_tensor.T, 1
)
)
dwf_aniso = DWF_aniso(
atom_aniso_uw, orth2frac_tensor, sym_oped_hkl[:, i, :]
) # [N_atoms, N_HKLs]
dwf_all = torch.where(
mask_vec[:, None], dwf_iso, dwf_aniso
) # [N_atoms, N_HKLs]
exp_phase_ij = dwf_all * torch.exp(1j * phase_ij)
))
# Shape [PARTITION, N_HKLs], sum over atoms
Fcalc_ij = torch.sum(exp_phase_ij * oc_sf, dim=1)
del exp_phase_ij # release the memory
# Shape [PARTITION, N_HKLs], sum over symmetry operations
Fcalc_j = Fcalc_j + Fcalc_ij
del Fcalc_ij
if j == 0:
F_calc = Fcalc_j
else:
Expand Down
2 changes: 1 addition & 1 deletion SFC_Torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def DWF_iso(b_iso, dr2_array):
-------
A 2D [N_atoms, N_HKLs] float32 tensor with DWF corresponding to different atoms and different HKLs
"""
dr2_tensor = torch.tensor(dr2_array, device=try_gpu())
dr2_tensor = torch.tensor(dr2_array).to(b_iso)
return torch.exp(-b_iso.view([-1, 1]) * dr2_tensor / 4.0).type(torch.float32)


Expand Down

0 comments on commit 829f3df

Please sign in to comment.