Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dtype mismatch in cat: bfloat16 and float16 #812

Closed
tfogal opened this issue Jul 20, 2024 · 1 comment · Fixed by #819
Closed

Dtype mismatch in cat: bfloat16 and float16 #812

tfogal opened this issue Jul 20, 2024 · 1 comment · Fixed by #819
Assignees
Labels
nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage

Comments

@tfogal
Copy link
Collaborator

tfogal commented Jul 20, 2024

🚀 Model / language coverage

First, I applied this diff to thunder:

diff --git a/thunder/core/utils.py b/thunder/core/utils.py
index 271dcdf3..346f264e 100644
--- a/thunder/core/utils.py
+++ b/thunder/core/utils.py
@@ -237,6 +237,10 @@ def check_same_dtype(*args):
             if dtype is None:
                 dtype = typ

+            if not are_same_dtypes(dtype, typ):
+                import traceback
+                print(f"mismatched types: {dtype}, {typ}")
+                traceback.print_stack()
             check(
                 are_same_dtypes(dtype, typ),
                 lambda: f"Expected dtype {dtype} but found {typ}!",

The diff was necessary to get the beginning of the output below, which conveys that a cat operator is what is at fault:

  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/torch/__init__.py", line 812, in cat
    return clang.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/clang/__init__.py", line 1289, in cat
    return prims.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 272, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/prims.py", line 2983, in cat_meta
    utils.check_same_dtype(*tensors)
  File "/home/tfogal/dev/thunder/thunder/core/utils.py", line 243, in check_same_dtype
    traceback.print_stack()
mismatched types: thunder.dtypes.bfloat16, thunder.dtypes.float16
Error executing job with overrides: ['trainer.precision=16', 'model.megatron_amp_O2=False', 'trainer.num_nodes=1', 'trainer.devices=1', 'trainer.val_check_interval=10', 'trainer.limit_val_batches=5', 'trainer.log_every_n_steps=1', '++exp_manager.max_time_per_run=00:00:03:00', 'trainer.max_steps=20', 'model.micro_batch_size=2', 'model.global_batch_size=4', 'model.tensor_model_parallel_size=1', 'model.pipeline_model_parallel_size=1', 'exp_manager.create_checkpoint_callback=False', 'model.data.data_path=./data/multimodal/tiny-neva/dummy.json', 'model.data.image_folder=./data/multimodal/tiny-neva/images', 'model.tokenizer.library=sentencepiece', 'model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model', 'model.num_layers=2', 'model.hidden_size=5120', 'model.ffn_hidden_size=13824', 'model.num_attention_heads=40', 'model.normalization=rmsnorm', 'model.data.num_workers=0', 'model.data.conv_template=llama_2', 'model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14', 'model.mm_cfg.llm.from_pretrained=null', 'model.use_flash_attention=false', 'exp_manager.exp_dir=./foo-neva-train']
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 51, in <module>
[rank0]:     main()
[rank0]:   File "/home/tfogal/dev/nemo/nemo/core/config/hydra_runner.py", line 129, in wrapper
[rank0]:     _run_hydra(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
[rank0]:     _run_app(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
[rank0]:     run_and_report(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
[rank0]:     raise ex
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
[rank0]:     return func()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
[rank0]:     lambda: hydra.run(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
[rank0]:     _ = ret.return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
[rank0]:     raise self._return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
[rank0]:     ret.return_value = task_function(task_cfg)
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 45, in main
[rank0]:     trainer.fit(model)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
[rank0]:     results = self._run_stage()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1028, in _run_stage
[rank0]:     self._run_sanity_check()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1057, in _run_sanity_check
[rank0]:     val_loop.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
[rank0]:     return loop_run(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
[rank0]:     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
[rank0]:     output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 410, in validation_step
[rank0]:     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
[rank0]:     wrapper_output = wrapper_module(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1640, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1456, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
[rank0]:     out = method(*_args, **_kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 897, in validation_step
[rank0]:     return MegatronGPTModel.validation_step(self, dataloader_iter)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1370, in validation_step
[rank0]:     loss = self.fwd_bwd_step(dataloader_iter, True, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 665, in fwd_bwd_step
[rank0]:     return MegatronGPTModel.fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 684, in fwd_bwd_step
[rank0]:     losses_reduced_per_micro_batch = fwd_bwd_function(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 395, in forward_backward_no_pipelining
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 219, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 832, in fwd_output_and_loss_func
[rank0]:     output_tensor = model(**forward_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/module.py", line 61, in forward
[rank0]:     res = self._forward_fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 683, in fn_
[rank0]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 225, in cache_info_wrapper
[rank0]:     res = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
[rank0]:     jit_results: TraceResults = interpreter(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 213, in _general_frontend
[rank0]:     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
[rank0]:     result = jfn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6769, in fn_
[rank0]:     raise e
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6737, in fn_2
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 470, in forward
[rank0]:     result = GPTModel.forward(self, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py", line 280, in forward
[rank0]:     lm_output = self.language_model(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 764, in forward
[rank0]:     encoder_input = self.embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 348, in forward
[rank0]:     words_embeddings = self.word_embeddings(input_ids)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 155, in forward
[rank0]:     return self.replace_media_embeddings(input_ids, words_embeddings, media)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 195, in replace_media_embeddings
[rank0]:     media_features = self.encode_vision_x(media)  # b T F S(eq) H(idden)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 176, in encode_vision_x
[rank0]:     vision_x = self.vision_encoder(vision_x, output_hidden_states=True)
[rank0]: RuntimeError: Expected dtype thunder.dtypes.bfloat16 but found thunder.dtypes.float16!

Full log of the run

Instructions on how to run NeVA are in #343.

Pitch

This is for the NeVA model #343 .

Alternatives / Potential work-arounds

It seems like our cat checks are too stringent, in that torch allows mismatched dtypes here:

>>> a = torch.randn((5,3), dtype=torch.bfloat16)
>>> b = torch.randn((2,3), dtype=torch.float16)
>>> c = torch.cat((a,b), dim=0)
>>> c
tensor([[ 1.7734,  0.4414, -0.3086],
        [-0.4453, -2.2969, -0.2129],
        [-0.6680, -1.3984, -0.0649],
        [ 0.0242, -0.6875,  0.4277],
        [-0.9141,  0.6367,  0.3828],
        [ 1.0635, -0.4417, -0.6030],
        [ 0.5215, -0.6226,  0.9912]])

I suppose torch semantics are to cast each type to the first type?

Note this is very similar to #750. It seems like the issue in #750 just appeared in cat even though the error was earlier, but now we are finding the issue in cat through some other code.

Minimal Repro

$ cat cat-dtype.py
import torch
import thunder

def foo():
  x = torch.randn((5,3), dtype=torch.bfloat16)
  y = torch.randn((2,3), dtype=torch.float16)
  z = torch.cat((x,y), dim=0)
  return z

foo()
thfoo = thunder.jit(foo)
thfoo()
$ python3 cat-dtype.py
Traceback (most recent call last):
  File "/tmp/cat-dtype.py", line 12, in <module>
    thfoo()
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 683, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 225, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 213, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6769, in fn_
    raise e
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6737, in fn_2
    return fn(*args, **kwargs)
  File "/tmp/cat-dtype.py", line 7, in foo
    z = torch.cat((x,y), dim=0)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 1272, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 704, in wrapper
    return fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 276, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/torch/__init__.py", line 812, in cat
    return clang.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/clang/__init__.py", line 1289, in cat
    return prims.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 272, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/prims.py", line 2983, in cat_meta
    utils.check_same_dtype(*tensors)
  File "/home/tfogal/dev/thunder/thunder/core/utils.py", line 240, in check_same_dtype
    check(
  File "/home/tfogal/dev/thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
RuntimeError: Expected dtype thunder.dtypes.bfloat16 but found thunder.dtypes.float16!

cc @tfogal

@tfogal tfogal added nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage labels Jul 20, 2024
@t-vi
Copy link
Collaborator

t-vi commented Jul 20, 2024

I suppose torch semantics are to cast each type to the first type?

No, I think they use upcasting (in the above, bf16 and fp16 give fp32).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants