Skip to content

Commit

Permalink
test update
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 30, 2024
1 parent 8e4f0b0 commit d60b4dc
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from thunder.tests.framework import (
instantiate,
TorchExecutor,
TorchCompileCatExecutor,
nvFuserExecutor,
DynamoThunderExecutor,
)
Expand Down Expand Up @@ -248,7 +247,7 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
@instantiate(
dtypes=(thunder.core.dtypes.float32, thunder.core.dtypes.bfloat16),
devicetypes=(thunder.core.devices.DeviceType.CUDA,),
executors=(TorchExecutor, TorchCompileCatExecutor, nvFuserExecutor, DynamoThunderExecutor),
executors=(TorchExecutor, nvFuserExecutor, DynamoThunderExecutor),
decorators=(
pytest.mark.skipif(
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
Expand All @@ -264,11 +263,10 @@ def test_torchao_float8_linear(executor, device, dtype, bias):
device = torch.device("cuda")
torch_dtype = thunder.core.dtypes.to_torch_dtype(dtype)

# model = nn.Linear(in_features, out_features, bias=False, device=device, dtype=torch_dtype)
model = nn.Sequential(
nn.Linear(in_features, out_features, bias=bias),
# nn.GELU(approximate="tanh"),
nn.Linear(out_features, out_features, bias=False),
nn.GELU(approximate="tanh"),
nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=torch_dtype)
fp8_model = convert_to_float8_training(model)
x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype)
Expand All @@ -286,8 +284,8 @@ def test_torchao_float8_linear(executor, device, dtype, bias):
expected = fp8_model(x)
jitted = executor.make_callable(fp8_model)

if bias and dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor:
with pytest.raises(AssertionError, match="unexpected a_role GemmInputRole.GRAD_OUTPUT and b_role GemmInputRole.GRAD_OUTPUT"):
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == nvFuserExecutor:
with pytest.raises(RuntimeError, match="INTERNAL ASSERT FAILED"):
jitted(x)
return
actual = jitted(x)
Expand All @@ -296,13 +294,15 @@ def test_torchao_float8_linear(executor, device, dtype, bias):
torch.testing.assert_close(actual, expected)
return

if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor):
pytest.xfail("numerical error")
torch.testing.assert_close(actual, expected)

# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`.
# Currently no subgraphs go to thunder.jit.
if is_thunderfx:
for subgraph in backend.subgraph_infos:
if not bias:
if not bias and dtype == thunder.core.dtypes.bfloat16:
assert not subgraph.thunder_compiled_fns
else:
assert len(subgraph.thunder_compiled_fns) == 1
assert subgraph.thunder_compiled_fns

0 comments on commit d60b4dc

Please sign in to comment.