Skip to content

Commit

Permalink
Merge branch 'main' into fix_torchspecial
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Aug 20, 2024
2 parents 052967b + 820a7d2 commit 126bda6
Show file tree
Hide file tree
Showing 32 changed files with 977 additions and 531 deletions.
8 changes: 4 additions & 4 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ jobs:
strategy:
matrix:
# CUDA 12.1
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3 | regular":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.3.0-dev"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3.1 | regular":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.3.1-dev"
CUDA_VERSION_MM: "121"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3 | distributed":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.3.0-dev"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3.1 | distributed":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.3.1-dev"
CUDA_VERSION_MM: "121"
testing: "distributed"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | regular":
Expand Down
69 changes: 22 additions & 47 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def jit(
if executors is None:
executors = compile_options.pop("executors_list")

if "early_transforms" in compile_options:
raise RuntimeError("early_transforms= has been absorbed by transforms=")

if compile_options.get("use_cudagraphs") is not None:
raise RuntimeError("use_cudagraphs is replaced by using thunder.transforms.CUDAGraphTransform")

# Resolves interpreter option
interpretation = resolve_interpretation_option(interpretation)
interpreter: Callable
Expand Down Expand Up @@ -341,16 +347,13 @@ def jit(

# TODO RC1 Refine the compile data option to remove unused options
# TODO: refine options
# NOTE(fixme): use_cudagraphs is being absorbed into compile_options
use_cudagraphs = compile_options.get("use_cudagraphs", False)
cd = CompileData(
fn=fn,
langctx=langctx,
executors_list=executors,
cache_option=cache,
sharp_edges=sharp_edges,
using_jit=True,
use_cudagraphs=use_cudagraphs,
disable_torch_autograd_support=disable_torch_autograd,
use_rematerialization=False,
only_execute_prims=False,
Expand Down Expand Up @@ -385,6 +388,7 @@ def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str:
return ""
return "-".join(alias_indices)

@langctxs.langctx(cd.langctx)
@_with_cache_info_ctx
def get_computation_and_inputs(*args, **kwargs):
# set up a record of things in the current environment that impact caching / prologues
Expand Down Expand Up @@ -527,16 +531,15 @@ def get_computation_and_inputs(*args, **kwargs):
# returns the (proxied) result of the operation
cs.last_trace_tracing_start = time.perf_counter_ns()

with langctxs.langctx(cd.langctx):
prologue_trc: TraceCtx
computation_trc: TraceCtx
jit_results: TraceResults = interpreter(
fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
)
prologue_trc = jit_results.prologue_trace
computation_trc = jit_results.computation_trace
epilogue_trc = jit_results.epilogue_trace
last_interpreter_log = jit_results.interpreter_log
prologue_trc: TraceCtx
computation_trc: TraceCtx
jit_results: TraceResults = interpreter(
fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
)
prologue_trc = jit_results.prologue_trace
computation_trc = jit_results.computation_trace
epilogue_trc = jit_results.epilogue_trace
last_interpreter_log = jit_results.interpreter_log

prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
Expand Down Expand Up @@ -659,49 +662,24 @@ def get_computation_and_inputs(*args, **kwargs):
# by split_forward_backward

if backward_trc is None:
## EPILOGUE and TRANSFORMS should not mix...
# applies transforms
cs.last_computation_transformation_start = time.perf_counter_ns()
for transform in transforms:
new_computation_trc = transform.transform_trace_additionally(
computation_trc, executors_list=cd.executors_list
)
if new_computation_trc is not computation_trc:
# todo: deprecation
computation_trc = new_computation_trc
computation_traces.append(computation_trc)
cs.last_computation_transformation_stop = time.perf_counter_ns()

from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import _transform_for_operator_executor_execution
from thunder.distributed.utils import maybe_sort_waits

with langctxs.langctx(cd.langctx):
tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list)
tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list)
is_transformed, tmp_comp_trc = maybe_sort_waits(tmp_comp_trc)
if is_transformed:
computation_trc = tmp_comp_trc
computation_traces.append(computation_trc)

with langctxs.langctx(cd.langctx):
extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
use_del_last_used=False,
)
extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
use_del_last_used=False,
)
computation_traces.extend(extraces)
computation_trc = computation_traces[-1]

if cd.use_cudagraphs:
from thunder.executors.cudagraphex import cudagraphex

computation_trc = cudagraphex.fusion_pass(computation_trc)
computation_traces.append(computation_trc)

if backward_trc is not None:
backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
backward_traces.append(backward_trc)

if backward_trc is None:
computation_trc = thunder.executors.passes.del_last_used(computation_trc)

Expand Down Expand Up @@ -767,7 +745,6 @@ def fn_(*args, **kwargs) -> Any:
cs.last_trace_host_execution_start = time.perf_counter_ns()

if cache_entry.vanilla_tensor_args:

if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps):
alias_tensor_indices = alias_tensor_indices_str
alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")}
Expand Down Expand Up @@ -835,7 +812,6 @@ def compile(
langctx: None | Any = None,
executors_list: None | Sequence[Executor] = None,
cache_mode: None | str | CACHE_OPTIONS = None,
use_cudagraphs: bool = False,
disable_torch_autograd_support: bool = False,
use_rematerialization: bool = False,
only_execute_prims: bool = False,
Expand All @@ -847,7 +823,6 @@ def compile(
langctx=langctx,
executors_list=executors_list,
cache_option=cache_mode,
use_cudagraphs=use_cudagraphs,
disable_torch_autograd_support=disable_torch_autograd_support,
use_rematerialization=use_rematerialization,
only_execute_prims=only_execute_prims,
Expand Down
4 changes: 3 additions & 1 deletion thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE
from thunder.executors.sdpaex import sdpa_ex
from thunder.executors.torch_compile import torch_compile_cat_ex, torch_compile_ex
from thunder.transforms.cudagraph import CUDAGraphTransform
from thunder.tests import nanogpt_model, hf_bart_self_attn, litgpt_model
from thunder.tests.litgpt_model import Config as LitGPTConfig
from thunder.tests.make_tensor import make_tensor, make_tensor_like
Expand Down Expand Up @@ -954,7 +955,8 @@ def default_thunder_cudagraphs_executor(fn: Callable) -> Callable:
executors_list.append("apex_xentropy")

executors_list.extend((executors.NVFUSER, executors.TORCH))
return thunder.jit(fn, executors=executors_list, use_cudagraphs=True, disable_torch_autograd=True)
transforms = [CUDAGraphTransform()]
return thunder.jit(fn, executors=executors_list, transforms=transforms, disable_torch_autograd=True)


#
Expand Down
Loading

0 comments on commit 126bda6

Please sign in to comment.