diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index fb8b6f4546..36ba5c3cd3 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -400,6 +400,51 @@ def test_thunderfx_mistral_nemo_small(): assert th_backend.subgraph_infos, "Should have at least 1 subgraph" +@thunder.tests.framework.requiresCUDA +def test_hf_qwen2(): + from thunder.dynamo import ThunderCompiler + from transformers import Qwen2Config, Qwen2ForCausalLM + + # https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + configuration = Qwen2Config( + # Qwen2.5-7B-Instruct uses Grouped-Query Attention, while the default + # config uses Multi-Head Attention + num_attention_heads=28, + num_key_value_heads=4, + # Scaled down for testing + hidden_size=56, + vocab_size=16, + max_position_embeddings=32, + ) + configuration.num_hidden_layers = 1 + with torch.device("cuda"): + model = Qwen2ForCausalLM(configuration).to(torch.bfloat16) + + # thunder.jit doesn't work with Qwen2, so we use torch.compile + # https://github.com/Lightning-AI/lightning-thunder/issues/1405 + backend = ThunderCompiler() + compiled_model = torch.compile(model, backend=backend, fullgraph=True) + + input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda") + ref_output = model(input_ids=input_ids, labels=input_ids) + ref_loss = ref_output.loss + + compiled_output = compiled_model(input_ids=input_ids, labels=input_ids) + compiled_loss = compiled_output.loss + + # Less strict tolerance probably due to different type promotion order for bfloat16 + # TODO: Investigate why the loss is different + # https://github.com/Lightning-AI/lightning-thunder/issues/1407 + torch.testing.assert_close(compiled_loss, ref_loss, rtol=1e-4, atol=1e-4) + + assert len(backend.subgraph_infos) == 1, "Should have exactly 1 subgraph because of fullgraph=True" + loss_grad = torch.randn_like(compiled_loss) + + grads_ref = torch.autograd.grad(ref_loss, model.parameters(), grad_outputs=loss_grad) + grads_compiled = torch.autograd.grad(compiled_loss, model.parameters(), grad_outputs=loss_grad) + torch.testing.assert_close(grads_ref, grads_compiled, rtol=1e-2, atol=1e-2) + + LLAMA_3_2_1B_CFG = { "architectures": ["LlamaForCausalLM"], "attention_bias": False,