Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def _resursively_swap_linear_layers_for_te(module: torch.nn.Module) -> None:

if isinstance(m, torch.nn.Linear):
has_bias = m.bias is not None
new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=device)
# Pass device as str (as there is a bug in TransformerEngine's handling of torch.device)
new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=str(device))
setattr(module, n, new_linear)

if swap_layernorm and isinstance(m, torch.nn.LayerNorm):
new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=device)
# Pass device as str (as there is a bug in TransformerEngine's handling of torch.device)
new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=str(device))
setattr(module, n, new_layernorm)

initial_params_cnt = parameters_cnt(model)
Expand Down Expand Up @@ -366,11 +368,6 @@ def __init__(
self.model = self.init_model()
print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")

if self.use_te_fp8_autocast:
is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm"
swap_linear_layers_for_te(self.model, device, swap_layernorm=not is_wo_layernorm)
self.model.to(torch.bfloat16)

# Setup the distributed algorithm choices
if distributed_first := (self.compile in ("eager", "inductor") or "dynamo" in self.compile):
self.model = self.setup_distributed(self.model)
Expand Down Expand Up @@ -407,8 +404,14 @@ def init_model(self):
init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device
with init_device:
model = GPT(self.config)
model.to(dtype=torch.bfloat16)

# Handle fp8 related Linear layer swapping (for torchao or TransformerEngine)
model = self._torchao_fp8_handler.convert_model_to_fp8(model)
if self.use_te_fp8_autocast:
is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm"
swap_linear_layers_for_te(model, init_device, swap_layernorm=not is_wo_layernorm)

model.to(dtype=torch.bfloat16)
return model

def setup_distributed(self, model):
Expand Down