Skip to content

Commit

Permalink
Add a test for Mistral-NeMo. (#1340)
Browse files Browse the repository at this point in the history
  • Loading branch information
tfogal authored Oct 31, 2024
1 parent a2587e2 commit 908be57
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,54 @@ def test_quantization():
assert len(sd) == len(sd2)
for k, v in sd.items():
assert_close(v, sd2[k])


@thunder.tests.framework.requiresCUDA
def test_thunderfx_mistral_nemo_small():
"""
Runs a small version of Mistral-NeMo
This is largely based on code from Alexandros Koumparoulis.
"""
import transformers

model_id = "mistralai/Mistral-Nemo-Base-2407"

# Setup a "small" version of NeMo-Mistral that does not require downloading
# weights. This is not a configuration that is worth benchmarking.
# This was created by using
# MistralConfig(num_hidden_layers=1, max_position_embeddings=1024)
# and then manually diffing that returned object with:
# transformers.AutoConfig.from_pretrained(model_id)
# until they lined up sans the hidden and embeddings changes, above.
config = transformers.models.mistral.configuration_mistral.MistralConfig(
num_hidden_layers=1,
torch_dtype=torch.bfloat16,
max_position_embeddings=1024,
architectures=["MistralForCausalLM"],
hidden_size=5120,
rms_norm_eps=1e-05,
rope_theta=1000000.0,
sliding_window=None,
vocab_size=131072,
head_dim=128,
_name_or_path=model_id,
)
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=False)
device = torch.device("cuda")
model.to(device)
model.train()
th_backend = thunder.dynamo.ThunderCompiler()
mdl = torch.compile(model, backend=th_backend)

batch_size = 1
iid_size = (batch_size, config.max_position_embeddings)
input_ids = torch.randint(0, config.vocab_size, iid_size, device=device)
attention_mask = torch.ones_like(input_ids)

output = mdl(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
logits = output.logits
grad_logits = torch.randn_like(logits)
logits.backward(grad_logits)

assert th_backend.subgraph_infos, "Should have at least 1 subgraph"

0 comments on commit 908be57

Please sign in to comment.