Skip to content

Commit ca40e14

Browse files
authored
Merge branch 'main' into cudnn/default
2 parents fdf1248 + 2fa5cab commit ca40e14

22 files changed

+384
-57
lines changed

notebooks/zero_to_thunder.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3389,7 +3389,8 @@
33893389
}
33903390
],
33913391
"source": [
3392-
"!torchrun --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py"
3392+
"# commented out for CI limitations, see https://github.com/Lightning-AI/lightning-thunder/issues/465\n",
3393+
"# !torchrun --standalone --nnodes=1 --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py"
33933394
]
33943395
},
33953396
{

thunder/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
"int32",
9393
"int64",
9494
"bfloat16",
95+
"float8_e5m2",
96+
"float8_e5m2fnuz",
97+
"float8_e4m3fn",
98+
"float8_e4m3fnuz",
9599
"float16",
96100
"float32",
97101
"float64",
@@ -130,6 +134,10 @@ def __version__():
130134
int32 = dtypes.int32
131135
int64 = dtypes.int64
132136
bfloat16 = dtypes.bfloat16
137+
float8_e5m2 = dtypes.float8_e5m2
138+
float8_e5m2fnuz = dtypes.float8_e5m2fnuz
139+
float8_e4m3fn = dtypes.float8_e4m3fn
140+
float8_e4m3fnuz = dtypes.float8_e4m3fnuz
133141
float16 = dtypes.float16
134142
float32 = dtypes.float32
135143
float64 = dtypes.float64
@@ -328,14 +336,17 @@ def jit(
328336
assert type(record_history) is bool
329337

330338
# TODO RC1 Refine the compile data option to remove unused options
339+
# TODO: refine options
340+
# NOTE(fixme): use_cudagraphs is being absorbed into compile_options
341+
use_cudagraphs = compile_options.get("use_cudagraphs", False)
331342
cd = CompileData(
332343
fn=fn,
333344
langctx=langctx,
334345
executors_list=executors,
335346
cache_option=cache,
336347
sharp_edges=sharp_edges,
337348
using_jit=True,
338-
use_cudagraphs=False,
349+
use_cudagraphs=use_cudagraphs,
339350
disable_torch_autograd_support=disable_torch_autograd,
340351
use_rematerialization=False,
341352
only_execute_prims=False,
@@ -587,6 +598,12 @@ def get_computation_and_inputs(*args, **kwargs):
587598
else:
588599
backward_fn = None
589600

601+
# TODO: using vanilla CUDAGraphExecutor is not safe unless the graph is always static!
602+
# (fixme): inspect torch.cuda.make_graph_callables and/or use it instead!
603+
# See https://github.com/Lightning-AI/lightning-thunder/issues/433
604+
if cd.use_cudagraphs:
605+
comp = CUDAGraphExecutor(comp)
606+
590607
# TODO RC1 Update the cache
591608
cache_entry = CacheEntry(
592609
pro,

thunder/clang/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1948,7 +1948,7 @@ def argmin(a: TensorProxy, /, dim: int | None = None, keepdim: bool | None = Fal
19481948
@clangop()
19491949
def topk(
19501950
a: TensorLike, /, k: int, dim: int | None = None, largest: bool = True, sorted: bool = True, *, out=None
1951-
) -> (TensorProxy, TensorProxy):
1951+
) -> tuple[TensorProxy, TensorProxy]:
19521952
if dim is None:
19531953
dim = a.ndim - 1 if a.ndim > 0 else 0
19541954
dim = utils.canonicalize_dim(a.ndim, dim)

thunder/core/baseutils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@ def indent(level):
304304
torch.int32: "torch.int32",
305305
torch.int64: "torch.int64",
306306
torch.bfloat16: "torch.bfloat16",
307+
torch.float8_e4m3fn: "torch.float8_e4m3fn",
308+
torch.float8_e4m3fnuz: "torch.float8_e4m3fnuz",
309+
torch.float8_e5m2: "torch.float8_e5m2",
310+
torch.float8_e5m2fnuz: "torch.float8_e5m2fnuz",
307311
torch.float16: "torch.float16",
308312
torch.float32: "torch.float32",
309313
torch.float64: "torch.float64",

thunder/core/dtypes.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ def __new__(cls, *args, **kwargs):
5959

6060
return object.__new__(cls)
6161

62-
def __init__(self, *, python_type, name, shortname, bytes, is_weak):
62+
def __init__(self, *, python_type, name, shortname, bytes, is_weak, variant=None):
6363
self._python_type = python_type
6464
self._name = name
65+
self._variant = variant
6566
self._shortname = shortname
6667
self._bytes = bytes
6768
self._is_weak = is_weak
@@ -80,23 +81,30 @@ def is_weak(self):
8081
return self._is_weak
8182

8283
def shortname(self):
83-
return f"{self._shortname}{8 * self._bytes}"
84+
return f"{self._shortname}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}"
8485

8586
# TODO Fix name printing
8687
def __repr__(self):
87-
return f"{self._name}{8 * self._bytes}{'_' if self._is_weak else ''}"
88+
return (
89+
f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"
90+
)
8891

8992
def __str__(self):
9093
return self.__repr__()
9194

9295
def __hash__(self) -> int:
93-
return hash((self._name, self._bytes, self._is_weak))
96+
return hash((self._name, self._bytes, self._is_weak, f"{self._variant if self._variant else ''}"))
9497

9598
def __eq__(self, other) -> bool:
9699
if not isinstance(other, dtype):
97100
return False
98101

99-
return self._name == other._name and self._bytes == other._bytes and self._is_weak == other._is_weak
102+
return (
103+
self._name == other._name
104+
and self._bytes == other._bytes
105+
and self._is_weak == other._is_weak
106+
and self._variant == other._variant
107+
)
100108

101109

102110
class exact(dtype):
@@ -152,14 +160,24 @@ class inexact(dtype):
152160

153161

154162
class floating(inexact):
155-
"""Base class for the floating dtypes: bfloat16, float16, float32, float64."""
163+
"""Base class for the floating dtypes: float8, bfloat16, float16, float32, float64."""
156164

157-
def __init__(self, name, shortname, *, bytes, is_weak):
158-
super().__init__(python_type=float, name=name, shortname=shortname, bytes=bytes, is_weak=is_weak)
165+
def __init__(self, name, shortname, *, bytes, is_weak, variant=None):
166+
super().__init__(
167+
python_type=float, name=name, shortname=shortname, bytes=bytes, is_weak=is_weak, variant=variant
168+
)
159169

160170

161171
bfloat16 = floating("bfloat", "bf", bytes=2, is_weak=False)
162172
bfloat16_ = floating("bfloat", "bf", bytes=2, is_weak=True)
173+
float8_e5m2 = floating("float", "f", bytes=1, is_weak=False, variant="e5m2")
174+
float8_e5m2_ = floating("float", "f", bytes=1, is_weak=True, variant="e5m2")
175+
float8_e5m2fnuz = floating("float", "f", bytes=1, is_weak=False, variant="e5m2fnuz")
176+
float8_e5m2fnuz_ = floating("float", "f", bytes=1, is_weak=True, variant="e5m2fnuz")
177+
float8_e4m3fn = floating("float", "f", bytes=1, is_weak=False, variant="e4m3fn")
178+
float8_e4m3fn_ = floating("float", "f", bytes=1, is_weak=True, variant="e4m3fn")
179+
float8_e4m3fnuz = floating("float", "f", bytes=1, is_weak=False, variant="e4m3fnuz")
180+
float8_e4m3fnuz_ = floating("float", "f", bytes=1, is_weak=True, variant="e4m3fnuz")
163181
float16 = floating("float", "f", bytes=2, is_weak=False)
164182
float16_ = floating("float", "f", bytes=2, is_weak=True)
165183
float32 = floating("float", "f", bytes=4, is_weak=False)
@@ -200,6 +218,14 @@ def __init__(self, name, shortname, *, bytes, is_weak):
200218
int64_,
201219
bfloat16,
202220
bfloat16_,
221+
float8_e5m2,
222+
float8_e5m2_,
223+
float8_e5m2fnuz,
224+
float8_e5m2fnuz_,
225+
float8_e4m3fn,
226+
float8_e4m3fn_,
227+
float8_e4m3fnuz,
228+
float8_e4m3fnuz_,
203229
float16,
204230
float16_,
205231
float32,
@@ -242,6 +268,10 @@ def __init__(self, name, shortname, *, bytes, is_weak):
242268

243269
float_dtypes = {d for d in all_dtypes if isinstance(d, floating)} | {float}
244270

271+
float_math_dtypes = {d for d in all_dtypes if isinstance(d, floating) and d.bytes >= 2}
272+
273+
float_8bit_dtypes = {d for d in all_dtypes if (isinstance(d, floating) and d.bytes == 1)}
274+
245275
complex_dtypes = {d for d in all_dtypes if isinstance(d, complexfloating)} | {complex}
246276

247277
inexact_dtypes = float_dtypes | complex_dtypes
@@ -306,11 +336,12 @@ def has_subdtype(x, cls):
306336

307337

308338
# Translates a sequence of dtypes and dtype classes into a concrete set of corresponding (strong) dtypes
309-
def resolve_dtypes(args):
339+
def resolve_dtypes(args: Iterable) -> set[dtype]:
310340
dtypes = set()
311341
for arg in args:
312342
if isinstance(arg, dtype):
313-
dtypes.add(arg)
343+
if not arg.is_weak:
344+
dtypes.add(arg)
314345
continue
315346

316347
if isinstance(arg, Iterable):
@@ -320,7 +351,8 @@ def resolve_dtypes(args):
320351
lambda: f"Iterables passed to resolve_dtypes must only contain dtypes, but found an Iterable with {a}",
321352
exception_type=NotImplementedError,
322353
)
323-
dtypes.add(a)
354+
if not a.is_weak:
355+
dtypes.add(a)
324356

325357
baseutils.check(
326358
arg in (dtype, exact, signedinteger, unsignedinteger, bool_, inexact, floating, complexfloating),
@@ -373,6 +405,10 @@ def corresponding_complex_dtype(dtype):
373405
int32: int32_,
374406
int64: int64_,
375407
bfloat16: bfloat16_,
408+
float8_e5m2: float8_e5m2_,
409+
float8_e5m2fnuz: float8_e5m2fnuz_,
410+
float8_e4m3fn: float8_e4m3fn_,
411+
float8_e4m3fnuz: float8_e4m3fnuz_,
376412
float16: float16_,
377413
float32: float32_,
378414
float64: float64_,
@@ -520,6 +556,14 @@ def are_same_dtypes(a, b, *, weak_and_strong_are_equivalent=True):
520556
int64: torch.int64,
521557
bfloat16_: torch.bfloat16,
522558
bfloat16: torch.bfloat16,
559+
float8_e5m2: torch.float8_e5m2,
560+
float8_e5m2_: torch.float8_e5m2,
561+
float8_e5m2fnuz: torch.float8_e5m2fnuz,
562+
float8_e5m2fnuz_: torch.float8_e5m2fnuz,
563+
float8_e4m3fn: torch.float8_e4m3fn,
564+
float8_e4m3fn_: torch.float8_e4m3fn,
565+
float8_e4m3fnuz: torch.float8_e4m3fnuz,
566+
float8_e4m3fnuz_: torch.float8_e4m3fnuz,
523567
float16_: torch.float16,
524568
float16: torch.float16,
525569
float32_: torch.float32,
@@ -551,7 +595,7 @@ def to_torch_dtype(x: None | torch.dtype | dtype) -> None | torch.dtype:
551595

552596
# Converts NumPy dtypes to and from thunder dtypes
553597

554-
# NOTE NumPy does not support the bfloat16 or complexhalf (complex32) datatypes
598+
# NOTE NumPy does not support the bfloat16, complexhalf (complex32) or float8 datatypes
555599
_thunder_to_numpy_dtype_map = {
556600
bool: np.bool_,
557601
int: np.int_,

thunder/core/module.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,22 @@ def __init__(self, model, compiled_model_call):
3030
# we populate these here for performance reasons (sam as module cache),
3131
# a single dict lookup is cheaper than traversin the module
3232
# hierarchy, see https://github.com/Lightning-AI/lightning-thunder/issues/396#issuecomment-2113231498
33-
self._overrides = {
34-
k: v for k, v in itertools.chain(self._model.named_parameters(), self._model.named_buffers())
35-
}
33+
self._overrides_parameters = dict(self._model.named_parameters())
34+
self._overrides_buffers = dict(self._model.named_buffers())
3635
self._module_cache = {k: v for k, v in self._model.named_modules()}
37-
3836
self._null = object()
3937

4038
def get_buffer(self, name):
41-
p = self._overrides.get(name, self._null)
39+
p = self._overrides_buffers.get(name, self._null)
4240
if p is not self._null:
4341
return p
4442
return self._model.get_buffer(name)
4543

4644
def set_buffer(self, name, value):
47-
p = self._overrides[name] = value
45+
p = self._overrides_buffers[name] = value
4846

4947
def get_parameter(self, name):
50-
p = self._overrides.get(name, self._null)
48+
p = self._overrides_parameters.get(name, self._null)
5149
if p is not self._null:
5250
return p
5351
return self._model.get_parameter(name)
@@ -62,6 +60,43 @@ def forward(self, *args, **kwargs):
6260
res = self._forward_fn(*args, **kwargs)
6361
return res
6462

63+
def _named_parameters_or_buffers(self, overrides, orig_iter, prefix="", recurse=True, remove_duplicate=True):
64+
seen_ids = set()
65+
seen_names = set()
66+
for k, v in itertools.chain(overrides.items(), orig_iter(remove_duplicate=remove_duplicate)):
67+
if remove_duplicate:
68+
id_v = id(v)
69+
if id_v in seen_ids:
70+
continue
71+
seen_ids.add(id_v)
72+
73+
mod, _, base_param = k.rpartition(".")
74+
if recurse or not mod:
75+
if k not in seen_names:
76+
seen_names.add(k)
77+
if prefix:
78+
yield (f"{prefix}.{k}", v)
79+
else:
80+
yield (k, v)
81+
82+
def named_parameters(self, prefix="", recurse=True, remove_duplicate=True):
83+
yield from self._named_parameters_or_buffers(
84+
self._overrides_parameters,
85+
self._model.named_parameters,
86+
prefix=prefix,
87+
recurse=recurse,
88+
remove_duplicate=remove_duplicate,
89+
)
90+
91+
def named_buffers(self, prefix="", recurse=True, remove_duplicate=True):
92+
yield from self._named_parameters_or_buffers(
93+
self._overrides_buffers,
94+
self._model.named_buffers,
95+
prefix=prefix,
96+
recurse=recurse,
97+
remove_duplicate=remove_duplicate,
98+
)
99+
65100
@contextmanager
66101
def no_sync(self):
67102
r"""Context manager to disable gradient synchronization in data parallel mode.

thunder/core/proxies.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from thunder.core.trace import VariableInterface, get_tracectx, TraceCtx
1818
from thunder.core.baseutils import ProxyInterface, NumberProxyInterface, TensorProxyInterface
1919
import thunder.core.baseutils as baseutils
20-
from thunder.core.langctxs import resolve_method
20+
from thunder.core.langctxs import resolve_method, get_langctx
2121
import thunder.core.devices as devices
2222
import thunder.core.dtypes as dtypes
2323

@@ -592,13 +592,18 @@ def known_value(self) -> bool:
592592
# fn is the function to call if executing outside a language context
593593
@staticmethod
594594
def _elementwise_unary_helper(a, name, fn, type_promotion_kind=None):
595-
trace: None | TraceCtx = get_tracectx()
596595

597596
vala = pyval(a)
598597

599-
if trace is None:
600-
# Outside of a trace context, operations on NumberProxies are executed by the
601-
# Python interpreter
598+
trace: None | TraceCtx = get_tracectx()
599+
lang: None | LangCtx = None
600+
try:
601+
lang = get_langctx()
602+
except LookupError:
603+
pass
604+
if trace is None or lang is None:
605+
# Outside of a trace or language context, operations on NumberProxies are
606+
# executed by the Python interpreter
602607
baseutils.check(
603608
vala is not None,
604609
lambda: f"Trying to {name} a number with an unknown value",
@@ -649,7 +654,12 @@ def _elementwise_binary_helper(a, b, name, fn, type_promotion_kind=None):
649654
valb = pyval(b) if isinstance(b, NumberProxy) else b
650655

651656
trace: None | TraceCtx = get_tracectx()
652-
if trace is None:
657+
lang: None | LangCtx = None
658+
try:
659+
lang = get_langctx()
660+
except LookupError:
661+
pass
662+
if trace is None or lang is None:
653663
# Outside of a trace or language context, binary operations on NumberProxies are
654664
# executed by the Python interpreter
655665
baseutils.check(

thunder/core/rematerialization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from thunder.core import prims, utils
1313
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface
1414
from thunder.core.prims import PrimIDs
15-
from thunder.core.proxies import TensorProxy, variableify
15+
from thunder.core.proxies import TensorProxy, variableify, NumberProxy
1616
from thunder.core.pytree import tree_flatten, tree_unflatten
1717
from thunder.core.symbol import has_tags
1818
from thunder.core.trace import from_trace, TraceCtx, TraceProvenance
@@ -332,6 +332,8 @@ def add_edge(src, dst, capacity):
332332
def get_weight(var):
333333
if isinstance(var, TensorProxy):
334334
return WEIGHT * var.dtype.bytes
335+
elif isinstance(var, NumberProxy):
336+
return 0.0
335337
return WEIGHT
336338

337339
def add_edges(var):

0 commit comments

Comments
 (0)