Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotImplementedError: VJP for PrimIDs.ITEM is not implemented for ThunderFX and litgGPT models. #1479

Closed
mpatel31415 opened this issue Nov 26, 2024 · 3 comments · Fixed by #1481
Assignees
Labels
autograd mixology Issues that the mixology team has surfaced

Comments

@mpatel31415
Copy link
Contributor

mpatel31415 commented Nov 26, 2024

🐛 Bug

For Llama-3-8B, Mistral-7B-v0.1 and Phi-3-mini-4k-instruct we get the following error:

6: [rank6]: Traceback (most recent call last):
6: [rank6]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 980, in
6: [rank6]: CLI(benchmark_main)
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 96, in CLI
6: [rank6]: return _run_component(components, init)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 204, in _run_component
6: [rank6]: return component(**cfg)
6: [rank6]: ^^^^^^^^^^^^^^^^
6: [rank6]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 877, in benchmark_main
6: [rank6]: benchmark.train()
6: [rank6]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 748, in train
6: [rank6]: logits = self.model(input_ids)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 573, in _fn
6: [rank6]: return fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py", line 31, in inner
6: [rank6]: return fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 864, in forward
6: [rank6]: output = self._fsdp_wrapped_module(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/litgpt/model.py", line 94, in forward
6: [rank6]: x = block(x, cos, sin, mask, input_pos)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 864, in forward
6: [rank6]: output = self._fsdp_wrapped_module(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/litgpt/model.py", line 167, in forward
6: [rank6]: def forward(
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 744, in _fn
6: [rank6]: return fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
6: [rank6]: return self._wrapped_call(self, *args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in call
6: [rank6]: raise e
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 387, in call
6: [rank6]: return super(self.cls, obj).call(*args, **kwargs) # type: ignore[misc]
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "<eval_with_key>.14", line 5, in forward
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
6: [rank6]: return self._call_impl(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
6: [rank6]: return forward_call(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/module.py", line 80, in forward
6: [rank6]: res = self.forward_fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/init.py", line 774, in wrapped
6: [rank6]: return fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/init.py", line 824, in fn

6: [rank6]: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/init.py", line 756, in wrapped
6: [rank6]: cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
6: [rank6]: result = fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/init.py", line 236, in cache_info_wrapper
6: [rank6]: res = fn(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/init.py", line 659, in get_computation_and_inputs
6: [rank6]: computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 156, in split_forward_backward
6: [rank6]: fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 3016, in forward_and_backward_from_trace
6: [rank6]: forward_trace, result, env = augmented_forward_pass_trace(trace, *trace.args, **trace.kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 2632, in augmented_forward_pass_trace
6: [rank6]: trace, result, env = interpret_trace_to_trace(
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/trace_interpreter.py", line 168, in interpret_trace_to_trace
6: [rank6]: result = prim_func(*args, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 2545, in _vjp_impl
6: [rank6]: out_primal, out_residuals = vjp_impl(*primals, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 2319, in decomposed_fn_aug_fwd_rule
6: [rank6]: result, env = augmented_forward_pass(*args, trace=trace, **kwargs)
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 2604, in augmented_forward_pass
6: [rank6]: result, env = eval_trace(
6: [rank6]: ^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/trace_interpreter.py", line 63, in interpret_trace
6: [rank6]: prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym
6: [rank6]: ^^^^^^^^^^^^^^^^^^^^^
6: [rank6]: File "/opt/pytorch/lightning-thunder/thunder/core/transforms.py", line 2541, in vjp_symbol_mapper
6: [rank6]: raise NotImplementedError(f"VJP for {symbol.sym.id} is not implemented")
6: [rank6]: NotImplementedError: VJP for PrimIDs.ITEM is not implemented

To Reproduce

Please use:
1 GPU (H100)
Image "INTERNAL_IMAGE:pjnl-20241125"
Training script:

python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name  Mistral-7B-v0.1 
--compile dynamo_thunder
--micro_batch_size 1

Expected behavior

Environment

As in the pjnl image

cc @tfogal

@IvanYashchuk IvanYashchuk added nemo Issues needed to support NVIDIA NeMo models. mixology Issues that the mixology team has surfaced autograd labels Nov 26, 2024
@IvanYashchuk
Copy link
Collaborator

Here's a simple reproducer for the problem:

import torch
import thunder

@thunder.jit
def f(x): return x.item()

a = torch.randn(1, requires_grad=True)
f(a)
File ~/dev/lightning-thunder/thunder/core/trace_interpreter.py:63, in interpret_trace(trace, symbol_mapper, with_env, *args, **kwargs)
     61 args = tree_map(read, symbol.args)
     62 kwargs = tree_map(read, symbol.kwargs)
---> 63 prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym
     64 if prim_func is None:
     65     continue

File ~/dev/lightning-thunder/thunder/core/transforms.py:2541, in vjp_symbol_mapper(symbol, *args, **kwargs)
   2539         if symbol.sym.id == "torch.nn.functional.dropout":
   2540             return None
-> 2541         raise NotImplementedError(f"VJP for {symbol.sym.id} is not implemented")
   2543 def _vjp_impl(*args, **kwargs):
   2544     primals, kwargs = tree_map(lambda x: x.primal if isinstance(x, VJPDual) else x, (args, kwargs))

NotImplementedError: VJP for PrimIDs.ITEM is not implemented

A "grad rule" is missing for thunder.prims.item.

@mpatel31415
Copy link
Contributor Author

Just to let you know - this issue was discovered in canary runs of LitGPT models, I'm not sure if it's related to nemo, as the label suggest.

@wprazuch
Copy link
Contributor

wprazuch commented Nov 27, 2024

To chip in - this problem also occurs for SFT benchmarks, and shadows this issue #1482 in newer pjnl container release (tested on pjnl-20241126). Breaks for phi-3 also. We don't have NeMo deps in the SFT script, so it's probably unrelated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autograd mixology Issues that the mixology team has surfaced
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants