-
Notifications
You must be signed in to change notification settings - Fork 3.7k
The example_input_array in LightningModule does not support using NestedTensor with layout=torch.jagged within the model. #21588
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
The example_input_array in LightningModule does not support using NestedTensor with layout=torch.jagged within the model. Although example_input_array consists of torch.Tensor, an error occurs when I use NestedTensor with layout=torch.jagged in the model. Here I provide an alternative version of the multi-head attention layer that supports NestedTensor. We flatten the other dimensions into the second dimension t, which is an irrelevant dimension. Due to the limitations of NestedTensor, I have to perform these cumbersome dimension transformations.
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
class MultiHeadAttention(nn.Module):
"""
Computes multi-head attention. Supports nested or padded tensors.
Args:
E_q (int): Size of embedding dim for query
E_k (int): Size of embedding dim for key
E_v (int): Size of embedding dim for value
emb_dim (int): Total embedding dim of combined heads post input projection. Each head
has dim emb_dim // nheads
nheads (int): Number of heads
dropout (float, optional): Dropout probability. Default: 0.0
bias (bool, optional): Whether to add bias to input projection. Default: True
"""
def __init__(
self,
E_q: int,
E_k: int,
E_v: int,
emb_dim: int,
nheads: int,
dropout: float = 0.0,
bias=True,
device=None,
dtype=None,
):
factory_kwargs = { "device": device, "dtype": dtype }
super().__init__()
self.nheads = nheads
self.dropout = dropout
self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
if self._qkv_same_embed_dim:
self.packed_proj = nn.Linear(E_q, emb_dim * 3, bias=bias, **factory_kwargs)
else:
self.q_proj = nn.Linear(E_q, emb_dim, bias=bias, **factory_kwargs)
self.k_proj = nn.Linear(E_k, emb_dim, bias=bias, **factory_kwargs)
self.v_proj = nn.Linear(E_v, emb_dim, bias=bias, **factory_kwargs)
E_out = E_q
self.out_proj = nn.Linear(emb_dim, E_out, bias=bias, **factory_kwargs)
assert emb_dim % nheads == 0, "Embedding dim is not divisible by nheads"
assert (emb_dim // nheads) % 8 == 0, "emb_dim//nheads dim is not a multiple of 8"
self.E_head = emb_dim // nheads
self.bias = bias
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal=False,
) -> torch.Tensor:
"""
Forward pass; runs the following process:
1. Apply input projection
2. Split heads and prepare for SDPA
3. Run SDPA
4. Apply output projection
Args:
query (torch.Tensor): query of shape [B, t, N_i_q, E_q] NestedTensor, t is an irrelevant dimension.
key (torch.Tensor): key of shape [B, t, N_i_kv, E_k] NestedTensor, t is an irrelevant dimension.
value (torch.Tensor): value of shape [B, t, N_i_kv, E_v] NestedTensor, t is an irrelevant dimension.
is_causal (bool, optional): Whether to apply causal mask. Default: False
Returns:
attn_output (torch.Tensor): [B, t, N_i_q, E_q]
"""
un_flag = False
if query.dim() == 3:
query = query.unflatten(-1, [1, -1]).transpose(1, 2)
key = key.unflatten(-1, [1, -1]).transpose(1, 2)
value = value.unflatten(-1, [1, -1]).transpose(1, 2)
un_flag = True
if self._qkv_same_embed_dim:
if query is key and key is value:
result = self.Bproj(self.packed_proj, query)
query, key, value = self.Bchunk(result, 3, dim=-1)
else:
q_weight, k_weight, v_weight = torch.chunk(
self.packed_proj.weight, 3, dim=0,
)
if self.bias:
q_bias, k_bias, v_bias = torch.chunk(
self.packed_proj.bias, 3, dim=0,
)
else:
q_bias, k_bias, v_bias = None, None, None
query, key, value = (
self.Bproj(F.linear, query, q_weight, q_bias),
self.Bproj(F.linear, key, k_weight, k_bias),
self.Bproj(F.linear, value, v_weight, v_bias),
)
else:
query = self.Bproj(self.q_proj, query)
key = self.Bproj(self.k_proj, key)
value = self.Bproj(self.v_proj, value)
query, q_eps_shape = self.trans_shape(query)
key, k_eps_shape = self.trans_shape(key)
value, v_eps_shape = self.trans_shape(value)
attn_output = F.scaled_dot_product_attention(
query, key, value, dropout_p=self.dropout, is_causal=is_causal,
)
attn_output = self.inv_shape(attn_output, q_eps_shape)
attn_output = self.Bproj(self.out_proj, attn_output)
if un_flag:
return self.Bsqueeze(attn_output)
return attn_output
def Bproj(self, proj_func, x, *args):
x = x.transpose(1, 2) # [b, n, x, c]
x = proj_func(x, *args)
x = x.transpose(1, 2) # [b, x, n, c]
return x
def Bchunk(self, x, chunks, dim):
x = x.transpose(1, 2) # [b, n, x, c]
x_l = torch.chunk(x, chunks, dim=dim)
return [x_.transpose(1, 2) for x_ in x_l]
def Bsqueeze(self, x):
x = x.transpose(1, 2) # [b, n, 1, c]
x = x.flatten(-2) # [b, n, c]
return x
def trans_shape(self, x):
x = x.transpose(1, 2) # [b, n, x, c]
eps_shape = x.size(-2)
x = x.unflatten(-1, [self.nheads, self.E_head]) # [b, n, x, h, E_head]
x = x.flatten(2, -2) # [b, n, xh, E_head]
x = x.transpose(1, 2) # [b, xh, n, E_head]
return x, eps_shape
def inv_shape(self, x, eps_shape):
x = x.transpose(1, 2) # [b, n, xh, c]
x = x.unflatten(-2, [eps_shape, self.nheads]) # [b, n, x, h, c]
x = x.flatten(-2) # [b, n, x, h*c]
x = x.transpose(1, 2) # [b, x, n, c]
return xError messages and logs
# Error messages and logs here please
Traceback (most recent call last):
File "/home/xzy/PycharmProjects/DyG/args.py", line 146, in <module>
cli_main(
File "/home/xzy/PycharmProjects/DyG/args.py", line 127, in cli_main
cli = MyLightningCLI(
^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 421, in __init__
self._run_subcommand(self.subcommand)
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 759, in _run_subcommand
fn(**fn_kwargs)
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 584, in fit
call._call_and_handle_interrupt(
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 49, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 630, in _fit_impl
self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1057, in _run
call._call_callback_hooks(self, "on_fit_start")
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 228, in _call_callback_hooks
fn(trainer, trainer.lightning_module, *args, **kwargs)
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_summary.py", line 64, in on_fit_start
model_summary = self._summary(trainer, pl_module)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_summary.py", line 90, in _summary
return summarize(pl_module, max_depth=self._max_depth)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 565, in summarize
return ModelSummary(lightning_module, max_depth=max_depth)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 228, in __init__
self._layer_summary = self.summarize()
^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 328, in summarize
self._forward_example_input()
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 369, in _forward_example_input
model(**input_)
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/PycharmProjects/DyG/models.py", line 115, in forward
return self.model(
^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/PycharmProjects/DyG/layers4.py", line 545, in forward
x = self.block(x, inp_time_emb, node_emb, batch, edge_index, seed, edge_weight) # [iT, BN, emb_dim]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/torch_geometric.nn.sequential_Sequential_390062_b66vhssz.py", line 21, in forward
x = self.module_0(x, time_emb, node_emb, batch, edge_index, seed, edge_weight)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/PycharmProjects/DyG/layers4.py", line 383, in forward
gx = self.global_block(x + node_emb, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/PycharmProjects/DyG/layers4.py", line 303, in forward
out = self.att(x, x, x) # [B, t, N_i, C]
^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/PycharmProjects/DyG/utils.py", line 346, in forward
result = self.Bproj(self.packed_proj, query)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/PycharmProjects/DyG/utils.py", line 384, in Bproj
x = proj_func(x, *args)
^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py", line 353, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 767, in __torch_dispatch__
r = func.decompose(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 766, in decompose
return self._op_dk(dk, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 767, in __torch_dispatch__
r = func.decompose(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
return self.py_kernels[dk](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_prims_common/wr
(ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(out_layer): Sequential(
(out_layer): Sequential(
(out_layer): Sequential(
appers.py", line 289, in _fn
result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_decomp/decompositions.py", line 4437, in matmul
t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 767, in __torch_dispatch__
r = func.decompose(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 766, in decompose
return self._op_dk(dk, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/utils/flop_counter.py", line 772, in __torch_dispatch__
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py", line 325, in __torch_dispatch__
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/ops.py", line 218, in inner
return func(aten_op, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xzy/miniconda3/envs/py312cu124/lib/python3.12/site-packages/torch/nested/_internal/ops.py", line 1546, in view_default
raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
RuntimeError: view(): cannot view shape (8, j1, 12, 64) as [96*j1, 64]
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.6.0
#- PyTorch Version (e.g., 2.5): 2.6.0+cuu124
#- Python version (e.g., 3.12): 3.12.12
#- OS (e.g., Linux): ubuntu
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
I look forward to fixing this bug in the future.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x