diff --git a/docs/source/index.rst b/docs/source/index.rst index c094a6dbd7..fec36ba4e3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -102,6 +102,7 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik What's next FSDP Under the Hood Tutorial Benchmarking Thunder + Writing a Transform .. toctree:: :maxdepth: 1 diff --git a/notebooks/writing_a_trace_transform_cpu_offloading.ipynb b/notebooks/writing_a_trace_transform_cpu_offloading.ipynb new file mode 100644 index 0000000000..8ccd0424e1 --- /dev/null +++ b/notebooks/writing_a_trace_transform_cpu_offloading.ipynb @@ -0,0 +1,877 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Introduction\n", + "\n", + "In this tutorial, we will write a Trace transformation to perform CPU Offloading of intermediate tensors.\n", + "\n", + "CPU Offloading is a technique to decrease the peak memory usage during training. This can allow us to train a larger model which would otherwise won't be possible. However, we have to trade of some performance (increased memory transfers) to achieve the same." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "from typing import Callable, Mapping\n", + "\n", + "import torch\n", + "import torch.utils.benchmark\n", + "\n", + "import thunder\n", + "from thunder.core.trace import TraceCtx\n", + "from thunder.core.transform_common import Transform\n", + "from thunder.core.proxies import TensorProxy, variableify, Variable\n", + "from thunder.core.pytree import tree_map\n", + "from thunder.core.trace import tracectx, from_trace\n", + "from thunder.extend import OperatorExecutor\n", + "from thunder.core.symbol import BoundSymbol\n", + "from thunder.core import prims\n", + "from thunder.core.transforms import bsym_list_to_dag, Node, toposort_bsym_dag, TOPOSORT_ORDER\n", + "from thunder.core.vjp_utils import get_saved_for_backward_tensors\n", + "from thunder.core.module import ThunderModule" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Transforms\n", + "\n", + "To understand transforms, we need to know what a `Trace` is. In `thunder`, `Trace` is the representation of the jitted program in terms of thunder operations/symbol. Each operation in `Trace` is a collection of `BoundSymbol` i.e. a `Symbol` with it's input and output. We can also print the trace as a Python program for easier inspection. We will do this later in the notebook. To understand these concepts, you can read the helpful `zero_to_thunder.ipynb`. \n", + "\n", + "`thunder` allows us to write our custom transforms to transform trace/s. These transforms can be used for replace pointwise operations with fused implementation, compute gradient of the given computation, etc. Besides this, `thunder` enables us to apply these transforms at different stages during compilation. In this tutorial, we will use the post optimization stage, the point at which we already have the forward and the backward execution trace ready. To write our transform, we have to inherit from `Transform` class. This class implements the interface that each transform should have. By default, it provides no-op transformations. To use our transform, we provide an instance of our transform object to the `thunder.jit` via `transforms` argument.\n", + "\n", + "However, before writing our transform, we will make an `OperatorExecutor` with which we will create 2 operators/symbol - 1. to offload tensors to CPU 2. Load the offloaded tensors back to CUDA device. You read more about adding custom operators in `adding_custom_operator.ipynb` and also `zero_to_thunder.ipynb`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a new executor.\n", + "offload_ex = OperatorExecutor(\"offload_ex\")\n", + "\n", + "# NOTE: We create the offloaded CPU tensor in pinned memory and load the tensor back onto GPU with `to(non_blocking=True)`.\n", + "# These allow for better memory transfer speeds.\n", + "# Read the following tutorial for detailed explanation - https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html\n", + "\n", + "# Offload the GPU tensor to a pinned CPU tensor.\n", + "def offload_to_cpu_impl(t):\n", + " # Due to https://github.com/Lightning-AI/lightning-thunder/issues/950\n", + " # it may receive tensor on CPU.\n", + " if t.device == torch.device(\"cpu\"):\n", + " return t\n", + "\n", + " packed = torch.empty(\n", + " t.size(),\n", + " dtype=t.dtype,\n", + " layout=t.layout,\n", + " pin_memory=True,\n", + " )\n", + " packed.copy_(t)\n", + " return packed\n", + "\n", + "offload_to_cpu = offload_ex.register_operator(\n", + " \"offload_to_cpu\",\n", + " meta=lambda t: TensorProxy(\"offloaded_\" + t.name, like=t, device=thunder.core.devices.Device(\"cpu\")),\n", + " fn=offload_to_cpu_impl,\n", + ")\n", + "\n", + "\n", + "# Load the tensor to given GPU\n", + "def load_to_gpu_impl(t, device):\n", + " return t.to(device, non_blocking=True)\n", + "\n", + "\n", + "load_to_gpu = offload_ex.register_operator(\n", + " \"load_to_gpu\",\n", + " meta=lambda t, device: TensorProxy(like=t, device=thunder.core.devices.Device(device)),\n", + " fn=load_to_gpu_impl,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we will have some helper functions to implement our transformation" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def get_symbols_to_first_or_last_used_variables(symbols, first_used=False):\n", + " \"\"\"\n", + " This function processes a sequence of symbols and determines which variables \n", + " are first/last used by each symbol determined based on argument `first_used`.\n", + " It returns a mapping from variables to the symbols where they were first/last used.\n", + "\n", + " Args:\n", + " symbols (iterable): An iterable of symbols\n", + " first_used (bool): Whether to return the map of first used variable to symbol mapping if True otherwise return the map for last used.\n", + " Defaults to False.\n", + "\n", + " Returns:\n", + " variable_to_symbol (dict): A dictionary mapping each variable to the symbol where it is first/last used based on `first_used` argument.\n", + " \"\"\"\n", + " variable_to_symbol = {}\n", + "\n", + " def _mark_first_or_last_use(symbol, variable):\n", + " if not variable in variable_to_symbol:\n", + " variable_to_symbol[variable] = symbol\n", + "\n", + " iter_symbols = symbols if first_used else reversed(symbols)\n", + " for symbol in iter_symbols:\n", + " # If this function is used in the combined nvfuser+torch executor, there are no symbols but regions.\n", + " # Regions do not have args, kwargs\n", + " if hasattr(symbol, \"inputs\"):\n", + " variables = tuple(symbol.inputs) + tuple(symbol.outputs)\n", + " else:\n", + " variables = (symbol.flat_variableified_proxy_args) + tuple(symbol.flat_variableified_proxy_outs)\n", + " tree_map(lambda x: _mark_first_or_last_use(symbol, x), variables)\n", + "\n", + " return variable_to_symbol\n", + "\n", + "\n", + "def get_symbol_to_idx(symbols):\n", + " '''\n", + " This function returns a map from symbol to it's position in the sequence.\n", + " '''\n", + " return {sym: idx for idx, sym in enumerate(symbols)}\n", + "\n", + "\n", + "def move_closer_to_consumer(execution_trace: TraceCtx) -> TraceCtx:\n", + " '''\n", + " This function takes the trace and reorders the operation such that operations producing input for the next operation\n", + " are closer together.\n", + "\n", + " This is required as in the backward trace, the first consumer of saved_for_backward tensor maybe\n", + " a reshape or permute op and the actual computation occurs 50-100 (or more) lines later.\n", + " Because of this we load more tensors than required eagerly (thus decreasing the memory gains from CPU Offloading).\n", + "\n", + " Args:\n", + " execution_trace (TraceCtx): Trace to be re-ordered.\n", + " '''\n", + " order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}\n", + "\n", + " def prefer_ops_closer_to_consumer(eligible_nodes: list[Node]) -> int:\n", + " def key(node: Node) -> int:\n", + " return order_in_trace[node.bsym]\n", + "\n", + " return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))\n", + "\n", + " # This moves all del or clear collection at the bottom (as they don't return anything)\n", + " bound_symbols = toposort_bsym_dag(\n", + " bsym_list_to_dag(execution_trace.bound_symbols)[1],\n", + " TOPOSORT_ORDER.BOTTOM_UP,\n", + " selector=prefer_ops_closer_to_consumer,\n", + " )\n", + "\n", + " for idx, bsym in enumerate(bound_symbols):\n", + " if bsym.sym.id == prims.PrimIDs.DEL:\n", + " break\n", + "\n", + " new_execution_trace = from_trace(execution_trace)\n", + " new_execution_trace.bound_symbols = bound_symbols[:idx]\n", + "\n", + " new_execution_trace = thunder.executors.passes.del_last_used(new_execution_trace, clear_mutable_collections=True)\n", + " return new_execution_trace" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now to the main topic, of writing the transform for CPUOffloading.\n", + "\n", + "The rough implementation of transform will look like this\n", + "1. From the forward computation trace, determine which tensors we want to offload to CPU. The `return` symbol of the forward trace has a sequence of tensors which are saved for the backward trace. We go through this list of tensor and find all the intermediate tensors (i.e. which are not an input to the trace). Here, we will also call a user provided callback which can further filter this list of tensors to offload.\n", + "2. In the forward trace, we then find the last of use of the tensors to offload from above step and insert a call to `offload_to_cpu` symbol that we created above. Note that we will also save a map of which tensors we offloaded. We also note the original device where the tensor lived so that we can load it back to correct device.\n", + "3. In the forward trace, we then update the `return` symbol to return the offloaded tensors (which are saved for the backward pass).\n", + "4. In the backward trace, we read from the map of the tensors which were offloaded and update the `unpack` symbol of saved tensors to replace the original tensors with our offloaded tensors.\n", + "5. In the backward trace, we then find the first use of the offloaded tensor in a computation and insert a `load_to_gpu` call before it. Note that here, we will use the previously stored map of tensor to original device so that we load it onto the correct device.\n", + "\n", + "To see this steps, in our implementation -\n", + "1. See method `transform_trace_post_optimization`, which is invoked by `thunder` with first the forward trace and then separately with the backward trace.\n", + "2. See method `_offload_tensors_from_forward`, which implements Step 1, 2 and 3 from above.\n", + "3. See method `_load_tensors_for_backward`, which implements Step 4, 5 from above.\n", + "\n", + "Note that each of the above method has more details regarding the implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class CPUOffloading(Transform):\n", + " '''\n", + " Transform to implement CPU Offloading.\n", + "\n", + " Args:\n", + " save_tensor_policy: Users can pass a callback with signature fn(offloaded_tensors, forward_trace) to filter\n", + " the offloaded_tensors based on their preference eg. biggest 20% intermediate tensors or\n", + " intermediates of certain operations\n", + " '''\n", + " def __init__(self, save_tensor_policy: Callable[[tuple[TensorProxy, ...], TraceCtx], tuple[TensorProxy, ...]] | None = None):\n", + " self.forward_pass = None\n", + " self.backward_pass = None\n", + " self._offloaded_tensors: Mapping[Variable, TensorProxy] = {}\n", + " self._offloaded_tensors_dev: Mapping[Variable, str] = {}\n", + " self.save_tensor_policy = None\n", + " if save_tensor_policy is not None:\n", + " assert callable(save_tensor_policy)\n", + " self.save_tensor_policy = save_tensor_policy\n", + "\n", + " def _get_tensors_to_offload(self, forward_trace):\n", + " '''\n", + " Based on the `forward_trace`, we find the symbols that we want to offload to CPU.\n", + " This function finds the intermediate tensors that are saved for backward i.e. ones that are not input or output of this trace.\n", + " '''\n", + " return_bsym = forward_trace.bound_symbols[-1]\n", + " trace_args = return_bsym.args[0][\"flat_args\"]\n", + " saved_tensors = get_saved_for_backward_tensors(forward_trace)\n", + "\n", + " tensor_args_name = tuple(arg.name for arg in trace_args if isinstance(arg, TensorProxy))\n", + "\n", + " def is_in_tensor_args(t):\n", + " return t.name in tensor_args_name\n", + "\n", + " def is_cuda_tensor(t):\n", + " return t.device.type == \"cuda\"\n", + "\n", + " # Tensors which are intermediate and not argument to the computation trace are\n", + " # the ones we are interested in offloading.\n", + " tensors_to_offload = tuple(t for t in saved_tensors if ((not is_in_tensor_args(t)) and is_cuda_tensor(t)))\n", + "\n", + " # Apply users policy if present.\n", + " if self.save_tensor_policy is not None:\n", + " tensors_to_offload = self.save_tensor_policy(tensors_to_offload, forward_trace)\n", + " self.tensors_to_offload = tensors_to_offload\n", + " return self.tensors_to_offload\n", + "\n", + " def _replace_saved_tensors(self, forward_trace, new_output_map):\n", + " return_bsym = forward_trace.bound_symbols[-1]\n", + " new_return_bsym = return_bsym.from_bsym_swap_proxies(new_output_map)\n", + "\n", + " # Replace the old return with our new return.\n", + " forward_trace.bound_symbols.pop(-1)\n", + " forward_trace.bound_symbols.append(new_return_bsym)\n", + "\n", + " def _offload_tensors_from_forward(self, computation_trace):\n", + " '''\n", + " This function takes the forward computation trace and performs following step\n", + " 1. Find the tensors to be offloaded using `_get_tensors_to_offload` (this also calls users `save_tensor_policy` if present).\n", + " 2. Insert calls to the `offload_to_cpu` symbol with the tensor to offload. These calls are placed after the last computational\n", + " use of the tensors to be offloaded so that we free the memory as soon as possible.\n", + " 3. Finally, we update the last symbol i.e. `return` symbol to return the offloaded tensors instead of the original tensors.\n", + " '''\n", + " # Step 1\n", + " # Find the tensors to offload.\n", + " # We offload saved tensors which are not arguments to the computation trace and are saved for backwards.\n", + " tensors_to_offload = self._get_tensors_to_offload(computation_trace)\n", + "\n", + " # Step 2\n", + " # Insert the offloading calls after the last use of the saved tensor (which we want to offload).\n", + " # NOTE - We pass `computation_trace.bound_symbols[:-1]` as we don't want to pass the `return` symbol (which will otherwise be the last use of the saved tensors).\n", + " variable_to_last_symbol = get_symbols_to_first_or_last_used_variables(\n", + " computation_trace.bound_symbols[:-1], first_used=False\n", + " )\n", + " symbol_to_idx = get_symbol_to_idx(computation_trace.bound_symbols)\n", + "\n", + " # Book keeping for backward pass update.\n", + " new_output_map: Mapping[Variable, TensorProxy] = {}\n", + " new_output_dev_map: Mapping[Variable, str] = {}\n", + "\n", + " # Since we are inserting in the list (we need to obey increasing order) - else the insertions will be incorrect.\n", + " sorted_tensors_to_offload = sorted(\n", + " tensors_to_offload, key=lambda t: symbol_to_idx[variable_to_last_symbol[variableify(t)]]\n", + " )\n", + " for idx, t in enumerate(sorted_tensors_to_offload):\n", + " last_used_symbol = variable_to_last_symbol[variableify(t)]\n", + " last_used_symbol_idx = symbol_to_idx[last_used_symbol]\n", + " computation_trace.push_scope([])\n", + " with tracectx(computation_trace):\n", + " o = offload_to_cpu(t)\n", + " prims.python_del(t)\n", + " scoped_comp = computation_trace.pop_scope()\n", + " scoped_comp[0].header = \"Created by CPU Offloading Transform\"\n", + " offload_to_cpu_symbol = scoped_comp[0]\n", + " del_symbol = scoped_comp[1]\n", + "\n", + " # This will insert `del` first and then push it down when we insert `offload_to_cpu`.\n", + " computation_trace.bound_symbols.insert(last_used_symbol_idx + 1 + (idx * 2), del_symbol)\n", + " computation_trace.bound_symbols.insert(last_used_symbol_idx + 1 + (idx * 2), offload_to_cpu_symbol)\n", + "\n", + " # Update book keeping.\n", + " new_output_map[variableify(t)] = o\n", + " new_output_dev_map[variableify(t)] = t.device.device_str()\n", + "\n", + " # Step 3\n", + " # Update the return symbol to return our offloaded tensors in saved for backward.\n", + " self._replace_saved_tensors(computation_trace, new_output_map)\n", + "\n", + " # Book keeping for backward pass update.\n", + " self._offloaded_tensors = new_output_map\n", + " self._offloaded_tensors_dev = new_output_dev_map\n", + " return computation_trace\n", + "\n", + " def _load_tensors_for_backward(self, computation_trace):\n", + " '''\n", + " This function takes the backward computation trace and performs following step\n", + " 1. Finds the unpack collection symbol which unpacks the saved tensors passed to the backward trace.\n", + " 2. Updates the unpack collection to unpack the offloaded tensors instead of the original ones.\n", + " 3. Before the first use of the offloaded tensor in computation, we insert the `load_to_gpu` to load the tensor back on GPU.\n", + " '''\n", + " self.backward_pass = computation_trace\n", + " offloaded_tensors = self._offloaded_tensors\n", + " offloaded_tensors_dev_map = self._offloaded_tensors_dev\n", + "\n", + " compute_producers, compute_consumers = thunder.core.utils.producers_and_consumers(computation_trace)\n", + "\n", + " # We want to insert `loads` before the first use of offloaded_tensors.\n", + " variable_to_first_symbol = get_symbols_to_first_or_last_used_variables(computation_trace.bound_symbols, first_used=True)\n", + "\n", + " symbol_to_idx = get_symbol_to_idx(computation_trace.bound_symbols)\n", + "\n", + " # Step 1 and 2\n", + " # Update unpack collection so that it\n", + " # outputs the offloaded tensor proxies (not the original ones).\n", + " unpack_sym = compute_producers[list(offloaded_tensors.keys())[0].proxy]\n", + " unpack_idx = symbol_to_idx[unpack_sym]\n", + " unpack_sym_out = unpack_sym.output\n", + " new_out = []\n", + " for out in unpack_sym_out:\n", + " if (vout := variableify(out)) in offloaded_tensors:\n", + " new_out.append(offloaded_tensors[vout])\n", + " else:\n", + " new_out.append(out)\n", + " new_unpack_bsym = BoundSymbol.from_bsym(unpack_sym, output=tuple(new_out))\n", + " computation_trace.bound_symbols[unpack_idx] = new_unpack_bsym\n", + "\n", + " # Now we again find the first usages of offloaded tensor\n", + " # This will actually point us to the first consumer of the offloaded tensor.\n", + " offset = unpack_idx + 1\n", + " variable_to_first_symbol = get_symbols_to_first_or_last_used_variables(computation_trace.bound_symbols[offset:], first_used=True)\n", + "\n", + " # Step 3\n", + " # Load the offloaded tensors to GPU before usage.\n", + " # Should iterate in correct order (else insertion positions will be incorrect).\n", + " for idx, (vt, offloaded_t) in enumerate(\n", + " sorted(offloaded_tensors.items(), key=lambda kv: symbol_to_idx[variable_to_first_symbol[kv[0]]])\n", + " ):\n", + " first_used_symbol = variable_to_first_symbol[vt]\n", + " first_used_symbol_idx = symbol_to_idx[first_used_symbol]\n", + " t = vt.proxy\n", + " device = offloaded_tensors_dev_map[vt]\n", + "\n", + " with tracectx(computation_trace):\n", + " new_sym = load_to_gpu.bind(offloaded_t, device, output=t)\n", + "\n", + " new_sym.header = \"Created by CPU Offloading Transform\"\n", + " computation_trace.bound_symbols.insert(first_used_symbol_idx + idx, new_sym)\n", + "\n", + " return computation_trace\n", + "\n", + " def transform_trace_post_optimization(self, computation_trace: thunder.TraceCtx, **kwargs):\n", + " if self.forward_pass is None:\n", + " self.forward_pass = computation_trace\n", + " # Processing for the forward pass (only if we are going to compute backward).\n", + " if \"augmented_forward\" in computation_trace.fn.__name__:\n", + " # Create a new copy of computation trace using `from_trace`.\n", + " new_computation_trace = from_trace(computation_trace)\n", + " # `from_trace` creates a shallow copy where `bound_symbols` and `provenance` are not copied.\n", + " new_computation_trace.bound_symbols = computation_trace.bound_symbols\n", + "\n", + " new_computation_trace = self._offload_tensors_from_forward(new_computation_trace)\n", + " else:\n", + " # Skip if no tensor was offloaded.\n", + " if len(self._offloaded_tensors) == 0:\n", + " return computation_trace\n", + "\n", + " # Create a new copy of computation trace using `from_trace`.\n", + " new_computation_trace = from_trace(computation_trace)\n", + " # `from_trace` creates a shallow copy where `bound_symbols` and `provenance` are not copied.\n", + " new_computation_trace.bound_symbols = computation_trace.bound_symbols\n", + "\n", + " # We need this because in unmodified backward trace, the first consumer of saved_for_backward maybe\n", + " # a reshape or permute op and the actual computation occurs 50-100 (or more) lines later.\n", + " # Because of this we load more tensors than required eagerly (thus decreasing the memory gains from CPU Offloading).\n", + " # Eg. on line 92\n", + " # # Created by CPU Offloading Transform\n", + " # t1319 = load_to_gpu(offloaded_t1319, 'cuda:0') # t1319: \"cuda:0 f32[8, 1024, 11008]\"\n", + " # t4021 = torch.reshape(t1319, (-1, 11008)) # t4021: \"cuda:0 f32[8192, 11008]\"\n", + " # # t4021 = ltorch.reshape(t1319, (-1, 11008)) # t4021: \"cuda:0 f32[8192, 11008]\"\n", + " # # t4021 = prims.reshape(t1319, (8192, 11008)) # t4021: \"cuda:0 f32[8192, 11008]\"\n", + " # del t1319\n", + " # And it's usage in computation is at 612\n", + " # t4022 = torch.matmul(t4020, t4021) # t4022: \"cuda:0 f32[4096, 11008]\"\n", + " # t4022 = ltorch.matmul(t4020, t4021) # t4022: \"cuda:0 f32[4096, 11008]\"\n", + " # t4022 = prims.matmul(t4020, t4021) # t4022: \"cuda:0 f32[4096, 11008]\"\n", + " new_computation_trace = move_closer_to_consumer(new_computation_trace)\n", + "\n", + " # Transform the backward trace to load offloaded tensors back to the device.\n", + " new_computation_trace = self._load_tensors_for_backward(new_computation_trace)\n", + "\n", + " return new_computation_trace" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def clear_memory():\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + " torch.cuda.reset_accumulated_memory_stats()\n", + " torch.cuda.reset_peak_memory_stats()\n", + " print(f\"Allocated Memory after cleaning {torch.cuda.memory_allocated() / 1e9} GB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark(jmodel: ThunderModule, model: torch.nn.Module, args, kwargs):\n", + " # NOTE - This function takes care of warm-up\n", + " stmt = \"\"\"\n", + "# Use the optimized model for prediction and backward\n", + "o = jmodel(*args, **kwargs)\n", + "o.sum().backward()\n", + "for param in model.parameters(): # use original model for clear grads\n", + " param.grad = None\n", + "\"\"\"\n", + " timer = torch.utils.benchmark.Timer(\n", + " stmt=stmt, globals={\"jmodel\": jmodel, \"model\": model, \"args\": args, \"kwargs\": kwargs}\n", + " ).timeit(number=10)\n", + " return timer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Testing our Transform on a Simple Model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Peak Memory with thunder : 19279872 bytes\n", + "Peak Memory with CPU Offloading : 18444288 bytes\n", + "Allocated Memory after cleaning 0.017047552 GB\n" + ] + } + ], + "source": [ + "class MySimpleModel(torch.nn.Module):\n", + " def __init__(self, n_layers=10):\n", + " super().__init__()\n", + " self.fcs = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(n_layers)])\n", + "\n", + " def forward(self, x):\n", + " for fc in self.fcs:\n", + " x = torch.nn.functional.relu(fc(x))\n", + " \n", + " return x\n", + "\n", + "def get_model_and_args():\n", + " device = 'cuda'\n", + " model = MySimpleModel(n_layers=100).to(device)\n", + " args = (torch.randn(128, 16, device=device),)\n", + " kwargs = {}\n", + " return model, args, kwargs\n", + "\n", + "model, args, kwargs = get_model_and_args()\n", + "\n", + "# Check against the vanilla `thunder.jit` model\n", + "expected = thunder.jit(model)(*args, **kwargs)\n", + "\n", + "grad_output = torch.randn_like(expected)\n", + "expected_grads = torch.autograd.grad(expected, model.parameters(), grad_output)\n", + "\n", + "print(f\"Peak Memory with thunder : {torch.cuda.max_memory_allocated()} bytes\")\n", + "torch.cuda.reset_peak_memory_stats()\n", + "\n", + "with torch.no_grad():\n", + " expected_cpu = expected.to(\"cpu\")\n", + " expected_grads_cpu = tree_map(lambda t: t.to(\"cpu\"), expected_grads)\n", + "\n", + "jmodel = thunder.jit(model, transforms=[CPUOffloading()])\n", + "\n", + "actual = jmodel(*args, **kwargs)\n", + "\n", + "# Verify that saved tensors are on CPU.\n", + "saved_tensor_devices = set()\n", + "for t in actual.grad_fn.saved_tensors:\n", + " saved_tensor_devices.add(str(t.device))\n", + "\n", + "assert \"cpu\" in saved_tensor_devices # Verify that we actually have saved tensors on CPU\n", + "actual_grads = torch.autograd.grad(actual, jmodel.parameters(), grad_output)\n", + "\n", + "print(f\"Peak Memory with CPU Offloading : {torch.cuda.max_memory_allocated()} bytes\")\n", + "\n", + "with torch.no_grad():\n", + " actual_cpu = actual.to(\"cpu\")\n", + " actual_grads_cpu = tree_map(lambda t: t.to(\"cpu\"), actual_grads)\n", + "\n", + "# Sanity Check that values match\n", + "torch.testing.assert_close(actual_cpu, expected_cpu)\n", + "torch.testing.assert_close(actual_grads_cpu, expected_grads_cpu)\n", + "\n", + "# Fetch the forward and backward traces for inspection\n", + "fw_traces = thunder.last_traces(jmodel)\n", + "bw_traces = thunder.last_backward_traces(jmodel)\n", + "\n", + "del jmodel, model, args, kwargs, actual, actual_grads, expected, expected_grads, grad_output # Free memory.\n", + "\n", + "clear_memory()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspecting the forward and the backward traces." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fw_traces[-1] # Note the calls to `offload_to_cpu` and verify that they are after the last usage of the tensor." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Snippet from the forward trace\n", + "```python\n", + " t485 = torch.nn.functional.linear(t484, t_fcs_97_weight, t_fcs_97_bias) # t485: \"cuda:0 f32[128, 16]\"\n", + " # t485 = ltorch.linear(t484, t_fcs_97_weight, t_fcs_97_bias) # t485: \"cuda:0 f32[128, 16]\"\n", + " # t485 = prims.linear(t484, t_fcs_97_weight, t_fcs_97_bias) # t485: \"cuda:0 f32[128, 16]\"\n", + " # Created by CPU Offloading Transform\n", + " offloaded_t484 = offload_to_cpu(t484) # offloaded_t484: \"cpu f32[128, 16]\"\n", + " del t484\n", + " [t487, t489] = nvFusion97(t485)\n", + " # t487 = prims.gt(t485, 0.0) # t487: \"cuda:0 b8[128, 16]\"\n", + " # t489 = prims.where(t487, t485, 0.0) # t489: \"cuda:0 f32[128, 16]\"\n", + " # Created by CPU Offloading Transform\n", + " offloaded_t487 = offload_to_cpu(t487) # offloaded_t487: \"cpu b8[128, 16]\"\n", + " del t487\n", + " del t485\n", + " t490 = torch.nn.functional.linear(t489, t_fcs_98_weight, t_fcs_98_bias) # t490: \"cuda:0 f32[128, 16]\"\n", + " # t490 = ltorch.linear(t489, t_fcs_98_weight, t_fcs_98_bias) # t490: \"cuda:0 f32[128, 16]\"\n", + " # t490 = prims.linear(t489, t_fcs_98_weight, t_fcs_98_bias) # t490: \"cuda:0 f32[128, 16]\"\n", + " # Created by CPU Offloading Transform\n", + " offloaded_t489 = offload_to_cpu(t489) # offloaded_t489: \"cpu f32[128, 16]\"\n", + " del t489\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bw_traces[-1] # Note the calls to `load_to_gpu` and verify that they are before the first usage of the tensor." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Snippet from the backward trace\n", + "\n", + "```python\n", + " # Created by CPU Offloading Transform\n", + " t489 = load_to_gpu(offloaded_t489, 'cuda:0') # t489: \"cuda:0 f32[128, 16]\"\n", + " t2015 = torch.reshape(t489, (-1, 16)) # t2015: \"cuda:0 f32[128, 16]\"\n", + " # t2015 = ltorch.reshape(t489, (-1, 16)) # t2015: \"cuda:0 f32[128, 16]\"\n", + " # t2015 = prims.reshape(t489, (128, 16)) # t2015: \"cuda:0 f32[128, 16]\"\n", + " del t489\n", + " t2016 = torch.matmul(t2014, t2015) # t2016: \"cuda:0 f32[16, 16]\"\n", + " # t2016 = ltorch.matmul(t2014, t2015) # t2016: \"cuda:0 f32[16, 16]\"\n", + " # t2016 = prims.matmul(t2014, t2015) # t2016: \"cuda:0 f32[16, 16]\"\n", + " del t2014, t2015\n", + " t2005 = torch.permute(t2002, (1, 0)) # t2005: \"cuda:0 f32[16, 128]\"\n", + " # t2005 = ltorch.permute(t2002, (1, 0)) # t2005: \"cuda:0 f32[16, 128]\"\n", + " # t2005 = prims.transpose(t2002, (1, 0)) # t2005: \"cuda:0 f32[16, 128]\"\n", + " del t2002\n", + " # Created by CPU Offloading Transform\n", + " t494 = load_to_gpu(offloaded_t494, 'cuda:0') # t494: \"cuda:0 f32[128, 16]\"\n", + " t2006 = torch.reshape(t494, (-1, 16)) # t2006: \"cuda:0 f32[128, 16]\"\n", + " # t2006 = ltorch.reshape(t494, (-1, 16)) # t2006: \"cuda:0 f32[128, 16]\"\n", + " # t2006 = prims.reshape(t494, (128, 16)) # t2006: \"cuda:0 f32[128, 16]\"\n", + " del t494\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Benchmark `thunder` vs `thunder + CPU Offloading` on Simple Model" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Allocated Memory after cleaning 0.017047552 GB\n" + ] + } + ], + "source": [ + "model, args, kwargs = get_model_and_args()\n", + "\n", + "measurement_thunder = benchmark(thunder.jit(model), model, args, kwargs)\n", + "measurement_thunder_offload = benchmark(thunder.jit(model, transforms=[CPUOffloading()]), model, args, kwargs)\n", + "\n", + "del model, args, kwargs\n", + "\n", + "clear_memory()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "stmt:\n", + " # Use the optimized model for prediction and backward\n", + " o = jmodel(*args, **kwargs)\n", + " o.sum().backward()\n", + " for param in model.parameters(): # use original model for clear grads\n", + " param.grad = None\n", + "\n", + " 8.50 ms\n", + " 1 measurement, 10 runs , 1 thread" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "measurement_thunder" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "stmt:\n", + " # Use the optimized model for prediction and backward\n", + " o = jmodel(*args, **kwargs)\n", + " o.sum().backward()\n", + " for param in model.parameters(): # use original model for clear grads\n", + " param.grad = None\n", + "\n", + " 12.62 ms\n", + " 1 measurement, 10 runs , 1 thread" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "measurement_thunder_offload" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us try it on a real-life model Llama-3. We will run it on a smaller Llama-3. Feel free to update `N_LAYER` and `BLOCK_SIZE` based on the available device memory.\n", + "\n", + "**NOTE**: Running the cell below requires `litgpt` installed. Use `pip install litgpt` if it is not available." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from litgpt import Config, GPT\n", + "from functools import partial\n", + "from torch.testing import make_tensor\n", + "\n", + "N_LAYER = 9\n", + "BLOCK_SIZE = 1024\n", + "\n", + "def get_model_and_args(batchdims=8):\n", + " with torch.device(\"cuda\"):\n", + " cfg: Config = Config.from_name(\"Llama-3-8B\")\n", + " # Smaller configuration\n", + " cfg.n_layer = N_LAYER\n", + " cfg.block_size = BLOCK_SIZE\n", + "\n", + " model = GPT(cfg)\n", + " make = partial(make_tensor, low=0, high=255, device='cuda', dtype=torch.int64, requires_grad=False)\n", + " shape = (batchdims,) + (cfg.block_size,)\n", + "\n", + " x = make(shape)\n", + " args, kwargs = (x,), {}\n", + "\n", + " return model, args, kwargs, cfg\n", + "\n", + "def print_memory_usage_and_benchmark(name):\n", + " print(f\"{name} took -\")\n", + " model, args, kwargs, cfg = get_model_and_args()\n", + "\n", + " if name == 'thunder':\n", + " jmodel = thunder.jit(model)\n", + " elif name == 'thunder_offload':\n", + " jmodel = thunder.jit(model, transforms=[CPUOffloading()])\n", + " else:\n", + " raise RuntimeError(\"Received invalid value for `name` - try `thunder` or `thunder_offload`.\")\n", + "\n", + " memory_after_model_load = torch.cuda.max_memory_allocated() / 1e9\n", + " print(f\"Peak memory after loading the model : {memory_after_model_load} GB\")\n", + "\n", + " a = jmodel(*args, **kwargs)\n", + "\n", + " memory_after_forward = torch.cuda.max_memory_allocated() / 1e9\n", + " print(f\"Peak memory after forward the model : {memory_after_forward} GB\")\n", + "\n", + " g = torch.rand_like(a)\n", + " actual_grads = torch.autograd.grad(a, model.parameters(), g)\n", + "\n", + " memory_after_backward = torch.cuda.max_memory_allocated() / 1e9\n", + " print(f\"Peak memory after backward the model : {memory_after_backward} GB\")\n", + "\n", + " del a, g, actual_grads # Clear data which is not required for benchmark to free some memory.\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + "\n", + " measurement = benchmark(jmodel, model, args, kwargs)\n", + " print(f\"Benchmark Timings - mean : {measurement.mean} - median {measurement.median}\")\n", + "\n", + " del jmodel, model, cfg, args, kwargs\n", + "\n", + " clear_memory()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "thunder took -\n", + "Peak memory after loading the model : 12.073366016 GB\n", + "Peak memory after forward the model : 38.901342208 GB\n", + "Peak memory after backward the model : 46.245552128 GB\n", + "Benchmark Timings - mean : 5.008525840996299 - median 5.008525840996299\n", + "Allocated Memory after cleaning 0.017047552 GB\n" + ] + } + ], + "source": [ + "# Uncomment this to run the benchmarks.\n", + "# print_memory_usage_and_benchmark(\"thunder\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "thunder_offload took -\n", + "Peak memory after loading the model : 12.073366016 GB\n", + "Peak memory after forward the model : 16.409812992 GB\n", + "Peak memory after backward the model : 35.545775616 GB\n", + "Benchmark Timings - mean : 5.91704241540283 - median 5.91704241540283\n", + "Allocated Memory after cleaning 0.017047552 GB\n" + ] + } + ], + "source": [ + "# Uncomment this to run the benchmarks.\n", + "# print_memory_usage_and_benchmark(\"thunder_offload\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Conclusion\n", + "\n", + "In this notebook, we have understood how to write our own `Transform` in `thunder`. As an example, we wrote an `CPUOffloading` transform to implement CPU offloading technique to decrease peak memory usage during training." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}