Skip to content

Commit

Permalink
Merge branch 'main' into tfogal/nemo-test-case
Browse files Browse the repository at this point in the history
  • Loading branch information
tfogal authored Oct 31, 2024
2 parents d77bc05 + a2587e2 commit 9deeef0
Show file tree
Hide file tree
Showing 23 changed files with 474 additions and 206 deletions.
4 changes: 2 additions & 2 deletions .azure/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ jobs:
#maxParallel: "3"
matrix:
# CUDA 12.1
"cuda 12.1 | torch 2.4.0 | cudnn FE v1.5.2":
{ CUDA_VERSION: "12.1.1", TORCH_VERSION: "2.4.0", TRITON_VERSION: "3.0.0", CUDNN_FRONTEND_VERSION: "1.5.2" }
"cuda 12.1 | torch 2.5.1 | cudnn FE v1.5.2":
{ CUDA_VERSION: "12.1.1", TORCH_VERSION: "2.5.1", TRITON_VERSION: "3.1.0", CUDNN_FRONTEND_VERSION: "1.5.2" }
"cuda 12.1 | torch 2.5 /nightly | cudnn FE v1.5.2":
{ CUDA_VERSION: "12.1.1", TORCH_VERSION: "main", TORCH_INSTALL: "source", CUDNN_FRONTEND_VERSION: "1.5.2" }
#'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found
Expand Down
8 changes: 4 additions & 4 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ jobs:
strategy:
matrix:
# CUDA 12.1
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.4.0 | regular":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.4.0-dev"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.5.1 | regular":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.5.1-dev"
CUDA_VERSION_MM: "121"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.4.0 | distributed":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.4.0-dev"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.5.1 | distributed":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.5.1-dev"
CUDA_VERSION_MM: "121"
testing: "distributed"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | regular":
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ pip install git+https://github.com/Lightning-AI/lightning-thunder.git@main

To achieve the best performance, you can install Thunder with the following additional dependencies:

- install nightly [nvFuser](https://github.com/NVIDIA/Fuser) built for PyTorch 2.4 as follows:
- install prerelease [nvFuser](https://github.com/NVIDIA/Fuser) built for PyTorch 2.5.1 as follows:

```bash
# install nvFuser built for the matching stable PyTorch
pip install --pre nvfuser-cu121-torch24
pip install --pre nvfuser-cu121-torch25
```

- install [cudnn](https://gitlab-master.nvidia.com/cudnn/cudnn_frontend) as follows:
Expand Down
6 changes: 3 additions & 3 deletions docs/source/fundamentals/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ Minimal dependencies

Follow these instructions to install PyTorch, nvFuser, and finally Thunder.

Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1 and PyTorch 2.4.x)::
Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1 and PyTorch 2.5.x)::

pip install --pre nvfuser-cu121-torch24
pip install --pre nvfuser-cu121-torch25

cu121 can be replaced with cu118 depending on your CUDA version. NVFuser builds typically support the latest point release of PyTorch stable versions.
For torch 2.4, cu124 is also supported. For nightly versions and more detailed instructions, please see https://github.com/NVIDIA/Fuser/#installation
For torch 2.5, cu124 is also supported. For nightly versions and more detailed instructions, please see https://github.com/NVIDIA/Fuser/#installation

You're all set with minimal dependencies, so you can follow `Install Thunder`_.

Expand Down
109 changes: 75 additions & 34 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Transform,
wrap_return_value_together_with_argments,
unwrap_return_value,
remove_context_manager_prims_from_trace,
)
from thunder.core.functionalization import (
check_inplace_to_views,
Expand Down Expand Up @@ -419,6 +420,9 @@ def get_computation_and_inputs(*args, **kwargs):
)
autocast_thunder_dtype = autocast_cpu_dtype if pytorch.is_autocast_cpu_enabled() else autocast_gpu_dtype
cache_info.update(autocast_thunder_dtype=str(autocast_thunder_dtype))
device = "cuda" if pytorch.is_autocast_enabled() else "cpu"
dtype = autocast_thunder_dtype
cd.autocast_stack.push(device, dtype, is_autocast_enabled)

cache_info["is_autocast_enabled"] = is_autocast_enabled

Expand Down Expand Up @@ -458,15 +462,10 @@ def get_computation_and_inputs(*args, **kwargs):
_vanilla_args,
) = cache_entry
try:
cs.last_prologue_execution_start = time.perf_counter_ns()
inps, pro_to_epi = pro(*args, **kwargs)
cs.last_prologue_execution_stop = time.perf_counter_ns()
except Exception as _:
continue

cs.last_trace_host_tracing_start = time.perf_counter_ns()
cs.last_trace_host_tracing_stop = time.perf_counter_ns()

# Updates cache statistics
cs.cache_hits += 1
cs.last_traces = comp_traces
Expand Down Expand Up @@ -495,12 +494,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces,
) = cache_entry

cs.last_prologue_execution_start = time.perf_counter_ns()
inps, pro_to_epi = pro(*args, **kwargs)
cs.last_prologue_execution_stop = time.perf_counter_ns()

cs.last_trace_host_tracing_start = time.perf_counter_ns()
cs.last_trace_host_tracing_stop = time.perf_counter_ns()

# Updates cache statistics
cs.cache_hits += 1
Expand Down Expand Up @@ -544,6 +538,9 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = wrap_return_value_together_with_argments(computation_trc)
computation_traces.append(computation_trc)

computation_trc = remove_context_manager_prims_from_trace(computation_trc)
computation_traces.append(computation_trc)

orig_to_view_swap_map = check_inplace_to_views(computation_trc)
vanilla_tensor_args: set[int] | None = None
if not compile_options.get("skip_inplace_functionalization", False):
Expand Down Expand Up @@ -622,6 +619,7 @@ def get_computation_and_inputs(*args, **kwargs):
)
prologue_trc = prologue_traces[-1]
pro = prologue_trc.python_callable(include_decorators=False)
pro = prologue_execution_timer(pro)

if epilogue_trc is not None:
epilogue = epilogue_trc.python_callable()
Expand All @@ -637,9 +635,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

cs.last_prologue_execution_start = time.perf_counter_ns()
inps, pro_to_epi = pro(*args, **kwargs)
cs.last_prologue_execution_stop = time.perf_counter_ns()

computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)
Expand Down Expand Up @@ -729,23 +725,55 @@ def get_computation_and_inputs(*args, **kwargs):

return cache_entry, inps, pro_to_epi

cd.get_computation_and_inputs = get_computation_and_inputs
def host_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_trace_host_execution_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_trace_host_execution_stop = time.perf_counter_ns()

@wraps(fn)
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
return fn(*args, **kwargs)
return wrapped

# Updats call statistics
cs.last_trace_host_start = time.perf_counter_ns()
cs.calls += 1
def prologue_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_prologue_execution_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_prologue_execution_stop = time.perf_counter_ns()

cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
cs.last_trace_host_execution_start = time.perf_counter_ns()
return wrapped

def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
def wrapped(*args, **kwargs):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi

return wrapped

get_computation_and_inputs = decorate_computation_function(get_computation_and_inputs, host_execution_timer)
cd.get_computation_and_inputs = get_computation_and_inputs

def update_call_statistics(fn):
def wrapped(*args, **kwargs):
cs.calls += 1
cs.last_trace_host_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_trace_host_stop = time.perf_counter_ns()

return wrapped

def check_storage_aliases(cache_entry, args):
if cache_entry.vanilla_tensor_args:
if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps):
if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args):
alias_tensor_indices = alias_tensor_indices_str
alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")}
vanilla_tensor_args = cache_entry.vanilla_tensor_args
Expand All @@ -755,13 +783,12 @@ def fn_(*args, **kwargs) -> Any:
NotImplementedError,
)

result = cache_entry.computation_fn(*inps)

def maybe_connect_to_autograd(cache_entry, result):
if cache_entry.backward_fn:
# Run the compiled forward function
# If the backward function is available, we need to connect the
# resulting tensors to PyTorch's Autograd graph using the
# ThunderFunction (which is a torch.autograd.Function subclass)
data_for_autograd, (saved_tensors, saved_other) = result

# Connect produced tensors with PyTorch's autograd graph
ThunderFunction.apply(
cache_entry.return_none_instead_of_grads,
cache_entry.backward_fn,
Expand All @@ -772,17 +799,31 @@ def fn_(*args, **kwargs) -> Any:
)
result = data_for_autograd["output"]

return result

def maybe_call_epilogue(cache_entry, result, pro_to_epi):
if cache_entry.epilogue_fn:
result, comp_to_epi = result
cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi)

cs.last_trace_host_execution_stop = time.perf_counter_ns()
cs.last_computation_execution_stop = cs.last_trace_host_execution_stop
return result

cs.last_executed = cache_entry.computation_fn
cs.last_trace_cache_stop = time.perf_counter_ns()
cs.last_trace_host_stop = time.perf_counter_ns()
@wraps(fn)
@update_call_statistics
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
return fn(*args, **kwargs)

cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)

check_storage_aliases(cache_entry, inps)

result = cache_entry.computation_fn(*inps)
result = maybe_connect_to_autograd(cache_entry, result)
result = maybe_call_epilogue(cache_entry, result, pro_to_epi)

cs.last_computation = cache_entry.computation_fn
return result

if isinstance(fn, pytorch.nn.Module):
Expand Down
Loading

0 comments on commit 9deeef0

Please sign in to comment.