Skip to content

Commit

Permalink
Merge branch 'main' into te_ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Apr 3, 2024
2 parents bf6d275 + 6babe4e commit 413e1f5
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 8 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/release-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ jobs:

- name: Install dependencies
run: python -m pip install --user --upgrade setuptools wheel
- name: Build
- name: Build package
env:
CONVERT_VERSION2NIGHTLY: "1"
run: python setup.py sdist bdist_wheel

# We do this, since failures on test.pypi aren't that bad
- name: Publish to Test PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
uses: pypa/gh-action-pypi-publish@v1.8.14
with:
user: __token__
password: ${{ secrets.test_pypi_password }}
repository_url: https://test.pypi.org/legacy/

- name: Publish distribution 📦 to PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
uses: pypa/gh-action-pypi-publish@v1.8.14
with:
user: __token__
Expand Down
82 changes: 82 additions & 0 deletions docs/source/basic/faq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
Thunder FAQ
################

=================================================
1. How does Thunder compare to ``torch.compile``?
=================================================

Both `Thunder <https://github.com/Lightning-AI/lightning-thunder>`_ and `torch.compile <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler-overview>`_ are both deep learning compilers that take a pytorch module or callable, and return a callable. It seems reasonable to compare them that way. With that said, the focus of the projects are completely different.

Torch compile is a framework for generating optimized kernels. Its focus is on making pytorch code run faster by generating optimized kernels, with minimal code changes. Thunder is a framework for layering optimizations. It generates none of these optimizations itself, instead delegating to other libraries, including ``torch.compile`` and `nvfuser <https://github.com/NVIDIA/Fuser>`_. Our focus is on understandability, extendability, and usability.

Modern deep learning optimization often involves stitching various kernels together, often from different sources. Thunder is designed to make it easy to use these tools together, and easy to add new tools to the mix.

As such, the two are not necessarily comparable. Or, they are, but you would be comparing against a default configuration, which we expect you to extend and change anyway.



============================================
2. How can I use torch.compile with Thunder?
============================================

The correct way to use ``torch.compile`` with Thunder is to use the executor. This way, you get finer grained control over which parts of the model are handled by which executor.

Calling ``torch.compile()`` and then ``thunder.jit()`` is not what you want. It doesn't give a good error message right now, but it should not work.

Instead, register the executor like so::

import torch
import thunder

def model(x, y):
return x + y

exc_list = [thunder.extend.get_executor('torchcompile'), *thunder.get_always_executors()]
jmodel = thunder.jit(model, executors=exc_list)


====================================================================================
3. I have a CUDA, Triton, CUDNN, or other gpu kernel. How can I use it with thunder?
====================================================================================

Why, yes! You can register it as an operator for a custom executor. See :doc:`extending thunder <../intermediate/additional_executors>` for more information.


========================================================================
3. Do you support custom hardware, or accelerators that aren't Nvidia's?
========================================================================

Yes, executors are device-agnostic. The python executor for example runs the operation with cpython on the cpu. We've been focusing on the Nvidia stack, but Thunder is designed to be extensible so you can write your own executors for any backend. Just make an executor and register the operators you need. We welcome contributions for executors for other accelerator backends.


=================================================================
4. I ran ``thunder.jit(model)(*args)``, and my model didn't work!
=================================================================

Thunder is in alpha. There will be bugs, and many torch operations are not supported. Try to run ``from thunder.examine import examine; examine(model, *args)``. This will list the operations which are not supported, and if they are all supported, test the model for consistency against torch eager.

If you need certain operations supported for your model, please let us know by creating an issue. We plan to get to all of them (with the exception of any :doc:`sharp edges <sharp_edges>`), but your issues help us prioritize which to do first.

There are potentially any number of other problems which could arise. Some of the problems are known, some may not be. Check out the :doc:`sharp edges <sharp_edges>` page. If what you're seeing still doesn't make sense, let us know by creating an issue.


=======================================
5. Does Thunder support dynamic shapes?
=======================================

No, not at the moment. However, we're actively working on it.

Meta functions operate on the exact shapes of the tensor proxies that pass through them. This is a limitation of the current implementation, and we plan to incorporate dynamic shapes in the future. If you have relevant experience experience with this problem in pytorch, or you are interested in helping us implement it, please let us know by creating an issue or reaching out.


================================================================
6. Does Thunder support inplace operations?
================================================================

Not at the moment. Implementing inplace operations would require tracking which tensors in a trace have been modified by operations in our optimization passes, which currently we represent as purely functional. All deep learning compiler frameworks have to deal with the problem of tensor aliasing in some way. The way we've chosen for now is to pretend that the problem doesn't exist.

The common solution is to represent programs in `SSA form <https://en.wikipedia.org/wiki/Static_single-assignment_form>`_, or do some form of SSA-inspired variable renaming, but SSA is a much less understandable representation than a list of symbols in a trace. Switching to SSA would also complicate optimization passes, and require rewriting many of them to handle these aliasing rules.

There also exists the problem that some backend executors do not support in-place operations. We have some ideas on how to functionalize ops for these executors, but some api issues are unresolved.

We want to support inplace operations eventually, but we are attached to traces as our program representation of choice for optimization passes. Much like with dynamic shapes, if you have relevant experience on how to best incorporate inplace operations without complicating optimization passes, come talk to us about it.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik
The sharp edges <basic/sharp_edges>
Train a MLP on MNIST <basic/mlp_mnist>
Functional jit <notebooks/functional_jit>
FAQ <basic/faq>

.. toctree::
:maxdepth: 1
Expand Down
12 changes: 11 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,17 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable:
def is_from_torch(fn):
return hasattr(fn, "__module__") and fn.__module__ and fn.__module__.startswith("torch")

if is_opaque(fn) and is_from_torch(fn):
has_tensor_arg = False
for a in args:
if isinstance(a.value, TensorProxy):
has_tensor_arg = True
break
if isinstance(a.value, Sequence):
if any(isinstance(i, TensorProxy) for i in a.value):
has_tensor_arg = True
break

if is_opaque(fn) and is_from_torch(fn) and has_tensor_arg:
if fn.__module__.startswith("torch._C"):
return lookaside

Expand Down
21 changes: 21 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,27 @@ def _sum_prim_grad(a: TensorProxy, /, dims: Sequence[int]) -> TensorProxy:
register_grad(pids.SUM, _sum_prim_grad)


@torchctx
def _topk_prim_grad(
a: TensorProxy, /, k: int, dim: None | int = None, largest: bool = True, sorted: bool = True, *, out=None
):
fwd = prims.topk(a, k, dim, largest, sorted, out=out)
val, idx = fwd

val_grad = get_grad(val)

a_grad = ltorch.zeros_like(a)
# TODO: replace with scatter once we have it.
# scatter_add is a prim and it relies on atomic ops.
a_grad = ltorch.scatter_add(a_grad, dim, idx, val_grad)
put_grad(a, a_grad)

return fwd


register_grad(pids.TOPK, _topk_prim_grad)


# TODO Fix division by zero when n_elem_reduced == 0 or when mean.numel == 0
# by returning zeros_like(a) or similar.
# TODO Fix grad when correction > n_elem_reduced.
Expand Down
33 changes: 31 additions & 2 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4761,12 +4761,41 @@ def topk_error_generator(op, device, **kwargs):
yield (SampleInput(make(3, 3), 1, -3), IndexError, err_msg)


# Phantom grad tests do not handle tensor outputs
# that do not require grad and/or do not have grad_fn.
# Therefore we explicitly filter outputs.
# See https://github.com/Lightning-AI/lightning-thunder/issues/119 {
def topk_thunder_ref(*args, **kwargs):
return clang.topk(*args, **kwargs)[0]


def topk_torch_ref(*args, **kwargs):
return torch.topk(*args, **kwargs)[0]


# }


topk_opinfo = OpInfo(
clang.topk,
topk_thunder_ref,
name="topk",
supports_grad=True,
# Without the fixed seed this generator does not guarantee
# to produce inputs at which topk is differentiable
# (i.e. when topk(x, ...).indices == topk(x + dx, ...).indices).
# TODO: (@nikitaved): potentially modify these inputs to
# fix the issue.
sample_input_generator=topk_sample_generator,
error_input_generator=topk_error_generator,
torch_reference=torch.topk,
torch_reference=topk_torch_ref,
dtypes=(datatypes.signedinteger, datatypes.unsignedinteger, datatypes.floating),
test_directives=(
DecorateInfo(
# See https://github.com/Lightning-AI/lightning-thunder/issues/120
pytest.mark.skip(reason="Cannot handle inputs/outputs which do not require grads"),
"test_vjp_correctness",
),
),
)
reduction_ops.append(topk_opinfo)

Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,8 @@ def test_populate_grads_nanogpt(executor, device, dtype):

from thunder.benchmarks import NanoGPTBenchmark, NanoGPTConfig

# NOTE Currently setting dropout to zero for reproducibility, other settings taken from gpt2 config
config = NanoGPTConfig(dropout=0, n_layer=12, n_head=12, n_embd=768)
# NOTE Currently setting dropout to zero for reproducibility
config = NanoGPTConfig(dropout=0, n_layer=2, n_head=1, n_embd=64)

bench = NanoGPTBenchmark(config=config, requires_grad=True, device=device, dtype=dtype)
model = bench.fn()
Expand Down
8 changes: 8 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,14 @@ def foo(a, b, i):
assert_close(expected, actual)


# see https://github.com/Lightning-AI/lightning-thunder/issues/95
def test_get_default_dtype():
def foo():
return torch.get_default_dtype()

assert foo() == thunder.jit(foo)()


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
Expand Down

0 comments on commit 413e1f5

Please sign in to comment.