diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index c5a44a0cbb..6076b681a4 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -359,3 +359,54 @@ def test_quantization(): assert len(sd) == len(sd2) for k, v in sd.items(): assert_close(v, sd2[k]) + + +@thunder.tests.framework.requiresCUDA +def test_thunderfx_mistral_nemo_small(): + """ + Runs a small version of Mistral-NeMo + + This is largely based on code from Alexandros Koumparoulis. + """ + import transformers + + model_id = "mistralai/Mistral-Nemo-Base-2407" + + # Setup a "small" version of NeMo-Mistral that does not require downloading + # weights. This is not a configuration that is worth benchmarking. + # This was created by using + # MistralConfig(num_hidden_layers=1, max_position_embeddings=1024) + # and then manually diffing that returned object with: + # transformers.AutoConfig.from_pretrained(model_id) + # until they lined up sans the hidden and embeddings changes, above. + config = transformers.models.mistral.configuration_mistral.MistralConfig( + num_hidden_layers=1, + torch_dtype=torch.bfloat16, + max_position_embeddings=1024, + architectures=["MistralForCausalLM"], + hidden_size=5120, + rms_norm_eps=1e-05, + rope_theta=1000000.0, + sliding_window=None, + vocab_size=131072, + head_dim=128, + _name_or_path=model_id, + ) + model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=False) + device = torch.device("cuda") + model.to(device) + model.train() + th_backend = thunder.dynamo.ThunderCompiler() + mdl = torch.compile(model, backend=th_backend) + + batch_size = 1 + iid_size = (batch_size, config.max_position_embeddings) + input_ids = torch.randint(0, config.vocab_size, iid_size, device=device) + attention_mask = torch.ones_like(input_ids) + + output = mdl(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) + logits = output.logits + grad_logits = torch.randn_like(logits) + logits.backward(grad_logits) + + assert th_backend.subgraph_infos, "Should have at least 1 subgraph"