Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
Browse files Browse the repository at this point in the history
… support-output-dataclass
  • Loading branch information
kshitij12345 committed Jun 24, 2024
2 parents 3bd342e + fa55b09 commit b45dbd0
Show file tree
Hide file tree
Showing 32 changed files with 1,592 additions and 132 deletions.
8 changes: 4 additions & 4 deletions .azure/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ jobs:
#maxParallel: "3"
matrix:
# CUDA 12.1
"cuda 12.1 | torch 2.3 | cudnn FE v1.4":
{ CUDA_VERSION: "12.1.1", TORCH_VERSION: "2.3.0", TRITON_VERSION: "2.3.0", CUDNN_FRONTEND_VERSION: "1.4.0" }
"cuda 12.1 | torch 2.4 /nightly | cudnn FE v1.4":
{ CUDA_VERSION: "12.1.1", TORCH_VERSION: "main", TORCH_INSTALL: "source", CUDNN_FRONTEND_VERSION: "1.4.0" }
"cuda 12.1 | torch 2.3 | cudnn FE v1.5.1":
{ CUDA_VERSION: "12.1.1", TORCH_VERSION: "2.3.0", TRITON_VERSION: "2.3.0", CUDNN_FRONTEND_VERSION: "1.5.1" }
"cuda 12.1 | torch 2.4 /nightly | cudnn FE v1.5.1":
{ CUDA_VERSION: "12.1.1", TORCH_VERSION: "main", TORCH_INSTALL: "source", CUDNN_FRONTEND_VERSION: "1.5.1" }
#'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found
# how much time to give 'run always even if cancelled tasks' before stopping them
cancelTimeoutInMinutes: "2"
Expand Down
8 changes: 4 additions & 4 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ jobs:
matrix:
# CUDA 12.1
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3 | regular":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_2.3.0-dev"
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_2.3.0-dev"
CUDA_VERSION_MM: "121"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3 | distributed":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_2.3.0-dev"
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_2.3.0-dev"
CUDA_VERSION_MM: "121"
testing: "distributed"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | regular":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_main-dev"
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_main-dev"
CUDA_VERSION_MM: "121"
"ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | distributed":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_main-dev"
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_main-dev"
CUDA_VERSION_MM: "121"
testing: "distributed"
# how much time to give 'run always even if cancelled tasks' before stopping them
Expand Down
4 changes: 2 additions & 2 deletions .azure/notebook-runs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ jobs:
strategy:
matrix:
"ubuntu22.04 | cuda 12.1 | torch 2.3":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_2.3.0-dev"
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_2.3.0-dev"
CUDA_VERSION_MM: "121"
"ubuntu22.04 | cuda 12.1 | torch-nightly":
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_main-dev"
docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_main-dev"
CUDA_VERSION_MM: "121"
# how long to run the job before automatically cancelling
timeoutInMinutes: "45"
Expand Down
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# These owners will be the default owners for everything in the repo. Unless a later match takes precedence,
# @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request.
* @mruberry @lantiga @robieta @t-vi @carmocca
* @mruberry @lantiga @t-vi @carmocca

# CI/CD and configs
/.azure/ @borda @lantiga @t-vi @carmocca
Expand Down
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/program_coverage.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
name: Feature request
about: Suggest an idea for this project
name: Program Coverage
about: Expand the programs / models Thunder can process
title: ''
labels: program-coverage
assignees: ''
Expand Down
2 changes: 1 addition & 1 deletion dockers/ubuntu-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ARG IMAGE_TYPE="devel"
FROM nvidia/cuda:${CUDA_VERSION}-${IMAGE_TYPE}-ubuntu${UBUNTU_VERSION}

ARG CUDNN_VERSION="9.1.0.70"
ARG CUDNN_FRONTEND_VERSION="1.4.0"
ARG CUDNN_FRONTEND_VERSION="1.5.1"
ARG PYTHON_VERSION="3.10"
ARG TORCH_VERSION="2.2.1"
ARG TRITON_VERSION="2.2.0"
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ absl-py # thunder/benchmarks/test_benchmark_litgpt.py
pandas # thunder/benchmarks/test_benchmark_litgpt.py
xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py
jsonargparse # thunder/benchmarks/benchmark_litgpt.py
bitsandbytes==0.42.0 # fixed version!

# Installs JAX on Linux and MacOS
jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation
Expand Down
25 changes: 21 additions & 4 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
import thunder.core.prims as prims
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.transform_common import dce, EarlyTransform, AdditionalTransform, PostOptimizationTransform
from thunder.core.transform_common import (
dce,
EarlyTransform,
AdditionalTransform,
PostOptimizationTransform,
functionalize_inplace_ops,
check_inplace_to_views,
)
from thunder.common import (
CompileData,
CompileStats,
Expand Down Expand Up @@ -418,7 +425,7 @@ def get_computation_and_inputs(*args, **kwargs):
) = cache_entry
try:
cs.last_prologue_execution_start = time.time_ns()
if epilogue:
if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON:
inps, pro_to_epi = pro(*args, **kwargs)
else:
inps = pro(*args, **kwargs)
Expand Down Expand Up @@ -459,7 +466,7 @@ def get_computation_and_inputs(*args, **kwargs):
) = cache_entry

cs.last_prologue_execution_start = time.time_ns()
if epilogue:
if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON:
inps, pro_to_epi = pro(*args, **kwargs)
else:
inps = pro(*args, **kwargs)
Expand Down Expand Up @@ -503,6 +510,14 @@ def get_computation_and_inputs(*args, **kwargs):

prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
orig_to_view_swap_map = check_inplace_to_views(computation_trc)
if not compile_options.get("skip_inplace_functionalization", False):
computation_traces.extend(
functionalize_inplace_ops(
computation_trace=computation_trc, orig_to_view_swap_map=orig_to_view_swap_map
)
)
computation_trc = computation_traces[-1]

if epilogue_trc is not None:
epilogue_traces = [epilogue_trc]
Expand Down Expand Up @@ -541,7 +556,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_prologue_transformation_stop = time.time_ns()

cs.last_prologue_execution_start = time.time_ns()
if epilogue:
if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON:
inps, pro_to_epi = pro(*args, **kwargs)
else:
inps = pro(*args, **kwargs)
Expand Down Expand Up @@ -692,6 +707,8 @@ def fn_(*args, **kwargs) -> Any:
if isinstance(fn, pytorch.nn.Module):
fn_ = ThunderModule(fn, fn_)
cd._thunder_module_map[id(fn)] = fn_
for transform in early_transforms:
transform.transform_module(fn_)

