Skip to content

Commit

Permalink
Merge branch 'main' into check_inplace_leafs
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Nov 19, 2024
2 parents b56dd80 + 60f3ee1 commit d79173c
Show file tree
Hide file tree
Showing 29 changed files with 627 additions and 156 deletions.
16 changes: 10 additions & 6 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

# These owners will be the default owners for everything in the repo. Unless a later match takes precedence,
# @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request.
* @mruberry @lantiga @t-vi @carmocca

# Thank you, our previous code owners for their service:
# @carmocca

* @mruberry @lantiga @t-vi

# CI/CD and configs
/.azure/ @borda @lantiga @t-vi @carmocca
/.github/ @borda @lantiga @t-vi @carmocca
/dockers/ @borda @lantiga @t-vi @carmocca
Makefile @borda @lantiga @t-vi @carmocca
*.yml @borda @lantiga @t-vi @carmocca
/.azure/ @borda @lantiga @t-vi
/.github/ @borda @lantiga @t-vi
/dockers/ @borda @lantiga @t-vi
Makefile @borda @lantiga @t-vi
*.yml @borda @lantiga @t-vi
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
exclude: "examples"

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.18
rev: 0.7.19
hooks:
- id: mdformat
additional_dependencies:
Expand All @@ -71,7 +71,7 @@ repos:
# args: ["--fix"]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
rev: v4.0.0-alpha.8
hooks:
- id: prettier
# https://prettier.io/docs/en/options.html#print-width
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Thunder aims to be usable, understandable, and extensible.

 

> \[!Note\]
> [!Note]
> Lightning Thunder is in alpha. Feel free to get involved, but expect a few bumps along the way.
 
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/thunder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Querying information on compiled functions and modules
.. autosummary::
:toctree: generated/

DebugOptions
compile_data
compile_stats
last_traces
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pandas # thunder/benchmarks/test_benchmark_litgpt.py
xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py
jsonargparse # thunder/benchmarks/benchmark_litgpt.py
bitsandbytes==0.42.0 # fixed version!
transformers==4.43.3 # for test_networks.py
transformers==4.46.2 # for test_networks.py

# Installs JAX on Linux and MacOS
jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation
Expand Down
18 changes: 12 additions & 6 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from thunder.core.options import (
CACHE_OPTIONS,
SHARP_EDGES_OPTIONS,
DebugOptions,
)
from thunder.core.trace import (
TraceResults,
Expand Down Expand Up @@ -124,6 +125,7 @@
"nvfuser_executor",
"pytorch_executor",
# debugging functions
"DebugOptions",
"set_execution_callback_file",
"jit",
"resolve_executors",
Expand Down Expand Up @@ -275,7 +277,6 @@ def compile(fn: Callable, recipe: Recipe | None):


# This function will replace compile() (below) before RC1
# TODO RC1 Consider adding a debug_log parameter to control debug printing
# TODO RC1 Consider renaming compile_options to additional_compile_options
def jit(
fn: Callable,
Expand All @@ -287,7 +288,7 @@ def jit(
cache: None | CACHE_OPTIONS | str = None,
disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1
transforms: list[Transform] | None = None,
record_history: bool = False,
debug_options: DebugOptions | None = None,
**compile_options, # TODO RC1 Make this explicit -- dict of options
) -> Callable:
"""Just-in-time compile a callable (function or model).
Expand All @@ -313,7 +314,9 @@ def jit(
- ``"constant values"`` - require Tensors to be of the same shape, device, dtype etc., and integers and strings to match exactly,
- ``"same input"`` - don't check, but just assume that a cached function works if it exists.
transforms: List of transforms to be applied. It should be an instance :class:`thunder.core.transforms.Transform`. Default: ``None``
transforms: optional list of transforms to be applied. It should be a list of instances of :class:`thunder.core.transforms.Transform`. Default: ``None``
debug_options: optional :class:`thunder.DebugOptions` instance. See the doc string of :class:`DebugOptions` for supported debug options. Default: ``None``
"""

if "executors_list" in compile_options:
Expand Down Expand Up @@ -345,8 +348,6 @@ def jit(
# TODO: sharp edge if lookasides are shadowed?
executor_lookasides.update(ex._lookasides)

assert type(record_history) is bool

# TODO RC1 Refine the compile data option to remove unused options
# TODO: refine options
cd = CompileData(
Expand All @@ -361,6 +362,7 @@ def jit(
disable_preprocessing=True,
compile_options=compile_options,
executor_lookasides=executor_lookasides,
debug_options=debug_options,
)
cs = CompileStats()

Expand Down Expand Up @@ -442,6 +444,11 @@ def get_computation_and_inputs(*args, **kwargs):
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform
# to treat certain Symbols as constant.
cache_info["is_grad_enabled"] = pytorch.is_grad_enabled()
cd.is_grad_enabled = pytorch.is_grad_enabled()

# TODO RC1 Add module and function checks to prologue (make it a compile option)

# Checks cache
Expand Down Expand Up @@ -524,7 +531,6 @@ def get_computation_and_inputs(*args, **kwargs):
args,
kwargs,
ad_hoc_executor=ad_hoc_executor,
record_history=record_history,
sharp_edges=cd.sharp_edges,
)
prologue_trc = jit_results.prologue_trace
Expand Down
19 changes: 11 additions & 8 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def _resursively_swap_linear_layers_for_te(module: torch.nn.Module) -> None:

if isinstance(m, torch.nn.Linear):
has_bias = m.bias is not None
new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=device)
# Pass device as str (as there is a bug in TransformerEngine's handling of torch.device)
new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=str(device))
setattr(module, n, new_linear)

if swap_layernorm and isinstance(m, torch.nn.LayerNorm):
new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=device)
# Pass device as str (as there is a bug in TransformerEngine's handling of torch.device)
new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=str(device))
setattr(module, n, new_layernorm)

initial_params_cnt = parameters_cnt(model)
Expand Down Expand Up @@ -366,11 +368,6 @@ def __init__(
self.model = self.init_model()
print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")

if self.use_te_fp8_autocast:
is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm"
swap_linear_layers_for_te(self.model, device, swap_layernorm=not is_wo_layernorm)
self.model.to(torch.bfloat16)

# Setup the distributed algorithm choices
if distributed_first := (self.compile in ("eager", "inductor") or "dynamo" in self.compile):
self.model = self.setup_distributed(self.model)
Expand Down Expand Up @@ -407,8 +404,14 @@ def init_model(self):
init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device
with init_device:
model = GPT(self.config)
model.to(dtype=torch.bfloat16)

# Handle fp8 related Linear layer swapping (for torchao or TransformerEngine)
model = self._torchao_fp8_handler.convert_model_to_fp8(model)
if self.use_te_fp8_autocast:
is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm"
swap_linear_layers_for_te(model, init_device, swap_layernorm=not is_wo_layernorm)

model.to(dtype=torch.bfloat16)
return model

def setup_distributed(self, model):
Expand Down
7 changes: 7 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
resolve_cache_option,
SHARP_EDGES_OPTIONS,
resolve_sharp_edges_option,
DebugOptions,
)
from thunder.core.utils import check, is_collection, AutocastStack
from thunder.core.pytree import tree_flatten, tree_map
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
compile_options: dict[str, Any] = {},
get_computation_and_inputs: Callable | None = None,
executor_lookasides: dict[Callable, Callable] | None = None,
debug_options: DebugOptions | None = None,
):
# Records whether we're using the thunder.jit() entrypoint or not
# The thunder.jit() entrypoint introduces important architectural updates,
Expand All @@ -221,6 +223,10 @@ def __init__(
# State for pytorch autocast context managers.
self.autocast_stack: AutocastStack = AutocastStack()

# State to query whether grad is enabled or disabled using
# torch.no_grad/torch.enable_grad/torch._C._set_grad_enabled
self.is_grad_enabled: bool = True

#
# Gathers additional metadata
#
Expand Down Expand Up @@ -257,6 +263,7 @@ def __init__(
self.disable_preprocessing = disable_preprocessing
self.disable_torch_autograd_support = disable_torch_autograd_support
self.debug_log = debug_log
self.debug_options = DebugOptions() if debug_options is None else debug_options

# TODO Consider validating that this dict has exclusively string keys
self.compile_options = compile_options
Expand Down
70 changes: 60 additions & 10 deletions thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,63 @@
# This feature is available in Python 3.7 and later.
# This import (like all __future__ imports) must be at the beginning of the file.
from __future__ import annotations
from collections.abc import Sequence
from enum import Enum
from types import MappingProxyType, ModuleType, CodeType, EllipsisType, FunctionType, MethodType
from typing import TYPE_CHECKING
import collections.abc
import dis
import functools
import inspect
import os
import dis

import sys
import collections.abc
from numbers import Number
from typing import Any, Type, Union, Optional, Tuple, List
from collections.abc import Callable
from collections.abc import Sequence
from types import MappingProxyType, ModuleType, CodeType, EllipsisType, FunctionType, MethodType
import re
import inspect
import sys

import torch
import numpy as np

if TYPE_CHECKING:
from collections.abc import Callable
from numbers import Number
from typing import Any


__all__ = [
"BoundSymbolInterface",
"NumberProxyInterface",
"ProxyInterface",
"SymbolInterface",
"TagBase",
"TensorProxyInterface",
"TermColors",
"TorchAutogradFunctionCtxProxyInterface",
"build_callable",
"check",
"check_type",
"check_types",
"check_valid_length",
"check_valid_shape",
"default_dataclass_params",
"extract_callable_name",
"fnprint",
"get_module",
"indent",
"init_colors",
"init_windows_terminal",
"is_base_printable",
"is_base_printable_literal",
"is_base_printable_type",
"is_base_printable_value",
"is_collection",
"print_base_printable",
"print_base_type",
"print_number",
"print_type",
"run_once",
"sequencify",
"warn_term_variable_once",
]


#
# Common utilities importable by any other file
Expand Down Expand Up @@ -171,6 +210,17 @@ def get_module(name: str) -> Any:
return sys.modules[name]


def is_likely_from_collections_namedtuple(tuple_type):
from collections import namedtuple

# Check if tuple_type code object is coming from namedtuple
return (
hasattr(tuple_type, "__repr__")
and hasattr(tuple_type.__repr__, "__code__")
and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts
)


#
# Functions related to printing and debugging
#
Expand Down
46 changes: 31 additions & 15 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,40 @@
from types import CodeType, FunctionType, MethodType, EllipsisType
from typing import List, Dict, Tuple, Set, Deque, Any, NamedTuple, Optional
from numbers import Number
from collections import deque
from collections.abc import Mapping, Sequence, Iterable, Callable
import inspect
from inspect import Parameter
import string
import functools
from __future__ import annotations
from functools import partial
from inspect import Parameter
from typing import TYPE_CHECKING, NamedTuple
import dataclasses
import dis
import functools
import inspect
import linecache
import dataclasses
import sys

import torch

import thunder.core.baseutils as baseutils
from thunder.core.baseutils import ProxyInterface, check
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.pytree import tree_flatten, tree_unflatten

if TYPE_CHECKING:
from typing import Any
from collections.abc import Callable, Sequence
from thunder.core.trace import TraceCtx


__all__ = [
"ContextObject",
"SigInfo",
"get_siginfo",
"get_source_line",
"indent_string",
"is_literal",
"is_printable",
"is_simple_printable_collection",
"module_shortname",
"prettyprint",
"to_printable",
]

#
# Functions related to analyzing and printing functions and arguments
#
Expand Down Expand Up @@ -106,7 +120,7 @@ def is_literal(x: Any) -> bool:
return True


def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, tuple[str, Any] | None]:
def _to_printable(tracectx: TraceCtx | None, x: Any) -> tuple[Any, tuple[str, Any] | None]:
can_print, module_info = is_printable(x)
if can_print:
return x, module_info
Expand All @@ -123,7 +137,7 @@ def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, tuple[str, Any] | No

# TODO Improve type annotations
def to_printable(
trace: Optional,
trace: TraceCtx | None,
x: Any,
*,
import_ctx: dict | None = None,
Expand Down Expand Up @@ -302,7 +316,9 @@ def __repr__(self):
# TODO Print the original signature's type annotations
# TODO Maybe be clear about what inputs are const and what aren't?
# TODO Improve this signature's type annotations
def prettyprint(self, *, trace: Optional = None, import_ctx: Optional = None, object_ctx=None) -> str:
def prettyprint(
self, *, trace: TraceCtx | None = None, import_ctx: Any | None = None, object_ctx: Any | None = None
) -> str:
def _arg_printer(name: str, has_default: bool, default: Any = None) -> str:
# NOTE In this case the argument has a default value, like 'a' in foo(a=5)
if has_default:
Expand Down
Loading

0 comments on commit d79173c

Please sign in to comment.