-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor and parallelize MaxText test runner #1113
Conversation
e2f9e3b
to
638509f
Compare
0b49275
to
7368a0b
Compare
bcc3ce8
to
5afaef1
Compare
- name: Test train.py with dropout | ||
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} max_target_length=128 per_device_batch_size=1 dropout_rate=0.02 | ||
- name: Test generate_param_only_checkpoint | ||
run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a ${{ matrix.device.attention }} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is equivalent to the following command, since int8 is the default quantization in
test_generate_param_only_checkpoint.sh
84bc33f
to
36242f2
Compare
bf32d38
to
ba3131b
Compare
.github/workflows/UnitTests.yml
Outdated
@@ -23,8 +23,8 @@ on: | |||
branches: [ "main" ] | |||
workflow_dispatch: | |||
schedule: | |||
# Run the job every 2 hours | |||
- cron: '0 */2 * * *' | |||
# Run the job every 6 hours |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this? Every 2 hours is pretty helpful to catch regression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO given that this also runs on every PR, 6 hours is plenty. For reference, MaxDiffusion has it set to every 12 hours, while Jax only runs its CI once or twice daily, depending on the workflow. I changed it to 4 hours as a compromise for now :) but please let me know if you still disagree.
@@ -85,7 +85,7 @@ def setUp(self): | |||
} | |||
self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch) | |||
|
|||
@pytest.mark.tpu | |||
@pytest.mark.tpu_only | |||
def test_logits_numerically(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhiyuLi-goog Can you explain this test
- It looks like it is using random model weights
- In an offline process you pass the same random model weights to get ground truth in a paxml implementation (how is this done, how do you ensure the same model weights are used?)
Won't this test fail if we pass different rng or jax changes its rng implementation since the model weights will change? If my understanding is correct I think this correctness test makes more sense as an e2e test that loads a checkpoint from GCS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeap. You are right. This test is simply to avoid numerical discrepancy caused by any unexpected implementation changes in sharing module like attention or linear layers.
We don't care about the exact value of the model weight. A fixed random seed ensures a frozen model weights. And we do expect the same results with the same model weights as long as the implementation is not changed.
The root cause might be relevant to a changing of random state, which might be affected by other test in parallel. I propose to double check random state in parallel tests first. Otherwise, we can switch to ckpt loading if there's hard conflict between random state and parallel testing.
In an offline process you pass the same random model weights to get ground truth in a paxml implementation (how is this done, how do you ensure the same model weights are used?)
This was completed offline during the #293 see numerical test. Paxml environment is hard to setup and might introduce some conflicts. The golden logits is actually from paxml model with the same model weights.
Won't this test fail if we pass different rng or jax changes its rng implementation since the model weights will change?
Yeap. The goal is to verify the same logit outputs with fixed rng and random weights while we may not care too much about the exact value of random weights. I think rng implementation might probably pretty stable in jax.
I think this correctness test makes more sense as an e2e test that loads a checkpoint from GCS.
You are right. It would be ideal to avoid dependency to random initialization. However, it might be hard to maintain for a long time since we need to change orbax ckpt each time
- we change the layer name
- change ckpt implementation
- some updates in orbax library
This happened in last mlperf 4.1 submission.
As for e2e test, there's a functional e2e test for ckpt conversion and loading.
MaxText/tests/decode_int8_test.py
Outdated
from absl.testing import absltest | ||
|
||
|
||
class Train(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this class be called "DecodeInt8Test" or similar instead of train?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, thanks!
from absl.testing import absltest | ||
|
||
|
||
class Train(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this class have a more specific name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
MaxText/tests/decode_test.py
Outdated
from absl.testing import absltest | ||
|
||
|
||
class Train(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DecodeTest class or similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think a shmapped collective matmul test count as an "integration test" - this is more of a unit test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
MaxText/tests/train_int8_test.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally I think these tests are moved to other train.py tests - these differ only with quantization=int8
so it is probably easier to maintain them in the same file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved all the new train configs into train_tests.py and all the new decode configs into decode_tests.py.
MaxText/tests/train_pdb_lt_1_test.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, this is just a slightly different config of train, does it make sense to include it in the same file as other train variants?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved all the new train configs into train_tests.py and all the new decode configs into decode_tests.py.
a4659d0
to
b56cfa6
Compare
] | ||
} | ||
|
||
@pytest.mark.tpu_only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of running separate @pytest for each config, e.g. something like
AOT tests:
maxtext/MaxText/tests/train_compile_test.py
Lines 27 to 80 in 4699f4f
@pytest.mark.tpu | |
def test_save_compiled_v4(self): | |
compiled_trainstep_file = "/tmp/test_compiled_v4.pickle" | |
train_compile_main( | |
( | |
None, | |
"configs/base.yml", | |
f"compiled_trainstep_file={compiled_trainstep_file}", | |
"compile_topology=v4-8", | |
"compile_topology_num_slices=1", | |
"base_emb_dim=256", | |
"base_mlp_dim=256", | |
"base_num_decoder_layers=2", | |
) | |
) | |
@pytest.mark.tpu | |
def test_save_compiled_v5e(self): | |
compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle" | |
train_compile_main( | |
( | |
None, | |
"configs/base.yml", | |
f"compiled_trainstep_file={compiled_trainstep_file}", | |
"compile_topology=v5e-16", | |
"compile_topology_num_slices=1", | |
"base_emb_dim=256", | |
"base_mlp_dim=256", | |
"base_num_decoder_layers=2", | |
) | |
) | |
# TODO (b/366200617) : This tests fails in AOT, but config works fine on real hardware | |
@pytest.mark.skip(reason="Issue w/ kernels_test. Error: The TPU is already in use by process...") | |
def test_minimal_offloaded_v5e(self): | |
compiled_trainstep_file = "/tmp/test_compiled_v5e_offload.pickle" | |
train_compile_main( | |
( | |
None, | |
"configs/base.yml", | |
f"compiled_trainstep_file={compiled_trainstep_file}", | |
"compile_topology=v5e-256", | |
"compile_topology_num_slices=1", | |
"per_device_batch_size=1", | |
"ici_fsdp_parallelism=16", | |
"ici_tensor_parallelism=16", | |
"max_target_length=2048", | |
"fused_qkv=true", | |
"fused_mlp=true", | |
"remat_policy=minimal_offloaded", | |
"use_iota_embed=true", | |
"global_parameter_scale=128", | |
) | |
) |
I prefer this since:
- Each test is more self contained, don't need to find the right config
- I think if a test fails in the current "for loop" implementation we won't get a clear error mesage as to which test really failed - we will get "test_tpu_config" failed and we will need to carefully scan the logs for the "Running TPU test for config...." message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
9c683a1
to
39c1d44
Compare
This change moves out train and decode tests into separate files, so they can be tested using pytest rather than explicitly being invoked from UnitTests.yml. This will allow us, in a subsequent PR, to parallelize the execution of these tests by utilizing pytest parallelization support. Several end-to-end tests are still being invoked from UnitTests.yml. This does not seem like the right place, and will hopefully be addressed in the future.
39c1d44
to
244a382
Compare
This change moves train, decode and integration tests into separate files, so they can be tested using pytest rather than explicitly being invoked from UnitTests.yml, and splits up various tasks in UnitTests.yml into separate jobs (with explicit dependencies) that execute concurrently. As a result, the run time of tests in MaxText PRs is reduced from ~40min to ~24min.
Before:
After:
The following changes are made:
There are multiple opportunities for further speed-up, not done in this PR. For example, utilizing pytest parallelization support, reducing the time to build a GPU image (which is currently ~13 min), avoiding downloading images for every test stage, etc.
FIXES: b/385505333
FIXES: b/385510862
Tests
This PR refactors and introduces new unit tests.
Checklist
Before submitting this PR, please make sure (put X in square brackets):