From d3872538a522d7cd528c14be548dc03d37265b8d Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Fri, 13 Sep 2024 22:57:31 +0000 Subject: [PATCH 1/4] Empty cache before we run folding --- chai_lab/chai1.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 69f0a35..858b84b 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -327,6 +327,9 @@ def run_folding_on_context( if device is None: device = torch.device("cuda:0") + # Clear memory + torch.cuda.empty_cache() + ## ## Validate inputs ## From c7a06f72c68b3f4c9af05648d9bed9974aefcc3b Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Fri, 13 Sep 2024 23:25:21 +0000 Subject: [PATCH 2/4] More aggressively move data off GPU --- chai_lab/chai1.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 858b84b..2e5bba1 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -613,6 +613,11 @@ def avg_per_token_1d(x): ## ## Write the outputs ## + # Move data to the CPU so we don't hit GPU memory limits + inputs = move_data_to_device(inputs, torch.device("cpu")) + atom_pos = atom_pos.cpu() + plddt_logits = plddt_logits.cpu() + pae_logits = pae_logits.cpu() # Plot coverage of tokens by MSA, save plot output_dir.mkdir(parents=True, exist_ok=True) From 448536075a13dd188f03de3e029472bfbae1d8c0 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Fri, 13 Sep 2024 23:38:07 +0000 Subject: [PATCH 3/4] Delete modules after we are finished with them --- chai_lab/chai1.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 2e5bba1..419bb55 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -446,6 +446,9 @@ def run_folding_on_context( token_single_mask=token_single_mask, token_pair_mask=token_pair_mask, ) + # We won't be using the trunk anymore; remove it from memory + del trunk + torch.cuda.empty_cache() ## ## Denoise the trunk representation by passing it through the diffusion module @@ -537,6 +540,10 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor: d_i_prime = (atom_pos - denoised_pos) / sigma_next atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2) + # We won't be running diffusion anymore + del diffusion_module + torch.cuda.empty_cache() + ## ## Run the confidence model ## From 90b8d27c2447478379046f201cb0079e19b9282e Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Fri, 13 Sep 2024 23:41:34 +0000 Subject: [PATCH 4/4] Remove redundant call to move inputs to cpu --- chai_lab/chai1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 419bb55..5cabab8 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -686,7 +686,7 @@ def avg_per_token_1d(x): write_pdbs_from_outputs( coords=atom_pos[idx : idx + 1], bfactors=scaled_plddt_scores_per_atom, - output_batch=move_data_to_device(inputs, torch.device("cpu")), + output_batch=inputs, write_path=pdb_out_path, ) output_paths.append(pdb_out_path)