Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionalize in-place ops #584

Merged
merged 38 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c1a13fa
init: some binary ops
crcrpar Jun 12, 2024
05afe9d
replace in-place with out-of-place
crcrpar Jun 12, 2024
351149a
correctly use `prims.copy_` outputs in trace.output
crcrpar Jun 12, 2024
c1d5234
fix proxy swaps.
crcrpar Jun 12, 2024
0044f6d
preserve the change from upstream
crcrpar Jun 12, 2024
6355d34
docstring
crcrpar Jun 12, 2024
e348150
relax
crcrpar Jun 12, 2024
d7dca4d
cover more in-place ops (#590)
crcrpar Jun 13, 2024
6093ffe
store original and intermediate traces
crcrpar Jun 13, 2024
532abbb
import cleanup
crcrpar Jun 13, 2024
82da20c
a bit simpler
crcrpar Jun 13, 2024
cf814f1
guard copy if dst is an arg or a kwarg
crcrpar Jun 15, 2024
ad6d84c
add `OpTags.IN_PLACE`
crcrpar Jun 15, 2024
acdab60
remove redundant checks
crcrpar Jun 15, 2024
b7ad1e2
remove workaround for in-place add of `num_batches_tracked`
crcrpar Jun 15, 2024
203c396
clena up
crcrpar Jun 15, 2024
54f9fa4
remove redundant copy bsym check
crcrpar Jun 16, 2024
8a04a5e
docstring
crcrpar Jun 16, 2024
e2a5c12
revert outdated diff
crcrpar Jun 17, 2024
c7d13d0
remove comment-out
crcrpar Jun 17, 2024
5afcbbe
`is_inplace` -> `is_functionalizable`
crcrpar Jun 17, 2024
367711f
cleanup imports
crcrpar Jun 17, 2024
1447071
Update thunder/core/transform_common.py
crcrpar Jun 17, 2024
8f6f23b
maintain `_inplace_to_out_of_place` map and give in-place tag to the …
crcrpar Jun 18, 2024
08479ef
support `inplace` args
crcrpar Jun 18, 2024
04808b5
register `relu_`
crcrpar Jun 18, 2024
05605f7
add inplace functionalization test
crcrpar Jun 18, 2024
f88e6d4
fix the bugs exposed by tests
crcrpar Jun 18, 2024
ac5abe0
remove debug prints
crcrpar Jun 18, 2024
494e08f
test ops with inplace args
crcrpar Jun 18, 2024
efea02d
comments
crcrpar Jun 18, 2024
828ba29
another comment
crcrpar Jun 18, 2024
2040d18
comments
crcrpar Jun 18, 2024
acd7655
`relu_` calls `relu(a, True)`
crcrpar Jun 18, 2024
fa9756f
Revert "`relu_` calls `relu(a, True)`"
crcrpar Jun 18, 2024
bae091e
silu sig update
crcrpar Jun 18, 2024
3f03f74
remove unused `TYPE_CHECKING`
crcrpar Jun 19, 2024
d4402b0
`return t` not `return torch_func(...)` for clarity
crcrpar Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
import thunder.core.prims as prims
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.transform_common import dce, EarlyTransform, AdditionalTransform, PostOptimizationTransform
from thunder.core.transform_common import (
dce,
EarlyTransform,
AdditionalTransform,
PostOptimizationTransform,
functionalize_inplace_ops,
)
from thunder.common import (
CompileData,
CompileStats,
Expand Down Expand Up @@ -503,6 +509,9 @@ def get_computation_and_inputs(*args, **kwargs):

prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
if not compile_options.get("skip_inplace_functionalization", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Longer term, I wonder if we should have a set of default transformations and this be one of them, but for now it is OK.

computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc))
computation_trc = computation_traces[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Longer term, I wonder if this should be a "default transform", but maybe it is important that this goes first and so it is tricky with the timing.


if epilogue_trc is not None:
epilogue_traces = [epilogue_trc]
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/langctxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable:
# ctx.get_method throws an AttributeError when the context does not have the requested attribute, except
# for the prims language context, which always throws a ValueError
method: Callable = ctx.get_method(id, *args, **kwargs)
except (AttributeError, ValueError) as e:
except (AttributeError, ValueError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch

return None
return method

Expand Down
1 change: 1 addition & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class OpTags(Enum):
DEVICE_SYNC_OP = auto()
# Labels operations that should not be removed by the dead code elimination (DCE) pass
DONT_DCE = auto()
IN_PLACE = auto()


# TODO RC1 Document this function and describe the parts of a primitive
Expand Down
35 changes: 35 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,13 @@ def __add__(self, other):
method = resolve_method("add", self, other)
return method(self, other)

def __iadd__(self, other):
return self.add_(other)

def add_(self, other):
method = resolve_method("add_", self, other)
return method(self, other)

def __radd__(self, other):
method = resolve_method("add", other, self)
return method(other, self)
Expand Down Expand Up @@ -1427,6 +1434,13 @@ def __mul__(self, other):
method = resolve_method("mul", self, other)
return method(self, other)

def __imul__(self, other):
return self.mul_(other)

def mul_(self, other):
method = resolve_method("mul_", self, other)
return method(self, other)

def __rmul__(self, other):
method = resolve_method("mul", other, self)
return method(other, self)
Expand All @@ -1435,6 +1449,13 @@ def __pow__(self, other):
method = resolve_method("pow", self, other)
return method(self, other)

def __ipow__(self, other):
return self.pow_(other)

def pow_(self, other):
method = resolve_method("pow_", self, other)
return method(self, other)

def __rpow__(self, other):
method = resolve_method("pow", other, self)
return method(other, self)
Expand All @@ -1443,6 +1464,13 @@ def __sub__(self, other):
method = resolve_method("sub", self, other)
return method(self, other)

def __isub__(self, other):
return self.sub_(other)

def sub_(self, other):
method = resolve_method("sub_", self, other)
return method(self, other)

def __rsub__(self, other):
method = resolve_method("sub", other, self)
return method(other, self)
Expand All @@ -1455,6 +1483,13 @@ def __rtruediv__(self, other):
method = resolve_method("true_divide", other, self)
return method(other, self)

def __itruediv__(self, other):
return self.div_(other)

def div_(self, other, *, rounding_mode: str | None = None):
method = resolve_method("div_", self, other, rounding_mode=rounding_mode)
return method(self, other)

#
# Logical operations
#
Expand Down
115 changes: 112 additions & 3 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations
import time
from typing import Any
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from collections.abc import Sequence
from itertools import filterfalse
from functools import partial

import thunder.core.prims as prims
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.proxies import Proxy, variableify, Variable
from thunder.core.pytree import tree_flatten, tree_map
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace
from thunder.core.utils import ProxyDict, producers, check

if TYPE_CHECKING:
from thunder.core.proxies import ProxyInterface
from thunder.core.symbol import Symbol, VariableInterface


#
# Common optimization and transform passes
Expand Down Expand Up @@ -363,3 +368,107 @@ class PostOptimizationTransform(Transform, ABC):
@abstractmethod
def transform_trace(self, computation_trace: Trace, **kwargs):
pass


def functionalize_inplace_ops(computation_trace: Trace) -> list[Trace]:
"""Functionalize in-place ops in ``computation_trace``.

In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`.
This function replaces such in-place ops with out-of-place ops.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"... only if the in-place argument is intermediate to the trace", right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see later "functionalization is not applied, if any of an in-place op's arguments are computation_trace.args or computation_trace.kwargs."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we error / warn in that case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that BatchNorm's num_batches_tracked tensor update is expressed as ltorch.add_(num_batches_tracked, 1) and the tensor is an arg. so this makes sense to me. also, if one or more of args & kwargs are updated in an in-place manner, then I guess there's some intention so I'm not inclined to ban such cases

Note that functionalization is not applied, if any of an in-place op's arguments are
``computation_trace.args`` or ``computation_trace.kwargs``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what we should do in these cases, though, warn, error?


For example, :func:`thunder.torch.add_` is represented as a :class:`thunder.core.symbol.BoundSymbol`
whose `subsymbols` are :func:`thunder.torch.add` and :func:`thunder.core.prims.copy_`. This function
replaces it with a :class:`~thunder.core.symbol.BoundSymbol` of :func:`~thunder.torch.add`.
"""
import thunder.torch

def is_functionalizable(bsym: BoundSymbol) -> bool:
"""Has `OpTags.IN_PLACE` and its args are NOT ``computation_trace.args`` nor ``computation_trace.kwargs``."""
Comment on lines +387 to +388
Copy link
Contributor

@nikitaved nikitaved Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is IN_PLACE actually used here? EDIT: yes, implicitly through being added to the map.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it also true that the trace args/kwargs are also being checked implicitly somewhere outside?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return (
Copy link
Collaborator

@jjsjann123 jjsjann123 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we care about having IN-PLACE tag or not here. since the logic below for replacing doesn't take any consideration like that.

I feel the logic here should just check for torch.xxx_ and see if there is a torch.xxx

If we want to move forward with the in_place tag here, maybe we should maintain a map from in_place to out_of_place function, instead of relying on the trailing underscore.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now here's _inplace_to_out_of_place

bsym.sym in thunder.torch._inplace_to_out_of_place
and bsym.subsymbols
and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we certainly should drop the subsymbols check. This is irrelevant from how this PR is handling functionalization.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is irrelevant from how this PR is handling functionalization.

why is it irrelevant? Currently in-place bsyms have out-of-place and copy as their subsymbols so I think it fair to check the last sub bound symbol is copy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can we tell an appropriate output tensor proxy if a bsym doesn't have a copy_ as its last sub bsym, while avoiding having a lot of new tensor proxy names in a trace?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops. sorry I read your implementation wrong earlier... I thought we are doing a blind torch.xxx_ to torch.xxx replacement but that's not the case. You actually are only looking at the last subsymbol and replacing that one entry only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That feels a bit restricted... But a first step is still better then nothing and I'll stop nitpicking on that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That feels a bit restricted... But a first step is still better then nothing and I'll stop nitpicking on that.

how would it be a bit restricted compared to a blind torch.foo_ to torch.foo replacement?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no that's not what I meant.

By That feels a bit restricted, I'm referring to the alternative that we functionalize directly at the subsymbol prim.copy_ level. But again we don't have to do that in this PR.

)

if not any(is_functionalizable(bsym) for bsym in computation_trace.bound_symbols):
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
return []

# Step 1: return the tensors returned from `prims.copy_` as possible not the args for clarity.
bsym: BoundSymbol
swap_map: dict[VariableInterface, ProxyInterface] = {}
bsyms: list[BoundSymbol] = []
for bsym in computation_trace.bound_symbols:
new_bsym = bsym.from_bsym_swap_proxies(swap_map)

# in-place functionalizable ops has `prims.copy_` as the last subsymbol.
if not is_functionalizable(new_bsym):
bsyms.append(new_bsym)
continue

copy_bsym = bsym.subsymbols[-1]
copy_out = copy_bsym.flat_proxy_outs[0]
copy_dst = copy_bsym.flat_proxy_args[1]
swap_map[variableify(copy_dst)] = copy_out
# make sure an in-place bsym returns `prims.copy_` output
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map, skip_inputs=True, skip_subsymbols=True)
bsyms.append(new_bsym)

intermediate_trace = from_trace(computation_trace)
intermediate_trace.bound_symbols = bsyms[:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: do we need to copy if we del below?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think we strictly need the copy here.

intermediate_trace.set_provenance(TraceProvenance("Intermediate trace of `functionalize_inplace_ops`"))
del bsyms
Comment on lines +418 to +421
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can't we just do intermediate_tensors.bound_symbols = bsyms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I didn't want to reuse bsyms


# Step 2: Remove `prims.copy_` if it's the last one of `bsym.subsymbols`,
# unless `copy_to` is `computation_trace.args` or `computation_trace.kwargs`
trace_args_set = ProxyDict()
for a in filter(
lambda a: isinstance(a, TensorProxy), tree_flatten((computation_trace.args, computation_trace.kwargs))[0]
):
trace_args_set[a] = a
bsym_inplace_to_functional = {}
swap_map.clear()
new_bsyms: list[BoundSymbol] = []
for bsym in intermediate_trace.bound_symbols:
new_bsym = bsym.from_bsym_swap_proxies(swap_map)

if not is_functionalizable(new_bsym):
new_bsyms.append(new_bsym)
continue
copy_bsym = bsym.subsymbols[-1]
copy_return = copy_bsym.flat_proxy_outs[0]
copy_from = copy_bsym.flat_proxy_args[0]
copy_to = copy_bsym.flat_proxy_args[1]
if copy_to in trace_args_set:
new_bsyms.append(new_bsym)
else:
swap_map[variableify(copy_return)] = copy_from
new_bsym.subsymbols = new_bsym.subsymbols[:-1]
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map)
Comment on lines +439 to +448
Copy link
Contributor

@nikitaved nikitaved Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic looks similar to what step 1 is doing. Couldn't they be merged? It seems like the whole thing could be done in a single pass?


functional_sym: Symbol
optional_inplace_arg_index: int
functional_sym, optional_inplace_arg_index = thunder.torch._inplace_to_out_of_place[new_bsym.sym]

flat_args, flat_args_spec = tree_flatten((new_bsym.args, new_bsym.kwargs))
if optional_inplace_arg_index > -1:
flat_args[optional_inplace_arg_index] = False
Comment on lines +455 to +456
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably needs a comment.

args, kwargs = tree_unflatten(flat_args, flat_args_spec)
new_functional_bsym = functional_sym.bind(
*args,
**kwargs,
output=new_bsym.output,
subsymbols=new_bsym.subsymbols,
_call_ctx=new_bsym._call_ctx,
)
new_bsyms.append(new_functional_bsym)
bsym_inplace_to_functional[new_bsym] = new_functional_bsym
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bsym_inplace_to_functional is never read from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right but I once tried to register this as an attribute of provenance at L473


functionalized_computation_trace = from_trace(computation_trace)
functionalized_computation_trace.bound_symbols = new_bsyms
functionalized_computation_trace.set_provenance(TraceProvenance("Functionalize in-place ops"))
# note(crcrpar): I kind of want to do the following two.
# functionalized_computation_trace._provenance.swap_map = swap_map
# functionalized_computation_trace._provenance.bsym_inplace_to_functional = bsym_inplace_to_functional
return [intermediate_trace, functionalized_computation_trace]
30 changes: 0 additions & 30 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,16 +1278,10 @@ def _abs_torch(x: torch.Tensor | Number):
elementwise_unary_ops.append(signbit_opinfo)


def silu_error_generator(op, device, dtype=torch.float32, **kwargs):
a = make_tensor((), dtype=dtype, device=device)
yield (SampleInput(a, inplace=True), NotImplementedError, "Thunder only supports silu with inplace=False")
crcrpar marked this conversation as resolved.
Show resolved Hide resolved


silu_opinfo = OpInfo(
ltorch.silu,
dtypes=(datatypes.floating,),
sample_input_generator=partial(elementwise_unary_generator, supports_numbers=False),
error_input_generator=silu_error_generator,
torch_reference=_elementwise_unary_torch(torch.nn.functional.silu),
test_directives=(
DecorateInfo(
Expand Down Expand Up @@ -1623,20 +1617,9 @@ def silu_error_generator(op, device, dtype=torch.float32, **kwargs):
elementwise_unary_ops.append(reciprocal_opinfo)


def relu_error_generator(op, device, dtype=torch.float32, **kwargs):
a = make_tensor((), dtype=dtype, device=device)
yield (SampleInput(a, inplace=True), NotImplementedError, "relu only supports inplace=False")


def relu6_error_generator(op, device, dtype=torch.float32, **kwargs):
a = make_tensor((), dtype=dtype, device=device)
yield (SampleInput(a, inplace=True), NotImplementedError, "relu6 only supports inplace=False")


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
error_input_generator=relu_error_generator,
torch_reference=_elementwise_unary_torch(torch.relu),
test_directives=(
# PyTorch does not support bool and complex types
Expand Down Expand Up @@ -1665,7 +1648,6 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs):
relu6_opinfo = OpInfo(
ltorch.relu6,
sample_input_generator=elementwise_unary_generator,
error_input_generator=relu6_error_generator,
torch_reference=_elementwise_unary_torch(torch.nn.functional.relu6),
test_directives=(
# PyTorch does not support bool for both CPU and CUDA relu6
Expand All @@ -1684,15 +1666,9 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs):
elementwise_unary_ops.append(relu6_opinfo)


def hardswish_error_generator(op, device, dtype=torch.float32, **kwargs):
a = make_tensor((), dtype=dtype, device=device)
yield (SampleInput(a, inplace=True), NotImplementedError, "hardswish only supports inplace=False")


hardswish_opinfo = OpInfo(
ltorch.hardswish,
sample_input_generator=elementwise_unary_generator,
error_input_generator=hardswish_error_generator,
torch_reference=_elementwise_unary_torch(torch.nn.functional.hardswish),
dtypes=(datatypes.floating,),
test_directives=(
Expand All @@ -1713,16 +1689,10 @@ def hardswish_error_generator(op, device, dtype=torch.float32, **kwargs):
elementwise_unary_ops.append(hardswish_opinfo)


def selu_error_generator(op, device, dtype=torch.float32, **kwargs):
a = make_tensor((), dtype=dtype, device=device)
yield (SampleInput(a, inplace=True), NotImplementedError, "selu only supports inplace=False")


selu_opinfo = OpInfo(
ltorch.selu,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_generator,
error_input_generator=selu_error_generator,
torch_reference=_elementwise_unary_torch(torch.selu),
test_directives=(
# Some versions of PyTorch do not support CPU float16 selu
Expand Down
3 changes: 2 additions & 1 deletion thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,7 +2121,8 @@ def test_xor(s, o):

for t in tests:
cfn = thunder.jit(t)
with pytest.raises(RuntimeError, match="not supported"):
# Some ops of `tests` already have in-place supported, leading to broadcast error
with pytest.raises(RuntimeError, match="not supported|Attempting"):
cfn(t1, t2)
# Note: Python maps inplace operations on (immutuables) to
# out of place operations, NumberProxy does this, too.
Expand Down
5 changes: 3 additions & 2 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ class Net(nn.Module):
def __init__(self):
super().__init__()
self.dense1_bn = nn.BatchNorm3d(2, track_running_stats=True)
# To address the failure, use a workaround since `add_` is utilized in `nn.BatchNorm3d` when `num_batches_tracked` is not None.
self.dense1_bn.num_batches_tracked = None

def forward(self, x):
x = self.dense1_bn(x)
Expand All @@ -112,6 +110,9 @@ def forward(self, x):
assert_close(thunder_out, torch_out)
assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"])
assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"])
assert_close(
net.state_dict()["dense1_bn.num_batches_tracked"], torch_net.state_dict()["dense1_bn.num_batches_tracked"]
)
assert_close(x.grad, x1.grad)


Expand Down
Loading
Loading