Skip to content

Commit

Permalink
Merge branch 'main' into tfogal/nvtx
Browse files Browse the repository at this point in the history
  • Loading branch information
tfogal committed Oct 30, 2024
2 parents 6e7953d + 9c916d9 commit 8d803ae
Show file tree
Hide file tree
Showing 17 changed files with 307 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v3.18.0
rev: v3.19.0
hooks:
- id: pyupgrade
args: ["--py310-plus"]
Expand Down
8 changes: 8 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
check_inplace_to_views,
functionalize_inplace_ops,
)
from thunder.core.recipe import Recipe, Lookaside
from thunder.common import (
CompileData,
CompileStats,
Expand Down Expand Up @@ -265,6 +266,13 @@ def _recursive_jit_call_warning() -> None:
)


def compile(fn: Callable, recipe: Recipe | None):
if recipe is None:
return thunder.jit(fn)

return recipe.apply(fn)


# This function will replace compile() (below) before RC1
# TODO RC1 Consider adding a debug_log parameter to control debug printing
# TODO RC1 Consider renaming compile_options to additional_compile_options
Expand Down
5 changes: 3 additions & 2 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None

print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms")
print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB")
print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}")
print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
if "tokens_per_sec" in benchmark.perf_metrics:
print(f"Tokens/s: {benchmark.perf_metrics.get['tokens_per_sec']:.02f}")
print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")

if benchmark.dump_memory_snapshot:
Expand Down
3 changes: 1 addition & 2 deletions thunder/benchmarks/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
NanoGPTConfig,
NanoGPTBenchmark,
LitGPTBenchmark,
LitGPTConfig,
)
from thunder.tests.litgpt_model import name_to_config
from thunder.tests.litgpt_model import name_to_config, Config as LitGPTConfig
from thunder.distributed import FSDPBucketingStrategy
from thunder.distributed import FSDPType

Expand Down
5 changes: 2 additions & 3 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@
# of where values originated. This mode is used by the general jit. This is done by
# wrapping all values in WrappedValues.
#
# Both thunder.jit and thunder.functional.jit use extensions (in jit_ext.py) to
# create Thunder Programs. They use callbacks and additional lookasides to
# add their functionality.
# thunder.jit uses extensions (in jit_ext.py) to create Thunder Programs.
# They use callbacks and additional lookasides to add their functionality.
#
# The Thunder program constructed has three parts, a "prologue trace", a
# "computation trace", and (optionally) an "epilogue trace". The prologue trace has
Expand Down
77 changes: 77 additions & 0 deletions thunder/core/recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import List, Dict

from thunder.core.transform_common import Transform
from thunder.extend import Executor, get_default_executors

import torch


class Lookaside:
def __init__(self, fn, replace_with):
self._fn = fn
self._replace_with = replace_with


class Recipe:
# thunder.jit | torch.compile
compiler = "thunder.jit"

def __init__(self):
pass

def validate(self, model):
# this is supposed to raise
pass

# def setup_operators(self) -> List[Operator]:
# # this is for registering custom kernels on the fly
# return None

def setup_lookasides(self) -> list[Lookaside] | None:
return None

def setup_transforms(self) -> list[Transform] | None:
return None

def setup_executors(self):
return get_default_executors()

def setup_config(self) -> dict:
return {}

def apply(self, model):
self.validate(model)

self.config = self.setup_config()
lookasides = self.setup_lookasides()

from thunder.core import jit_ext, interpreter

if lookasides is not None:
for lookaside in lookasides:
wrapped_replacement_fn = interpreter.interpreter_needs_wrap(lookaside._replace_with)
jit_ext._general_jit_lookaside_map[lookaside._fn] = wrapped_replacement_fn

self.lookasides = lookasides
self.executors = self.setup_executors()
self.transforms = self.setup_transforms()

if self.compiler == "thunder.jit":
from thunder import jit

thunder_model = jit(model, transforms=self.transforms, executors=self.executors, **self.config)

elif self.compiler == "torch.compile":
from thunder.dynamo import ThunderCompiler

thunder_backend = ThunderCompiler(transforms=self.transforms, executors=self.executors, **self.config)
thunder_model = torch.compile(model, backend=thunder_backend)

else:
raise AttributeError(f"Compiler must be one of 'thunder.jit', 'torch.compile'. Found: {self.compiler}.")

return thunder_model


class DynamoRecipe(Recipe):
compiler = "torch.compile"
20 changes: 20 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,26 @@ def _maximum_grad(a: TensorProxy, b: TensorProxy, /):
# This operation creates no grad associations
register_grad(pids.SHAPE, prims.shape)


def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):
fwd = prims.copy_with_setitem(a, index, value)
g = get_grad(fwd)

a_grad = prims.copy_with_setitem(g, index, 0)
put_grad(a, a_grad)

if isinstance(value, TensorProxy):
value_grad = g[index]
expanded_dims = value_grad.ndim - value.ndim
if expanded_dims > 0:
value_grad = prims.sum(value_grad, tuple(range(expanded_dims)))
put_grad(value, value_grad)

return fwd


register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad)

#
# Phantom grad transform helpers
#
Expand Down
6 changes: 6 additions & 0 deletions thunder/executors/apex_fused_rms_norm_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

APEX_FUSED_NORMS_AVAILABLE = True
try:
# Fused layer norm is only importable if torch.distributed is available
# https://github.com/NVIDIA/apex/issues/1853
from torch.distributed import is_available

if not is_available():
raise ImportError
import fused_layer_norm_cuda
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
except ImportError:
Expand Down
19 changes: 19 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,29 @@ def forward(
ctx.saved_other = saved_other
ctx.compiled_backward = compiled_backward

# NOTE [Saved view of output of torch.autograd.Function leaks]
# We detach here to avoid a bug in PyTorch where
# it leaks memory if view of the output of torch.autograd.Function
# is saved for backward.
# See - https://github.com/pytorch/pytorch/issues/94990#issuecomment-1435181804
# NOTE - Detaching here would lead to problem with higher order differentiation but
# this is ok for now because ThunderFunction is only `once_differentiable`.
def detach_if_tensor(t):
# Some operations may claim to return Tensor (as per their meta function)
# but may return None at Runtime (eg. noticed this for sdpa)
if isinstance(t, torch.Tensor):
return t.detach()
return t

saved_tensors = tuple(map(detach_if_tensor, saved_tensors))

# We must save tensors using ctx.save_for_backward
ctx.save_for_backward(*saved_tensors)
return flat_output

# NOTE: If `torch.autograd.function.once_differentiable` is to be removed,
# one must take care of correctly removing the `detach_if_tensor` above.
# For more context, see NOTE [Saved view of output of torch.autograd.Function leaks] above.
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *args):
Expand Down
Empty file added thunder/recipes/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions thunder/recipes/hf_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import thunder
import transformers
import torch


class HFBertBasic(thunder.Recipe):
def validate(self, model):
if not isinstance(model, transformers.BertForSequenceClassification):
raise ValueError("The model must be a BertForSequenceClassification")

def setup_lookasides(self):
warn_lookaside = thunder.Lookaside(
fn=transformers.modeling_utils.PreTrainedModel.warn_if_padding_and_no_attention_mask,
replace_with=lambda *args: None,
)

if hasattr(torch, "compiler") and hasattr(torch.compiler, "is_compiling"):
is_compiling = torch.compiler.is_compiling
else:
is_compiling = torch._dynamo.is_compiling

is_compiling_lookaside = thunder.Lookaside(fn=is_compiling, replace_with=lambda *_: True)

return [warn_lookaside, is_compiling_lookaside]
11 changes: 10 additions & 1 deletion thunder/tests/test_apex_fused_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
from torch.testing import assert_close

fused_layer_norm_cuda = pytest.importorskip("fused_layer_norm_cuda")
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction

from torch.distributed import is_available
from thunder.executors.apexex import apex_ex
import thunder


# See https://github.com/NVIDIA/apex/issues/1853
@pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
@pytest.mark.parametrize("requires_grad", [True, False])
@pytest.mark.parametrize("memory_efficient", [True, False])
def test_apex_fused_rms_norm(requires_grad, memory_efficient):
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction

