Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notebook: ThunderFX tutorial #1524

Merged
merged 10 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik
Thunder step by step <basic/inspecting_traces>
The sharp edges <basic/sharp_edges>
Train a MLP on MNIST <basic/mlp_mnist>
Hello world ThunderFX <notebooks/hello_world_thunderfx>
FAQ <basic/faq>

.. toctree::
Expand Down
281 changes: 281 additions & 0 deletions notebooks/hello_world_thunderfx.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
{
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \"Hello, World!\" ThunderFX\n",
"\n",
"In this tutorial, we’ll explore how to use ThunderFX to accelerate a PyTorch program.\n",
"\n",
"We’ll cover the basics of ThunderFX, demonstrate how to apply it to PyTorch functions and models, and evaluate its performance in both inference (forward-only) and training (forward and backward)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Getting Started with ThunderFX\n",
"\n",
"Let's see an example of using ThunderFX on a PyTorch function. ThunderFX optimizes the given callable and returns a compiled version of the function. You can then use the compiled function just like you would the original one."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from thunder.dynamo import thunderfx\n",
"\n",
"def foo(x, y):\n",
" return torch.sin(x) + torch.cos(y)\n",
"\n",
"# Compiles foo with ThunderFX\n",
"compiled_foo = thunderfx(foo)\n",
"\n",
"# Creates inputs\n",
"inputs = [torch.randn(4, 4), torch.randn(4, 4)]\n",
"\n",
"eager_results = foo(*inputs)\n",
"# Runs the compiled model\n",
"thunderfx_results = compiled_foo(*inputs)\n",
"\n",
"torch.testing.assert_close(eager_results, thunderfx_results)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ThunderFX supports both CPU and CUDA tensors. However, its primary focus is optimizing CUDA calculations. The following example demonstrates ThunderFX with CUDA tensors:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"# Checks if CUDA is available\n",
"if not torch.cuda.is_available():\n",
" print(\"No suitable GPU detected. Unable to proceed with the tutorial. Cell execution has been stopped.\")\n",
" sys.exit()\n",
"\n",
"\n",
"# Creates inputs\n",
"inputs = [torch.randn(4, 4, device=\"cuda\"), torch.randn(4, 4, device=\"cuda\")]\n",
"\n",
"eager_result = foo(*inputs)\n",
"thunderfx_result = compiled_foo(*inputs)\n",
"\n",
"torch.testing.assert_close(eager_result, thunderfx_result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Performance Optimization with ThunderFX\n",
"\n",
"Next, let’s evaluate how ThunderFX improves performance on a real-world model. We'll use the Llama3 model as an example and compare the execution time for both inference and gradient calculations.\n",
"\n",
"We begin by loading and configuring a smaller version of the Llama3 model:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GPT(\n",
" (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n",
" (transformer): ModuleDict(\n",
" (wte): Embedding(128256, 4096)\n",
" (h): ModuleList(\n",
" (0-1): 2 x Block(\n",
" (norm_1): RMSNorm()\n",
" (attn): CausalSelfAttention(\n",
" (attn): Linear(in_features=4096, out_features=6144, bias=False)\n",
" (proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (post_attention_norm): Identity()\n",
" (norm_2): RMSNorm()\n",
" (mlp): LLaMAMLP(\n",
" (fc_1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (fc_2): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
" )\n",
" (post_mlp_norm): Identity()\n",
" )\n",
" )\n",
" (ln_f): RMSNorm()\n",
" )\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from litgpt import Config, GPT\n",
"from functools import partial\n",
"from torch.testing import make_tensor\n",
"from thunder.dynamo import thunderfx\n",
"\n",
"cfg = Config.from_name(\"Llama-3-8B\")\n",
"\n",
"# Uses a reduced configuration for this tutorial\n",
"cfg.n_layer = 2\n",
"cfg.block_size = 1024\n",
"batch_dim = 4\n",
"\n",
"torch.set_default_dtype(torch.bfloat16)\n",
"make = partial(make_tensor, low=0, high=255, device='cuda', dtype=torch.int64)\n",
"\n",
"with torch.device('cuda'):\n",
" model = GPT(cfg)\n",
" shape = (batch_dim, cfg.block_size)\n",
" x = make(shape)\n",
"\n",
"model "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again we first compile our model and compare the output"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"deviation: 0.015625\n"
]
}
],
"source": [
"compiled_model = thunderfx(model)\n",
"thunderfx_result = compiled_model(x)\n",
"eager_result = model(x)\n",
"print(\"deviation:\", (thunderfx_result - eager_result).abs().max().item())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: ThunderFX compiles the model into optimized kernels as it executes. Compiling these kernels can take seconds or even minutes for larger models, but each kernel only has to be compiled once, and subsequent runs will benefit from it.\n",
"\n",
"To evaluate ThunderFX’s inference performance, we compare the execution time of the compiled model versus the standard PyTorch model:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ThunderFX Inference Time:\n",
"66.7 ms ± 289 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"Torch Eager Inference Time:\n",
"72.2 ms ± 287 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"# Clears data to free some memory.\n",
"del thunderfx_result, eager_result\n",
"import gc\n",
"gc.collect()\n",
"torch.cuda.empty_cache()\n",
"\n",
"# Measures inference time\n",
"print(\"ThunderFX Inference Time:\")\n",
"%timeit r = compiled_model(x); torch.cuda.synchronize()\n",
"print(\"Torch Eager Inference Time:\")\n",
"%timeit r = model(x); torch.cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly, let’s measure the performance improvement for training:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ThunderFX Training Time:\n",
"197 ms ± 5.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"Torch Eager Training Time:\n",
"213 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"print(\"ThunderFX Training Time:\")\n",
"%timeit r = compiled_model(x); r.sum().backward(); torch.cuda.synchronize()\n",
"print(\"Torch Eager Training Time:\")\n",
"%timeit r = model(x); r.sum().backward(); torch.cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Conclusion\n",
"\n",
"ThunderFX can accelerate PyTorch programs, particularly CUDA programs. By compiling optimized kernels specific to the program you're running. It can accelerate both inference (forward-only) and training (forward and backward) computations.\n",
"\n",
"For more information about Thunder and ThunderFX in particular, see https://github.com/Lightning-AI/lightning-thunder/tree/main/notebooks."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading