From 829f3df268061d2a4d07e491b6b93b2441ef4d92 Mon Sep 17 00:00:00 2001 From: minhuanli Date: Thu, 12 Oct 2023 17:35:32 -0400 Subject: [PATCH] Optimize Fprotein batch memory usage --- SFC_Torch/Fmodel.py | 26 +++++++++++++++----------- SFC_Torch/utils.py | 2 +- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/SFC_Torch/Fmodel.py b/SFC_Torch/Fmodel.py index 85f901c..1575ed6 100644 --- a/SFC_Torch/Fmodel.py +++ b/SFC_Torch/Fmodel.py @@ -978,7 +978,8 @@ def calc_fprotein_batch(self, atoms_position_batch, Return=False, PARTITION=20): PARTITION: Int, default 20 To reduce the memory cost during the computation, we divide the batch into several partitions and loops through them. - Larger PARTITION will require larger GPU memory. Default 20 will take around 4GB, if N_atoms~1600 and N_HKLs~13000. + Larger PARTITION will require larger GPU memory. Default 20 will require around 21GB, if N_atoms~2400 and N_HKLs~24000. + Memory scales linearly with PARITION, N_atoms, and N_HKLs. But larger PARTITION will give a smaller wall time, so this is a trade-off. """ # Read and tensor-fy necessary information @@ -1253,7 +1254,7 @@ def F_protein_batch( # F_calc = sum_Gsum_j{ [f0_sj*DWF*exp(2*pi*i*(h,k,l)*(R_G*(x1,x2,x3)+T_G))]} fractional postion, Rupp's Book P279 # G is symmetry operations of the spacegroup and j is the atoms # DWF is the Debye-Waller Factor, has isotropic and anisotropic version, based on the PDB file input, Rupp's Book P641 - HKL_tensor = torch.tensor(HKL_array, device=try_gpu()).type(torch.float32) + HKL_tensor = torch.tensor(HKL_array).to(fullsf_tensor) batchsize = atom_pos_frac_batch.shape[0] oc_sf = fullsf_tensor * atom_occ[..., None] # [N_atom, N_HKLs] @@ -1282,25 +1283,28 @@ def F_protein_batch( start = j * PARTITION end = min((j + 1) * PARTITION, batchsize) for i in range(N_ops): # Loop through symmetry operations to reduce memory cost + # [N_atoms, N_HKLs] + dwf_aniso = DWF_aniso( + atom_aniso_uw, orth2frac_tensor, sym_oped_hkl[:, i, :] + ) + # [N_atoms, N_HKLs] + dwf_all = torch.where( + mask_vec[:, None], dwf_iso, dwf_aniso + ) # Shape [PARTITION, N_atoms, N_HKLs] - phase_ij = ( + exp_phase_ij = dwf_all * torch.exp(1j * ( 2 * torch.pi * torch.tensordot( sym_oped_pos_frac[start:end, :, :, i], HKL_tensor.T, 1 ) - ) - dwf_aniso = DWF_aniso( - atom_aniso_uw, orth2frac_tensor, sym_oped_hkl[:, i, :] - ) # [N_atoms, N_HKLs] - dwf_all = torch.where( - mask_vec[:, None], dwf_iso, dwf_aniso - ) # [N_atoms, N_HKLs] - exp_phase_ij = dwf_all * torch.exp(1j * phase_ij) + )) # Shape [PARTITION, N_HKLs], sum over atoms Fcalc_ij = torch.sum(exp_phase_ij * oc_sf, dim=1) + del exp_phase_ij # release the memory # Shape [PARTITION, N_HKLs], sum over symmetry operations Fcalc_j = Fcalc_j + Fcalc_ij + del Fcalc_ij if j == 0: F_calc = Fcalc_j else: diff --git a/SFC_Torch/utils.py b/SFC_Torch/utils.py index 3d2d74f..a4add92 100644 --- a/SFC_Torch/utils.py +++ b/SFC_Torch/utils.py @@ -107,7 +107,7 @@ def DWF_iso(b_iso, dr2_array): ------- A 2D [N_atoms, N_HKLs] float32 tensor with DWF corresponding to different atoms and different HKLs """ - dr2_tensor = torch.tensor(dr2_array, device=try_gpu()) + dr2_tensor = torch.tensor(dr2_array).to(b_iso) return torch.exp(-b_iso.view([-1, 1]) * dr2_tensor / 4.0).type(torch.float32)