Skip to content

support inferring dtype with torch.get_default_dtype for factory functions #775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 16, 2024

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Jul 16, 2024

Fixes: #750

Changes -

  1. Stash torch.get_default_dtype in cache_info. Also, this adds a check to the prologue trace to verify that jitted fn is called with same default dtype - see example prologue below.
  2. Factory functions infer the dtype based on torch.get_default_dtype from cache_info (if dtype is not passed explicitly)
  3. We don't support changing the default dtype in the jitted fn as reordering and fusion can lead to issue - it is a loud error for now (we can revisit in follow-up if required).

Repro:

import torch
import thunder

def foo(x: torch.Tensor) -> torch.Tensor:
    o = torch.zeros(x.shape, device=x.device)
    return o

jfoo = thunder.jit(foo)
o = jfoo(torch.randn(3, 3))
print(o.dtype)
print(thunder.last_prologue_traces(jfoo)[0])
print()
print(thunder.last_traces(jfoo)[0])

Prologue Trace

import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  prims.check_len(args, 1)
  # kwargs: "Any"
  prims.check_len(kwargs, 0)
  t_0: "cpu f32[3, 3]" = args[0]
  prims.check_tensor_metadata(t_0, (3, 3), 'cpu', torch.float32, False)
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  # NOTE - We bake the torch.dtype in trace (check_tensor_metadata also does the same).
  prims.check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  prims.check_number_type_and_value(cache_info_no_grad_sync, False)
  return ((), ())

Computation Trace

import thunder
import thunder.core.devices as devices
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  # /home/kkalambarkar/lightning-thunder/scratchpad/test.py:64:             o = torch.zeros(x.shape, device=x.device)
  o = ltorch.zeros((3, 3), device=devices.Device("cpu"), dtype=None)  # o: "cpu f32[3, 3]"
    # o = ltorch.full((3, 3), 0, device=devices.Device("cpu"), dtype=None)  # o: "cpu f32[3, 3]"
      # o = prims.full((3, 3), 0, device=devices.Device("cpu"), dtype=dtypes.float32)  # o: "cpu f32[3, 3]"
  return o

@kshitij12345 kshitij12345 marked this pull request as ready for review July 16, 2024 12:34
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supergood, thank you @kshitij12345

@t-vi t-vi merged commit 6703b35 into Lightning-AI:main Jul 16, 2024
39 checks passed
@kshitij12345 kshitij12345 changed the title [WIP] support inferring dtype with torch.get_default_dtype for factory functions support inferring dtype with torch.get_default_dtype for factory functions Jul 16, 2024
@github-actions github-actions bot deleted the support-torch-default-dtype branch October 16, 2024 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

type inference: mismatched dtype in cat operator
2 participants