Skip to content

Commit

Permalink
add litgpt fabric test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali Alshaarawy committed Dec 12, 2024
1 parent e2aacb9 commit 246e8a8
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,32 @@ def fn(a):
a = torch.randn(3)
jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,))
assert_close(jfn(a), fn(a))

@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported")
@requiresCUDA
@pytest.mark.skipif(not device_supports_bf16(torch.device("cuda")), reason="bf16 is not supported")
def test_litgpt_fabric_for_callable():
from typing import Any, Callable, Optional, Tuple, Union, List, Dict
from litgpt.model import Config, GPT
import torch.nn as nn

def jit(fn: Callable, executors: List[str]) -> Any:
assert executors is not None
return thunder.jit(fn, executors=executors)

def forward_and_loss(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor:
logits = model(input_ids)
return logits

forward_and_loss_jitted = jit(forward_and_loss, executors=("sdpa", "torchcompile", "nvfuser", "torch"))

config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)

with torch.device("cuda"):
model = GPT(config)

input_ids = torch.zeros(1, 2, dtype=torch.int64, device="cuda")
out = forward_and_loss(model, input_ids)
out_jitted = forward_and_loss_jitted(model, input_ids)

assert_close(out, out_jitted)

0 comments on commit 246e8a8

Please sign in to comment.