From 246e8a8c22b8c97b497d07fa6b62b3a66d9c0c38 Mon Sep 17 00:00:00 2001 From: Ali Alshaarawy Date: Thu, 12 Dec 2024 14:21:56 +0000 Subject: [PATCH] add litgpt fabric test --- thunder/tests/test_torch_compile_executor.py | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index 6560dddbc8..545c3c05d6 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -90,3 +90,32 @@ def fn(a): a = torch.randn(3) jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,)) assert_close(jfn(a), fn(a)) + +@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") +@requiresCUDA +@pytest.mark.skipif(not device_supports_bf16(torch.device("cuda")), reason="bf16 is not supported") +def test_litgpt_fabric_for_callable(): + from typing import Any, Callable, Optional, Tuple, Union, List, Dict + from litgpt.model import Config, GPT + import torch.nn as nn + + def jit(fn: Callable, executors: List[str]) -> Any: + assert executors is not None + return thunder.jit(fn, executors=executors) + + def forward_and_loss(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor: + logits = model(input_ids) + return logits + + forward_and_loss_jitted = jit(forward_and_loss, executors=("sdpa", "torchcompile", "nvfuser", "torch")) + + config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) + + with torch.device("cuda"): + model = GPT(config) + + input_ids = torch.zeros(1, 2, dtype=torch.int64, device="cuda") + out = forward_and_loss(model, input_ids) + out_jitted = forward_and_loss_jitted(model, input_ids) + + assert_close(out, out_jitted)