Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type inference: mismatched dtype in cat operator #750

Closed
tfogal opened this issue Jul 10, 2024 · 4 comments · Fixed by #775
Closed

type inference: mismatched dtype in cat operator #750

tfogal opened this issue Jul 10, 2024 · 4 comments · Fixed by #775
Assignees
Labels
high priority nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage

Comments

@tfogal
Copy link
Collaborator

tfogal commented Jul 10, 2024

🚀 Model / language coverage

The following code results in a

RuntimeError: Expected dtype thunder.dtypes.float32 but found thunder.dtypes.int64_!

error in cats implementation. It seems we end up confused about the proper dtype of the second tensor.

#!python3
import torch
import thunder
import einops

def foo(input_ids, inputs_embeds):
  batch_size, sequence_length, hidden_size = inputs_embeds.shape

  media_features = torch.randn((2,1,1,256,5120), dtype=torch.float16)
  num_images_per_sample = media_features.size(1)
  num_patches = media_features.size(3) * media_features.size(2)

  media_end_id = 32005
  sorted_media_end_positions_mask, media_end_positions_mask_sort_idx = (
      # NOTE: to(torch.long) is needed because PyTorch does not have sort for boolean tensors on CUDA
      (input_ids == media_end_id).to(torch.long).sort(dim=-1, descending=True, stable=True)
  )

  padded_media_indices = torch.where(
    sorted_media_end_positions_mask.to(torch.bool),
    media_end_positions_mask_sort_idx - num_patches + 1,
    sequence_length
  )
  padded_media_indices = padded_media_indices.unsqueeze(-1) + torch.arange(
    num_patches, device=padded_media_indices.device
  ).repeat(*padded_media_indices.shape, 1)
  padded_media_indices = padded_media_indices.reshape(batch_size, -1)
  padded_media_indices = einops.repeat(padded_media_indices, 'b s -> b s h', h=hidden_size)

  second = torch.zeros((batch_size, num_patches, hidden_size), device=inputs_embeds.device)
  # Note: thunder can be made to work by explicitly setting the dtype:
  #   second = torch.zeros((batch_size, num_patches, hidden_size), dtype=torch.float32, device=inputs_embeds.device)
  #print(f"ii dt:shape={inputs_embeds.dtype}:{inputs_embeds.shape}")
  #print(f"2nd dt:shape={second.dtype}:{second.shape}")
  updated_input_embeds = torch.cat(
    (inputs_embeds, second), dim=1
  )
  return updated_input_embeds

at = torch.zeros((2,384), dtype=torch.int64)
bt = torch.randn((2,384, 5120), dtype=torch.float32)

foo(at, bt)

thfoo = thunder.jit(foo)
thfoo(at, bt)

As the comment in the zeros line indicates, thunder can be coerced into compiling this by explicitly adding a dtype to the zeros call. However, it seems the bug is more global than just zeros, as our zeros works perfectly on its own:

>>> def z1(x: torch.Tensor) -> torch.Tensor:
...     return torch.zeros([2,1,2], device=x.device)
...
>>> abc
tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])
>>> abc.shape
torch.Size([2, 2, 2])
>>> abc.dtype
torch.float32
>>> z1(abc).dtype
torch.float32
>>> th_z1 = thunder.jit(z1)
>>> th_z1(abc).dtype
torch.float32
>>> assert z1(abc).dtype == th_z1(abc).dtype
>>> 

Pitch

This came about while using Nik's patch to try to get #343 to work. Nik and I still need some iteration on his patch, so there's no guarantee that this will be the next bug after #124, but it's plausibly a blocker.

cc @apaz-cli @tfogal

@tfogal tfogal added the program-coverage Requests for model and program coverage label Jul 10, 2024
@tfogal tfogal added nemo Issues needed to support NVIDIA NeMo models. triage review high priority labels Jul 10, 2024
@t-vi
Copy link
Collaborator

t-vi commented Jul 11, 2024

Note that PyTorch upcasts automatically when given tensors of varying dtype while Thunder currently errors. When I tried to add this (clumsily) #41 , it seemed that I hit some inconsistency in torch eager vs. compile.

@tfogal
Copy link
Collaborator Author

tfogal commented Jul 11, 2024

Note that PyTorch upcasts automatically when given tensors of varying dtype while Thunder currently errors.

Ahh, yeah, I suspected something is off there; thanks for the confirmation!

But I think something more insidious is going on here---when run in eager, the types match. i.e.: print(f"dtypes: {inputs_embeds.dtype}, {second.dtype}") says 'float32' twice in eager mode, but 'float32, int64' in thunder.

If we were to actually do #41, it should get us through this but would actually end up masking the deeper bug.

@tfogal tfogal changed the title Mismatched dtype in cat operator type inference: mismatched dtype in cat operator Jul 11, 2024
@kshitij12345
Copy link
Collaborator

kshitij12345 commented Jul 15, 2024

The actual issue here is that the factory functions like zeros and ones rely on full which infers it's dtype based on fill value (when dtype is not passed explicitly)

# Infers dtype from the fill_value when not explicitly provided
if dtype is None:
dtype = dtypes.numbertype_to_dtype(dtypes.to_dtype(fill_value))

Also, this is hidden during execution with torchex as it does the correct thing of reading the value from torch.get_default_dtype-

_register_implementation(ltorch.zeros, checker=_always_executable, execution_transform=_zeros_transform)

Minimal Repro (output is float but in trace we see that proxy has integer dtype):

import torch
import thunder

def foo(x: torch.Tensor) -> torch.Tensor:
    o = torch.zeros((2,1,2), device=x.device)
    return o

jfoo = thunder.jit(foo)
o = jfoo(torch.randn(3, 3))
print(o.dtype)
print(thunder.last_traces(jfoo)[0])

Output

torch.float32

import thunder
import thunder.core.devices as devices
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  # /home/kkalambarkar/lightning-thunder/scratchpad/test.py:63:             o = torch.zeros((2,1,2), device=x.device)
  o = ltorch.zeros((2, 1, 2), device=devices.Device("cpu"), dtype=None)  # o: "cpu i64[2, 1, 2]"
    # o = ltorch.full((2, 1, 2), 0, device=devices.Device("cpu"), dtype=None)  # o: "cpu i64[2, 1, 2]"
      # o = prims.full((2, 1, 2), 0, device=devices.Device("cpu"), dtype=dtypes.int64_)  # o: "cpu i64[2, 1, 2]"
  return o

I think this is a duplicate of #621

@tfogal
Copy link
Collaborator Author

tfogal commented Jul 15, 2024

triage review:

  • subtlety: do we want to record the default dtype at call and re-trace when it changes? (put it in CacheInfo, record dtypes in calls to factory primitives). Yes, agreement that this is the way to go.
  • another option is to query the default type during execution (this would be super hard, because dtypes are constexpr today)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage
Projects
None yet
3 participants