Skip to content

Commit

Permalink
Enable autocast for llama2.c example (bf16) (PR1771)
Browse files Browse the repository at this point in the history
Co-authored-by: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
  • Loading branch information
2 people authored and Borda committed Mar 20, 2024
1 parent 296f04d commit 9536cd7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 35 deletions.
21 changes: 10 additions & 11 deletions examples/llama2.c/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ The code is configured to run with Thunder by default.

Results with 1 GPU:

- ~339 ms/iter (torch.compile 'inductor')
- ~347 ms/iter (thunder nvfuser)
- ~431 ms/iter (eager)
- ~215 ms/iter (torch.compile 'inductor')
- ~239 ms/iter (thunder nvfuser)
- ~339 ms/iter (eager)

CUDAGraphs are not used as the results were worse with them.

Expand All @@ -46,15 +46,14 @@ nanoGPT doesn't implement KV caching so this is expectedly slow. Please checkout
## Setup

```text
Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Is debug build: False
CUDA used to build PyTorch: 12.1
CUDA runtime version: 12.1.105
CUDA used to build PyTorch: 12.4
CUDA runtime version: 12.4.99
GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 525.125.06
Nvidia driver version: 550.54.14
pytorch-triton @ https://download.pytorch.org/whl/nightly/pytorch_triton-3.0.0%2B901819d2b6-cp310-cp310-linux_x86_64.whl
torch @ https://download.pytorch.org/whl/nightly/cu121/torch-2.3.0.dev20240130%2Bcu121-cp310-cp310-linux_x86_64.whl
lightning-thunder==8b107c6fe531c94c6705dbf39700863685ba5b65
nvfuser_cu121==0.1.5.dev20240131
triton == 3.0.0
torch == 2.4.0a0+git685ace3
nvfuser @ 0.2.0+git70101da
```
2 changes: 1 addition & 1 deletion examples/llama2.c/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from thunder.executors.sdpaex import sdpa_ex

executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]
cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors_list=executors)
cmodel = thunder.jit(model, disable_torch_autograd_support=True, executors=executors)
# the generate implementation is not compile friendly, so bind the compiled model to the generate implementation
generate = partial(Transformer.generate, cmodel)
# workaround for "Foward nn.Module attributes through the ThunderOptimizedModule"
Expand Down
44 changes: 25 additions & 19 deletions examples/llama2.c/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@
warmup_iters = 1000 # how many steps to warm up for
# system
device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
# dtype = "bfloat16" # float32|bfloat16|float16
compile = "thunder" # eager|torch|thunder
dtype = "bfloat16" # float32|bfloat16|float16
compile = "thunder" # thunder|torch|eager
# -----------------------------------------------------------------------------
config_keys = [
k
Expand Down Expand Up @@ -122,8 +122,15 @@
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
# ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
ctx = nullcontext() # torch.amp.autocast(device_type=device_type, dtype=ptdtype)
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
# Disable other than FlashAttention backends for SDPA
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

# task-specific setup
iter_batches = partial(
Expand Down Expand Up @@ -179,30 +186,31 @@
model.load_state_dict(state_dict)
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]

model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(False)) # dtype == "float16"))
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == "resume" and "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
checkpoint = None # free up memory

raw_model = eval_model = train_model = model
raw_model = model

# wrap model into DDP container
if ddp:
if compile == "thunder":
from thunder.distributed import ddp

train_model = ddp(train_model)
model = ddp(model)
else:
# Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
# construction time since NCCL does not support `ComplexFloat`
train_model._ddp_params_and_buffers_to_ignore = {"freqs_cis"}
train_model = DDP(train_model, device_ids=[ddp_local_rank])
model._ddp_params_and_buffers_to_ignore = {"freqs_cis"}
model = DDP(model, device_ids=[ddp_local_rank])

# compile the model
if compile == "thunder":
Expand All @@ -212,31 +220,29 @@
from thunder.executors.sdpaex import sdpa_ex
executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]

eval_model = thunder.compile(eval_model.eval(), disable_torch_autograd_support=True, executors_list=executors)
train_model = thunder.compile(train_model.train(), executors_list=executors)
model = thunder.jit(model, executors=executors)
elif compile == "torch":
print("compiling the model with torch... (takes a ~minute)")
eval_model = torch.compile(eval_model)
train_model = torch.compile(train_model)
model = torch.compile(model)

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
out = {}
if compile != "thunder":
eval_model.eval()
model.eval()
for split in ["train", "val"]:
batch_iter = iter_batches(split=split)
losses = torch.zeros(eval_iters) # keep on CPU
for k in range(eval_iters):
X, Y = next(batch_iter)
with ctx:
logits = eval_model(X, Y)
logits = model(X, Y)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1)
losses[k] = loss.item()
out[split] = losses.mean()
if compile != "thunder":
train_model.train()
model.train()
return out

# learning rate decay scheduler (cosine with warmup)
Expand Down Expand Up @@ -313,9 +319,9 @@ def get_lr(it):
# the official way to do this is with model.no_sync() context manager, but
# this forces us to repeat code.
# looking at the source of that context manager, it just toggles this variable
train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
with ctx:
logits = train_model(X, Y)
logits = model(X, Y)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1)
loss = loss / gradient_accumulation_steps
# immediately async prefetch next batch while model is doing the forward pass on the GPU
Expand All @@ -325,7 +331,7 @@ def get_lr(it):
# clip the gradient
if grad_clip != 0.0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(train_model.parameters(), grad_clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
Expand Down
11 changes: 7 additions & 4 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3984,10 +3984,13 @@ def decorator(func):

def maybe_downcast_to(dtype, args):
allowed_downcast_types = (dtypes.float16, dtypes.bfloat16, dtypes.float32)
if all(tree_map(lambda a: a.dtype in allowed_downcast_types, args)):
return tree_map(lambda a: maybe_convert_to_dtype(a, dtype), args)
else:
return args

def map_fn(a):
if isinstance(a, TensorProxy) and a.dtype in allowed_downcast_types:
return maybe_convert_to_dtype(a, dtype)
return a

return tree_map(map_fn, args)


@register_autocast_rule("torch.matmul")
Expand Down

0 comments on commit 9536cd7

Please sign in to comment.