Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunderfx : detecting parameters and buffers on thunderfx path #1575

Open
kshitij12345 opened this issue Dec 19, 2024 · 1 comment
Open

thunderfx : detecting parameters and buffers on thunderfx path #1575

kshitij12345 opened this issue Dec 19, 2024 · 1 comment
Labels
jit thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@kshitij12345
Copy link
Collaborator

The FXGraph provided by Dynamo takes in Parameters and Buffers as arguments, however thunder.jit currently only determines a TensorProxy to be a parameter if it is unpacked from a Module. So, on thunderfx path, we don't tag these parameters with STATIC_MEMORY_LOCATION, leading to problem with CUDAGraphTransform and ExtraionOnlyPrologueTransform which depend on these tags.

if typ == "_parameters":
bsym = prims.unpack_parameter.bind(root_module, name, output=output)
output.tags.add(ProxyTag.STATIC_MEMORY_LOCATION)
elif typ == "_buffers":
bsym = prims.unpack_buffer.bind(root_module, name, output=output)
output.tags.add(ProxyTag.STATIC_MEMORY_LOCATION)
elif typ == "_modules":
bsym = prims.unpack_submodule.bind(root_module, name, output=output)

Potential Solution for Parameters : For parameters, maybe thunder.jit tag Proxies based on isinstance(obj, nn.Parameter).

Sample:

import torch

def backend(gm, sample_arg):
    gm.print_readable()
    print(sample_arg)
    return gm

model = torch.nn.Linear(2, 2)

cmodel = torch.compile(model, backend=backend)
cmodel(torch.randn(1, 2))

Output

class GraphModule(torch.nn.Module):
    def forward(self, L_fn_parameters_weight_: "f32[2, 2]", L_fn_parameters_bias_: "f32[2]", L_args_0_: "f32[1, 2]"):
        l_fn_parameters_weight_ = L_fn_parameters_weight_
        l_fn_parameters_bias_ = L_fn_parameters_bias_
        l_args_0_ = L_args_0_
        
         # File: /home/kkalambarkar/git/pytorch/torch/_dynamo/external_utils.py:31 in inner, code: return fn(*args, **kwargs)
        linear: "f32[1, 2]" = torch._C._nn.linear(l_args_0_, l_fn_parameters_weight_, l_fn_parameters_bias_);  l_args_0_ = l_fn_parameters_weight_ = l_fn_parameters_bias_ = None
        return (linear,)
        
[Parameter containing:
tensor([[ 0.6478,  0.6590],
        [-0.5319, -0.3303]], requires_grad=True), Parameter containing:
tensor([0.3209, 0.6565], requires_grad=True), tensor([[ 0.7282, -0.3549]])]
{}
@kshitij12345 kshitij12345 added jit thunderfx for things that could be applicable to the dynamo+thunder frontend labels Dec 19, 2024
@kshitij12345
Copy link
Collaborator Author

Example of ExtrationOnlyPrologueTransform not working

import torch
import thunder
import thunder.transforms
from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform
from thunder.dynamo import thunderfx

model = torch.nn.Linear(16, 16)
x = torch.randn(16, 16)

cmodel = thunderfx(model, transforms=[ExtractionOnlyPrologueTransform()])
_ = cmodel(x)

assert len(cmodel._backend.subgraph_infos) == 1
subgraph_info = cmodel._backend.subgraph_infos[0]
thunder_fn = subgraph_info.thunder_compiled_fns[0]
original_subgraph = subgraph_info.original_graph_module

prl_trc = thunder.last_prologue_traces(thunder_fn)[-1]

with open("prologue_trc.py", "w") as f:
    f.write(str(prl_trc))

Prologue Trace

# Constructed by Transform for execution (took 3 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  check_len(args, 3)
    # prims.check_len(args, 3)
  # kwargs: "Any"
  check_len(kwargs, 0)
    # prims.check_len(kwargs, 0)
  l_args_0_: "cpu f32[16, 16]" = args[0]
  l_fn_parameters_weight_: "cpu f32[16, 16]" = args[1]
  l_fn_parameters_bias_: "cpu f32[16]" = args[2]
  check_tensor_metadata(l_args_0_, (16, 16), 'cpu', torch.float32, False)
    # prims.check_tensor_shape_and_metadata(l_args_0_, (16, 16), 'cpu', torch.float32, False)
  check_tensor_metadata(l_fn_parameters_weight_, (16, 16), 'cpu', torch.float32, True)
    # prims.check_tensor_shape_and_metadata(l_fn_parameters_weight_, (16, 16), 'cpu', torch.float32, True)
  check_tensor_metadata(l_fn_parameters_bias_, (16,), 'cpu', torch.float32, True)
    # prims.check_tensor_shape_and_metadata(l_fn_parameters_bias_, (16,), 'cpu', torch.float32, True)
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  check_literal_like(cache_info_default_dtype, torch.float32)
    # prims.check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
  check_literal_like(cache_info_default_device, torch.device("cpu"))
    # prims.check_literal_like(cache_info_default_device, torch.device("cpu"))
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  check_number_type_and_value(cache_info_is_autocast_enabled, False)
    # prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  check_number_type_and_value(cache_info_no_grad_sync, False)
    # prims.check_number_type_and_value(cache_info_no_grad_sync, False)
  cache_info_alias_tensor_indices: "str" = cache_info['alias_tensor_indices']
  check_string_value(cache_info_alias_tensor_indices, '')
    # prims.check_string_value(cache_info_alias_tensor_indices, '')
  cache_info_is_grad_enabled: "bool True" = cache_info['is_grad_enabled']
  check_number_type_and_value(cache_info_is_grad_enabled, True)
    # prims.check_number_type_and_value(cache_info_is_grad_enabled, True)
  return ((l_args_0_, l_fn_parameters_weight_, l_fn_parameters_bias_), ())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jit thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

No branches or pull requests

1 participant