Skip to content

Commit

Permalink
remove requires_desugarring
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 28, 2024
1 parent 71f6474 commit 148fc8c
Showing 1 changed file with 20 additions and 55 deletions.
75 changes: 20 additions & 55 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,40 +187,19 @@ class DesugarTensorSubclass:
computation_trace: TraceCtx
swap_map: dict[Variable, ProxyInterface] = field(init=False, default_factory=dict)
fake_tensor_mode: FakeTensorMode = field(init=False, default_factory=FakeTensorMode)
fx_computation_trace: GraphModule = field(init=False, default=None)
computation_trace_output: tuple[OutputWrapperForFxTracing, ...] = field(init=False, default=None)
fx_computation_trace_result: tuple[torch.Tensor, ...] = field(init=False, default=None)
spec_of_fx_computation_trace_result: PyTreeSpec = field(init=False, default=None)
flat_trace_args: Sequence[ProxyInterface] = field(init=False, default=None)
flat_trace_args_spec: Any = field(init=False, default=None)
requires_desugarring: bool = field(init=False, default=False)
subclass_proxy_to_flatten: set[Variable] = field(init=False, default_factory=set)
bsym_to_new_outputs: dict[BoundSymbol, list[TensorProxy]] = field(init=False, default_factory=dict)

def __post_init__(self) -> None:
self.flat_trace_args, self.flat_trace_args_spec = tree_flatten(
(self.computation_trace.args, self.computation_trace.kwargs)
)
for arg in self.flat_trace_args:
# self.maybe_update_subclass_type_dict(arg)
if isinstance(arg, SubclassTensorProxy):
self.subclass_proxy_to_flatten.add(variableify(arg))

# TODO(crcrpar): From my perspective, this check is rather for the sake of faster compilation.
# There could be a computation graph where none of the inputs are subclass while
# that graph call subclass creation inside of it.
self.requires_desugarring = any(isinstance(t, SubclassTensorProxy) for t in self.flat_trace_args)
if not self.requires_desugarring:
return

(
self.fx_computation_trace,
self.computation_trace_output,
self.fx_computation_trace_result,
self.spec_of_fx_computation_trace_result,
) = self.convert_trace_to_fx_graph_and_get_fake_result(
self.computation_trace,
)

def _get_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]:
return p._tensor_attr_names

Expand All @@ -230,7 +209,7 @@ def _get_non_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]:
def translate_fx_graph_into_bsym(
self,
bsym: BoundSymbol,
fx: GraphModule,
fx_graph: GraphModule,
) -> BoundSymbol | tuple[BoundSymbol, ...]:
import thunder.torch as ltorch

Expand Down Expand Up @@ -272,7 +251,7 @@ def translate_fx_graph_into_bsym(
list_of_placeholder_node: list[Node] = []
list_of_function_call_node: list[Node] = []
node_of_output: Node
for node in fx.graph.nodes:
for node in fx_graph.graph.nodes:
if node.op == PLACEHOLDER:
list_of_placeholder_node.append(node)
if node.op == CALL_FUNCTION:
Expand Down Expand Up @@ -310,6 +289,7 @@ def translate_fx_graph_into_bsym(
if is_subclass_ctor_bsym := bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR:
utils.check_type(orig_output, SubclassTensorProxy)
if isinstance(orig_output, SubclassTensorProxy):
# note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors.
args: list[Node] = node_of_output.args[0]
new_tensor_proxies = []
for a in args:
Expand All @@ -328,6 +308,15 @@ def translate_fx_graph_into_bsym(
new_subclass = orig_output.replace()
for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies):
setattr(new_subclass, name, value)
bsyms.append(
prims.unflatten_tensor_subclass.bind(
new_subclass._subclass_type,
dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)),
dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)),
output=new_subclass,
)
)

self.swap_map[variableify(orig_output)] = new_subclass
return bsyms

Expand Down Expand Up @@ -423,35 +412,8 @@ def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing,
def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map)
if updated_bsym.sym.id == prims.PrimIDs.RETURN:
unflatten_fake_tensor_result = tree_unflatten(
self.fx_computation_trace_result,
self.spec_of_fx_computation_trace_result,
)
outputs: dict[str, Any] = updated_bsym.args[0] # {"output": ..., "flat_args": ...}
utils.check_type(outputs, dict)
utils.check(
isinstance(outputs, dict) and len(outputs) == 2 and ("output", "flat_args") == tuple(outputs.keys()),
lambda: fr"{outputs=} does not conform to the format of \{'output': ..., 'flat_args': [...]\}",
)
seq_outs = utils.sequencify(outputs["output"])
seq_fake_ret = utils.sequencify(unflatten_fake_tensor_result["output"])
utils.check(
len(seq_outs) == len(seq_fake_ret),
lambda: f"{outputs['output']=}, {unflatten_fake_tensor_result['output']=}",
)

bsyms: list[BoundSymbol] = []
for proxy_output, fx_output in zip(seq_outs, seq_fake_ret):
if not isinstance(proxy_output, SubclassTensorProxy):
continue
tensor_attrs, metadata = proxy_output.__tensor_flatten__()
tensors = [getattr(proxy_output, name) for name in tensor_attrs]
bsyms.append(
prims.unflatten_tensor_subclass.bind(
type(fx_output), dict(zip(tensor_attrs, tensors)), metadata, output=proxy_output
)
)
return [*bsyms, updated_bsym]
if not self.subclass_proxy_to_flatten or True:
return [updated_bsym]

is_subclass_ctor = bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR
if not is_subclass_ctor and not any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args):
Expand Down Expand Up @@ -505,6 +467,8 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
self.swap_map.update(dict(zip(sequence_out, utils.sequencify(out_proxy))))

bsym_with_modified_output = updated_bsym.from_bsym_swap_proxies(self.swap_map)

self.bsym_to_new_outputs[bsym_with_modified_output] = bsym_with_modified_output
return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx)


Expand Down Expand Up @@ -535,14 +499,15 @@ def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx:
behavior is spelled out.
"""
desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=computation_trace)
if not desugar_tensor_subclass.requires_desugarring:
return computation_trace
updated_bsyms: list[BoundSymbol] = []
bsym: BoundSymbol
for bsym in computation_trace.bound_symbols:
maybe_desugared_bsyms = desugar_tensor_subclass(bsym)
updated_bsyms.extend(maybe_desugared_bsyms)

if not desugar_tensor_subclass.subclass_proxy_to_flatten:
return computation_trace

computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace)
computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms)
computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared"))
Expand Down

0 comments on commit 148fc8c

Please sign in to comment.