Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
Browse files Browse the repository at this point in the history
… te-intermediate-sharding
kshitij12345 committed Jul 4, 2024
2 parents a2353ec + 40da5bd commit b062aff
Showing 16 changed files with 254 additions and 110 deletions.
1 change: 1 addition & 0 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
@@ -79,6 +79,7 @@ jobs:

- bash: |
set -ex
export CUDA_LAUNCH_BLOCKING=1
coverage run --source thunder -m \
pytest thunder/tests/ \
-m "not standalone" \
9 changes: 1 addition & 8 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
@@ -391,6 +391,7 @@ def get_computation_and_inputs(*args, **kwargs):
autocast_cpu_dtype=str(autocast_cpu_dtype),
)
autocast_thunder_dtype = autocast_cpu_dtype if pytorch.is_autocast_cpu_enabled() else autocast_gpu_dtype
cache_info.update(autocast_thunder_dtype=str(autocast_thunder_dtype))

cache_info["is_autocast_enabled"] = is_autocast_enabled

@@ -571,14 +572,6 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

if is_autocast_enabled:
from thunder.core.transforms import autocast

computation_trc = trace(compile_data=cd)(
autocast(computation_trc.python_callable(), dtype=autocast_thunder_dtype), *inps
)
computation_traces.append(computation_trc)

backward_trc = None
if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
69 changes: 40 additions & 29 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
@@ -363,30 +363,37 @@ def pad_collate(batch):
return train_dataloader

def calculate_model_flops(self):
meta = torch.device("meta")
device = self.device
self.device = meta

# calculate flops on a meta-device model because we only care about the shapes and
# because the flops calculator installs hooks on the model
meta_model = self.init_model()

x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta)
model_fwd = lambda: meta_model(x)
model_loss = lambda y: torch.nn.functional.cross_entropy(
y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1
)
self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss)

self.device = device
try:
meta = torch.device("meta")
self.device = meta

# calculate flops on a meta-device model because we only care about the shapes and
# because the flops calculator installs hooks on the model
meta_model = self.init_model()

x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta)
model_fwd = lambda: meta_model(x)
model_loss = lambda y: torch.nn.functional.cross_entropy(
y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1
)
self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss)
finally:
self.device = device

def train(self):
t0 = None
if global_rank in [0, None]:
# Calculate the model FLOPs
self.calculate_model_flops()
# Setup throughput Collection
self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size)
try:
# Calculate the model FLOPs
self.calculate_model_flops()
# Setup throughput Collection
self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size)
except:
self.throughput = None
print(
f"Model Flops/Throughput calculation failed for model {self.model_name}. Skipping throughput metric collection."
)

if self.skip_data_sync:
data_sync_ctx = self.model.no_sync
@@ -442,22 +449,26 @@ def train(self):
f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}"
)
if i >= self.warmup_iter:
self.throughput.update(
time=(t1 - t0),
flops=self.perf_metrics["model_flops"],
batches=i,
samples=(i * self.micro_batch_size * self.gradient_accumulation_steps),
lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size),
)
if self.throughput:
self.throughput.update(
time=(t1 - t0),
flops=self.perf_metrics["model_flops"],
batches=i,
samples=(i * self.micro_batch_size * self.gradient_accumulation_steps),
lengths=(
i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size
),
)

if global_rank in [0, None]:
# print(f"Total time: {(t1 - t0):.2f}s")
self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter)

def add_perf_metrics(self):
metrics = self.throughput.compute()
self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"])
self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"])
if self.throughput:
metrics = self.throughput.compute()
self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"])
self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"])
self.perf_metrics["memory_used_GB"] = torch.cuda.max_memory_allocated() / 1e9

