Skip to content

Commit

Permalink
fake_tensor.foo -> foo
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 6, 2024
1 parent c069073 commit 6043450
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses import fake_tensor
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensorMode
Expand Down Expand Up @@ -59,7 +58,7 @@ class OutputWrapperForFxTracing(NamedTuple):
metadata: dict[str, Any] | None


def _materialize_tensor_proxy(t: TensorProxy, fake_tensor_mode: fake_tensor.FakeTensorMode | None) -> torch.Tensor:
def _materialize_tensor_proxy(t: TensorProxy, fake_tensor_mode: FakeTensorMode | None) -> torch.Tensor:
shape = t.shape
device = devices.to_torch_device(t.device)
dtype = dtypes.to_torch_dtype(t.dtype)
Expand All @@ -77,7 +76,7 @@ def _materialize_tensor_proxy(t: TensorProxy, fake_tensor_mode: fake_tensor.Fake

def _make_fake_subclass_tensor_from_subclass_tensor_proxy(
tensor_proxy: SubclassTensorProxy,
fake_tensor_mode: fake_tensor.FakeTensorMode,
fake_tensor_mode: FakeTensorMode,
subclass_to_attr_names: dict[_TensorMeta, tuple[list[str], list[str]]],
) -> torch.Tensor:
utils.check(
Expand Down Expand Up @@ -132,8 +131,8 @@ def maybe_materialize_tensor(
return t.value


def proxy_fake_tensor(t: torch.Tensor | fake_tensor.FakeTensor) -> ProxyInterface:
if isinstance(t, fake_tensor.FakeTensor) or (isinstance(t, torch.Tensor) and not issubclass(type(t), torch.Tensor)):
def proxy_fake_tensor(t: torch.Tensor | FakeTensor) -> ProxyInterface:
if isinstance(t, FakeTensor) or (isinstance(t, torch.Tensor) and not issubclass(type(t), torch.Tensor)):
return TensorProxy(
None,
shape=list(t.shape),
Expand Down Expand Up @@ -192,7 +191,7 @@ def aten_core_ir_op_to_ltorch_op(aten_op: OpOverload) -> Symbol:
class DesugarTensorSubclass:
computation_trace: TraceCtx
swap_map: dict[Variable, ProxyInterface] = field(init=False, default_factory=dict)
fake_tensor_mode: fake_tensor.FakeTensorMode = field(init=False, default_factory=fake_tensor.FakeTensorMode)
fake_tensor_mode: FakeTensorMode = field(init=False, default_factory=FakeTensorMode)
fx_computation_trace: GraphModule = field(init=False, default=None)
computation_trace_output: tuple[OutputWrapperForFxTracing, ...] = field(init=False, default=None)
fx_computation_trace_result: tuple[torch.Tensor, ...] = field(init=False, default=None)
Expand Down

0 comments on commit 6043450

Please sign in to comment.