Skip to content

Trainable transformer with fwd+bwd ops in Triton, matching the performance of PyTorch + cuDNN/cuBLAS

License

Notifications You must be signed in to change notification settings

fattorib/tritonformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tritonformer

A GPT2-style transformer in Triton. Differentiable with forward + backward speeds matching that of a PyTorch transformer using cuBLAS kernels and flash-attention.

Supports Autograd and Automatic Mixed Precision.

Kernels

Triton kernels definitions are located under tritonformer/kernels. Autograd functions wrapping kernels are located under tritonformer.

Kernels implemented:

Attention

  • Fused Attention kernels/attention.py

  • Fused Attention w/ ALiBi kernels/biased_attention.py

GEMM

  • MatMul: (b,m,n) @ (n,k) -> (b,m,k) kernels/gemm.py
  • Fused MatMul + ReLU: (b,m,n) @ (n,k) -> (b,m,k) kernels/gemm.py
  • Fused MatMul + Add: (b,m,n) @ (n,k)-> (b,m,k) + (k,) kernels/gemm.py
  • Fused MatMul + Add + ReLU: (b,m,n) @ (n,k)-> (b,m,k) + (k,) kernels/gemm.py
  • Batched MatMul: (b,m,n) @ (b,n,k) -> (b,m,k) kernels/gemm.py

Losses

  • CrossEntropy kernels/crossentropy.py

Activations/Normalization

  • LayerNorm kernels/layernorm.py
  • Softmax kernels/softmax.py

Use

Clone repo and install requirements:

git clone https://github.com/fattorib/tritonformer.git
pip install -r requirements.txt
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.dev20231014192330

PyTorch transformer models using Triton and cuBLAS kernels are provided under transformer_triton.py and transformer_torch.py. For example to use the model with Triton kernels, add:

...
from transformer import Transformer, TransformerConfig
config = TransformerConfig(...) # create config
model = Transformer(config) # create model
...

Train like you would any other model!

Benchmarks

Benchmark code is quite simple:

# imports and instantiate model
...
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for step in range(100):
  with torch.cuda.amp.autocast():
    logits, loss = model(batch, batch)
  loss.backward()

torch.cuda.synchronize()
end.record()
time_s = start.elapsed_time(end) / 1e3
# convert time_s to TFLOPs

Setup: Benchmarks are performed on an RTX 3090 with torch==2.0.0 and triton-nightly==2.1.0.dev20231014192330. Flash Attention from PyTorch is enabled with the torch.backends.cuda.sdp_kernel context manager.

Model Size Context Tritonformer (TFLOPs) PyTorch (TFLOPs) Config
410M 2048 50.1719 50.2183 bs=4,use_linear_bias = False,attn_bias = False
410M 4096 42.2310 45.5872 bs=2,use_linear_bias = False,attn_bias = False
840M 2048 52.6596 52.2332 bs=2,use_linear_bias = False,attn_bias = False
840M 4096 45.8857 48.4785 bs=1,use_linear_bias = False,attn_bias = False
1.4B 2048 55.6311 55.8419 bs=2,use_linear_bias = False,attn_bias = False

Sample Loss Curves

On a small 160M parameter model, end-to-end training1 with Tritonformer achieves pairity with the equivalent PyTorch model:

Tests

Tests are handled by PyTest. Run pytest tritonformer/ to run all tests. Note: There are a lot of tests so running them all can take a while!

Acknowledgments

  • Parts of the PyTorch transformer code were originally based off of zphang/minimal-opt

  • Some earlier forward-pass kernels were written as I followed along through the Triton Tutorials

Future Extensions

I'm currently happy with the state of this project but given more time, here are a few extensions I'd like to do:

  • The weight backward pass for matrix multiplication requires a reduction over the batch dimension after performing bmm(activation.T, grad_output)2. It should be possible to fuse this reduction into the matmul kernel.
  • Indexing support in Triton is difficult. As such token and position embeddings use nn.Embedding instead of a custom kernel.
  • Extend Fused Attention to support Flash Attention 2.

Footnotes

  1. Run parameters: Dataset: MiniPile, Context: 2048 Tokens (Neox Tokenizer), Batch Size: 64, FP16 Autocast Enabled, RTX 3090

  2. The full backward pass is: grad_weight = reduce(bmm(activation.T, grad_output), 'b r c -> r c').

About

Trainable transformer with fwd+bwd ops in Triton, matching the performance of PyTorch + cuDNN/cuBLAS

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages