Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

factory functions - fix handling of default device #820

Merged
merged 5 commits into from
Jul 22, 2024

Conversation

kshitij12345
Copy link
Collaborator

Fixes #621

Changes are similar to #775

  • Stash torch.get_default_device in cache_info. Also, this adds a check to the prologue trace to verify that jitted fn is called with same default device- see example prologue below.
  • Factory functions infer the device based on torch.get_default_device from cache_info (if device is not passed explicitly)
  • We don't support changing the default device in the jitted fn as reordering and fusion can lead to issue - it is a loud error for now (we can revisit in follow-up if required).

Repro:

import torch
import thunder

def foo(x):
    return torch.ones(x.shape).device

x = torch.randn(3)

torch.set_default_device("cuda")
print(foo(x))
jfoo = thunder.jit(foo)
print(jfoo(x))

print(thunder.last_prologue_traces(jfoo)[-1])
print(thunder.last_traces(jfoo)[-1])

Prologue Trace:

# Constructed by Transform for execution (took 1 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  check_len(args, 1)
    # prims.check_len(args, 1)
  # kwargs: "Any"
  check_len(kwargs, 0)
    # prims.check_len(kwargs, 0)
  t_0: "cpu f32[3]" = args[0]
  check_tensor_metadata(t_0, (3,), 'cpu', torch.float32, False)
    # prims.check_tensor_metadata(t_0, (3,), 'cpu', torch.float32, False)
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  check_literal_like(cache_info_default_dtype, torch.float32)
    # prims.check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
  check_literal_like(cache_info_default_device, torch.device("cuda:0"))
    # prims.check_literal_like(cache_info_default_device, torch.device("cuda:0"))
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  check_number_type_and_value(cache_info_is_autocast_enabled, False)
    # prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  check_number_type_and_value(cache_info_no_grad_sync, False)
    # prims.check_number_type_and_value(cache_info_no_grad_sync, False)
  return ((), ())

Comp Trace

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  return torch.device("cuda:0")

@kshitij12345
Copy link
Collaborator Author

Current stable PyTorch is 2.3. For mac, looks like we install - torch-2.2.2 - https://github.com/Lightning-AI/lightning-thunder/actions/runs/10040430679/job/27746469483?pr=820#step:9:311 which may have some issue with torch.get_default_device (see pytorch/pytorch#126632)

@t-vi
Copy link
Collaborator

t-vi commented Jul 22, 2024

Ugh, I'll look into updating, should not keep you back.

@kshitij12345 kshitij12345 marked this pull request as ready for review July 22, 2024 16:51
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As always, a pleasure to read. Thank you @kshitij12345

@t-vi t-vi merged commit 85df4f6 into Lightning-AI:main Jul 22, 2024
39 checks passed
@github-actions github-actions bot deleted the default-device-handling branch October 23, 2024 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

use torch.get_default_dtype and torch.get_default_device for factory method in thunder/torch/__init__.py
2 participants