Skip to content

Commit 381916d

Browse files
committed
Merge branch 'main' into add_celu
2 parents ad978cb + 72345cc commit 381916d

File tree

6 files changed

+108
-54
lines changed

6 files changed

+108
-54
lines changed

thunder/__init__.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -458,15 +458,10 @@ def get_computation_and_inputs(*args, **kwargs):
458458
_vanilla_args,
459459
) = cache_entry
460460
try:
461-
cs.last_prologue_execution_start = time.perf_counter_ns()
462461
inps, pro_to_epi = pro(*args, **kwargs)
463-
cs.last_prologue_execution_stop = time.perf_counter_ns()
464462
except Exception as _:
465463
continue
466464

467-
cs.last_trace_host_tracing_start = time.perf_counter_ns()
468-
cs.last_trace_host_tracing_stop = time.perf_counter_ns()
469-
470465
# Updates cache statistics
471466
cs.cache_hits += 1
472467
cs.last_traces = comp_traces
@@ -495,12 +490,7 @@ def get_computation_and_inputs(*args, **kwargs):
495490
backward_traces,
496491
) = cache_entry
497492

498-
cs.last_prologue_execution_start = time.perf_counter_ns()
499493
inps, pro_to_epi = pro(*args, **kwargs)
500-
cs.last_prologue_execution_stop = time.perf_counter_ns()
501-
502-
cs.last_trace_host_tracing_start = time.perf_counter_ns()
503-
cs.last_trace_host_tracing_stop = time.perf_counter_ns()
504494

505495
# Updates cache statistics
506496
cs.cache_hits += 1
@@ -622,6 +612,7 @@ def get_computation_and_inputs(*args, **kwargs):
622612
)
623613
prologue_trc = prologue_traces[-1]
624614
pro = prologue_trc.python_callable(include_decorators=False)
615+
pro = prologue_execution_timer(pro)
625616

626617
if epilogue_trc is not None:
627618
epilogue = epilogue_trc.python_callable()
@@ -637,9 +628,7 @@ def get_computation_and_inputs(*args, **kwargs):
637628
cs.last_interpreter_log = last_interpreter_log
638629
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))
639630

640-
cs.last_prologue_execution_start = time.perf_counter_ns()
641631
inps, pro_to_epi = pro(*args, **kwargs)
642-
cs.last_prologue_execution_stop = time.perf_counter_ns()
643632

644633
computation_trc = dce(computation_trc)
645634
computation_traces.append(computation_trc)
@@ -729,23 +718,55 @@ def get_computation_and_inputs(*args, **kwargs):
729718

730719
return cache_entry, inps, pro_to_epi
731720

732-
cd.get_computation_and_inputs = get_computation_and_inputs
721+
def host_execution_timer(fn):
722+
def wrapped(*args, **kwargs):
723+
cs.last_trace_host_execution_start = time.perf_counter_ns()
724+
try:
725+
return fn(*args, **kwargs)
726+
finally:
727+
cs.last_trace_host_execution_stop = time.perf_counter_ns()
733728

734-
@wraps(fn)
735-
def fn_(*args, **kwargs) -> Any:
736-
if is_tracing():
737-
_recursive_jit_call_warning()
738-
return fn(*args, **kwargs)
729+
return wrapped
739730

740-
# Updats call statistics
741-
cs.last_trace_host_start = time.perf_counter_ns()
742-
cs.calls += 1
731+
def prologue_execution_timer(fn):
732+
def wrapped(*args, **kwargs):
733+
cs.last_prologue_execution_start = time.perf_counter_ns()
734+
try:
735+
return fn(*args, **kwargs)
736+
finally:
737+
cs.last_prologue_execution_stop = time.perf_counter_ns()
743738

744-
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
745-
cs.last_trace_host_execution_start = time.perf_counter_ns()
739+
return wrapped
740+
741+
def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
742+
def wrapped(*args, **kwargs):
743+
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
744+
decorated_computation_fn = cache_entry.computation_fn
745+
for decorator in decorators:
746+
decorated_computation_fn = decorator(decorated_computation_fn)
747+
if decorators:
748+
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
749+
return cache_entry, inps, pro_to_epi
750+
751+
return wrapped
752+
753+
get_computation_and_inputs = decorate_computation_function(get_computation_and_inputs, host_execution_timer)
754+
cd.get_computation_and_inputs = get_computation_and_inputs
755+
756+
def update_call_statistics(fn):
757+
def wrapped(*args, **kwargs):
758+
cs.calls += 1
759+
cs.last_trace_host_start = time.perf_counter_ns()
760+
try:
761+
return fn(*args, **kwargs)
762+
finally:
763+
cs.last_trace_host_stop = time.perf_counter_ns()
746764

