Skip to content

Commit 1d3e499

Browse files
authored
Empty cache before we run folding (#47)
* Empty cache before we run folding * More aggressively move data off GPU * Delete modules after we are finished with them * Remove redundant call to move inputs to cpu
1 parent ecd62ff commit 1d3e499

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

chai_lab/chai1.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ def run_folding_on_context(
327327
if device is None:
328328
device = torch.device("cuda:0")
329329

330+
# Clear memory
331+
torch.cuda.empty_cache()
332+
330333
##
331334
## Validate inputs
332335
##
@@ -443,6 +446,9 @@ def run_folding_on_context(
443446
token_single_mask=token_single_mask,
444447
token_pair_mask=token_pair_mask,
445448
)
449+
# We won't be using the trunk anymore; remove it from memory
450+
del trunk
451+
torch.cuda.empty_cache()
446452

447453
##
448454
## Denoise the trunk representation by passing it through the diffusion module
@@ -534,6 +540,10 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
534540
d_i_prime = (atom_pos - denoised_pos) / sigma_next
535541
atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2)
536542

543+
# We won't be running diffusion anymore
544+
del diffusion_module
545+
torch.cuda.empty_cache()
546+
537547
##
538548
## Run the confidence model
539549
##
@@ -610,6 +620,11 @@ def avg_per_token_1d(x):
610620
##
611621
## Write the outputs
612622
##
623+
# Move data to the CPU so we don't hit GPU memory limits
624+
inputs = move_data_to_device(inputs, torch.device("cpu"))
625+
atom_pos = atom_pos.cpu()
626+
plddt_logits = plddt_logits.cpu()
627+
pae_logits = pae_logits.cpu()
613628

614629
# Plot coverage of tokens by MSA, save plot
615630
output_dir.mkdir(parents=True, exist_ok=True)
@@ -671,7 +686,7 @@ def avg_per_token_1d(x):
671686
outputs_to_cif(
672687
coords=atom_pos[idx : idx + 1],
673688
bfactors=scaled_plddt_scores_per_atom,
674-
output_batch=move_data_to_device(inputs, torch.device("cpu")),
689+
output_batch=inputs,
675690
write_path=cif_out_path,
676691
entity_names={
677692
c.entity_data.entity_id: c.entity_data.entity_name

0 commit comments

Comments
 (0)