Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ThunderFX][HF] ValueError: unrecognized type in arguments: <class 'NoneType'> #1482

Closed
wprazuch opened this issue Nov 27, 2024 · 2 comments · Fixed by #1563
Closed

[ThunderFX][HF] ValueError: unrecognized type in arguments: <class 'NoneType'> #1482

wprazuch opened this issue Nov 27, 2024 · 2 comments · Fixed by #1563
Assignees
Labels
high priority nemo Issues needed to support NVIDIA NeMo models. sdpa thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@wprazuch
Copy link
Contributor

wprazuch commented Nov 27, 2024

🐛 Bug

When running train loop for Qwen/Qwen2.5-7B-Instruct, we get ValueError in ctx.compiled_backward:

ValueError: unrecognized type in arguments: <class 'NoneType'>

To Reproduce

  1. Use pjnl-20241120 container
  2. Run pip install datasets==3.0.2 for creating dummy dataset
  3. Run this small example:
import time
import torch
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler
import thunder
import thunder.dynamo
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform


def make_dummy_dataset(tokenizer, seq_len, batch_size, n=1):
    data = {'text': "Below is an instruction ... endoftext|>"}

    def fmt(example):
        ans = tokenizer(example['text'], padding="max_length", truncation=True, max_length=seq_len)
        tokens = ans['input_ids']
        return {'tokens': torch.tensor(tokens, dtype=torch.long), 'labels': torch.tensor(tokens[1:] + [tokens[-1]], dtype=torch.long)}

    from datasets import Dataset

    dataset = Dataset.from_dict({"text": [data['text'] for _ in range(n)]})
    dataset = dataset.map(fmt, batched=False, batch_size=batch_size)
    dataset.set_format(type='torch')
    return dataset

def main():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct", torch_dtype=torch.bfloat16)
    config.num_hidden_layers = 2
    model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen2.5-7B-Instruct",
        torch_dtype=torch.bfloat16,
    )

    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        model.resize_token_embeddings(len(tokenizer))

    xforms: list = [NvtxProfileTransform()]
    be = thunder.dynamo.ThunderCompiler(transforms=xforms)
    model.compile(backend=be)
    
    dataset = make_dummy_dataset(
        tokenizer, 128, 1, n=10
    )
    dataloader = DataLoader(
        dataset, batch_size=1, shuffle=True, drop_last=True
    )

    optimizer = torch.optim.SGD(model.parameters(), lr=5e-5)
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=1 * len(dataloader),
    )

    model.to(device)
    model.train()
    iter_times = []

    for epoch in range(1):
        total_loss = 0
        for i, batch in enumerate(dataloader):
            iter_t0 = time.perf_counter()
            input_ids = batch["tokens"].to(device)
            attention_mask = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
            )
            loss = outputs.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            lr_scheduler.step()
            iter_t1 = time.perf_counter()
            iter_times.append(iter_t1 - iter_t0)

        avg_loss = total_loss / len(dataloader)


if __name__ == "__main__":
    main()
  1. See the error:
minimal_qwen.py", line 77, in main
    optimizer.step()
    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 600, in wrapper
    outputs = fn(ctx, *args)
              ^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 115, in backward
    grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/torchex.py", line 179, in no_autocast_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "thunder.backward_fn_12", line 197, in backward_fn
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 443, in __call__
    fd = self.get_fd(self.to_descriptors(args))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 423, in to_descriptors
    return tuple(to_descriptor(proxy_arg, arg) for proxy_arg, arg in zip(proxy_args, args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 423, in <genexpr>
    return tuple(to_descriptor(proxy_arg, arg) for proxy_arg, arg in zip(proxy_args, args))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 421, in to_descriptor
    raise ValueError(f"unrecognized type in arguments: {type(arg)}")
ValueError: unrecognized type in arguments: <class 'NoneType'>

Code sample

As in the repro steps.

Expected behavior

It should run smoothly as it runs in eager and with default torch.compile

Environment

As in the container.

Additional context

Happy to provide any information if needed :)

cc @apaz-cli @tfogal

@wprazuch
Copy link
Contributor Author

In newer container releases, it is shadowed by: #1479

@kiya00
Copy link
Collaborator

kiya00 commented Nov 28, 2024

The error is because the last output of sdpaex_scaled_dot_product_efficient_attention_backward is infered as TensorProxy in metafunc(attn_mask is not None but atten_mask.requires_grad is False) but the actual output value is None:

grad_attn_mask = None
if attn_mask is not None:
grad_attn_mask = TensorProxy(like=attn_mask, shape=attn_mask.shape)
# Return gradients for query, key, value, and attn_mask tensor inputs
return (grad_query, grad_key, grad_value, grad_attn_mask)

according to https://github.com/pytorch/pytorch/blob/6b430c26bd78cf9f3736e0f9caf23f40e2a867f1/torch/_meta_registrations.py#L5318-L5329
grad_attn_mask is None when grad_input_mask[3] is False even though attn_mask is not None

And in the actual execution, the last output is None (the attn_mask.requires_grad is False)

grad_input_mask = [a.requires_grad for a in (query, key, value)]
if attn_mask is None:
grad_input_mask.append(False)
else:
grad_input_mask.append(attn_mask.requires_grad)
# Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L394-L415
return torch.ops.aten._scaled_dot_product_efficient_attention_backward(
grad_out,
_sdpa_enforce_input_tensor_contiguity(query),
_sdpa_enforce_input_tensor_contiguity(key),
_sdpa_enforce_input_tensor_contiguity(value),
_attention_mask_memory_efficient_helper(attn_mask, query),
out,
logsumexp,
philox_seed,
philox_offset,
dropout_p,
grad_input_mask,
is_causal,
scale=scale,
)

@tfogal tfogal added high priority nemo Issues needed to support NVIDIA NeMo models. thunderfx for things that could be applicable to the dynamo+thunder frontend labels Dec 13, 2024
@nvMelissa nvMelissa added the sdpa label Dec 16, 2024
kiya00 added a commit that referenced this issue Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority nemo Issues needed to support NVIDIA NeMo models. sdpa thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
4 participants