Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dtype mismatch in linear layer #678

Closed
tfogal opened this issue Jun 28, 2024 · 7 comments · Fixed by #705
Closed

Dtype mismatch in linear layer #678

tfogal opened this issue Jun 28, 2024 · 7 comments · Fixed by #705
Assignees
Labels
amp nemo Issues needed to support NVIDIA NeMo models.

Comments

@tfogal
Copy link
Collaborator

tfogal commented Jun 28, 2024

🐛 Bug

[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/torch/__init__.py", line 4580, in linear
[rank0]:     return prims.linear(a, w, bias)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 264, in __call__
[rank0]:     result = self.meta(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/prims.py", line 3585, in linear_meta
[rank0]:     utils.check(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/baseutils.py", line 103, in check
[rank0]:     raise exception_type(s())
[rank0]: RuntimeError: Expected a.dtype=bfloat16 and w.dtype=float32 to be the same!

Full log of the run that includes the unabbreviated traceback.

To Reproduce

HYDRA_FULL_ERROR=1 \
THUNDER_ANNOTATE_TRACES=1 \
NEMO_THUNDER_NEVA=1 \
python3 ./examples/multimodal/multimodal_llm/neva/neva_pretrain.py trainer.precision=16 model.megatron_amp_O2=False trainer.num_nodes=1 trainer.devices=1 trainer.val_check_interval=10 trainer.limit_val_batches=5 trainer.log_every_n_steps=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=2 model.global_batch_size=4 model.tensor_model_parallel_size=1 model.pipeline_model_parallel_size=1 exp_manager.create_checkpoint_callback=False model.data.data_path=./data/multimodal/tiny-neva/dummy.json model.data.image_folder=./data/multimodal/tiny-neva/images model.tokenizer.library=sentencepiece model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model model.num_layers=2 model.hidden_size=5120 model.ffn_hidden_size=13824 model.num_attention_heads=40 model.normalization=rmsnorm model.data.num_workers=0 model.data.conv_template=llama_2 model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14 model.mm_cfg.llm.from_pretrained=null model.use_flash_attention=false exp_manager.exp_dir=./foo-neva-train

Note you'll need the referenced ./data directory.

Expected behavior

Environment

$ nvidia-smi | grep -i cuda
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
$ python3 -m pip freeze | egrep -i "(nvfuser)|(lightning)|(thunder)|(nemo)|(megatron)|(torch)"
-e git+ssh://git@github.com/tfogal/lightning.git@8df5db52ead1804f9021bb07caa2d4a7a6ab03a1#egg=lightning
lightning-cloud==0.5.69
-e git+ssh://git@github.com/Lightning-AI/lightning-thunder.git@72e033a0e0dfe44d4770dec2399a9058971003ec#egg=lightning_thunder
lightning-utilities==0.11.2
megatron_core @ file:///home/tfogal/Megatron-LM
-e git+ssh://git@github.com/NVIDIA/NeMo.git@c86449e1a93049d2283ebcea8ee4546f2ea241de#egg=nemo_toolkit
# Editable Git install with no remote (nvfuser==0.2.6+git9c5f006)
-e /opt/pytorch/nvfuser
open-clip-torch==2.24.0
pytorch-lightning==2.3.0
-e git+https://github.com/pytorch/pytorch.git@bd72e28314d8d63bb347becb8309f5ac7761c6b5#egg=torch
torchdiffeq==0.2.4
torchmetrics==1.4.0.post0
torchsde==0.2.6
torchvision @ git+https://github.com/pytorch/vision.git@bf01bab6125c5f1152e4f336b470399e52a8559d
-e git+https://gitlab-ci-token:glcbt-64_VRyDQgDXFf-uV3J9S3gy@gitlab-master.nvidia.com/dl/pytorch/update-scripts.git@5bbcbd6d7aff52c6e6d0b47febe053d4894b3464#egg=zpyt_nightly_ci

cc @crcrpar @tfogal

@tfogal tfogal added the nemo Issues needed to support NVIDIA NeMo models. label Jun 28, 2024
@t-vi
Copy link
Collaborator

t-vi commented Jun 28, 2024

[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py", line 914, in forward
[rank0]:     return self.mm_projector(x)

we could likely print self.mm_projector.weight.dtype and x.dtype to figure out what we get without dtype.

@kshitij12345 kshitij12345 self-assigned this Jul 1, 2024
@kshitij12345
Copy link
Collaborator

kshitij12345 commented Jul 1, 2024

I am seeing a different error related to advanced indexing.

[rank0]:   File "thunder/core/proxies.py", line 1333, in __getitem__
[rank0]:     return method(self, key)
[rank0]:   File "thunder/core/symbol.py", line 268, in __call__
[rank0]:     result = self.meta(*args, **kwargs)
[rank0]:   File "thunder/core/langctxs.py", line 132, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:   File "thunder/torch/__init__.py", line 890, in getitem
[rank0]:     return clang.getitem(a, key)
[rank0]:   File "thunder/core/langctxs.py", line 132, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:   File "thunder/clang/__init__.py", line 868, in getitem
[rank0]:     return _advanced_indexing(a, key)
[rank0]:   File "thunder/core/langctxs.py", line 132, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:   File "thunder/clang/__init__.py", line 729, in _advanced_indexing
[rank0]:     utils.check(
[rank0]:   File "thunder/core/baseutils.py", line 103, in check
[rank0]:     raise exception_type(s())
[rank0]: RuntimeError: Advanced indexing currently only supports zero or one-dimensional integer tensors, but found a tensor with dtype int64 and 2 dimensions

thunder commit used - 72e033a

Full Log: neva.log

@tfogal
Copy link
Collaborator Author

tfogal commented Jul 1, 2024

triage review

  • @tfogal dig in and see if latest thunder (or 72e033a) reproduces the indexing error or something else
  • we may have a type promotion that PyTorch does not have. let's print some dtypes at various levels to see what's going on here
  • we might also be not mimicking some autocast logic
  • This seems to happen to us a bunch through recurrent paths

@tfogal tfogal self-assigned this Jul 1, 2024
@kshitij12345
Copy link
Collaborator

I have been able to repro the failure with an independent script. The failure happens due to the interaction of autocast and mixed input dtypes.

import thunder
import torch

def foo(x, w):
    return torch.nn.functional.linear(x, w)

device = torch.device("cuda")
with device:
    # Mixed input types.
    x, w = torch.randn(16, 16, dtype=torch.bfloat16), torch.randn(16, 16)

    # Same input types (works with thunder)
    # x, w = torch.randn(16, 16), torch.randn(16, 16)

    print(x.dtype, w.dtype)
    
with torch.autocast("cuda", torch.bfloat16):
    # Eager autocast handles mixed input types.
    eager_out = foo(x, w)

    # `thunder.jit` doesn't handle mixed inputs.    
    jfoo = thunder.jit(foo)
    jit_out = jfoo(x, w)


print(thunder.last_traces(jfoo)[-1])
torch.testing.assert_close(eager_out, jit_out)

@tfogal tfogal removed their assignment Jul 2, 2024
@tfogal
Copy link
Collaborator Author

tfogal commented Jul 2, 2024

I have been able to repro the failure with an independent script.

Great! Thank you, excellent work :-)

@kshitij12345
Copy link
Collaborator

The reason it fails currently is because, while tracing with thunder.jit -

  1. We first try to generate the computation trace based on the actual function/model.
  2. And we apply the autocast transform later on the above computation trace.

With mixed input dtypes, we fail at step 1 as these operators don't allow mixed inputs.

(In eager, with the context manager active, dispatcher first applies the conversion before passing the converted inputs to the operators).

Potential Fix:

  1. We should stash the detail that autocast is enabled in CompileData.
  2. Then while tracing at the place where we translate torch functions to thunder - we should check if autocast is active and if we have a rule for this function.

@t-vi I would like your opinion on the same and some pointers. Thank you!

@t-vi
Copy link
Collaborator

t-vi commented Jul 3, 2024

Great analysis @kshitij12345 !

For 1: We do have autocast handling in thunder.jit and cache_info.

cache_info["is_autocast_enabled"] = is_autocast_enabled

For 2: To my mind, this is a thunder.torch thing more than something specific to jit_ext, so I would probably look at trying to handle it in thunder.torch.torchsymbol

def __call__(self, fn: Callable) -> Symbol:

WDYT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
amp nemo Issues needed to support NVIDIA NeMo models.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants