@@ -327,6 +327,9 @@ def run_folding_on_context(
327
327
if device is None :
328
328
device = torch .device ("cuda:0" )
329
329
330
+ # Clear memory
331
+ torch .cuda .empty_cache ()
332
+
330
333
##
331
334
## Validate inputs
332
335
##
@@ -443,6 +446,9 @@ def run_folding_on_context(
443
446
token_single_mask = token_single_mask ,
444
447
token_pair_mask = token_pair_mask ,
445
448
)
449
+ # We won't be using the trunk anymore; remove it from memory
450
+ del trunk
451
+ torch .cuda .empty_cache ()
446
452
447
453
##
448
454
## Denoise the trunk representation by passing it through the diffusion module
@@ -534,6 +540,10 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
534
540
d_i_prime = (atom_pos - denoised_pos ) / sigma_next
535
541
atom_pos = atom_pos + (sigma_next - sigma_hat ) * ((d_i_prime + d_i ) / 2 )
536
542
543
+ # We won't be running diffusion anymore
544
+ del diffusion_module
545
+ torch .cuda .empty_cache ()
546
+
537
547
##
538
548
## Run the confidence model
539
549
##
@@ -610,6 +620,11 @@ def avg_per_token_1d(x):
610
620
##
611
621
## Write the outputs
612
622
##
623
+ # Move data to the CPU so we don't hit GPU memory limits
624
+ inputs = move_data_to_device (inputs , torch .device ("cpu" ))
625
+ atom_pos = atom_pos .cpu ()
626
+ plddt_logits = plddt_logits .cpu ()
627
+ pae_logits = pae_logits .cpu ()
613
628
614
629
# Plot coverage of tokens by MSA, save plot
615
630
output_dir .mkdir (parents = True , exist_ok = True )
@@ -671,7 +686,7 @@ def avg_per_token_1d(x):
671
686
outputs_to_cif (
672
687
coords = atom_pos [idx : idx + 1 ],
673
688
bfactors = scaled_plddt_scores_per_atom ,
674
- output_batch = move_data_to_device ( inputs , torch . device ( "cpu" )) ,
689
+ output_batch = inputs ,
675
690
write_path = cif_out_path ,
676
691
entity_names = {
677
692
c .entity_data .entity_id : c .entity_data .entity_name
0 commit comments