def add_model_info_to_metrics(self):
5 changes: 4 additions & 1 deletion thunder/core/codeutils.py
Original file line number Diff line number Diff line change
@@ -227,7 +227,10 @@ def prettyprint(
unflattened_str = unflattened_str.replace(f"'{_quote_marker}", "")
return unflattened_str
if isinstance(x, dtypes.dtype):
return m(f"dtypes.{str(x)}")
# str(x) -> thunder.dtypes.foo
# For consistency with previous repr,
# remove `thunder.` from the representation.
return m(f"{str(x).replace('thunder.', '')}")
if isinstance(x, devices.Device):
return m(f'devices.Device("{x.device_str()}")')
if type(x) is type:
37 changes: 18 additions & 19 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
import traceback
import weakref
import torch
from typing import Any, Literal, NamedTuple, TypedDict
from typing import Any, Literal, TypedDict
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, Sequence, Set, Sized
import collections
import operator
@@ -29,7 +29,6 @@
CodeType,
CoroutineType,
FrameType,
GetSetDescriptorType,
FunctionType,
MethodType,
MethodDescriptorType,
@@ -180,7 +179,7 @@ def track_items(self):
self.has_item_tracking = True

def register_proxy(self, proxy):
# note: the proxy is responsible for capturiing all the existing attributes/values
# note: the proxy is responsible for capturing all the existing attributes/values
assert (
self.original_value is self.nothing
), "cannot proxy multiple times, please file an issue to discuss your use-case"
@@ -272,7 +271,7 @@ def wrap(value: Any, /, *, provenance: ProvenanceRecord) -> WrappedValue:
if cached is not None:
potential_wrap = cached[0]()
if potential_wrap is not None:
# Note: we want to cache mutuable objects to not run into trouble
# Note: we want to cache mutable objects to not run into trouble
# with multiple accesses to the same.
# As the cache only holds a weakref to the WrappedValue instance
# one must persist WrappedValues once things are starting to be modified.
@@ -368,7 +367,7 @@ def populate_item_wrappers(l):
l.key_wrappers[k] = wk # or have those from an iteration of the input?
return

raise NotImplementedError(f"populate item wrapppers for {type(l.value)}")
raise NotImplementedError(f"populate item wrappers for {type(l.value)}")


#
@@ -495,7 +494,7 @@ class ReturnLogItem(TypedDict):
# To the interpreter User Exceptions are part of its normal operations. So we usually
# don't use Python's exception mechanism to handle these. Instead the
# interpreter mimics CPython's implementation of exception handling by
# defining analoguous structures (curexec and exception_stack) set by do_raise and returns
# defining analogous structures (curexec and exception_stack) set by do_raise and returns
# INTERPRETER_SIGNALS.EXCEPTION_RAISED when running things that raised exceptions.
# Here in particular:
# - Handlers are "inside the interpreter",
@@ -524,7 +523,7 @@ def __init__(self, record_history: bool, debug_log: None | StringIO):
self.exception_stack = [None]
# ts.exc_info is the top of the stack (self.exception_stack[-1]).
# Note that most of the time, we are changing the self.exception_stack[-1] instead of popping/pushing exceptions.
# `exception_stack[-1] = ...` is the equivalent of assiging ts.exc_info->exc_type/value/traceback.
# `exception_stack[-1] = ...` is the equivalent of assigning ts.exc_info->exc_type/value/traceback.
# ts.exc_state is exc_info (the bottom-most element of the stack(?))
# ts.curexc (in Python 3.10 as _type / _value / _traceback) is the exception currently being raised

@@ -613,7 +612,7 @@ def get_current_user_source_location(self) -> tuple[str, Positions]:
return frame.code.co_filename, frame.positions
return None, None

# TODO Instead of appending to both the log and and interpreted_instructions we could
# TODO Instead of appending to both the log and interpreted_instructions we could
# consider just appending to the log and then filtering to only instructions when
# interpreted_instructions is accessed
def record_interpreted_instruction(self, inst: dis.Instruction, /) -> InterpreterRuntimeCtx:
@@ -1297,7 +1296,7 @@ def wrapping_wrapper(*args, **kwargs):
return wrapping_wrapper


# Calling a function as an opaque function makes the interpeter not trace into it
# Calling a function as an opaque function makes the interpreter not trace into it
@interpreter_needs_wrap
def call_opaque(fn, /, *args, **kwargs):
return fn(*args, **kwargs)
@@ -2003,15 +2002,15 @@ def __next__(self):
return res


# wrapper-handling lookasides for sequences and mutuable sequences.
# wrapper-handling lookasides for sequences and mutable sequences.
# note:
# - these are only for use when wrapping is enabled
# - the methods (or the corresponding functions) will be registered
# as lookasides for tuple and list. This means that they will be
# called with wrapped values also for self and self.value will point
# to the actual object...
#
# TODO: maybe make these generic for sequences / mutuable sequence
# TODO: maybe make these generic for sequences / mutable sequence
# https://docs.python.org/3/library/stdtypes.html#common-sequence-operations
class SequenceWrapperMethods(WrappedValue):
# NOTE! This is not actually a WrappedValue. However,
@@ -2700,8 +2699,8 @@ def create_namedtuple(typename: str, field_names: str, **kwargs):
}


# While mutuable sequences (lists) are created empty in __new__ and populated in __init__,
# immutuable sequences (tuples) are created with contents in __new__ and __init__ is a nop
# While mutable sequences (lists) are created empty in __new__ and populated in __init__,
# immutable sequences (tuples) are created with contents in __new__ and __init__ is a nop
# (object.__init__, actually).
def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /):
new_tuple_type = cls.value
@@ -3950,7 +3949,7 @@ def _end_async_for_handler_3_10(

assert len(stack) >= try_block.level + 3
del stack[try_block.level + 3 :]
exc_type = frame.interpreter_stack.pop() # we ignore that and asume == type(exc_value)
exc_type = frame.interpreter_stack.pop() # we ignore that and assume == type(exc_value)
exc_value = frame.interpreter_stack.pop()
exc_traceback = frame.interpreter_stack.pop()
if exc_value != None:
@@ -4205,11 +4204,11 @@ def _import_name_handler(
fromlist = stack.pop()
level = stack.pop()

# relative imports rely on the the current module's name (from the frame stac?)
# relative imports rely on the current module's name (from the frame stac?)
# but that isn't available if we use impl, so we resolve it here.
if level > 0: # relative import
# cannot do this in impl easily, but error handling?
# TODO: model this more after resove_name in CPython's Python/import.c
# TODO: model this more after resolve_name in CPython's Python/import.c
def get_current_name(globals):
package = globals.get("__package__")
if package is None:
@@ -5218,7 +5217,7 @@ def impl(tos):
def _push_exc_info_handler(inst: dis.Instruction, /, stack: InterpreterStack, exception_stack: list, **kwargs) -> None:
assert exception_stack
top = stack.pop()
# CPython reads exc_info->exc_type/value/traceback, see RuntimeCtx inititalization of exception_stack for more info
# CPython reads exc_info->exc_type/value/traceback, see RuntimeCtx initialization of exception_stack for more info
stack.append(exception_stack[-1])
stack.append(top)
assert isinstance(top, BaseException)
@@ -5869,7 +5868,7 @@ def _yield_value_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kw
# (generator, async generator, coroutine), we define generic equivalents here
# that take the interpreter frame and execute until the next yield point.
# The way these functions work in Python is that objects are created and
# retruned either on invocation (Python <=3.10) or by the RETURN_GENERATOR
# returned either on invocation (Python <=3.10) or by the RETURN_GENERATOR
# opcode (Python >= 3.11).
def make_generator(
frame: InterpreterFrame,
@@ -6482,7 +6481,7 @@ def _run_frame(
# NormalizeException ?

# CPython sets exc_info->exc_type/value/traceback here
# see RuntimeCtx inititalization of exception_stack for more info
# see RuntimeCtx initialization of exception_stack for more info
runtimectx.exception_stack[-1] = current_exception
with frame.interpreter_stack.set_cur_instruction(PseudoInst.EXCEPTION_HANDLER):
frame.interpreter_stack.append(
14 changes: 14 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -891,6 +891,20 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs):
)


@general_jit_lookaside(torch.autograd.function.Function.apply.__func__)
def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs):

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
args_, kwargs_ = tree_map(unwrap, (args, kwargs))
ctx = torch.autograd.function.FunctionCtx()

pr = ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[wrap_const(custom_forward).provenance])
wrapped_ctx = wrap(ctx, provenance=pr)
args_, kwargs_ = tree_map(lambda a: wrap(a, provenance=pr), (args_, kwargs_))
return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_)


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
2 changes: 1 addition & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
@@ -1302,7 +1302,7 @@ def __repr__(self):
return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>'

def type_string(self):
return f"{self.device} {self.dtype.shortname()}{list(self.shape)}"
return f"{self.device.device_str()} {self.dtype.shortname()}{list(self.shape)}"

# NOTE __getattr__ is overridden to support language-specific methods
def __getattr__(self, attr: str, /):
Loading

0 comments on commit b062aff

Please sign in to comment.