Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and parallelize MaxText test runner #1113

Merged
merged 1 commit into from
Dec 29, 2024
Merged

Conversation

shralex
Copy link
Collaborator

@shralex shralex commented Dec 20, 2024

This change moves train, decode and integration tests into separate files, so they can be tested using pytest rather than explicitly being invoked from UnitTests.yml, and splits up various tasks in UnitTests.yml into separate jobs (with explicit dependencies) that execute concurrently. As a result, the run time of tests in MaxText PRs is reduced from ~40min to ~24min.

Before:

image

After:

Screenshot 2024-12-28 at 7 52 59 PM

The following changes are made:

  1. All explicit test invocations are moved out of UnitTests.yml. The file is renamed to RunTests.yml.
  2. end-to-end tests are placed in a new folder tests/integration_tests
  3. RunTests.yml is refactored to use shared modules in 2 new files (one for building/uploading images and another for running tests) and explicit jobs are created for various types of tests. Previously, building a gpu image blocked TPU tests, and all TPU tests ran sequentially. Now GPU image only blocks GPU tests, and unit / integration tests run concurrently.
  4. Pytest marker tpu is renamed to tpu_only without change in meaning, and we also add a gpu_only marker and an integration_test marker.
  5. The frequency of unit-test runs from every 2 hours to every 6 hours, which seems sufficient.
  6. This PR fixes a small bug where we were invoking end_to_end/test_generate_param_only_checkpoint.sh twice with effectively identical arguments.

There are multiple opportunities for further speed-up, not done in this PR. For example, utilizing pytest parallelization support, reducing the time to build a GPU image (which is currently ~13 min), avoiding downloading images for every test stage, etc.

FIXES: b/385505333
FIXES: b/385510862

Tests

This PR refactors and introduces new unit tests.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@shralex shralex force-pushed the refactor_tests branch 9 times, most recently from e2f9e3b to 638509f Compare December 20, 2024 18:00
@shralex shralex force-pushed the refactor_tests branch 5 times, most recently from 0b49275 to 7368a0b Compare December 21, 2024 06:23
@shralex shralex requested a review from RissyRan as a code owner December 21, 2024 06:23
@shralex shralex force-pushed the refactor_tests branch 5 times, most recently from bcc3ce8 to 5afaef1 Compare December 21, 2024 19:32
- name: Test train.py with dropout
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} max_target_length=128 per_device_batch_size=1 dropout_rate=0.02
- name: Test generate_param_only_checkpoint
run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a ${{ matrix.device.attention }}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is equivalent to the following command, since int8 is the default quantization in
test_generate_param_only_checkpoint.sh

@shralex shralex force-pushed the refactor_tests branch 4 times, most recently from 84bc33f to 36242f2 Compare December 21, 2024 21:58
@shralex shralex force-pushed the refactor_tests branch 12 times, most recently from bf32d38 to ba3131b Compare December 25, 2024 19:33
@shralex shralex changed the title Move tests from UnitTests.yml to tests files Refactor and parallelize MaxText test runner Dec 25, 2024
@@ -23,8 +23,8 @@ on:
branches: [ "main" ]
workflow_dispatch:
schedule:
# Run the job every 2 hours
- cron: '0 */2 * * *'
# Run the job every 6 hours
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this? Every 2 hours is pretty helpful to catch regression

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO given that this also runs on every PR, 6 hours is plenty. For reference, MaxDiffusion has it set to every 12 hours, while Jax only runs its CI once or twice daily, depending on the workflow. I changed it to 4 hours as a compromise for now :) but please let me know if you still disagree.