# Sets compile options and statistics attributes
cd._get_computation_and_inputs = get_computation_and_inputs
Expand Down
1 change: 0 additions & 1 deletion thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def to_printable(

if is_collection(x):
flat, spec = tree_flatten(x)

printables = []
for f in flat:
printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
Expand Down
69 changes: 61 additions & 8 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
EXT_FLAG_IS_MODULE_MEMBER_DICT = 4
EXT_FLAG_IS_MODULE = 8
EXT_FLAG_IS_CALLABLE = 16
EXT_FLAG_IS_CONSTRAINABLE_INPUT = 32
MODULE_MEMBER_DICT_ATTRS = {
"_parameters",
"_modules",
Expand Down Expand Up @@ -652,6 +653,8 @@ def proxify(self, value: WrappedValue) -> Any:
value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED
# we follow the caching mechanisms of the eager_unpack_interpreter
p = proxy(uvalue, history=value.provenance)
if value.provenance.ext_flag & EXT_FLAG_IS_CONSTRAINABLE_INPUT and hasattr(p, "make_constrainable"):
p.make_constrainable()
assert p.history is not None, f"{p.history}, {value.provenance} {type(p)}"

co: CACHE_OPTIONS = get_cache_option()
Expand Down Expand Up @@ -849,6 +852,11 @@ def _general_jit_hasattr_lookaside(obj: Any, name: str):
# recording the constraint to conditional jumps and such.
def _general_jit_bool_lookaside(wrapped_x: Any) -> bool | INTERPRETER_SIGNALS:
assert isinstance(wrapped_x, WrappedValue)
# It doesn't feel right to insert constraints in bool lookaside, constraints here only applies when the bool value is used in control flow.
if isinstance(wrapped_x.value, NumberProxy):
if wrapped_x.value.is_dynamic():
raise NotImplementedError(f"conversion to bool is not allowed on dynamic proxy={wrapped_x.value}")
wrapped_x.value.make_static_constrained()
bool_lookaside = default_lookaside(bool) or bool
return bool_lookaside(wrapped_x)

Expand Down Expand Up @@ -1168,6 +1176,7 @@ def _general_jit_wrap_callback(value):
pass
elif should_register_for_prologue(value.provenance):
value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED
value.provenance.ext_flag |= EXT_FLAG_IS_CONSTRAINABLE_INPUT
# we follow the caching mechanisms of the eager_unpack_interpreter
p = ctx.proxify(value)
else:
Expand Down Expand Up @@ -1213,6 +1222,52 @@ def _general_jit_store_deref_callback(
general_jit_callbacks = default_callbacks | general_jit_callbacks


# This pass identifies NumberProxy that's marked as statically constrained and propagate the constraints to inputs to the trace.
# The logic is that, if all inputs that produces a NumberProxy is marked statically constrained, then the value of the NumberProxy is statically constrained.
# This pass currently only does backward propagation to insert constraints in prologue trace
# TODO: We should be able to apply constant-folding and simplify computation_trace.
# TODO: If we allow symbolic constraints, we would be able to get more cache re-use. i.e. rather than requiring a NumberProxy to be static, we can have a finer grained constraints as `check_number_gt`.
def propagate_constraints(ctx, inputs, intermediates, computation_trace):
import thunder.core.utils as utils

# set of NumberProxy variables that has already been traversed and marked as statically constrained.
static_np_set = set()

# add static constraints for inputs
for inp in inputs:
u_inp = unvariableify(inp)
if not isinstance(u_inp, NumberProxy):
continue
if u_inp.is_static_constrained():
ctx.add_constraint((clang.check_number_type_and_value, u_inp, u_inp.value))
static_np_set.add(inp)

producers = utils.producers(computation_trace.bound_symbols, _map_to_numbers=False)
# add static constraints propagated from intermediates.
for intermediate in intermediates:
u_intermediate = unvariableify(intermediate)
if not isinstance(u_intermediate, NumberProxy) or not u_intermediate.is_static_constrained():
continue

# DFS traversal along producers, starting from seed `intermediate`
front = [intermediate]
while len(front) != 0:
v = front.pop()
if v in static_np_set:
continue
static_np_set.add(v)

uv = unvariableify(v)
if v in inputs:
ctx.add_constraint((clang.check_number_type_and_value, uv, uv.value))
else:
producer = producers[uv]
for inp in producer.flat_proxy_args:
if not isinstance(inp, NumberProxy):
continue
front.append(variableify(inp))


def get_computation_inputs_and_intermediates(computation_trace):
inputs_list = []
inputs_set = set()
Expand Down Expand Up @@ -1254,7 +1309,7 @@ def get_parameter_or_buffer_or_submodule_name_and_root(provenance):
return typ, name, mprovenance


def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, *, has_epilogue: bool):
def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs):
already_unpacked: dict[int, Proxy] = {}
orig_modules: dict[int, Proxy] = {}

Expand Down Expand Up @@ -1502,10 +1557,7 @@ def from_provenance(provenance, *, new_output=False):
else:
raise NotImplementedError(f"cache info of type {type(v).__name__}")

if has_epilogue:
prims.python_return((pro_to_comp, pro_to_epi))
else:
prims.python_return(pro_to_comp)
prims.python_return((pro_to_comp, pro_to_epi))

return pro_to_comp, pro_to_epi

Expand Down Expand Up @@ -1660,6 +1712,9 @@ def thunder_general_jit(
comp_to_epi = []
pro_to_epi = []

# propagate static constrained intermediates to inputs
propagate_constraints(ctx, pro_to_comp, computation_intermediates, computation_trace)

for i in epilogue_inputs:
if i in computation_intermediates:
comp_to_epi.append(i)
Expand All @@ -1674,9 +1729,7 @@ def thunder_general_jit(
assert last.sym.id == prims.PrimIDs.RETURN
prims.python_return(comp_to_epi_proxies)

pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
)
pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs)

proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/langctxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable:
# ctx.get_method throws an AttributeError when the context does not have the requested attribute, except
# for the prims language context, which always throws a ValueError
method: Callable = ctx.get_method(id, *args, **kwargs)
except (AttributeError, ValueError) as e:
except (AttributeError, ValueError):
return None
return method

Expand Down
25 changes: 25 additions & 0 deletions thunder/core/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import contextmanager
import itertools
from typing import Any
import collections

import torch as pytorch

Expand Down Expand Up @@ -97,6 +98,30 @@ def named_buffers(self, prefix="", recurse=True, remove_duplicate=True):
remove_duplicate=remove_duplicate,
)

def load_original_state_dict(self, state_dict):
# this loads the state dict incrementally to not exhaust memory
module_names = {n for n, _ in self.named_modules()}
sd_per_module = collections.defaultdict(dict)
for k, v in state_dict.items():
prefix, sep, _ = k.rpartition(".")
# not great but should not happen too often / deep
while prefix not in module_names:
prefix, sep, _ = prefix.rpartition(".")
sd_per_module[prefix][k[len(prefix) + len(sep) :]] = v

for submodule_name, sd_part in sd_per_module.items():
prefix = submodule_name + ("." if submodule_name else "")
for transform in self._lc_early_transforms:
sd_part = transform.transform_state_dict_for_submodule(self, submodule_name, sd_part)
for k, v in sd_part.items():
full_k = prefix + k
if k in self._overrides_parameters:
self._overrides_parameters[full_k] = v
elif k in model._overrides_buffers:
self._overrides_buffers[full_k] = v
else:
raise NotImplementedError(f"don't know how to handle {full_k}")

@contextmanager
def no_sync(self):
r"""Context manager to disable gradient synchronization in data parallel mode.
Expand Down
Loading

0 comments on commit b45dbd0

Please sign in to comment.