From d3698c601608cf7ab75f715639c14296fb474a4f Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Mon, 2 Mar 2026 09:43:58 -0800 Subject: [PATCH] oom val fix, naccl timeout --- nequix/torch_impl/model.py | 4 ++-- nequix/torch_impl/train.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/nequix/torch_impl/model.py b/nequix/torch_impl/model.py index 2c65ccf..695ca11 100644 --- a/nequix/torch_impl/model.py +++ b/nequix/torch_impl/model.py @@ -635,7 +635,7 @@ def forward( minus_forces, virial = torch.autograd.grad( outputs=[node_energies.sum()], inputs=[positions, eps], - create_graph=True, + create_graph=self.training, materialize_grads=True, ) @@ -647,7 +647,7 @@ def forward( minus_forces = torch.autograd.grad( outputs=[node_energies.sum()], inputs=[positions], - create_graph=True, + create_graph=self.training, materialize_grads=True, )[0] stress = None diff --git a/nequix/torch_impl/train.py b/nequix/torch_impl/train.py index bd5c18b..a229d72 100644 --- a/nequix/torch_impl/train.py +++ b/nequix/torch_impl/train.py @@ -1,6 +1,7 @@ import argparse import copy import os +from datetime import timedelta import time from collections import defaultdict from pathlib import Path @@ -123,9 +124,9 @@ def evaluate( val_loss, metrics = loss( model, batch, energy_weight, force_weight, stress_weight, loss_type, device ) - total_metrics["loss"] += val_loss.item() * n_graphs + total_metrics["loss"] += val_loss.detach().item() * n_graphs for key, value in metrics.items(): - total_metrics[key] += value.item() * n_graphs + total_metrics[key] += value.detach().item() * n_graphs total_count += n_graphs for key, value in total_metrics.items(): @@ -546,6 +547,7 @@ def train_step(model, ema_model, batch, step): torch.distributed.barrier() if rank == 0: + # TODO: multi gpu validation, evaluate subset on each rank and aggregate metrics val_metrics = evaluate( ema_model, val_loader, @@ -580,7 +582,9 @@ def train_step(model, ema_model, batch, step): def setup_ddp(): """Initialize distributed training""" - init_process_group(backend="nccl") + # NOTE: set timeout to 30 minutes so validation doesn't cause NCCL timeout + # (wouldn't be a problem if we used multi-gpu validation, see TODO above) + init_process_group(backend="nccl", timeout=timedelta(minutes=30)) torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))