Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nequix/torch_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def forward(
minus_forces, virial = torch.autograd.grad(
outputs=[node_energies.sum()],
inputs=[positions, eps],
create_graph=True,
create_graph=self.training,
materialize_grads=True,
)

Expand All @@ -647,7 +647,7 @@ def forward(
minus_forces = torch.autograd.grad(
outputs=[node_energies.sum()],
inputs=[positions],
create_graph=True,
create_graph=self.training,
materialize_grads=True,
)[0]
stress = None
Expand Down
10 changes: 7 additions & 3 deletions nequix/torch_impl/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import copy
import os
from datetime import timedelta
import time
from collections import defaultdict
from pathlib import Path
Expand Down Expand Up @@ -123,9 +124,9 @@ def evaluate(
val_loss, metrics = loss(
model, batch, energy_weight, force_weight, stress_weight, loss_type, device
)
total_metrics["loss"] += val_loss.item() * n_graphs
total_metrics["loss"] += val_loss.detach().item() * n_graphs
for key, value in metrics.items():
total_metrics[key] += value.item() * n_graphs
total_metrics[key] += value.detach().item() * n_graphs
total_count += n_graphs

for key, value in total_metrics.items():
Expand Down Expand Up @@ -546,6 +547,7 @@ def train_step(model, ema_model, batch, step):
torch.distributed.barrier()

if rank == 0:
# TODO: multi gpu validation, evaluate subset on each rank and aggregate metrics
val_metrics = evaluate(
ema_model,
val_loader,
Expand Down Expand Up @@ -580,7 +582,9 @@ def train_step(model, ema_model, batch, step):

def setup_ddp():
"""Initialize distributed training"""
init_process_group(backend="nccl")
# NOTE: set timeout to 30 minutes so validation doesn't cause NCCL timeout
# (wouldn't be a problem if we used multi-gpu validation, see TODO above)
init_process_group(backend="nccl", timeout=timedelta(minutes=30))
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


Expand Down