-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
244 additions
and
308 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## \"Hello, World!\" ThunderFX\n", | ||
"\n", | ||
"In this tutorial, we’ll explore how to use ThunderFX to accelerate PyTorch program.\n", | ||
"\n", | ||
"In this guide, we’ll explore the basics of ThunderFX, how to apply it to PyTorch functions and models, and evaluate its performance in both inference and gradient calculations." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Getting Started with ThunderFX\n", | ||
"\n", | ||
"Let's see an example of using ThunderFX on a PyTorch function. ThunderFX optimizes the given callable and returns a compiled version of the function. You can then use the compiled function just like you would the original one." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/wayan/lightning-thunder/thunder/dynamo/compiler.py:24: UserWarning: The ThunderCompiler is in active development and may not work as expected. Please report any issues you encounter to the Lightning Thunder team.\n", | ||
" warnings.warn(\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<function foo at 0x7d1508a37600>\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"from thunder.dynamo import thunderfx\n", | ||
"\n", | ||
"def foo(x, y):\n", | ||
" return torch.sin(x) + torch.cos(y)\n", | ||
"\n", | ||
"# Compiles foo with ThunderFX\n", | ||
"compiled_foo = thunderfx(foo)\n", | ||
"\n", | ||
"# Creates inputs\n", | ||
"inputs = [torch.randn(4, 4), torch.randn(4, 4)]\n", | ||
"\n", | ||
"eager_results = foo(*inputs)\n", | ||
"# Runs the compiled model as you normally would\n", | ||
"thunderfx_results = compiled_foo(*inputs)\n", | ||
"\n", | ||
"torch.testing.assert_close(eager_results, thunderfx_results)\n", | ||
"print(compiled_foo)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"ThunderFX supports both CPU and CUDA tensors. However, its primary focus is optimizing CUDA calculations. The following example demonstrates ThunderFX with CUDA tensors:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys\n", | ||
"\n", | ||
"# Checks if CUDA is available\n", | ||
"if not torch.cuda.is_available():\n", | ||
" print(\"No suitable GPU detected. Unable to proceed with the tutorial. Cell execution has been stopped.\")\n", | ||
" sys.exit()\n", | ||
"\n", | ||
"\n", | ||
"# Creates inputs\n", | ||
"inputs = [torch.randn(4, 4, device=\"cuda\"), torch.randn(4, 4, device=\"cuda\")]\n", | ||
"\n", | ||
"eager_result = foo(*inputs)\n", | ||
"thunderfx_result = compiled_foo(*inputs)\n", | ||
"\n", | ||
"torch.testing.assert_close(eager_result, thunderfx_result)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Performance Optimization with ThunderFX\n", | ||
"\n", | ||
"Next, let’s evaluate how ThunderFX improves performance on a real-world model. We'll use the Llama3 model as an example and compare the execution time for both inference and gradient calculations.\n", | ||
"\n", | ||
"We begin by loading and configuring a lightweight version of the Llama3 model:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from litgpt import Config, GPT\n", | ||
"from functools import partial\n", | ||
"from torch.testing import make_tensor\n", | ||
"from thunder.dynamo import thunderfx\n", | ||
"\n", | ||
"cfg = Config.from_name(\"Llama-3-8B\")\n", | ||
"\n", | ||
"# we use fewer layer and smaller block size to , you could accomandate the configurations according to the gpu you're using\n", | ||
"cfg.n_layer = 2 # fewer layers\n", | ||
"cfg.block_size = 1024\n", | ||
"batch_dim = 8\n", | ||
"torch.set_default_dtype(torch.bfloat16)\n", | ||
"make = partial(make_tensor, low=0, high=255, device='cuda', dtype=torch.int64, requires_grad=False)\n", | ||
"\n", | ||
"with torch.device('cuda'):\n", | ||
" model = GPT(cfg)\n", | ||
" shape = (batch_dim, cfg.block_size)\n", | ||
" x = make(shape)\n", | ||
"\n", | ||
"model " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Again we first compile our model and compare the output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"compiled_model = thunderfx(model)\n", | ||
"thunderfx_result = compiled_model(x)\n", | ||
"eager_result = model(x)\n", | ||
"print(\"deviation:\", (thunderfx_result - eager_result).abs().max().item())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Note: ThunderFX compiles the model into optimized kernels as it executes. This means the first run may take longer due to the compilation process, but subsequent runs will benefit from significant speedups.\n", | ||
"\n", | ||
"To evaluate ThunderFX’s inference performance, we compare the execution time of the compiled model versus the standard PyTorch model:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"137 ms ± 18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | ||
"161 ms ± 1.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Clear data which is not required for benchmark to free some memory.\n", | ||
"del thunderfx_result, eager_result\n", | ||
"import gc\n", | ||
"gc.collect()\n", | ||
"torch.cuda.empty_cache()\n", | ||
"\n", | ||
"%timeit r = compiled_model(x); torch.cuda.synchronize()\n", | ||
"%timeit r = model(x); torch.cuda.synchronize()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Similarly, let’s measure the performance improvement for gradient calculations:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"443 ms ± 4.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | ||
"481 ms ± 2.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%timeit r = compiled_model(x); torch.autograd.grad(r.sum(), model.parameters()); torch.cuda.synchronize()\n", | ||
"%timeit r = model(x); torch.autograd.grad(r.sum(), model.parameters()); torch.cuda.synchronize()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Conclusion\n", | ||
"\n", | ||
"ThunderFX provides an efficient way to accelerate PyTorch programs, particularly for GPU workloads. By compiling functions and models, it reduces runtime for both inference and gradient computations. This tutorial demonstrated its usage and performance benefits using both simple functions and a real-world model." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.