Skip to content

support inferring dtype with torch.get_default_dtype for factory functions#775

Merged
t-vi merged 3 commits intoLightning-AI:mainfrom
kshitij12345:support-torch-default-dtype
Jul 16, 2024
Merged

support inferring dtype with torch.get_default_dtype for factory functions#775
t-vi merged 3 commits intoLightning-AI:mainfrom
kshitij12345:support-torch-default-dtype

Conversation

@kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Jul 16, 2024

Fixes: #750

Changes -

  1. Stash torch.get_default_dtype in cache_info. Also, this adds a check to the prologue trace to verify that jitted fn is called with same default dtype - see example prologue below.
  2. Factory functions infer the dtype based on torch.get_default_dtype from cache_info (if dtype is not passed explicitly)
  3. We don't support changing the default dtype in the jitted fn as reordering and fusion can lead to issue - it is a loud error for now (we can revisit in follow-up if required).

Repro:

import torch
import thunder

def foo(x: torch.Tensor) -> torch.Tensor:
    o = torch.zeros(x.shape, device=x.device)
    return o

jfoo = thunder.jit(foo)
o = jfoo(torch.randn(3, 3))
print(o.dtype)
print(thunder.last_prologue_traces(jfoo)[0])
print()
print(thunder.last_traces(jfoo)[0])

Prologue Trace

import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  prims.check_len(args, 1)
  # kwargs: "Any"
  prims.check_len(kwargs, 0)
  t_0: "cpu f32[3, 3]" = args[0]
  prims.check_tensor_metadata(t_0, (3, 3), 'cpu', torch.float32, False)
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  # NOTE - We bake the torch.dtype in trace (check_tensor_metadata also does the same).
  prims.check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  prims.check_number_type_and_value(cache_info_no_grad_sync, False)
  return ((), ())

Computation Trace

import thunder
import thunder.core.devices as devices
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  # /home/kkalambarkar/lightning-thunder/scratchpad/test.py:64:             o = torch.zeros(x.shape, device=x.device)
  o = ltorch.zeros((3, 3), device=devices.Device("cpu"), dtype=None)  # o: "cpu f32[3, 3]"
    # o = ltorch.full((3, 3), 0, device=devices.Device("cpu"), dtype=None)  # o: "cpu f32[3, 3]"
      # o = prims.full((3, 3), 0, device=devices.Device("cpu"), dtype=dtypes.float32)  # o: "cpu f32[3, 3]"
  return o

@kshitij12345 kshitij12345 marked this pull request as ready for review July 16, 2024 12:34
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supergood, thank you @kshitij12345

@t-vi t-vi merged commit 6703b35 into Lightning-AI:main Jul 16, 2024
@kshitij12345 kshitij12345 changed the title [WIP] support inferring dtype with torch.get_default_dtype for factory functions support inferring dtype with torch.get_default_dtype for factory functions Jul 16, 2024
@github-actions github-actions bot deleted the support-torch-default-dtype branch October 16, 2024 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

type inference: mismatched dtype in cat operator

2 participants