Skip to content

Commit 958577a

Browse files
committed
fix: add H200 TFLOPS
Signed-off-by: Alexander Zhipa <azzhipa@amazon.com>
1 parent c32778d commit 958577a

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

nemo_rl/utils/flops_tracker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def is_using_tf32() -> bool:
115115
("NVIDIA A100 80GB PCIe", torch.float32): 312 / 2 if is_using_tf32() else 19.5,
116116
("NVIDIA H100 80GB HBM3", torch.bfloat16): 1979 / 2,
117117
("NVIDIA H100 80GB HBM3", torch.float32): 989 / 2 if is_using_tf32() else 67.0,
118+
("NVIDIA H200", torch.bfloat16): 1979 / 2,
119+
("NVIDIA H200", torch.float32): 989 / 2 if is_using_tf32() else 67.0,
118120
("NVIDIA B200", torch.bfloat16): 4500 / 2,
119121
("NVIDIA B200", torch.float32): 2200 / 2 if is_using_tf32() else 80.0,
120122
("NVIDIA B300", torch.bfloat16): 4500 / 2,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import torch
18+
19+
from nemo_rl.utils.flops_tracker import get_theoretical_tflops, is_using_tf32
20+
21+
22+
@pytest.mark.parametrize(
23+
"device_name, model_dtype, tflops",
24+
[
25+
("NVIDIA A100 80GB PCIe", torch.bfloat16, 624 / 2),
26+
("NVIDIA A100 80GB PCIe", torch.float32, 312 / 2 if is_using_tf32() else 19.5),
27+
("NVIDIA H100 80GB HBM3", torch.bfloat16, 1979 / 2),
28+
("NVIDIA H100 80GB HBM3", torch.float32, 989 / 2 if is_using_tf32() else 67.0),
29+
("NVIDIA H200", torch.bfloat16, 1979 / 2),
30+
("NVIDIA H200", torch.float32, 989 / 2 if is_using_tf32() else 67.0),
31+
("NVIDIA B200", torch.bfloat16, 4500 / 2),
32+
("NVIDIA B200", torch.float32, 2200 / 2 if is_using_tf32() else 80.0),
33+
("NVIDIA B300", torch.bfloat16, 4500 / 2),
34+
("NVIDIA B300", torch.float32, 2200 / 2 if is_using_tf32() else 80.0),
35+
("NVIDIA GB200", torch.bfloat16, 4900 / 2),
36+
("NVIDIA GB200", torch.float32, 2500 / 2 if is_using_tf32() else 80.0),
37+
("NVIDIA GB300", torch.bfloat16, 4900 / 2),
38+
("NVIDIA GB300", torch.float32, 2500 / 2 if is_using_tf32() else 80.0),
39+
],
40+
)
41+
def test_theoretical_tflops(device_name, model_dtype, tflops):
42+
assert get_theoretical_tflops(device_name, model_dtype) == pytest.approx(tflops)

0 commit comments

Comments
 (0)