@@ -85,7 +85,7 @@ def setUp(self):
}
self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch)

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_logits_numerically(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhiyuLi-goog Can you explain this test

  1. It looks like it is using random model weights
  2. In an offline process you pass the same random model weights to get ground truth in a paxml implementation (how is this done, how do you ensure the same model weights are used?)

Won't this test fail if we pass different rng or jax changes its rng implementation since the model weights will change? If my understanding is correct I think this correctness test makes more sense as an e2e test that loads a checkpoint from GCS

Copy link
Collaborator

@ZhiyuLi-goog ZhiyuLi-goog Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap. You are right. This test is simply to avoid numerical discrepancy caused by any unexpected implementation changes in sharing module like attention or linear layers.

We don't care about the exact value of the model weight. A fixed random seed ensures a frozen model weights. And we do expect the same results with the same model weights as long as the implementation is not changed.

The root cause might be relevant to a changing of random state, which might be affected by other test in parallel. I propose to double check random state in parallel tests first. Otherwise, we can switch to ckpt loading if there's hard conflict between random state and parallel testing.



In an offline process you pass the same random model weights to get ground truth in a paxml implementation (how is this done, how do you ensure the same model weights are used?)

This was completed offline during the #293 see numerical test. Paxml environment is hard to setup and might introduce some conflicts. The golden logits is actually from paxml model with the same model weights.



Won't this test fail if we pass different rng or jax changes its rng implementation since the model weights will change?

Yeap. The goal is to verify the same logit outputs with fixed rng and random weights while we may not care too much about the exact value of random weights. I think rng implementation might probably pretty stable in jax.



I think this correctness test makes more sense as an e2e test that loads a checkpoint from GCS.

You are right. It would be ideal to avoid dependency to random initialization. However, it might be hard to maintain for a long time since we need to change orbax ckpt each time

  • we change the layer name
  • change ckpt implementation
  • some updates in orbax library

This happened in last mlperf 4.1 submission.

As for e2e test, there's a functional e2e test for ckpt conversion and loading.

from absl.testing import absltest


class Train(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this class be called "DecodeInt8Test" or similar instead of train?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks!

from absl.testing import absltest


class Train(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this class have a more specific name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

from absl.testing import absltest


class Train(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DecodeTest class or similar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a shmapped collective matmul test count as an "integration test" - this is more of a unit test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I think these tests are moved to other train.py tests - these differ only with quantization=int8 so it is probably easier to maintain them in the same file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved all the new train configs into train_tests.py and all the new decode configs into decode_tests.py.

Copy link
Collaborator

@gobbleturk gobbleturk Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this is just a slightly different config of train, does it make sense to include it in the same file as other train variants?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved all the new train configs into train_tests.py and all the new decode configs into decode_tests.py.

@shralex shralex force-pushed the refactor_tests branch 5 times, most recently from a4659d0 to b56cfa6 Compare December 27, 2024 10:48
]
}

@pytest.mark.tpu_only
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of running separate @pytest for each config, e.g. something like

AOT tests:

@pytest.mark.tpu
def test_save_compiled_v4(self):
compiled_trainstep_file = "/tmp/test_compiled_v4.pickle"
train_compile_main(
(
None,
"configs/base.yml",
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v4-8",
"compile_topology_num_slices=1",
"base_emb_dim=256",
"base_mlp_dim=256",
"base_num_decoder_layers=2",
)
)
@pytest.mark.tpu
def test_save_compiled_v5e(self):
compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle"
train_compile_main(
(
None,
"configs/base.yml",
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5e-16",
"compile_topology_num_slices=1",
"base_emb_dim=256",
"base_mlp_dim=256",
"base_num_decoder_layers=2",
)
)
# TODO (b/366200617) : This tests fails in AOT, but config works fine on real hardware
@pytest.mark.skip(reason="Issue w/ kernels_test. Error: The TPU is already in use by process...")
def test_minimal_offloaded_v5e(self):
compiled_trainstep_file = "/tmp/test_compiled_v5e_offload.pickle"
train_compile_main(
(
None,
"configs/base.yml",
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5e-256",
"compile_topology_num_slices=1",
"per_device_batch_size=1",
"ici_fsdp_parallelism=16",
"ici_tensor_parallelism=16",
"max_target_length=2048",
"fused_qkv=true",
"fused_mlp=true",
"remat_policy=minimal_offloaded",
"use_iota_embed=true",
"global_parameter_scale=128",
)
)

I prefer this since:

  • Each test is more self contained, don't need to find the right config
  • I think if a test fails in the current "for loop" implementation we won't get a clear error mesage as to which test really failed - we will get "test_tpu_config" failed and we will need to carefully scan the logs for the "Running TPU test for config...." message

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@shralex shralex force-pushed the refactor_tests branch 4 times, most recently from 9c683a1 to 39c1d44 Compare December 28, 2024 08:28
This change moves out train and decode tests into separate files,
so they can be tested using pytest rather than explicitly being
invoked from UnitTests.yml. This will allow us, in a subsequent
PR, to parallelize the execution of these tests by utilizing
pytest parallelization support.

Several end-to-end tests are still being invoked from UnitTests.yml.
This does not seem like the right place, and will hopefully be
addressed in the future.
@copybara-service copybara-service bot merged commit bb4bfb7 into main Dec 29, 2024
14 of 15 checks passed
@copybara-service copybara-service bot deleted the refactor_tests branch December 29, 2024 11:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants