Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeVa - ValueError: <object> had an unexpected type <class 'object'> #717

Closed
kshitij12345 opened this issue Jul 5, 2024 · 9 comments · Fixed by #718
Closed

NeVa - ValueError: <object> had an unexpected type <class 'object'> #717

kshitij12345 opened this issue Jul 5, 2024 · 9 comments · Fixed by #718
Assignees
Labels
nemo Issues needed to support NVIDIA NeMo models.

Comments

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Jul 5, 2024

NOTE: For minimal repro - see comment below

[rank0]:   File "/opt/pytorch/lightning-thunder/NeMo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 348, in forward
[rank0]:     words_embeddings = self.word_embeddings(input_ids)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1714, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1725, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/opt/pytorch/lightning-thunder/NeMo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 155, in forward
[rank0]:     return self.replace_media_embeddings(input_ids, words_embeddings, media)
[rank0]:   File "/opt/pytorch/lightning-thunder/NeMo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 206, in replace_media_embeddings
[rank0]:     for idx, input_id in enumerate(input_ids):
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 4059, in _next_impl
[rank0]:     return next(tos)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1862, in impl
[rank0]:     return iterator.__next__()
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 6188, in _call_dispatch
[rank0]:     opaque_result: Any = fn(*args_, **kwargs_)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 5887, in thunder_interpreter_generator
[rank0]:     raise InterpreterError(msg) from e
[rank0]: thunder.core.interpreter.InterpreterError: Encountered exception ValueError: <object object at 0x7f433cba2d80> had an unexpected type <class 'object'>. Supported types are (<class 'thunder.core.proxies.TensorProxy'>, <class 'numbers.Number'>, <class 'thunder.core.proxies.NumberProxy'>)

Full Log - error.log

(Steps to repro are same from #678 and copied from there except addition of megatron_core commit details in environment)
To Repro -

HYDRA_FULL_ERROR=1 \
THUNDER_ANNOTATE_TRACES=1 \
NEMO_THUNDER_NEVA=1 \
python3 ./examples/multimodal/multimodal_llm/neva/neva_pretrain.py trainer.precision=16 model.megatron_amp_O2=False trainer.num_nodes=1 trainer.devices=1 trainer.val_check_interval=10 trainer.limit_val_batches=5 trainer.log_every_n_steps=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=2 model.global_batch_size=4 model.tensor_model_parallel_size=1 model.pipeline_model_parallel_size=1 exp_manager.create_checkpoint_callback=False model.data.data_path=./data/multimodal/tiny-neva/dummy.json model.data.image_folder=./data/multimodal/tiny-neva/images model.tokenizer.library=sentencepiece model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model model.num_layers=2 model.hidden_size=5120 model.ffn_hidden_size=13824 model.num_attention_heads=40 model.normalization=rmsnorm model.data.num_workers=0 model.data.conv_template=llama_2 model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14 model.mm_cfg.llm.from_pretrained=null model.use_flash_attention=false exp_manager.exp_dir=./foo-neva-train

Note you'll need the referenced ./data directory; ping @tfogal privately for now.

Environment

$ nvidia-smi | grep -i cuda
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
$ python3 -m pip freeze | egrep -i "(nvfuser)|(lightning)|(thunder)|(nemo)|(megatron)|(torch)"
-e git+ssh://git@github.com/tfogal/lightning.git@8df5db52ead1804f9021bb07caa2d4a7a6ab03a1#egg=lightning
lightning-cloud==0.5.69
-e git+ssh://git@github.com/Lightning-AI/lightning-thunder.git@72e033a0e0dfe44d4770dec2399a9058971003ec#egg=lightning_thunder
lightning-utilities==0.11.2
megatron_core @ git+https://github.com/NVIDIA/Megatron-LM.git@e33c8f78a35765d5aa37475a144da60e8a2349d1
-e git+ssh://git@github.com/NVIDIA/NeMo.git@c86449e1a93049d2283ebcea8ee4546f2ea241de#egg=nemo_toolkit
# Editable Git install with no remote (nvfuser==0.2.6+git9c5f006)
-e /opt/pytorch/nvfuser
open-clip-torch==2.24.0
pytorch-lightning==2.3.0
-e git+https://github.com/pytorch/pytorch.git@bd72e28314d8d63bb347becb8309f5ac7761c6b5#egg=torch
torchdiffeq==0.2.4
torchmetrics==1.4.0.post0
torchsde==0.2.6
torchvision @ git+https://github.com/pytorch/vision.git@bf01bab6125c5f1152e4f336b470399e52a8559d
-e git+https://gitlab-ci-token:glcbt-64_VRyDQgDXFf-uV3J9S3gy@gitlab-master.nvidia.com/dl/pytorch/update-scripts.git@5bbcbd6d7aff52c6e6d0b47febe053d4894b3464#egg=zpyt_nightly_ci

cc @tfogal

@kshitij12345 kshitij12345 added the nemo Issues needed to support NVIDIA NeMo models. label Jul 5, 2024
@kshitij12345
Copy link
Collaborator Author

Minimal Repro:

import torch
import thunder

def f(ids):
    for ix, t in enumerate(ids):
        pass

    return ids

jf = thunder.jit(f)

ids = torch.randn(2, 2).to(torch.long)
jf(ids)

@t-vi
Copy link
Collaborator

t-vi commented Jul 5, 2024

Great repro, I think we would want Tensor.__iter__ here. I'm not 100% sure what the best strategy is given that iters are not first class objects, one might be

  • add a torch.Tensor.__iter__ lookaside to jit_ext.py,
  • in an impl
    • use torch.unbind or torch.split to get a list
    • iter over the list.

Does that make sense?

@kshitij12345 kshitij12345 self-assigned this Jul 5, 2024
@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Jul 5, 2024

I tried adding lookaside for torch.Tensor.__iter__ with following patch

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index ac9e127..2da42d7 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -890,6 +890,18 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs):
         model, model.named_buffers, model.get_buffer, *unwrapped_args, **unwrapped_kwargs
     )
 
+@general_jit_lookaside(torch.Tensor.__iter__)
+def _general_tensor_iter_lookaside(obj: Any, *args, **kwargs):
+
+    # NOTE: This will be interpreted.
+    def _tensor_iter_impl(t):
+        for t_slice in t.unbind():
+            yield t_slice
+
+    pr = ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[wrap_const(torch.Tensor.__iter__).provenance])
+
+    return _interpret_call(_tensor_iter_impl, wrap(unwrap(obj), pr))
+
 
 @general_jit_lookaside(torch.autograd.function.Function.apply.__func__)
 def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs):

But it never gets called.

On printing what iter_lookaside gets, I see that it receives a TensorProxy and then ends up calling __getitem__ (which does something unintended) as it doesn't have __iter__ .

Adding __iter__ to TensorProxy with following patch works (I think other iterable proxies like TupleProxy and ListProxy probably work because they inherit from tuple and list which allows them to have __iter__)

diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py
index df6dce4..b440091 100644
--- a/thunder/core/proxies.py
+++ b/thunder/core/proxies.py
@@ -1340,6 +1340,9 @@ class TensorProxy(Proxy, TensorProxyInterface):
         method = resolve_method("getitem", self, key)
         return method(self, key)
 
+    def __iter__(self):
+        return iter(self.unbind())
+
     #
     # Elementwise unary operators
     #

What do you think about this (or maybe there is still a way with lookaisde)?

@nikitaved
Copy link
Contributor

nikitaved commented Jul 5, 2024

Ooops, sorry, @kshitij12345 , I accidentally stepped on your feet... But I am not sure about my approach. Is it fine to have a lookside for TensorProxies?

@kshitij12345
Copy link
Collaborator Author

Ah, no worries, as long as the issue is fixed :) Thanks for looking into this.

Is it fine to have a lookside for TensorProxies?

@t-vi what are your thoughts?

@t-vi
Copy link
Collaborator

t-vi commented Jul 5, 2024

Seems good to have one. The reason to do it as a lookaside is to not handle iter objects in the trace.

@t-vi
Copy link
Collaborator

t-vi commented Jul 5, 2024

But yeah, if defining the iter on the tensorproxy works, lets just have that, I think it might be the same result just with a slightly different execution model. We should comment the itermethod. What is the trace we get from that?

@nikitaved
Copy link
Contributor

nikitaved commented Jul 5, 2024

@t-vi , something like this