def fn(x, weight, normalized_shape, eps):
return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)

Expand All @@ -34,9 +39,13 @@ def fn(x, weight, normalized_shape, eps):
assert_close(actual_grad, expected_grad)


# See https://github.com/NVIDIA/apex/issues/1853
@pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
@pytest.mark.parametrize("requires_grad", [True, False])
@pytest.mark.parametrize("memory_efficient", [True, False])
def test_apex_fused_rms_norm_autoregister(requires_grad, memory_efficient):
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction

def fn(x, weight, normalized_shape, eps):
return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)

Expand Down
39 changes: 39 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2981,3 +2981,42 @@ def bar(a: torch.Tensor):
baz(torch.randn(19))

foo()


def test_saved_view_of_output_of_autograd_function_does_not_leak():
# Verify that we have side-stepped the bug in torch.autograd.Function
# where saving a view of the output for backward leads to leak.
# See NOTE [Saved view of output of torch.autograd.Function leaks]
def fn(idx, weight):
tok_emb = torch.nn.functional.embedding(idx, weight)
emb = torch.reshape(tok_emb, (2, 32))
matmul = emb @ emb.T
return tok_emb, matmul

weight = make_tensor((16, 32), dtype=torch.float, device="cpu", requires_grad=True)
x = make_tensor((1, 2), dtype=torch.int64, low=0, high=10, device="cpu")

jfn = thunder.jit(fn)

# Computation Trace for jfn
# We save view of the output `tok_emb` for backward.
# @torch.no_grad()
# @no_autocast
# def computation(idx, t_wte_weight):
# # idx: "cuda:0 i64[1, 2]"
# # t_wte_weight: "cuda:0 f32[16, 32]"
# tok_emb = torch.nn.functional.embedding(idx, t_wte_weight, None, None, 2.0, False, False) # tok_emb: "cuda:0 f32[1, 2, 32]"
# [emb, t4] = nvFusion0(tok_emb)
# # emb = prims.reshape(tok_emb, (2, 32)) # emb: "cuda:0 f32[2, 32]"
# # t4 = prims.transpose(emb, (1, 0)) # t4: "cuda:0 f32[32, 2]"
# matmul = torch.matmul(emb, t4) # matmul: "cuda:0 f32[2, 2]"
# return {'output': (tok_emb, matmul), 'flat_args': [idx, t_wte_weight], 'flat_output': (tok_emb, matmul)}, ((emb, idx, t4), ())

prev_iter_refs = []
for iter_n in range(4):
tok_emb, _ = jfn(x, weight)
if iter_n < 3:
prev_iter_refs.append(weakref.ref(tok_emb))

for ref in prev_iter_refs:
assert ref() is None
3 changes: 2 additions & 1 deletion thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,8 @@ def f(x):
# With activation checkpointing, we are saving only the original input.
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.saved_tensors) == 1
assert out.grad_fn.saved_tensors[0] is x
# We detach the saved tensors (which returns a new Python tensor backed by same storage)
assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr()

g = torch.ones_like(out)
out.backward(g)
Expand Down
6 changes: 4 additions & 2 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,13 @@ def test_nanogpt():
("cpu", "cuda", "meta"),
)
def test_litgpt_variants(name, device):
from litgpt.config import Config
from thunder.tests.litgpt_model import Config
from litgpt.model import GPT

if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
if device == "cuda" and name == "falcon-7b-like":
pytest.skip("NVFuser reenable when https://github.com/NVIDIA/Fuser/issues/3292 is fixed")

device = torch.device(device)

Expand Down Expand Up @@ -733,7 +735,7 @@ def test_litgpt_variants(name, device):
("cpu", "cuda"),
)
def test_litgpt_variants_kvcache(name, device):
from litgpt.config import Config
from thunder.tests.litgpt_model import Config
from litgpt.model import GPT
import torch._dynamo # this monkeypatches torch.manual_seed

Expand Down
Loading

0 comments on commit 8d803ae

Please sign in to comment.