Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a test for Mistral-NeMo. #1340

Merged
merged 2 commits into from
Oct 31, 2024
Merged

Add a test for Mistral-NeMo. #1340

merged 2 commits into from
Oct 31, 2024

Conversation

tfogal
Copy link
Collaborator

@tfogal tfogal commented Oct 21, 2024

See issue #1285.

What does this PR do?

Adds a test case for #1285.

@tfogal tfogal requested a review from IvanYashchuk October 21, 2024 23:01
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have https://github.com/Lightning-AI/lightning-thunder/blob/79e59d0c5c5f8aa8ef80eb31f3fe918466d64c1c/thunder/tests/test_networks.py, I thought it's nice to add this to the file instead of the new one

thunder/tests/test_mistral_nemo.py Outdated Show resolved Hide resolved
thunder/tests/test_mistral_nemo.py Outdated Show resolved Hide resolved
@tfogal
Copy link
Collaborator Author

tfogal commented Oct 22, 2024

We already have https://github.com/Lightning-AI/lightning-thunder/blob/79e59d0c5c5f8aa8ef80eb31f3fe918466d64c1c/thunder/tests/test_networks.py, I thought it's nice to add this to the file instead of the new one

thanks, good idea, will do

@tfogal
Copy link
Collaborator Author

tfogal commented Oct 22, 2024

Latest version moves the code where it belongs, in test_networks.py, as Masaki pointed out.

I also rebuilt my container and that made it clear that there's a couple difficulties with this:

  • I needed to turn trust_remote_code on to get the tiny shakespeare dataset to work
  • One needs to have an active login to huggingface's hub to grab configurations
  • This requires some dependencies (transformers, datasets) that we wouldn't otherwise depend on.

As such I made this skipped by default. I still think it makes sense to store as a thunder test, as that's a logical source of truth for the whole team to work on. But I invite discussion.

@tfogal tfogal marked this pull request as ready for review October 22, 2024 18:28
@tfogal
Copy link
Collaborator Author

tfogal commented Oct 23, 2024

@t-vi this is ready for review.

The CI failure was a node timing out running the tests, even though all this does is add a single, skipped test; seems like there's more going on there. Do you want me to add empty commits until this happens to pass or can you override that?

@lantiga
Copy link
Collaborator

lantiga commented Oct 26, 2024

thank you @tfogal

I don’t think we need tiny shakespeare to be downloaded, we can get away with something simpler or random (just like with other examples), this way we don’t even have to get the tokenizer

do we actually need credentials for the configs or is it the tokenizer checkpoint that requires it?

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Oct 28, 2024

do we actually need credentials for the configs or is it the tokenizer checkpoint that requires it?

There should be a way to avoid using any credentials since the tokenizer is unnecessary and we don't load any weights for the config.

One needs to have an active login to huggingface's hub to grab configurations

But is downloading anything required here? The configuration is defined directly with transformers.models.mistral.configuration_mistral.MistralConfig.

Usually, we check Thunder's ability to run a network with a sample random input and then invoke backward. A mock training loop is unnecessary here and can be avoided. Here's a patch that would verify that Thunder successfully runs the model without any dataset, optimizers, etc:

diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py
index 57a9a759..9bf0cb57 100644
--- a/thunder/tests/test_networks.py
+++ b/thunder/tests/test_networks.py
@@ -362,7 +362,7 @@ def test_quantization():
 
 
 @thunder.tests.framework.requiresCUDA
-@pytest.mark.skip(reason="Dependencies, trust issues")
+# @pytest.mark.skip(reason="Dependencies, trust issues")
 def test_thunderfx_mistral_nemo_small():
     """
     Runs a small version of Mistral-NeMo
@@ -370,17 +370,9 @@ def test_thunderfx_mistral_nemo_small():
     This is largely based on code from Alexandros Koumparoulis.
     """
     import transformers
-    import datasets
 
     model_id = "mistralai/Mistral-Nemo-Base-2407"
 
-    tokenizer = transformers.AutoTokenizer.from_pretrained(
-        model_id,
-        torch_dtype=torch.bfloat16,
-        ignore_mismatched_sizes=True,
-        trust_remote_code=False,
-    )
-
     # Setup a "small" version of NeMo-Mistral that does not require downloading
     # weights. This is not a configuration that is worth benchmarking.
     # This was created by using