765+
return wrapped
766+
767+
def check_storage_aliases(cache_entry, args):
747768
if cache_entry.vanilla_tensor_args:
748-
if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps):
769+
if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args):
749770
alias_tensor_indices = alias_tensor_indices_str
750771
alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")}
751772
vanilla_tensor_args = cache_entry.vanilla_tensor_args
@@ -755,13 +776,12 @@ def fn_(*args, **kwargs) -> Any:
755776
NotImplementedError,
756777
)
757778

758-
result = cache_entry.computation_fn(*inps)
759-
779+
def maybe_connect_to_autograd(cache_entry, result):
760780
if cache_entry.backward_fn:
761-
# Run the compiled forward function
781+
# If the backward function is available, we need to connect the
782+
# resulting tensors to PyTorch's Autograd graph using the
783+
# ThunderFunction (which is a torch.autograd.Function subclass)
762784
data_for_autograd, (saved_tensors, saved_other) = result
763-
764-
# Connect produced tensors with PyTorch's autograd graph
765785
ThunderFunction.apply(
766786
cache_entry.return_none_instead_of_grads,
767787
cache_entry.backward_fn,
@@ -772,17 +792,31 @@ def fn_(*args, **kwargs) -> Any:
772792
)
773793
result = data_for_autograd["output"]
774794

795+
return result
796+
797+
def maybe_call_epilogue(cache_entry, result, pro_to_epi):
775798
if cache_entry.epilogue_fn:
776799
result, comp_to_epi = result
777800
cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi)
778801

779-
cs.last_trace_host_execution_stop = time.perf_counter_ns()
780-
cs.last_computation_execution_stop = cs.last_trace_host_execution_stop
802+
return result
781803

782-
cs.last_executed = cache_entry.computation_fn
783-
cs.last_trace_cache_stop = time.perf_counter_ns()
784-
cs.last_trace_host_stop = time.perf_counter_ns()
804+
@wraps(fn)
805+
@update_call_statistics
806+
def fn_(*args, **kwargs) -> Any:
807+
if is_tracing():
808+
_recursive_jit_call_warning()
809+
return fn(*args, **kwargs)
810+
811+
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
812+
813+
check_storage_aliases(cache_entry, inps)
814+
815+
result = cache_entry.computation_fn(*inps)
816+
result = maybe_connect_to_autograd(cache_entry, result)
817+
result = maybe_call_epilogue(cache_entry, result, pro_to_epi)
785818

819+
cs.last_computation = cache_entry.computation_fn
786820
return result
787821

788822
if isinstance(fn, pytorch.nn.Module):

thunder/benchmarks/benchmark_litgpt.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def setup_distributed(self, model):
532532
return model
533533

534534
def setup_activation_checkpointing(self):
535-
if "thunder" in self.compile:
535+
if "thunder" in self.compile and "dynamo" not in self.compile:
536536
# checkpointing is an option to thunder.jit
537537
return
538538

@@ -571,25 +571,23 @@ def setup_compile(self, model):
571571

572572
executors.insert(0, transformer_engine_ex)
573573

574-
jit_options = {
575-
"enable_saved_for_backward_recomputation": self.checkpoint_activations,
576-
"recomputation_policy": None,
577-
}
578-
579574
if "dynamo" in self.compile:
580575
if self.distributed_mode == "fsdp2":
581576
print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile")
582577
import torch._dynamo.config as dynamo_config
583578

584579
dynamo_config.cache_size_limit = 64
585580

586-
backend = ThunderCompiler(executors=executors, **jit_options)
581+
self.backend = ThunderCompiler(executors=executors)
587582
# Because Lightning Fabric is imported in this script it monkey patches the torch.compile function
588583
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
589584
# using __wrapped__ to access the original torch.compile function did not work
590585
# so we are using the lower level torch._dynamo.optimize function
591-
model = torch._dynamo.optimize(backend=backend)(model)
586+
model = torch._dynamo.optimize(backend=self.backend)(model)
592587
else:
588+
jit_options = {
589+
"enable_saved_for_backward_recomputation": self.checkpoint_activations,
590+
}
593591
jit_options["fp8_shard_intermediate_activation"] = self.fp8_shard_intermediate_activation
594592
model = thunder.jit(model, executors=executors, **jit_options)
595593

@@ -844,16 +842,24 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
844842
for jitted in benchmark.thunder_as_torch_compile_backend.gm_to_thunder.values():
845843
fwd_traces.append(thunder.last_traces(jitted))
846844
bwd_traces.append(thunder.last_backward_traces(jitted))
847-
else:
845+
elif "dynamo" not in benchmark.compile:
848846
fwd_traces = [thunder.last_traces(benchmark.model)]
849847
bwd_traces = [thunder.last_backward_traces(benchmark.model)]
850848

851-
for i, f_traces in enumerate(fwd_traces, start=1):
852-
print(f"##########\n#{i}-th ThunderModule\n##########")
853-
print(f_traces[-1])
854-
for i, b_traces in enumerate(bwd_traces, start=1):
855-
print(f"##########\n#{i}-th ThunderModule\n##########")
856-
print(b_traces[-1])
849+
if "dynamo" in benchmark.compile:
850+
for gid, infos in enumerate(benchmark.backend.subgraph_infos):
851+
for subgid, thunder_fn in enumerate(infos.thunder_compiled_fns):
852+
print(f"##########\n#Graph{gid}-ThunderFn{subgid} last forward trace\n##########")
853+
print(thunder.last_traces(thunder_fn)[-1])
854+
print(f"##########\n#Graph{gid}-ThunderFn{subgid} last backward trace\n##########")
855+
print(thunder.last_backward_traces(thunder_fn)[-1])
856+
else:
857+
for i, f_traces in enumerate(fwd_traces, start=1):
858+
print(f"##########\n#{i}-th ThunderModule\n##########")
859+
print(f_traces[-1])
860+
for i, b_traces in enumerate(bwd_traces, start=1):
861+
print(f"##########\n#{i}-th ThunderModule\n##########")
862+
print(b_traces[-1])
857863

858864
if global_rank in [0, None]:
859865
if return_metrics_as_json:

thunder/common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464

6565

6666
# Holds statistics and caches for a compiled function
67-
# TODO RC1 Update last_executed to last_computation
6867
# TODO RC1 Review how autograd traces are presented
6968
class CompileStats:
7069
"""A class holding statistics and caches for a compiled function.
@@ -76,7 +75,7 @@ class CompileStats:
7675
See :mod:`thunder` for more of such utility functions.
7776
7877
Attributes:
79-
last_executed:
78+
last_computation (Callable):
8079
last_traces (Sequence[TraceCtx]):
8180
last_prologue (TraceCtx):
8281
last_prologue_traces (Sequence[TraceCtx]):
@@ -107,7 +106,7 @@ class CompileStats:
107106

108107
def __init__(self):
109108
# Callables and traces
110-
self.last_executed = None
109+
self.last_computation = None
111110
self.last_traces = None
112111
self.last_prologue = None
113112
self.last_prologue_traces = None

thunder/core/jit_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1512,7 +1512,7 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:
15121512
return not isinstance(p, TensorProxy)
15131513

15141514
# TODO: This is just a WAR to get things working. We'll revisit this when
1515-
# we deal with cosntraints in prologue trace.
1515+
# we deal with constraints in prologue trace.
15161516
#
15171517
# We sort variables to before `unpack` to put TensorProxy before others.
15181518
# Because we could have TensorProxy.shape be part of `pro_to_xxx` along with

thunder/executors/apex_fused_rms_norm_impl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
APEX_FUSED_NORMS_AVAILABLE = True
1616
try:
17+
# Fused layer norm is only importable if torch.distributed is available
18+
# https://github.com/NVIDIA/apex/issues/1853
19+
from torch.distributed import is_available
20+
21+
if not is_available():
22+
raise ImportError
1723
import fused_layer_norm_cuda
1824
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
1925
except ImportError:

thunder/tests/test_apex_fused_norms.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
from torch.testing import assert_close
44

55
fused_layer_norm_cuda = pytest.importorskip("fused_layer_norm_cuda")
6-
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
6+
7+
from torch.distributed import is_available
78
from thunder.executors.apexex import apex_ex
89
import thunder
910

1011

12+
# See https://github.com/NVIDIA/apex/issues/1853
13+
@pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
1114
@pytest.mark.parametrize("requires_grad", [True, False])
1215
@pytest.mark.parametrize("memory_efficient", [True, False])
1316
def test_apex_fused_rms_norm(requires_grad, memory_efficient):
17+
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
18+
1419
def fn(x, weight, normalized_shape, eps):
1520
return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)
1621

@@ -34,9 +39,13 @@ def fn(x, weight, normalized_shape, eps):
3439
assert_close(actual_grad, expected_grad)
3540

3641

42+
# See https://github.com/NVIDIA/apex/issues/1853
43+
@pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
3744
@pytest.mark.parametrize("requires_grad", [True, False])
3845
@pytest.mark.parametrize("memory_efficient", [True, False])
3946
def test_apex_fused_rms_norm_autoregister(requires_grad, memory_efficient):
47+
from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
48+
4049
def fn(x, weight, normalized_shape, eps):
4150
return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)
4251

0 commit comments

Comments
 (0)