Skip to content

Commit

Permalink
Added reset default tensor type before calling the newtonnet tests. T…
Browse files Browse the repository at this point in the history
…his was added because MACE was resetting it.
  • Loading branch information
kumaranu committed Jul 9, 2024
1 parent eb9bd53 commit a935cee
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/geodesic_ts_with_hessian/test_using_newtonnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import pytest
import logging

import torch
from ase.io import read
from pathlib import Path
from quacc import get_settings
Expand All @@ -19,6 +21,11 @@ def setup_test_environment(tmp_path):
return reactant, product


@pytest.fixture(autouse=True)
def reset_default_tensor_type():
torch.set_default_tensor_type(torch.FloatTensor)


def test_geodesic_ts_hess_irc_newtonnet(setup_test_environment):
reactant, product = setup_test_environment

Expand Down
6 changes: 6 additions & 0 deletions tests/geodesic_ts_without_hessian/test_using_newtonnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch
import pytest
import logging
from ase.io import read
Expand All @@ -19,6 +20,11 @@ def setup_test_environment(tmp_path):
return reactant, product


@pytest.fixture(autouse=True)
def reset_default_tensor_type():
torch.set_default_tensor_type(torch.FloatTensor)


def test_geodesic_ts_hess_irc_newtonnet(setup_test_environment):
reactant, product = setup_test_environment

Expand Down
6 changes: 6 additions & 0 deletions tests/neb_ts_with_hessian/test_using_newtonnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import pytest
import logging
from ase.io import read
Expand All @@ -18,6 +19,11 @@ def setup_test_environment(tmp_path):
return reactant, product


@pytest.fixture(autouse=True)
def reset_default_tensor_type():
torch.set_default_tensor_type(torch.FloatTensor)


def test_neb_ts_hess_irc_newtonnet(setup_test_environment):
reactant, product = setup_test_environment

Expand Down
6 changes: 6 additions & 0 deletions tests/neb_ts_without_hessian/test_using_newtonnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import pytest
import logging
from ase.io import read
Expand All @@ -18,6 +19,11 @@ def setup_test_environment(tmp_path):
return reactant, product


@pytest.fixture(autouse=True)
def reset_default_tensor_type():
torch.set_default_tensor_type(torch.FloatTensor)


def test_neb_ts_no_hess_irc_newtonnet(setup_test_environment):
reactant, product = setup_test_environment

Expand Down

0 comments on commit a935cee

Please sign in to comment.