@@ -389,7 +381,7 @@ def test_thunderfx_mistral_nemo_small():
     #   transformers.AutoConfig.from_pretrained(model_id)
     # until they lined up.
     config = transformers.models.mistral.configuration_mistral.MistralConfig(
-        num_hidden_layers=2,
+        num_hidden_layers=1,
         torch_dtype=torch.bfloat16,
         max_position_embeddings=1024,
         architectures=["MistralForCausalLM"],
@@ -404,53 +396,18 @@ def test_thunderfx_mistral_nemo_small():
     model = transformers.AutoModelForCausalLM.from_config(config)
     device = torch.device("cuda")
     model.to(device)
-    mdl = torch.compile(model, backend=thunder.dynamo.ThunderCompiler())
+    backend = thunder.dynamo.ThunderCompiler()
+    mdl = torch.compile(model, backend=backend)
     del model
 
-    # Add a padding token to the tokenizer
-    if tokenizer.pad_token is None:
-        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
-        mdl.resize_token_embeddings(len(tokenizer))
-
-    dataset = datasets.load_dataset("tiny_shakespeare", split="train", trust_remote_code=True)
-
-    def tokenize_function(examples):
-        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=2)
+    batch_size = 1
+    input_ids = torch.randint(0, config.vocab_size, (batch_size, config.max_position_embeddings), device=device)
+    attention_mask = torch.ones_like(input_ids)
 
-    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
-
-    # Convert the dataset to PyTorch format and specify columns to return as tensors
-    tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
-
-    dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=1, shuffle=True)
-
-    # Define optimizer and learning rate scheduler
-    optimizer = torch.optim.AdamW(mdl.parameters(), lr=5e-5)
-    num_epochs = 3
-    lr_scheduler = transformers.get_scheduler(
-        "linear",
-        optimizer=optimizer,
-        num_warmup_steps=0,
-        num_training_steps=num_epochs * len(dataloader),
-    )
+    output = mdl(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
+    logits = output.logits
+    grad_logits = torch.randn_like(logits)
+    logits.backward(grad_logits)
 
-    mdl.train()
-    for epoch in range(num_epochs):
-        total_loss = 0
-        for batch in dataloader:
-            # Move input tensors to device
-            input_ids = batch["input_ids"].to(device)
-            attention_mask = batch["attention_mask"].to(device)
-
-            # Forward pass
-            outputs = mdl(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
-            loss = outputs.loss
-            total_loss += loss.item()
-
-            # Backward pass
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
-
-            # Update learning rate
-            lr_scheduler.step()
+    # Check that Thunder has actually compiled the model
+    assert backend.subgraph_infos, "No subgraphs found"

@tfogal
Copy link
Collaborator Author

tfogal commented Oct 28, 2024

There should be a way to avoid using any credentials

Thanks Luca, Ivan. I've applied Ivan's patch + some minor other changes, and indeed it appears to not download anything new (into ~/.cache/huggingface, at least), now.

Unfortunately in the interim something to seems to have tickled things so that #1240 is now a blocking issue, so I am leaving the skip designation for now :-(

@kshitij12345
Copy link
Collaborator

#1240 has been fixed just now, so we can probably remove the skip.

@tfogal
Copy link
Collaborator Author

tfogal commented Oct 29, 2024

hrm my merge of main seemed to have not gone well. meetings now but i will fix after...

See issue #1285.

Thanks:
	Alexandros Koumparoulis
	Ivan Yashchuk
	Masaki Kozuki
for various fixes/guidance. + Kshiteej Kalambarkar for fixing 1240.
@tfogal tfogal force-pushed the tfogal/nemo-test-case branch from 0d438dd to d77bc05 Compare October 29, 2024 18:40
@tfogal
Copy link
Collaborator Author

tfogal commented Oct 29, 2024

Hi, sorry for the weirdness. I couldn't figure it out why the github diff was wild even though git diff main..tfogal/nemo-test-case was sane, so I just ended up rebasing it.

It's pretty tiny (thanks to Ivan's patch), so hopefully it's not too painful to review from scratch again.

@tfogal
Copy link
Collaborator Author

tfogal commented Oct 29, 2024

CI failure is real: Found two different const extents in the same set: { bS39{1}; bS43{1}; bS66{1}; bS70{1 ex 8}; bS50{1}; bS54{1 ex 32} } nvFuser issue.

However, it works fine with nvFuser e33316d9480508b49db788a7472f4df52e53af92. Do we need an nvFuser upgrade for this to work?

@tfogal tfogal mentioned this pull request Oct 30, 2024
@t-vi t-vi enabled auto-merge (squash) October 31, 2024 13:42
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The slimmed-down version looks good.
Thannk you @tfogal @IvanYashchuk @lantiga @crcrpar @kshitij12345

@t-vi t-vi merged commit 908be57 into main Oct 31, 2024
41 checks passed
@t-vi t-vi deleted the tfogal/nemo-test-case branch October 31, 2024 18:15
# until they lined up sans the hidden and embeddings changes, above.
config = transformers.models.mistral.configuration_mistral.MistralConfig(
num_hidden_layers=1,
torch_dtype=torch.bfloat16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This setting is ignored by the model instantiation. It can be checked by inspecting for example model.model.layers[0].mlp.gate_proj.weight.dtype.

@riccardofelluga, when you look into what Thunder executes for this and other HF models and what is missing for performance please update this test to use bfloat16 weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants