Skip to content

Commit

Permalink
some tweaks
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 8, 2024
1 parent f8e0803 commit abf0167
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
65 changes: 36 additions & 29 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.baseutils import ProxyInterface
from types import FunctionType


OPTREE_NAMESPACE = "thunder"

Expand All @@ -30,36 +30,43 @@
)


allowed_types = {
FunctionType,
dict,
list,
str,
int,
bool,
tuple,
NamedTuple,
torch.dtype,
float,
dtypes.floating,
dtypes.bool_,
devices.Device,
torch.memory_format,
type(None),
slice,
complex,
type,
type(Ellipsis),
torch.Size,
torch.finfo,
dtypes.signedinteger,
# FakeTensor type is used for automatic registration of torch ops
torch._subclasses.fake_tensor.FakeTensor,
torch.device,
torch.autograd.function.FunctionCtx,
}


def register_type(typ):
allowed_types.add(typ)


def tree_flatten(args, namespace=OPTREE_NAMESPACE):
if (
type(args)
not in {
FunctionType,
dict,
list,
str,
int,
bool,
tuple,
torch.dtype,
float,
dtypes.floating,
dtypes.bool_,
devices.Device,
torch.memory_format,
type(None),
slice,
complex,
type,
type(Ellipsis),
torch.Size,
torch.finfo,
dtypes.signedinteger,
# FakeTensor type is used for automatic registration of torch ops
torch._subclasses.fake_tensor.FakeTensor,
torch.device,
torch.autograd.function.FunctionCtx,
}
type(args) not in allowed_types
and not isinstance(args, (ProxyInterface))
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
# may mark some of the operation's outputs as unused
some_unused = False
for out in bsym.flat_proxy_outs:
if variableify(out) in needed_proxies and producer_map[out] == bsym:
if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym:
needed = True
else:
some_unused = True
Expand Down

0 comments on commit abf0167

Please sign in to comment.