Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic registration of torch operators using FakeTensor #554

Merged
merged 58 commits into from
Aug 6, 2024
Merged

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Jun 7, 2024

What does this PR do?

Design doc: https://docs.google.com/document/d/1BVr_5co6GDj3wC7NrWq0ckfsbgX5l6w1dOSFBi2c-Hc/edit?usp=sharing

This PR did some experiments to use FakeTensor for our meta function to allow automatic fallback to torch operators when the op is not supposed to have a fusion backend.

Currently it can work for the example test cases(the sinc, hsplit, normalize, adaptive_avg_pool3d are ops currently not added in thunder ):

import torch
import thunder
from torch import nn

def func(a, out):
    t1 = torch.nn.functional.adaptive_avg_pool3d(a, output_size=out)
    return torch.nn.functional.normalize(t1)

def func1(a):
    x = torch.sinc(a)
    t = torch.arange(16.0).reshape(4,4)
    return torch.hsplit(t, 2), x


a = torch.randn((1, 64, 8, 9, 10)).cuda()
jfunc = thunder.jit(func)
out = jfunc(a, (5, 7, 9))
res = func(a, (5, 7, 9))
torch.testing.assert_close(out, res)

inp = torch.rand(2,2).cuda()
jfunc1 = thunder.jit(func1)
out = jfunc1(inp)
expected = func1(inp)
torch.testing.assert_close(out, expected)
import torch
import thunder
def func(a):
    x = torch.sinc(a)
    x1,x2= torch.hsplit(x, 2)
    return x1+x2

def func1(a, out):
    t1 = torch.nn.functional.adaptive_avg_pool3d(a, output_size=out)
    return torch.nn.functional.normalize(t1)


inp = torch.rand(4,4).cuda().requires_grad_()
inp_1 = inp.detach().clone().requires_grad_()


jfunc = thunder.jit(func)
out = jfunc(inp)
out.sum().backward()
print(thunder.last_traces(jfunc)[-1])
print(thunder.last_backward_traces(jfunc)[-1])


expected = func(inp_1)
expected.sum().backward()

torch.testing.assert_close(out, expected)
torch.testing.assert_close(inp.grad, inp_1.grad)


a = torch.randn((1, 64, 8, 9, 10)).cuda().requires_grad_()
a1 = a.detach().clone().requires_grad_()

res = func1(a1, (5, 7, 9))
res.sum().backward()

jfunc = thunder.jit(func1)
out = jfunc(a, (5, 7, 9))
out.sum().backward()
print(thunder.last_traces(jfunc)[-1])
print(thunder.last_backward_traces(jfunc)[-1])

torch.testing.assert_close(out, res)
torch.testing.assert_close(a.grad, a1.grad)

Fixes #811

cc: @IvanYashchuk

@riccardofelluga
Copy link
Collaborator

Hey this is pretty cool! I've tried it with this example that we had no op implemented and it worked!

import torch
import thunder

@thunder.jit
def func(x, y, z):
    t0 = torch.baddbmm(x, y, z)
    return t0

t0 = torch.randn(10, 3, 5)
b1 = torch.randn(10, 3, 4)
b2 = torch.randn(10, 4, 5)
out = func(t0, b1, b2)
print(thunder.last_traces(func)[-1])
print(out.size())

looking forward for this feature!

@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 11, 2024

currently output trace is like:

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[1, 64, 8, 9, 10]"
  t0 = torch.nn.functional.adaptive_avg_pool3d(a, output_size=(5, 7, 9))  # t0: "cuda:0 f32[1, 64, 5, 7, 9]"
  t1 = torch.nn.functional.normalize(t0)  # t1: "cuda:0 f32[1, 64, 5, 7, 9]"
  return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((a, t0), (5, 7, 9))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  a, t0, = C0
  clear_mutable_collection(C0)
  del C0
  i0, i1, i2, = C1
  clear_mutable_collection(C1)
  del C1
  (t14,) = normalize_vjp({'inputs': ((t0,), {}), 'func': _function_0}, t2)
  del t0, t2
  (t15,) = adaptive_avg_pool3d_vjp({'inputs': ((a,), {'output_size': (i0, i1, i2)}), 'func': _function_1}, t14)
  del a, i0, i1, i2, t14
  return (t15,)

@kiya00 kiya00 changed the title [no need for review][draft] Experiments of auto fallback to torch operator using FakeTensor [WIP] Experiments of auto fallback to torch operator using FakeTensor Jul 3, 2024
@kiya00 kiya00 requested a review from IvanYashchuk July 5, 2024 15:42
@kiya00 kiya00 force-pushed the faketensor branch 5 times, most recently from e2ea5c3 to 210ed60 Compare July 18, 2024 15:41
@t-vi
Copy link
Collaborator

t-vi commented Jul 20, 2024

First off, great job, @kiya00 !

One fundamental question I would have about the design is: given that we are talking about a finite set of operators here, would it be more transparent to autogenerate the code for them and keep them a file? That would have us need to keep up / update for new stuff PyTorch introduces, but in general, it would be less magic. When we implement something in a more detailed way, it would be removed from the autogeneration to manual.

I'm imagining that we could/would still use the bulk of your code, just the time of registration would move.

WDYT?

@kiya00 kiya00 force-pushed the faketensor branch 2 times, most recently from f6126fc to 88fe4b3 Compare July 30, 2024 12:43
thunder/tests/test_auto_register_torchops.py Outdated Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/tests/test_auto_register_torchops.py Outdated Show resolved Hide resolved
kiya00 and others added 2 commits August 2, 2024 16:47
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work!

There are only a few cosmetic changes to be done and then I think we should merge this.

thunder/core/pytree.py Show resolved Hide resolved
thunder/torch/default_torch_ops.py Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/torch/__init__.py Show resolved Hide resolved
thunder/torch/__init__.py Show resolved Hide resolved
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what an awesome work

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks really great! Just had a couple of suggestions, thank you @kiya00!

thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/tests/test_auto_register_torchops.py Outdated Show resolved Hide resolved
thunder/tests/test_auto_register_torchops.py Outdated Show resolved Hide resolved
thunder/tests/test_auto_register_torchops.py Outdated Show resolved Hide resolved
@kiya00
Copy link
Collaborator Author

kiya00 commented Aug 5, 2024

Hi @t-vi , I found the CI failure seems unrelated to this PR, and can only reproduce it locally by running pytest thunder/tests/test_inplace_functionalization.py thunder/tests/test_nvfuser_remat.py::test_find_cut_dropout_nvfuser_cuda_None together (Separate execution succeeds). I tried to fix it by #928

@t-vi t-vi enabled auto-merge (squash) August 6, 2024 07:12
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi merged commit 36f46dd into main Aug 6, 2024
37 checks passed
@t-vi t-vi deleted the faketensor branch August 6, 2024 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Automatic registration of unknown PyTorch operations
7 participants