In [8]: def f(x):
   ...:     res = x
   ...:     for xi in x:
   ...:         res = res + xi.unsqueeze(0)
   ...:     return res

f(torch.rand(3, 2, 2))

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[3, 2, 2]"
  (xi, t4, t5) = torch.unbind(x, 0)
    # (xi, t4, t5) = ltorch.unbind(x, 0)
      # (t15, t16, t17) = ltorch.tensor_split(x, 3, 0)
        # t15 = prims.slice_prim(x, [0, 0, 0], [1, 2, 2], [1, 1, 1])  # t15: "cpu f32[1, 2, 2]"
        # t16 = prims.slice_prim(x, [1, 0, 0], [2, 2, 2], [1, 1, 1])  # t16: "cpu f32[1, 2, 2]"
        # t17 = prims.slice_prim(x, [2, 0, 0], [3, 2, 2], [1, 1, 1])  # t17: "cpu f32[1, 2, 2]"
      # xi = ltorch.squeeze(t15, 0)  # xi: "cpu f32[2, 2]"
        # xi = prims.squeeze(t15, (0,))  # xi: "cpu f32[2, 2]"
      # t4 = ltorch.squeeze(t16, 0)  # t4: "cpu f32[2, 2]"
        # t4 = prims.squeeze(t16, (0,))  # t4: "cpu f32[2, 2]"
      # t5 = ltorch.squeeze(t17, 0)  # t5: "cpu f32[2, 2]"
        # t5 = prims.squeeze(t17, (0,))  # t5: "cpu f32[2, 2]"
  b = torch.unsqueeze(xi, 0)  # b: "cpu f32[1, 2, 2]"
    # b = ltorch.unsqueeze(xi, 0)  # b: "cpu f32[1, 2, 2]"
      # b = prims.broadcast_in_dim(xi, [1, 2, 2], [1, 2])  # b: "cpu f32[1, 2, 2]"
  del xi
  t9 = torch.unsqueeze(t4, 0)  # t9: "cpu f32[1, 2, 2]"
    # t9 = ltorch.unsqueeze(t4, 0)  # t9: "cpu f32[1, 2, 2]"
      # t9 = prims.broadcast_in_dim(t4, [1, 2, 2], [1, 2])  # t9: "cpu f32[1, 2, 2]"
  del t4
  t12 = torch.unsqueeze(t5, 0)  # t12: "cpu f32[1, 2, 2]"
    # t12 = ltorch.unsqueeze(t5, 0)  # t12: "cpu f32[1, 2, 2]"
      # t12 = prims.broadcast_in_dim(t5, [1, 2, 2], [1, 2])  # t12: "cpu f32[1, 2, 2]"
  del t5
  result = torch.add(x, b)  # result: "cpu f32[3, 2, 2]"
    # result = ltorch.add(x, b, alpha=None)  # result: "cpu f32[3, 2, 2]"
      # t22 = prims.broadcast_in_dim(b, (3, 2, 2), (0, 1, 2))  # t22: "cpu f32[3, 2, 2]"
      # result = prims.add(x, t22)  # result: "cpu f32[3, 2, 2]"
  del x, b
  res = torch.add(result, t9)  # res: "cpu f32[3, 2, 2]"
    # res = ltorch.add(result, t9, alpha=None)  # res: "cpu f32[3, 2, 2]"
      # t25 = prims.broadcast_in_dim(t9, (3, 2, 2), (0, 1, 2))  # t25: "cpu f32[3, 2, 2]"
      # res = prims.add(result, t25)  # res: "cpu f32[3, 2, 2]"
  del result, t9
  t14 = torch.add(res, t12)  # t14: "cpu f32[3, 2, 2]"
    # t14 = ltorch.add(res, t12, alpha=None)  # t14: "cpu f32[3, 2, 2]"
      # t28 = prims.broadcast_in_dim(t12, (3, 2, 2), (0, 1, 2))  # t28: "cpu f32[3, 2, 2]"
      # t14 = prims.add(res, t28)  # t14: "cpu f32[3, 2, 2]"
  del res, t12
  return t14

...

@t-vi
Copy link
Collaborator

t-vi commented Jul 5, 2024

I'd say not overly pretty but not too terrible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
nemo Issues needed to support NVIDIA NeMo models.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants