Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move from torch.cuda.amp to torch.autocast; Add tests for amp #838

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

misko
Copy link
Collaborator

@misko misko commented Sep 9, 2024

This PR updates torch.cuda.amp.autocast(args...) and torch.cuda.amp.GradScaler(args...) (deprecated) and also adds CPU AMP tests. ( https://pytorch.org/docs/stable/amp.html )
Recently there was an eSCN model that did not run on GPU AMP and the current set of tests did not catch it. This PR adds those equivalent tests on CPU AMP, which will test this going forward.
Additional fixes haven been made to eSCN / SCN / Gemnet OC in this PR.

@misko misko marked this pull request as ready for review September 9, 2024 23:52
@misko misko added enhancement New feature or request patch Patch version release labels Sep 10, 2024
@misko misko mentioned this pull request Sep 10, 2024
Copy link
Collaborator

@lbluque lbluque left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @misko, mostly style suggestions

src/fairchem/core/models/equiformer_v2/layer_norm.py Outdated Show resolved Hide resolved
if node_energy.device.type == "cuda":
energy.index_add_(0, data.batch, node_energy.view(-1))
else:
energy.index_add_(0, data.batch, node_energy.float().view(-1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we loose a lot of performance or use up too much memory when casting? if not then should we cast regardless?

src/fairchem/core/modules/scaling/fit.py Show resolved Hide resolved
tests/core/e2e/test_s2ef.py Show resolved Hide resolved
Copy link

codecov bot commented Sep 19, 2024

Codecov Report

Attention: Patch coverage is 71.13402% with 28 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...c/fairchem/core/models/equiformer_v2/layer_norm.py 64.93% 27 Missing ⚠️
...e/models/equiformer_v2/equiformer_v2_deprecated.py 0.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
...rc/fairchem/applications/cattsunami/core/ocpneb.py 87.12% <100.00%> (+0.09%) ⬆️
src/fairchem/core/common/relaxation/ase_utils.py 63.15% <100.00%> (+1.20%) ⬆️
src/fairchem/core/models/escn/escn.py 95.86% <100.00%> (ø)
src/fairchem/core/models/gemnet_oc/gemnet_oc.py 89.91% <100.00%> (ø)
src/fairchem/core/models/scn/scn.py 93.56% <100.00%> (ø)
src/fairchem/core/modules/scaling/fit.py 72.72% <100.00%> (ø)
src/fairchem/core/trainers/base_trainer.py 88.86% <100.00%> (+0.94%) ⬆️
src/fairchem/core/trainers/ocp_trainer.py 69.12% <100.00%> (+0.10%) ⬆️
...e/models/equiformer_v2/equiformer_v2_deprecated.py 91.15% <0.00%> (ø)
...c/fairchem/core/models/equiformer_v2/layer_norm.py 57.64% <64.93%> (-0.26%) ⬇️

@misko misko requested a review from lbluque September 19, 2024 23:46
Copy link
Collaborator

@lbluque lbluque left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg, just minor suggestions!

@@ -36,6 +36,8 @@ def __init__(
precon=None,
cpu=False,
batch_size=4,
seed=0, # set a seed for reproducibility
amp=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just set amp=True as the default. From the lines bellow it looks like thats the case

@@ -110,11 +112,13 @@ def __init__(
local_rank=config.get("local_rank", 0),
is_debug=config.get("is_debug", True),
cpu=cpu,
amp=True,
amp=(amp==None or amp), # AMP on by default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, its more pythonic to check None with amp is None (its a "singleton")

if self.lmax > 0:
num_m_components = (self.lmax + 1) ** 2
feature = node_input.narrow(1, 1, num_m_components - 1)
with torch.autocast(device_type=node_input.device.type, enabled=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to double up on autocast decorators like the forward method above?
I am just worried about potential incorrect indentation bugs in the future.

out = self._forward(batch)
out = {k: v.float() for k, v in out.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we decide to always cast to float32 on predictions? Thats sounds ok with me, just making sure because this is different than before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request patch Patch version release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants