Skip to content

The example_input_array in LightningModule does not support using NestedTensor with layout=torch.jagged within the model. #21588

@Zhiyong-Xu2

Description

@Zhiyong-Xu2

Bug description

The example_input_array in LightningModule does not support using NestedTensor with layout=torch.jagged within the model. Although example_input_array consists of torch.Tensor, an error occurs when I use NestedTensor with layout=torch.jagged in the model. Here I provide an alternative version of the multi-head attention layer that supports NestedTensor. We flatten the other dimensions into the second dimension t, which is an irrelevant dimension. Due to the limitations of NestedTensor, I have to perform these cumbersome dimension transformations.

What version are you seeing the problem on?

master

Reproduced in studio

No response

How to reproduce the bug

class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        emb_dim (int): Total embedding dim of combined heads post input projection. Each head
            has dim emb_dim // nheads
        nheads (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        bias (bool, optional): Whether to add bias to input projection. Default: True
    """

    def __init__(
        self,
        E_q: int,
        E_k: int,
        E_v: int,
        emb_dim: int,
        nheads: int,
        dropout: float = 0.0,
        bias=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = { "device": device, "dtype": dtype }
        super().__init__()
        self.nheads = nheads
        self.dropout = dropout
        self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
        if self._qkv_same_embed_dim:
            self.packed_proj = nn.Linear(E_q, emb_dim * 3, bias=bias, **factory_kwargs)
        else:
            self.q_proj = nn.Linear(E_q, emb_dim, bias=bias, **factory_kwargs)
            self.k_proj = nn.Linear(E_k, emb_dim, bias=bias, **factory_kwargs)
            self.v_proj = nn.Linear(E_v, emb_dim, bias=bias, **factory_kwargs)
        E_out = E_q
        self.out_proj = nn.Linear(emb_dim, E_out, bias=bias, **factory_kwargs)
        assert emb_dim % nheads == 0, "Embedding dim is not divisible by nheads"
        assert (emb_dim // nheads) % 8 == 0, "emb_dim//nheads dim is not a multiple of 8"
        self.E_head = emb_dim // nheads
        self.bias = bias

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        is_causal=False,
    ) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape [B, t, N_i_q, E_q] NestedTensor, t is an irrelevant dimension.
            key (torch.Tensor): key of shape [B, t, N_i_kv, E_k] NestedTensor, t is an irrelevant dimension.
            value (torch.Tensor): value of shape [B, t, N_i_kv, E_v] NestedTensor, t is an irrelevant dimension.
            is_causal (bool, optional): Whether to apply causal mask. Default: False

        Returns:
            attn_output (torch.Tensor): [B, t, N_i_q, E_q]
        """
        un_flag = False
        if query.dim() == 3:
            query = query.unflatten(-1, [1, -1]).transpose(1, 2)
            key = key.unflatten(-1, [1, -1]).transpose(1, 2)
            value = value.unflatten(-1, [1, -1]).transpose(1, 2)
            un_flag = True

        if self._qkv_same_embed_dim:
            if query is key and key is value:
                result = self.Bproj(self.packed_proj, query)
                query, key, value = self.Bchunk(result, 3, dim=-1)
            else:
                q_weight, k_weight, v_weight = torch.chunk(
                    self.packed_proj.weight, 3, dim=0,
                )
                if self.bias:
                    q_bias, k_bias, v_bias = torch.chunk(
                        self.packed_proj.bias, 3, dim=0,
                    )
                else:
                    q_bias, k_bias, v_bias = None, None, None
                query, key, value = (
                    self.Bproj(F.linear, query, q_weight, q_bias),
                    self.Bproj(F.linear, key, k_weight, k_bias),
                    self.Bproj(F.linear, value, v_weight, v_bias),
                )
        else:
            query = self.Bproj(self.q_proj, query)
            key = self.Bproj(self.k_proj, key)
            value = self.Bproj(self.v_proj, value)

        query, q_eps_shape = self.trans_shape(query)
        key, k_eps_shape = self.trans_shape(key)
        value, v_eps_shape = self.trans_shape(value)

        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=self.dropout, is_causal=is_causal,
        )
        attn_output = self.inv_shape(attn_output, q_eps_shape)
        attn_output = self.Bproj(self.out_proj, attn_output)

        if un_flag:
            return self.Bsqueeze(attn_output)
        return attn_output

    def Bproj(self, proj_func, x, *args):
        x = x.transpose(1, 2)  # [b, n, x, c]
        x = proj_func(x, *args)
        x = x.transpose(1, 2)  # [b, x, n, c]
        return x

    def Bchunk(self, x, chunks, dim):
        x = x.transpose(1, 2)  # [b, n, x, c]
        x_l = torch.chunk(x, chunks, dim=dim)
        return [x_.transpose(1, 2) for x_ in x_l]

    def Bsqueeze(self, x):
        x = x.transpose(1, 2)  # [b, n, 1, c]
        x = x.flatten(-2)  # [b, n, c]
        return x

    def trans_shape(self, x):
        x = x.transpose(1, 2)  # [b, n, x, c]
        eps_shape = x.size(-2)
        x = x.unflatten(-1, [self.nheads, self.E_head])  # [b, n, x, h, E_head]
        x = x.flatten(2, -2)  # [b, n, xh, E_head]
        x = x.transpose(1, 2)  # [b, xh, n, E_head]
        return x, eps_shape

    def inv_shape(self, x, eps_shape):
        x = x.transpose(1, 2)  # [b, n, xh, c]
        x = x.unflatten(-2, [eps_shape, self.nheads])  # [b, n, x, h, c]
        x = x.flatten(-2)  # [b, n, x, h*c]
        x = x.transpose(1, 2)  # [b, x, n, c]
        return x

Error messages and logs

# Error messages and logs here please
Traceback (most recent call last):
  File "/home/xzy/PycharmProjects/DyG/args.py", line 146, in <module>
    cli_main(
  File "/home/xzy/PycharmProjects/DyG/args.py", line 127, in cli_main
    cli = MyLightningCLI(
          ^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 421, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 759, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 584, in fit
    call._call_and_handle_interrupt(
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 49, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 630, in _fit_impl
    self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1057, in _run
    call._call_callback_hooks(self, "on_fit_start")
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 228, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_summary.py", line 64, in on_fit_start
    model_summary = self._summary(trainer, pl_module)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_summary.py", line 90, in _summary
    return summarize(pl_module, max_depth=self._max_depth)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 565, in summarize
    return ModelSummary(lightning_module, max_depth=max_depth)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 228, in __init__
    self._layer_summary = self.summarize()
                          ^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 328, in summarize
    self._forward_example_input()
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 369, in _forward_example_input
    model(**input_)
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/PycharmProjects/DyG/models.py", line 115, in forward
    return self.model(
           ^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/PycharmProjects/DyG/layers4.py", line 545, in forward
    x = self.block(x, inp_time_emb, node_emb, batch, edge_index, seed, edge_weight)  # [iT, BN, emb_dim]
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torch_geometric.nn.sequential_Sequential_390062_b66vhssz.py", line 21, in forward
    x = self.module_0(x, time_emb, node_emb, batch, edge_index, seed, edge_weight)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/PycharmProjects/DyG/layers4.py", line 383, in forward
    gx = self.global_block(x + node_emb, batch)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/PycharmProjects/DyG/layers4.py", line 303, in forward
    out = self.att(x, x, x)  # [B, t, N_i, C]
          ^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/PycharmProjects/DyG/utils.py", line 346, in forward
    result = self.Bproj(self.packed_proj, query)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/PycharmProjects/DyG/utils.py", line 384, in Bproj
    x = proj_func(x, *args)
        ^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py", line 353, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 767, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 766, in decompose
    return self._op_dk(dk, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 767, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
    return self.py_kernels[dk](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_prims_common/wr
      (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (out_layer): Sequential(
  (out_layer): Sequential(
  (out_layer): Sequential(
appers.py", line 289, in _fn
    result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_decomp/decompositions.py", line 4437, in matmul
    t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 767, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 766, in decompose
    return self._op_dk(dk, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 772, in __torch_dispatch__
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py", line 325, in __torch_dispatch__
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/ops.py", line 218, in inner
    return func(aten_op, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/ops.py", line 1546, in view_default
    raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
RuntimeError: view(): cannot view shape (8, j1, 12, 64) as [96*j1, 64]

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.6.0
#- PyTorch Version (e.g., 2.5): 2.6.0+cuu124
#- Python version (e.g., 3.12): 3.12.12
#- OS (e.g., Linux): ubuntu
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

I look forward to fixing this bug in the future.

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions