Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test HF's implementation of Qwen 2 model #1406

Merged
merged 8 commits into from
Nov 13, 2024
45 changes: 45 additions & 0 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,51 @@ def test_thunderfx_mistral_nemo_small():
assert th_backend.subgraph_infos, "Should have at least 1 subgraph"


@thunder.tests.framework.requiresCUDA
def test_hf_qwen2():
from thunder.dynamo import ThunderCompiler
from transformers import Qwen2Config, Qwen2ForCausalLM

# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json
configuration = Qwen2Config(
# Qwen2.5-7B-Instruct uses Grouped-Query Attention, while the default
# config uses Multi-Head Attention
num_attention_heads=28,
num_key_value_heads=4,
# Scaled down for testing
hidden_size=56,
vocab_size=16,
max_position_embeddings=32,
)
configuration.num_hidden_layers = 1
with torch.device("cuda"):
model = Qwen2ForCausalLM(configuration).to(torch.bfloat16)

# thunder.jit doesn't work with Qwen2, so we use torch.compile
# https://github.com/Lightning-AI/lightning-thunder/issues/1405
backend = ThunderCompiler()
compiled_model = torch.compile(model, backend=backend, fullgraph=True)

input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda")
ref_output = model(input_ids=input_ids, labels=input_ids)
ref_loss = ref_output.loss

compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
compiled_loss = compiled_output.loss

# Less strict tolerance probably due to different type promotion order for bfloat16
# TODO: Investigate why the loss is different
# https://github.com/Lightning-AI/lightning-thunder/issues/1407
torch.testing.assert_close(compiled_loss, ref_loss, rtol=1e-4, atol=1e-4)

assert len(backend.subgraph_infos) == 1, "Should have exactly 1 subgraph because of fullgraph=True"
loss_grad = torch.randn_like(compiled_loss)

grads_ref = torch.autograd.grad(ref_loss, model.parameters(), grad_outputs=loss_grad)
grads_compiled = torch.autograd.grad(compiled_loss, model.parameters(), grad_outputs=loss_grad)
torch.testing.assert_close(grads_ref, grads_compiled, rtol=1e-2, atol=1e-2)


LLAMA_3_2_1B_CFG = {
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
Expand Down
Loading