diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 02fc45ed5c..b74e5f2f42 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ # Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml -* @gobbleturk @khatwanimohit @bvandermoon @vipannalla +* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan diff --git a/.github/workflows/AddLabel.yml b/.github/workflows/AddLabel.yml index d949e25007..451668d244 100644 --- a/.github/workflows/AddLabel.yml +++ b/.github/workflows/AddLabel.yml @@ -16,7 +16,7 @@ name: Add Label on: workflow_run: - workflows: [Unit Test, CodeQL] + workflows: [Tests, CodeQL] types: - completed pull_request_review: @@ -57,6 +57,7 @@ jobs: khatwanimohit: "", bvandermoon: "", vipannalla: "", + RissyRan: "", } const reviews = await github.rest.pulls.listReviews({ owner, diff --git a/.github/workflows/CPUTests.yml b/.github/workflows/CPUTests.yml index 6ce5efa9b9..03876616af 100644 --- a/.github/workflows/CPUTests.yml +++ b/.github/workflows/CPUTests.yml @@ -1,9 +1,9 @@ name: Linter on: + pull_request: push: - branches: - - '**' + branches: [ "main" ] jobs: cpu: diff --git a/.github/workflows/RunTests.yml b/.github/workflows/RunTests.yml new file mode 100644 index 0000000000..13b9056dce --- /dev/null +++ b/.github/workflows/RunTests.yml @@ -0,0 +1,114 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Tests + +on: + pull_request: + push: + branches: [ "main" ] + workflow_dispatch: + schedule: + # Run the job every 4 hours + - cron: '0 */4 * * *' + +jobs: + prelim: + runs-on: ["self-hosted"] + steps: + - name: Test gsutil installation + run: which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;} + - name: Cleanup old docker images + run: docker system prune --all --force + + tpu_image: + needs: prelim + uses: ./.github/workflows/build_upload_internal.yml + with: + device_type: tpu + device_name: v4-8 + build_mode: stable_stack + base_image: us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest + + gpu_image: + needs: prelim + uses: ./.github/workflows/build_upload_internal.yml + with: + device_type: gpu + device_name: a100-40gb-4 + build_mode: pinned + + tpu_unit_tests: + needs: tpu_image + uses: ./.github/workflows/run_tests_internal.yml + with: + device_type: tpu + device_name: v4-8 + pytest_marker: 'not gpu_only and not integration_test' + test_directory: 'tests' + xla_python_client_mem_fraction: 0.75 + tf_force_gpu_allow_growth: false + container_resource_option: "--privileged" + + tpu_integration_tests: + needs: tpu_image + uses: ./.github/workflows/run_tests_internal.yml + with: + device_type: tpu + device_name: v4-8 + pytest_marker: 'not gpu_only and integration_test' + test_directory: 'tests/integration_tests' + xla_python_client_mem_fraction: 0.75 + tf_force_gpu_allow_growth: false + container_resource_option: "--privileged" + + gpu_unit_tests: + needs: gpu_image + uses: ./.github/workflows/run_tests_internal.yml + with: + device_type: gpu + device_name: a100-40gb-4 + pytest_marker: 'not tpu_only and not integration_test' + test_directory: 'tests' + xla_python_client_mem_fraction: 0.65 + tf_force_gpu_allow_growth: true + container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged" + + gpu_integration_tests: + needs: gpu_image + uses: ./.github/workflows/run_tests_internal.yml + with: + device_type: gpu + device_name: a100-40gb-4 + pytest_marker: 'not tpu_only and integration_test' + test_directory: 'tests/integration_tests' + xla_python_client_mem_fraction: 0.65 + tf_force_gpu_allow_growth: true + container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged" + + + clean_up: + if: ${{ always() }} # always execute, regardless of previous jobs or steps. + needs: [gpu_unit_tests, gpu_integration_tests, tpu_unit_tests, tpu_integration_tests] + name: "Clean up" + runs-on: ["self-hosted"] + steps: + - name: Delete GPU image + run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu --force-delete-tags --quiet + - name: Delete TPU image + run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu --force-delete-tags --quiet + diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml deleted file mode 100644 index 91a62efc2c..0000000000 --- a/.github/workflows/UnitTests.yml +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - -name: Unit Test - -on: - pull_request: - push: - branches: [ "main" ] - workflow_dispatch: - schedule: - # Run the job every 2 hours - - cron: '0 */2 * * *' - -jobs: - build_and_upload_image: - strategy: - fail-fast: false - matrix: - device: - - type: tpu - name: v4-8 - mode: stable - - type: gpu - name: a100-40gb-4 - mode: pinned - name: Build and upload image (${{ matrix.device.name }}) - runs-on: ["self-hosted", "${{ matrix.device.type }}", "${{ matrix.device.name }}"] - steps: - - uses: actions/checkout@v4 - - name: Cleanup old docker images - run: docker system prune --all --force - - name: Build an image - run: | - bash docker_build_dependency_image.sh MODE=${{ matrix.device.mode }} DEVICE=${{ matrix.device.type }} - - name: Tag the image - run: | - docker tag maxtext_base_image gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }} - - name: Upload the image - run: | - docker push gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }} - - common: - needs: build_and_upload_image - strategy: - fail-fast: False - matrix: - device: - - type: tpu - name: v4-8 - attention: autoselected - pytest_marker: '' - container_env: - XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75 - TF_FORCE_GPU_ALLOW_GROWTH: false - container_resource_option: "--privileged" - - type: gpu - name: a100-40gb-4 - image_suffix: gpu_jax_pinned - attention: dot_product - pytest_marker: -m 'not tpu' - container_env: - XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65 - TF_FORCE_GPU_ALLOW_GROWTH: true - container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged" - name: Common test (${{ matrix.device.name }}) - runs-on: ["self-hosted", "${{ matrix.device.type }}", "${{ matrix.device.name }}"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }} - volumes: - - /home/runner/actions-runner/_work/maxtext/maxtext:/deps - env: - XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ matrix.device.container_env.XLA_PYTHON_CLIENT_MEM_FRACTION }} - TF_FORCE_GPU_ALLOW_GROWTH: ${{ matrix.device.container_env.TF_FORCE_GPU_ALLOW_GROWTH }} - options: ${{ matrix.device.container_resource_option }} - steps: - - uses: actions/checkout@v4 - - name: Test gsutil installation - run: which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;} - - name: Test with pytest - run: cd MaxText;python3 -m pytest ${{ matrix.device.pytest_marker }} - - name: Test train.py with TFDS c4 - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test train.py with HF c4 - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs hf_train_files=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet hf_path=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large attention=${{ matrix.device.attention }} enable_checkpointing=false - - name: Test train.py with synthetic data - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} dataset_type=synthetic - - name: Test train.py with per_device_batch_size < 1 - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test decode.py - run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1 - - name: Test int8_decode - run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1 quantization=int8 quantize_kvcache=True - - name: Test decode.py with per_device_batch_size < 1 - run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=.25 - - name: Test int8_training - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test fp8_training - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=fp8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test train.py with dropout - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} max_target_length=128 per_device_batch_size=1 dropout_rate=0.02 - - name: Test generate_param_only_checkpoint - run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a ${{ matrix.device.attention }} - - name: Test generate_param_only_checkpoint with int8 quantization - run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8 -a ${{ matrix.device.attention }} - - name: Test grain checkpoint determinism - run: bash end_to_end/test_checkpointing.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset False grain ${{ matrix.device.attention }} - - name: Test checkpoint compatibility - run: bash end_to_end/test_checkpoint_compatibility.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset ${{ matrix.device.attention }} - - tpu: - needs: build_and_upload_image - strategy: - fail-fast: false - matrix: - device-type: ["v4-8"] - name: "TPU test (${{ matrix.device-type }})" - runs-on: ["self-hosted", "tpu", "${{ matrix.device-type }}"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu - volumes: - - /home/runner/actions-runner/_work/maxtext/maxtext:/deps - options: "--privileged" - steps: - - uses: actions/checkout@v4 - - name: Validate Pedagogical Example, Shmap_collective_matmul - run: python3 pedagogical_examples/shmap_collective_matmul.py - - gpu: - needs: build_and_upload_image - strategy: - fail-fast: false - matrix: - device-type: ["a100-40gb-4"] - build-mode: ["pinned"] - name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})" - runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu - volumes: - - /home/runner/actions-runner/_work/maxtext/maxtext:/deps - env: - XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65 - TF_FORCE_GPU_ALLOW_GROWTH: true - options: "--shm-size 2g --runtime=nvidia --gpus all --privileged" - steps: - - uses: actions/checkout@v4 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Test train.py with flash attention - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=cudnn_flash_te - - clean_up: - if: ${{ always() }} - needs: [common, gpu, tpu] - name: "Clean up" - runs-on: ["self-hosted"] - steps: - - name: Delete GPU image - run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu --force-delete-tags --quiet - - name: Delete TPU image - run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu --force-delete-tags --quiet - diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index 7f091b5a37..8971be9d48 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -41,10 +41,13 @@ jobs: bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_nightly MODE=nightly DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_nightly - name: build jax stable stack image run : | - bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable_stack_0.4.35 MODE=stable_stack DEVICE=TPU PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable_stack_0.4.35 BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt + bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable_stack MODE=stable_stack DEVICE=TPU PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt - name: build image with stable stack nightly jax run: | - bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_stable_stack_nightly_jax MODE=stable_stack DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack_nightly BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt + bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_stable_stack_nightly_jax MODE=stable_stack DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_tpu_jax_stable_stack_nightly BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt + - name: build image with jax stable stack release candidate image + run: | + bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_stable_stack_candidate MODE=stable_stack DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_stable_stack_candidate BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt gpu: strategy: fail-fast: false diff --git a/.github/workflows/build_upload_internal.yml b/.github/workflows/build_upload_internal.yml new file mode 100644 index 0000000000..dcf7810a4d --- /dev/null +++ b/.github/workflows/build_upload_internal.yml @@ -0,0 +1,50 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file defines a module for building and uploading an image used in UnitTests.yml + +name: Build and Upload Image + +on: + workflow_call: + inputs: + device_type: + required: true + type: string + device_name: + required: true + type: string + build_mode: + required: true + type: string + base_image: + required: false + type: string + +jobs: + build_and_upload: + name: Build and upload image (${{ inputs.device_name }}) + runs-on: ["self-hosted", "${{ inputs.device_type }}", "${{ inputs.device_name }}"] + steps: + - uses: actions/checkout@v4 + - name: Build an image + run: | + bash docker_build_dependency_image.sh MODE=${{ inputs.build_mode }} DEVICE=${{ inputs.device_type }} BASEIMAGE=${{ inputs.base_image }} + - name: Tag the image + run: | + docker tag maxtext_base_image gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ inputs.device_type }} + - name: Upload the image + run: | + docker push gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ inputs.device_type }} + diff --git a/.github/workflows/run_tests_internal.yml b/.github/workflows/run_tests_internal.yml new file mode 100644 index 0000000000..03cb84c563 --- /dev/null +++ b/.github/workflows/run_tests_internal.yml @@ -0,0 +1,60 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file defines a module for running tests used in UnitTests.yml + +name: Run Tests + +on: + workflow_call: + inputs: + device_type: + required: true + type: string + device_name: + required: true + type: string + pytest_marker: + required: true + type: string + test_directory: + required: true + type: string + xla_python_client_mem_fraction: + required: true + type: string + tf_force_gpu_allow_growth: + required: true + type: string + container_resource_option: + required: true + type: string + +jobs: + run: + runs-on: ["self-hosted", "${{ inputs.device_type }}", "${{ inputs.device_name }}"] + container: + image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ inputs.device_type }} + volumes: + - /home/runner/actions-runner/_work/maxtext/maxtext:/deps + env: + XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} + TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} + options: ${{ inputs.container_resource_option }} + steps: + - uses: actions/checkout@v4 + - name: Run Tests + run: | + cd MaxText + python3 -m pytest ${{ inputs.test_directory }} -m "${{ inputs.pytest_marker }}" diff --git a/MaxText/accelerator_to_spec_map.py b/MaxText/accelerator_to_spec_map.py index acd2a0659d..0bd240e237 100644 --- a/MaxText/accelerator_to_spec_map.py +++ b/MaxText/accelerator_to_spec_map.py @@ -45,6 +45,9 @@ class SystemCharacteristics: "v6e-128": SystemCharacteristics("tpu", "v6e:8x16", "default", (2, 2, 1), 128, (False, True, False)), "v6e-256": SystemCharacteristics("tpu", "v6e:16x16", "default", (2, 2, 1), 256, (True, True, False)), # v5e: one core per chip with 16 GB HBM + "v5e-1": SystemCharacteristics("tpu", "v5e:1x1", "default", (1, 1, 1), 1, (False, False, False)), + "v5e-4": SystemCharacteristics("tpu", "v5e:2x2", "default", (2, 2, 1), 4, (False, False, False)), + "v5e-8": SystemCharacteristics("tpu", "v5e:2x4", "default", (2, 2, 1), 8, (False, False, False)), "v5e-16": SystemCharacteristics("tpu", "v5e:4x4", "default", (2, 2, 1), 16, (False, False, False)), "v5e-32": SystemCharacteristics("tpu", "v5e:4x8", "default", (2, 2, 1), 32, (False, False, False)), "v5e-64": SystemCharacteristics("tpu", "v5e:8x8", "default", (2, 2, 1), 64, (False, False, False)), @@ -169,4 +172,9 @@ class SystemCharacteristics: def get_system_characteristics(user_facing_name): - return UserFacingNameToSystemCharacteristics.get(user_facing_name) + system_characteristics = UserFacingNameToSystemCharacteristics.get(user_facing_name) + if system_characteristics is None: + raise ValueError( + f"Invalid compile topology: {user_facing_name}. Valid topology names: {UserFacingNameToSystemCharacteristics.keys()}" + ) + return system_characteristics diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index 5c7fbf2baa..4ead9b4e93 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -27,6 +27,7 @@ import numpy as np import orbax.checkpoint as ocp import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager +import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager # pylint: disable=too-many-positional-arguments @@ -91,7 +92,7 @@ def create_orbax_emergency_checkpoint_manager( persistent_save_interval_steps: int, orbax_logger: Optional[abstract_logger.AbstractLogger] = None, ): - """Returns an emergency checkpoint.""" + """Returns an emergency checkpoint manager.""" flags.FLAGS.experimental_orbax_use_distributed_process_id = True max_logging.log("Creating emergency checkpoint manager...") @@ -99,7 +100,7 @@ def create_orbax_emergency_checkpoint_manager( local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps), persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps), ) - emergency_mngr = emergency_checkpoint_manager.CheckpointManager( + manager = emergency_checkpoint_manager.CheckpointManager( local_checkpoint_dir, epath.Path(persistent_checkpoint_dir), global_mesh=global_mesh, @@ -109,7 +110,37 @@ def create_orbax_emergency_checkpoint_manager( ) max_logging.log("Emergency checkpoint manager created!") - return emergency_mngr + return manager + + +def create_orbax_emergency_replicator_checkpoint_manager( + local_checkpoint_dir: str, + save_interval_steps: int, + global_mesh: jax.sharding.Mesh, +): + """Returns an emergency replicator checkpoint manager.""" + flags.FLAGS.experimental_orbax_use_distributed_process_id = True + max_logging.log("Creating emergency replicator checkpoint manager...") + + options = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManagerOptions( + save_interval_steps=save_interval_steps, + ) + manager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager( + local_checkpoint_dir, + options, + global_mesh=global_mesh, + ) + + max_logging.log("Emergency replicator checkpoint manager created!") + return manager + + + +def print_save_message(step, async_checkpointing): + if async_checkpointing: + max_logging.log(f"Started an asynchronous checkpoint save for step {step}") + else: + max_logging.log(f"Saved a checkpoint at step {step}.") def _find_idx(array: np.ndarray, replica_axis_idx: int): diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 83afb73e15..074cf49e00 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -160,10 +160,10 @@ set_remat_policy_on_layers_per_stage: False # Choose 'remat_policy' between 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', -# 'save_qkv_proj', 'qkv_proj_offloaded', 'custom' 'minimal_offloaded', 'save_out_proj' and 'full'. +# 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'. # These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) remat_policy: 'full' -# If custom_save_offload remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory. +# If "custom" remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory. # Pick one of these options for following tensors: ['remat','device','offload'] decoder_layer_input: 'device' # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points context: 'remat' # From https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/attention.py#L581-L583 @@ -227,47 +227,49 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # Parallelism -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive'] +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_heads', ['tensor','sequence']], - ['activation_kv_heads', ['tensor','sequence']], - ['activation_length', 'sequence'], + ['activation_heads', ['tensor','sequence','tensor_sequence']], + ['activation_kv_heads', ['tensor','sequence','tensor_sequence']], + ['activation_length', ['sequence']], + ['activation_norm_length', ['tensor_sequence', 'sequence']], ['activation_embed', 'tensor'], - ['activation_mlp', 'tensor'], - ['activation_kv', 'tensor'], + ['activation_mlp', ['tensor', 'tensor_sequence']], + ['activation_kv', ['tensor', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_head_dim', 'tensor'], - ['activation_vocab', ['tensor', 'sequence']], + ['activation_kv_head_dim', ['tensor', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'sequence', 'tensor_sequence']], ['activation_vocab', 'tensor'], + ['activation_vocab', 'tensor_sequence'], ['activation_vocab', 'sequence'], ['activation_stage', 'stage'], ['activation_exp', 'expert'], - ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], - ['vocab', ['tensor', 'autoregressive']], + ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['vocab', ['tensor', 'tensor_sequence', 'autoregressive']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], ['embed_no_exp', ['fsdp', 'sequence']], - ['norm', 'tensor'], - ['q_heads', ['tensor', 'autoregressive']], - ['heads', ['tensor', 'autoregressive']], + ['norm', ['tensor', 'tensor_sequence']], + ['q_heads', ['tensor', 'tensor_sequence', 'autoregressive']], + ['heads', ['tensor', 'tensor_sequence', 'autoregressive']], ['layers', 'stage'], ['kv', []], - ['kv_heads', ['tensor', 'autoregressive']], + ['kv_heads', ['tensor', 'tensor_sequence', 'autoregressive']], ['kv_head_dim', []], ['cache_batch_prefill', []], ['cache_batch', []], - ['cache_heads', ['autoregressive', 'tensor']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], ['cache_sequence', []], ['exp', 'expert'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']] # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 @@ -281,6 +283,7 @@ dcn_fsdp_parallelism: 1 dcn_fsdp_transpose_parallelism: 1 dcn_sequence_parallelism: 1 # never recommended dcn_tensor_parallelism: 1 # never recommended +dcn_tensor_sequence_parallelism: 1 # never recommended dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended @@ -289,6 +292,7 @@ ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 ici_tensor_parallelism: 1 +ici_tensor_sequence_parallelism: 1 ici_autoregressive_parallelism: 1 ici_pipeline_parallelism: 1 ici_expert_parallelism: 1 @@ -346,6 +350,10 @@ jax_distributed_initialization_timeout: 300 # This is the default timeout in htt # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers # only to the jax coordination service. jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. +skip_jax_distributed_system: False # If True we will not initialize the jax distributed system. +# Currently the jax distributed is needed on cloud TPUs for async checkpointing. +# However when run on google internal TPUs the coordination service is started automatically +# and we should set this to True so we won't try to initialize a second time manually. # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 # Learning rate schedule has either two or three parts: @@ -378,6 +386,20 @@ skip_first_n_steps_for_profiler: 1 # Profile for a small number of steps to avoid a large profile file size. profiler_steps: 5 profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step. +profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps. +# This is useful to debug scenarios where performance is changing. + + +# Dump HLO options +dump_hlo: False +dump_hlo_local_dir: "/tmp/xla_dump/" +dump_hlo_delete_local_after: True # Cleans local directory after its uploaded +dump_hlo_gcs_dir: "" # Defaults to {base_output_directory}/{run_name}/xla_dump +dump_hlo_module_name: "jit_train_step" # Filter uploading modules by this string. Set to empty string to remove any filter. +dump_hlo_xla_flags: "" # Defaults to "--xla_dump_to={dump_hlo_local_dir} --xla_dump_hlo_module_re={dump_hlo_module_name} --xla_dump_large_constants" +dump_hlo_upload_all: False # If true all hosts dump HLO, false only jax.process_index()==0 +# All hosts should have identical HLO for SPMD programs, however we have encountered some bugs +# where this is not the case and it is helpful to compare HLO across hosts. # When dropout is false the model is a deterministic function of the # data_shuffle_seed and init_weights_seed (i.e. reproducible losses) @@ -456,6 +478,8 @@ inference_microbenchmark_stages: "prefill,generate" inference_microbenchmark_loop_iters: 10 inference_microbenchmark_log_file_path: "" inference_metadata_file: "" # path to a json file +inference_server: "MaxtextInterleavedServer" # inference server to start +inference_benchmark_test: False enable_model_warmup: False # Stack prefill cache across the layer to reduce the diff --git a/MaxText/configs/trillium/llama2_70b_4096.sh b/MaxText/configs/trillium/llama2_7b_4096.sh similarity index 90% rename from MaxText/configs/trillium/llama2_70b_4096.sh rename to MaxText/configs/trillium/llama2_7b_4096.sh index 4c0e4b8769..fea2b3d1a6 100644 --- a/MaxText/configs/trillium/llama2_70b_4096.sh +++ b/MaxText/configs/trillium/llama2_7b_4096.sh @@ -1,4 +1,4 @@ -# Llama2-70b model. +# Llama2-7b model. # This config will work out of the box for any number of trillium-256 slices. # # Command Flags: @@ -7,7 +7,7 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash MaxText/configs/trillium/llama2_70b_4096.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash MaxText/configs/trillium/llama2_7b_4096.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # diff --git a/MaxText/decode.py b/MaxText/decode.py index b743852073..ef2f2fc79b 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -21,10 +21,20 @@ import os import pyconfig -import sys +from typing import Sequence +from absl import app + + +def main(argv: Sequence[str]) -> None: + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + pyconfig.initialize(argv) + config = pyconfig.config + validate_config(config) + max_utils.print_system_information() -def main(config): engine = maxengine.MaxEngine(config) rng = jax.random.PRNGKey(1234) rng, rng_load_params = jax.random.split(rng) @@ -71,10 +81,4 @@ def validate_config(config): if __name__ == "__main__": - jax.config.update("jax_default_prng_impl", "unsafe_rbg") - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - pyconfig.initialize(sys.argv) - cfg = pyconfig.config - validate_config(cfg) - max_utils.print_system_information() - main(cfg) + app.run(main) diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index 3254faf056..fd2f39f578 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -84,8 +84,8 @@ def prefill_insert_benchmark_loop( config, engine, decode_state, params, total_slots, tokens, true_length, iters, profile_name ): """Inner loop for benchmarking prefill and insert step.""" - prof = profiler.Profiler(config, profile_name) - prof.activate() + prof = profiler.Profiler(config) + prof.activate(optional_postfix=profile_name) start = datetime.datetime.now() rng = jax.random.PRNGKey(1234) for i in range(iters): @@ -121,8 +121,8 @@ def prefill_insert_benchmark(config, engine, decode_state, params, total_slots, def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name): """Inner loop for benchmarking ar step.""" - prof = profiler.Profiler(config, profile_name) - prof.activate() + prof = profiler.Profiler(config) + prof.activate(optional_postfix=profile_name) start = datetime.datetime.now() rng = jax.random.PRNGKey(1234) for _ in range(iters): diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 784bd5f282..f5990e7d9b 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -65,6 +65,7 @@ class AttentionType(enum.Enum): KV_BATCH = common_types.KV_BATCH LENGTH = common_types.LENGTH HEAD = common_types.HEAD +EMBED = common_types.EMBED KV_HEAD = common_types.KV_HEAD D_KV = common_types.D_KV KV_HEAD_DIM = common_types.KV_HEAD_DIM @@ -393,23 +394,30 @@ def cudnn_flash_attention( model_mode: str = common_types.MODEL_MODE_TRAIN, ) -> Array: """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA - 2. Supports head_dim till 128; head_dim=256 support will be added soon + 1. Stable API, supports GQA, SWA (only with causal masking) + 2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6 """ # These imports are only meant to work in a GPU build. from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error _, _, _, head_dim = query.shape # pylint: disable=unused-variable - # generate attn_mask - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + sliding_window_size = self.sliding_window_size + if self.attention_type == AttentionType.LOCAL_SLIDING: + sliding_window_size = [self.sliding_window_size, 0] + mask_type = "causal" # SWA only works with causal masking + attn_mask = None + else: + # generate attn_mask + mask_type = "padding_causal" # only padding_causal mask type can take a created mask + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) dpa_layer = DotProductAttention( head_dim=head_dim, num_attention_heads=self.num_query_heads, num_gqa_groups=self.num_kv_heads, - attn_mask_type="padding_causal", # 'no_mask', 'padding', 'causal', or 'padding_causal' - attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' attention_dropout=self.dropout_rate, dropout_rng_name="aqt", dtype=self.dtype, @@ -417,6 +425,7 @@ def cudnn_flash_attention( qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' scale_factor=1.0 / math.sqrt(head_dim), transpose_batch_sequence=False, + window_size=sliding_window_size, ) return dpa_layer(query, key, value, mask=attn_mask) @@ -1106,6 +1115,7 @@ class Attention(nn.Module): prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED) key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) @@ -1257,6 +1267,9 @@ def __call__( Returns: output of shape `[batch, length, q_features]`. """ + inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) + inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names) + # apply projection. if self.config.fused_qkv: query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") diff --git a/MaxText/layers/gemma.py b/MaxText/layers/gemma.py index 9ce072d8a0..52ff402431 100644 --- a/MaxText/layers/gemma.py +++ b/MaxText/layers/gemma.py @@ -69,14 +69,14 @@ def __call__( ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("norm",))( inputs ) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) attention_layer = Attention( config=cfg, @@ -108,7 +108,9 @@ def __call__( model_mode=model_mode, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + attention_lnx = nn.with_logical_constraint( + attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") + ) attention_lnx += inputs residual = attention_lnx attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("norm",))( @@ -126,7 +128,7 @@ def __call__( config=cfg, quant=self.quant, )(attn_output, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) next_layer_addition = mlp_lnx + residual @@ -137,7 +139,7 @@ def __call__( layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: diff --git a/MaxText/layers/gemma2.py b/MaxText/layers/gemma2.py index 2286984810..dd2db3a586 100644 --- a/MaxText/layers/gemma2.py +++ b/MaxText/layers/gemma2.py @@ -69,14 +69,14 @@ def __call__( ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_local", kernel_axes=("norm",) )(inputs) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) attention_layer = Attention( config=cfg, @@ -113,7 +113,9 @@ def __call__( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_local", kernel_axes=("norm",) )(attention_lnx) - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + attention_lnx = nn.with_logical_constraint( + attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") + ) attention_lnx += inputs residual = attention_lnx @@ -137,7 +139,7 @@ def __call__( mlp_lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm_local", kernel_axes=("norm",))( mlp_lnx ) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) next_layer_addition = mlp_lnx + residual @@ -148,18 +150,18 @@ def __call__( layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) ### global part - inputs = nn.with_logical_constraint(layer_output, ("activation_batch", "activation_length", "activation_embed")) + inputs = nn.with_logical_constraint(layer_output, ("activation_batch", "activation_norm_length", "activation_embed")) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_global", kernel_axes=("norm",) )(inputs) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) attention_layer = Attention( config=cfg, @@ -195,7 +197,9 @@ def __call__( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_global", kernel_axes=("norm",) )(attention_lnx) - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + attention_lnx = nn.with_logical_constraint( + attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") + ) attention_lnx += inputs residual = attention_lnx @@ -219,7 +223,7 @@ def __call__( mlp_lnx ) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) next_layer_addition = mlp_lnx + residual @@ -230,7 +234,7 @@ def __call__( layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py index 07ee9b0fb5..e9b6e65e9f 100644 --- a/MaxText/layers/gpt3.py +++ b/MaxText/layers/gpt3.py @@ -47,6 +47,7 @@ LENGTH = common_types.LENGTH HEAD = common_types.HEAD D_KV = common_types.D_KV +EMBED = common_types.EMBED DenseGeneral = linears.DenseGeneral NdInitializer = initializers.NdInitializer @@ -67,7 +68,7 @@ class Gpt3LayerNorm(nn.Module): epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 - kernel_axes: Tuple[str, ...] = () + kernel_axes: Tuple[Optional[str], ...] = () scale_init: Initializer = nn.initializers.zeros use_bias: bool = True reductions_in_fp32: bool = False @@ -148,6 +149,7 @@ class Gpt3MultiHeadAttention(nn.Module): kv_quant: Optional[KVQuant] = None use_bias: bool = True + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED) query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) @@ -213,6 +215,7 @@ def __call__( model_mode: str = common_types.MODEL_MODE_TRAIN, deterministic: bool = False, ): + inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) if self.fused_qkv: query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: @@ -279,7 +282,7 @@ def __call__( cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx_layer_norm = Gpt3LayerNorm( dtype=cfg.dtype, @@ -291,7 +294,7 @@ def __call__( ) lnx = lnx_layer_norm(inputs) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) # Self-attention block assert ( @@ -319,7 +322,9 @@ def __call__( lnx, decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, deterministic=deterministic ) - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + attention_lnx = nn.with_logical_constraint( + attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") + ) attention_lnx += inputs # MLP block. @@ -335,7 +340,7 @@ def __call__( config=cfg, quant=self.quant, )(attention_lnx, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = attention_lnx + mlp_lnx @@ -343,7 +348,7 @@ def __call__( layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 10d1d04525..31d01b3879 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -55,13 +55,6 @@ COMBINE = "combine" -def _get_model_call_mode(config): - if config.model_cal_mode == "inference": - return "inference" - else: - return None - - def _convert_to_activation_function(fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: """Convert a string to an activation function.""" if fn_or_string == "linear": @@ -107,7 +100,7 @@ class DenseGeneral(nn.Module): weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") - kernel_axes: Tuple[str, ...] = () + kernel_axes: Tuple[Optional[str], ...] = () quant: Optional[Quant] = None use_bias: bool = False matmul_precision: str = "default" @@ -309,7 +302,7 @@ class MoeBlock(nn.Module): num_experts_per_tok: int mesh: Mesh kernel_init: NdInitializer - kernel_axes: Tuple[str, ...] + kernel_axes: Tuple[Optional[str], ...] weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 quant: Optional[Quant] = None @@ -597,7 +590,8 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel, kernel_ax return kernel def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): - gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", "activation_embed")) + # gate_logits: batch, length, expert + gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None)) softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) # shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok) top_k_weights, top_k_indices = jax.lax.top_k(softmax_probs, self.num_experts_per_tok) @@ -609,7 +603,10 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): mask_axes = ("activation_batch", "activation_length", None, None) dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) - loss = self.load_balance_loss(top_k_indices, softmax_probs) + if self.config.model_call_mode != "inference": + loss = self.load_balance_loss(top_k_indices, softmax_probs) + else: + loss = None inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) with jax.named_scope("dispatch"): dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)( @@ -657,6 +654,8 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): intermediate_layer, ("activation_exp", "activation_batch_no_exp", None, "activation_embed"), ) + if self.config.activations_in_float32: + intermediate_layer = intermediate_layer.astype(jnp.float32) intermediate_layer = checkpoint_name(intermediate_layer, "mlpwo") with jax.named_scope("combine"): # Matmul & element wise operation @@ -665,7 +664,7 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): intermediate_layer, combine_mask, precision=matmul_precision, - ) + ).astype(self.dtype) return output, loss else: top_k_weights /= top_k_weights.sum(-1, keepdims=True) @@ -674,12 +673,16 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision - ).astype(jnp.float32) + ) + if self.config.activations_in_float32: + layer_w0 = layer_w0.astype(jnp.float32) layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") with jax.named_scope("wi_1"): layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision - ).astype(jnp.float32) + ) + if self.config.activations_in_float32: + layer_w1 = layer_w1.astype(jnp.float32) layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) @@ -687,12 +690,14 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): intermediate_layer = self.get_einsum(rhs_mesh_axes=self.wo_kernel_axes)( "BSEH,EHM -> BSEM", layer_multiply, wo_kernel, precision=matmul_precision ) + if self.config.activations_in_float32: + intermediate_layer = intermediate_layer.astype(jnp.float32) intermediate_layer = checkpoint_name(intermediate_layer, "mlpwo") with jax.named_scope("w_sum"): output = jnp.einsum( "BSEM,BSE -> BSM", - intermediate_layer.astype(jnp.float32), - weights.astype(jnp.float32), + intermediate_layer, + weights, ).astype(self.dtype) return output, None diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 604ccc7305..9b198c5947 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -78,7 +78,7 @@ def __call__( cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx_rms = models.RMSNorm( dtype=cfg.dtype, @@ -89,7 +89,7 @@ def __call__( ) lnx = lnx_rms(inputs) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) # Self-attention block attention_layer = Attention( @@ -124,7 +124,9 @@ def __call__( model_mode=model_mode, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + attention_lnx = nn.with_logical_constraint( + attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") + ) intermediate_inputs = inputs + attention_lnx # Fully Connected @@ -135,7 +137,9 @@ def __call__( kernel_axes=("norm",), epsilon=cfg.normalization_layer_epsilon, )(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) # MLP block. mlp_lnx = linears.MlpBlock( @@ -148,7 +152,7 @@ def __call__( config=cfg, quant=self.quant, )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs @@ -156,7 +160,7 @@ def __call__( layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 6208efdd91..5fbb9e1e12 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -70,7 +70,7 @@ def __call__( cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx_rms = models.RMSNorm( dtype=cfg.dtype, @@ -81,7 +81,7 @@ def __call__( ) lnx = lnx_rms(inputs) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) # Self-attention block attention_layer = Attention( @@ -110,7 +110,9 @@ def __call__( model_mode=model_mode, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + attention_lnx = nn.with_logical_constraint( + attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") + ) intermediate_inputs = inputs + attention_lnx # Fully Connected @@ -121,7 +123,9 @@ def __call__( kernel_axes=("norm",), epsilon=cfg.normalization_layer_epsilon, )(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) load_balance_loss = None if cfg.num_experts > 1: @@ -131,12 +135,12 @@ def __call__( num_experts_per_tok=cfg.num_experts_per_tok, mesh=mesh, kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), - kernel_axes=("embed", "mlp"), + kernel_axes=("embed", None), dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, quant=self.quant, )(hidden_states) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) else: mlp_lnx = linears.MlpBlock( intermediate_dim=cfg.mlp_dim, @@ -148,14 +152,14 @@ def __call__( config=cfg, quant=self.quant, )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.num_experts > 1 and load_balance_loss is not None: diff --git a/MaxText/layers/normalizations.py b/MaxText/layers/normalizations.py index 862c586c9e..fcb8bf0e50 100644 --- a/MaxText/layers/normalizations.py +++ b/MaxText/layers/normalizations.py @@ -14,7 +14,7 @@ """Normalization Layers.""" -from typing import Any, Tuple +from typing import Any, Tuple, Optional from flax import linen as nn from jax import lax @@ -30,7 +30,7 @@ class RMSNorm(nn.Module): epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 - kernel_axes: Tuple[str, ...] = () + kernel_axes: Tuple[Optional[str], ...] = () scale_init: Initializer = nn.initializers.ones @nn.compact diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 00cca229fa..8bfebf021d 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -153,6 +153,8 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper( mesh_axes, is_tiled, replicate_scale=self.replicate_scale ) + # module_path = "/".join(nn.module._context.module_stack[-1].path) + # print(f"quant_dg: {quant_dg}, is_tiled: {is_tiled}, module_path: {module_path}") aqt_dg_cls = functools.partial( aqt_flax.AqtDotGeneral, quant_dg, @@ -247,7 +249,9 @@ def _get_mixed_precision_quant_config(mixed_precision_config): ret_config = {} default_mp_config = _get_default_mp_config(default=mixed_precision_config.get(DEFAULT, None)) for layer_name_re, layer_quantization_config in mixed_precision_config.items(): - quant_config = default_mp_config + # Make a copy of default_mp_config to avoid updaing original dict + quant_config = default_mp_config.copy() + # print(f"Mixed precision config: processing {layer_name_re} - {layer_quantization_config}, default config - {quant_config}") if layer_name_re != DEFAULT: for k in quant_config.keys(): quant_config[k] = layer_quantization_config.get(k, default_mp_config[k]) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index b47bf29a41..04504f78e9 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -15,6 +15,7 @@ """ """ Common Max Utils needed by multiple modules""" +import shutil import numpy as np import jax import jax.numpy as jnp @@ -202,6 +203,12 @@ def parse_gcs_bucket_and_prefix(destination_gcs_name): return bucket, key +def add_trailing_slash(path): + if not path.endswith("/"): + return path + "/" + return path + + def upload_blob(destination_gcs_name, source_file_name): """Uploads a file to a GCS location""" bucket_name, prefix_name = parse_gcs_bucket_and_prefix(destination_gcs_name) @@ -211,6 +218,34 @@ def upload_blob(destination_gcs_name, source_file_name): blob.upload_from_filename(source_file_name) +def upload_dump(local_dir, target_dir, module_name=None, delete_local_after=True, all_host_upload=False): + """Uploads a directory to a GCS location, with an optional filter""" + if not all_host_upload and jax.process_index() != 0: + return + storage_client = storage.Client() + bucket_name, prefix_name = parse_gcs_bucket_and_prefix(target_dir) + bucket = storage_client.get_bucket(bucket_name) + if all_host_upload: + hostname = socket.gethostname() # Alternatively can use jax.process_id() + prefix_name = os.path.join(prefix_name, hostname) + target_dir = os.path.join(target_dir, hostname) + max_logging.log(f"Uploading HLO Dump to {target_dir}...") + for root, _, files in os.walk(local_dir): + for file in files: + if module_name and module_name not in file: + continue + else: + max_logging.log(f"Uploading {file}") + local_path = os.path.join(root, file) + relative_path = os.path.relpath(local_path, local_dir) + blob_name = os.path.join(prefix_name, relative_path) + blob = bucket.blob(blob_name) + blob.upload_from_filename(local_path) + max_logging.log(f"HLO Dump Uploaded to {target_dir}!") + if delete_local_after: + shutil.rmtree(local_dir) + + def maybe_initialize_jax_distributed_system(raw_keys): """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of indirection in MaxText to avoid breaking the call sites unnecessarily. @@ -219,6 +254,12 @@ def maybe_initialize_jax_distributed_system(raw_keys): For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. """ + if raw_keys["skip_jax_distributed_system"]: + max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.") + return + if raw_keys["inference_benchmark_test"]: + # Disable initialization for inference benmark test. + return if raw_keys["compile_topology"]: # Don't initialize jax distributed with AOT compilation return @@ -422,10 +463,10 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) target_type = "slices" if parallelism_type == "DCN" else "devices per slice" - assert ( - np.prod(parallelism_vals) == target_product - ), f"Number of {target_type} {target_product} does not match\ - the product of the {parallelism_type} parallelism {np.prod(parallelism_vals)}" + assert np.prod(parallelism_vals) == target_product, ( + f"Number of {target_type} {target_product} does not match" + f" the product of the {parallelism_type} parallelism {np.prod(parallelism_vals)}" + ) return parallelism_vals @@ -531,18 +572,18 @@ def create_device_mesh(config, devices=None): if devices is None: devices = jax.devices() num_devices = len(devices) - num_slices = config.num_slices + num_slices = 1 if config.inference_benchmark_test else config.num_slices num_devices_per_slice = num_devices // num_slices multi_slice_env = num_slices > 1 # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism, num_devices_per_slice, "ICI") + ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False if multi_slice_env: - dcn_parallelism = fill_unspecified_mesh_axes(config.dcn_parallelism, num_slices, "DCN") + dcn_parallelism = fill_unspecified_mesh_axes(config.dcn_parallelism.copy(), num_slices, "DCN") if is_valid_custom_mesh(ici_parallelism, config.custom_mesh): mesh = create_custom_device_mesh(ici_parallelism, dcn_parallelism, devices, config.custom_mesh) else: @@ -1024,15 +1065,15 @@ def save_quantized_checkpoint_if_configured(config, params): def print_mem_stats(label: str): - print(f"\nMemstats: {label}:") + max_logging.log(f"\nMemstats: {label}:") try: for d in jax.local_devices(): stats = d.memory_stats() used = round(stats["bytes_in_use"] / 2**30, 2) limit = round(stats["bytes_limit"] / 2**30, 2) - print(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + max_logging.log(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") except (RuntimeError, KeyError, TypeError) as ex: - print(f"\tMemstats unavailable, error: {ex}") + max_logging.log(f"\tMemstats unavailable, error: {ex}") def print_system_information(): diff --git a/MaxText/maxengine_server.py b/MaxText/maxengine_server.py index 586c2362a5..e45c3b6b0e 100644 --- a/MaxText/maxengine_server.py +++ b/MaxText/maxengine_server.py @@ -37,7 +37,7 @@ def main(config): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() - server_config = maxengine_config.get_server_config("MaxtextInterleavedServer", config) + server_config = maxengine_config.get_server_config(config.inference_server, config) metrics_server_config: config_lib.MetricsServerConfig | None = None if config.prometheus_port != 0: diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 14abcd5689..64caf06138 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -239,13 +239,13 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance): """ total_num_params = max_utils.calculate_num_params_from_pytree(params) product_num_devices_for_weight_sharding = 1 - for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor", "stage", "expert"]: + for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_sequence", "stage", "expert"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( "Number of parameters per chip must not be less than in the ideal sharded " - "scenario across `fsdp`, `fsdp_transpose`,`sequence`, `tensor`, `expert` axes." + "scenario across `fsdp`, `fsdp_transpose`,`sequence`, `tensor`, `tensor_sequence`, `expert` axes." ) unsharded_param_perc = total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 assert unsharded_param_perc < tolerance, ( diff --git a/MaxText/profiler.py b/MaxText/profiler.py index 9a430f2bc1..207f516b17 100644 --- a/MaxText/profiler.py +++ b/MaxText/profiler.py @@ -28,18 +28,27 @@ class Profiler: """Activate/deactivate a profiler based on the 'profiler' config""" - def __init__(self, config, optional_postfix=""): + def __init__(self, config, offset_step=0): self.libcudart = None self.mode = config.profiler - self.upload_all_profiler_results = config.upload_all_profiler_results if self.mode != "": - self.output_path = os.path.join(config.tensorboard_dir, optional_postfix) + self.base_output_dir = config.tensorboard_dir + self.output_path = "" + self.upload_all_profiler_results = config.upload_all_profiler_results + self.profile_cleanly = config.profile_cleanly + self.profile_period = config.profile_periodically_period + self.start_initial_profile_step = self._set_first_profiler_step(config.skip_first_n_steps_for_profiler, offset_step) + self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps) - def activate(self): + def activate(self, blocking_object=None, optional_postfix=""): """Start the profiler. nsys profiler becomes no-op when libcudart.so is not available on the system""" + if self.profile_cleanly and blocking_object is not None: + jax.block_until_ready(blocking_object) if not (self.upload_all_profiler_results or jax.process_index() == 0): return + if self.mode != "": + self.output_path = os.path.join(self.base_output_dir, optional_postfix) if self.mode == "nsys": try: self.libcudart = cdll.LoadLibrary("libcudart.so") @@ -50,9 +59,11 @@ def activate(self): elif self.mode == "xplane": jax.profiler.start_trace(self.output_path) - def deactivate(self): + def deactivate(self, blocking_object=None): """End the profiler. The result is uploaded to the output bucket""" + if self.profile_cleanly and blocking_object is not None: + jax.block_until_ready(blocking_object) if not (self.upload_all_profiler_results or jax.process_index() == 0): return if self.mode == "nsys": @@ -68,3 +79,15 @@ def deactivate(self): max_logging.log("WARNING: gsutil is not installed or not found in the system's PATH. Skipping upload...") elif self.mode == "xplane": jax.profiler.stop_trace() + + def _set_first_profiler_step(self, skip_steps, start_step): + return start_step + skip_steps + + def _set_last_profiler_step(self, profiler_steps, last_job_step): + return min(self.start_initial_profile_step + profiler_steps - 1, last_job_step - 1) + + def should_activate_periodic_profile(self, step): + return self.profile_period > 0 and (step - self.start_initial_profile_step) % self.profile_period == 0 + + def should_deactivate_periodic_profile(self, step): + return self.profile_period > 0 and (step - self.finished_initial_profile_step) % self.profile_period == 0 diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5b203c9129..fd0d177567 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -85,6 +85,17 @@ def validate_profiler_type(s: str) -> None: raise ValueError("Invalid profiler type was passed. Valid options ", valid_profiler_types) +def validate_periodic_profiler(profiler, profile_periodically_period, profiler_steps): + if profile_periodically_period <= 0: + return + if not profiler: + raise ValueError("Periodic profiler requested but no profiler was set, set it via profiler=xplane or profiler=nsys") + if profile_periodically_period < profiler_steps: + raise ValueError( + f"You must set the profile_periodically_period {profile_periodically_period} at least as long profiler_steps {profiler_steps}." + ) + + def validate_model_call_mode(s: str) -> None: valid_model_call_modes = ("", "inference") if s not in valid_model_call_modes: # currently supported attention @@ -94,7 +105,8 @@ def validate_model_call_mode(s: str) -> None: def validate_prefill_and_target_lengths(max_prefill_length: int, max_target_length: int) -> None: if max_prefill_length <= 0: raise ValueError(f"Invalid max_prefill_predict_length {max_prefill_length}, it should be a positive number") - if max_target_length <= max_prefill_length: + if max_target_length < max_prefill_length: + # valid max_target_length = max_prefill_length for existing logit checks raise ValueError( f"Invalid max_target_length {max_target_length}, this should be sum of " f"max_prefill_predict_length ({max_prefill_length}) and max output length expected." @@ -105,6 +117,7 @@ def validate_keys(keys): validate_attention_kernel(keys["attention"]) validate_attention_type(keys["attention_type"]) validate_profiler_type(keys["profiler"]) + validate_periodic_profiler(keys["profiler"], keys["profile_periodically_period"], keys["profiler_steps"]) validate_compute_axis_order(keys["compute_axis_order"]) validate_kv_quant_axis(keys["kv_quant_axis"], keys["quantize_kvcache"]) validate_model_call_mode(keys["model_call_mode"]) @@ -344,6 +357,9 @@ def __init__(self, argv: list[str], **kwargs): max_logging.log(f"Updating keys from model: {keys_from_model}") validate_no_keys_overwritten_twice(keys_from_env_and_command_line, keys_from_model) + # This must be invoked before initializing the backend + raw_keys = validate_and_set_hlo_dump_defaults(raw_keys) + # We initialize the jax distributed system here because it must be done before device backend is initialized. if raw_keys["jax_debug_log_modules"]: jax.config.update("jax_debug_log_modules", raw_keys["jax_debug_log_modules"]) @@ -403,6 +419,8 @@ def user_init(raw_keys): raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"] raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"] + # This is the first command that initializes the backend - it calls + # jax.devices() ( raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"], @@ -494,6 +512,7 @@ def create_parallelisms_list(raw_keys): raw_keys["ici_fsdp_transpose_parallelism"], raw_keys["ici_sequence_parallelism"], raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_sequence_parallelism"], raw_keys["ici_expert_parallelism"], raw_keys["ici_autoregressive_parallelism"], ] @@ -504,6 +523,7 @@ def create_parallelisms_list(raw_keys): raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], raw_keys["dcn_tensor_parallelism"], + raw_keys["dcn_tensor_sequence_parallelism"], raw_keys["dcn_expert_parallelism"], raw_keys["dcn_autoregressive_parallelism"], ] @@ -512,6 +532,26 @@ def create_parallelisms_list(raw_keys): return raw_keys +def validate_and_set_hlo_dump_defaults(raw_keys): + if not raw_keys["dump_hlo"]: + return raw_keys + if os.environ.get("XLA_FLAGS") and raw_keys["dump_hlo_xla_flags"]: + raise ValueError("You must set either XLA_FLAGS or dump_hlo_xla_flags to dump HLO, but not both.") + if not os.environ.get("XLA_FLAGS") and not raw_keys["dump_hlo_xla_flags"]: + raw_keys["dump_hlo_xla_flags"] = f"--xla_dump_to={raw_keys['dump_hlo_local_dir']} --xla_dump_large_constants" + if raw_keys["dump_hlo_module_name"]: + raw_keys["dump_hlo_xla_flags"] = ( + f"{raw_keys['dump_hlo_xla_flags']} --xla_dump_hlo_module_re={raw_keys['dump_hlo_module_name']}" + ) + if not raw_keys["dump_hlo_gcs_dir"]: + raw_keys["dump_hlo_gcs_dir"] = os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], "xla_dump") + else: + raw_keys["dump_hlo_gcs_dir"] = max_utils.add_trailing_slash(raw_keys["dump_hlo_gcs_dir"]) + if not os.environ.get("XLA_FLAGS"): + os.environ["XLA_FLAGS"] = raw_keys["dump_hlo_xla_flags"] + return raw_keys + + def validate_multiple_slices(raw_keys): if ( math.fabs( @@ -523,6 +563,7 @@ def validate_multiple_slices(raw_keys): raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], raw_keys["dcn_tensor_parallelism"], + raw_keys["dcn_tensor_sequence_parallelism"], raw_keys["dcn_expert_parallelism"], raw_keys["dcn_autoregressive_parallelism"], ] @@ -558,6 +599,7 @@ def pipeline_first_axis(raw_keys): raw_keys["ici_fsdp_transpose_parallelism"], raw_keys["ici_sequence_parallelism"], raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_sequence_parallelism"], raw_keys["ici_expert_parallelism"], raw_keys["ici_autoregressive_parallelism"], ] @@ -568,11 +610,24 @@ def pipeline_first_axis(raw_keys): raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], raw_keys["dcn_tensor_parallelism"], + raw_keys["dcn_tensor_sequence_parallelism"], raw_keys["dcn_expert_parallelism"], raw_keys["dcn_autoregressive_parallelism"], ] - mesh_axes = ["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"] - data_sharding = [["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]] + mesh_axes = [ + "stage", + "data", + "fsdp", + "fsdp_transpose", + "sequence", + "tensor", + "tensor_sequence", + "expert", + "autoregressive", + ] + data_sharding = [ + ["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_sequence", "expert", "autoregressive"] + ] raw_keys["ici_parallelism"] = ici_parallelism raw_keys["dcn_parallelism"] = dcn_parallelism @@ -621,7 +676,12 @@ def validate_megablox_parallelism(raw_keys): using_sequence_parallelism(raw_keys) or using_pipeline_parallelism(raw_keys) or using_expert_parallelism(raw_keys) ): raise ValueError("Currently we only support Megablox with data and tensor parallelism.") - tensor_parallelism = raw_keys["ici_tensor_parallelism"] * raw_keys["dcn_tensor_parallelism"] + tensor_parallelism = ( + raw_keys["ici_tensor_parallelism"] + * raw_keys["dcn_tensor_parallelism"] + * raw_keys["ici_tensor_sequence_parallelism"] + * raw_keys["dcn_tensor_sequence_parallelism"] + ) if raw_keys["megablox"] and using_tensor_parallelism(raw_keys) and (raw_keys["emb_dim"] % tensor_parallelism): raise ValueError( f"The embedding dimension {raw_keys['emb_dim']} is not divisible by tensor parallelism setting {tensor_parallelism}." @@ -717,8 +777,10 @@ def calculate_global_batch_sizes( def get_num_target_devices(raw_keys): - compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys.get("compile_topology", "")) - if compile_topology is not None: + # In AOT case compile_topology is set (e.g. is not the empty string), and we determine the + # number of devices from the compile_topology. In non-AOT settings we simply can use jax.devices(). + if raw_keys.get("compile_topology"): + compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys["compile_topology"]) devices_per_slice = compile_topology.devices_per_slice return int(devices_per_slice * raw_keys["compile_topology_num_slices"]) else: @@ -737,7 +799,12 @@ def using_pipeline_parallelism(raw_keys) -> bool: def using_tensor_parallelism(raw_keys) -> bool: - return int(raw_keys["ici_tensor_parallelism"]) > 1 or int(raw_keys["dcn_tensor_parallelism"]) > 1 + return ( + int(raw_keys["ici_tensor_parallelism"]) > 1 + or int(raw_keys["dcn_tensor_parallelism"]) > 1 + or int(raw_keys["ici_tensor_sequence_parallelism"]) > 1 + or int(raw_keys["dcn_tensor_sequence_parallelism"]) > 1 + ) def using_sequence_parallelism(raw_keys) -> bool: diff --git a/MaxText/pytest.ini b/MaxText/pytest.ini index fa8e8142be..fc6c896a99 100644 --- a/MaxText/pytest.ini +++ b/MaxText/pytest.ini @@ -3,11 +3,15 @@ testpaths = tests python_files = *_test.py -addopts = - -rf --import-mode=importlib +addopts = + -rf --import-mode=importlib --strict-markers --ignore=tests/profiler_test.py --ignore=tests/train_smoke_test.py --ignore=tests/train_int8_smoke_test.py --ignore=tests/train_gpu_smoke_test.py -markers = - tpu: marks tests to be run on TPU \ No newline at end of file +markers = + tpu_only: marks tests to be run on TPUs only + gpu_only: marks tests to be run on GPUs only + integration_test: tests exercising larger portions of the system, + including interactions with other systems like GCS, + e.g., end_to_end tests diff --git a/MaxText/scratch_code/mixtral-numerical-verification.ipynb b/MaxText/scratch_code/mixtral-numerical-verification.ipynb new file mode 100644 index 0000000000..0d55263051 --- /dev/null +++ b/MaxText/scratch_code/mixtral-numerical-verification.ipynb @@ -0,0 +1,249 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "bce1951a-8eef-4842-a70f-987b85a3240f", + "metadata": {}, + "outputs": [], + "source": [ + "# installation\n", + "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", + "!pip3 install tokenizers -U\n", + "!pip3 install transformers -U" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9769e847-d838-473d-8d32-1061b3e0f1c8", + "metadata": {}, + "outputs": [], + "source": [ + "# go to maxtext/MaxText for library import\n", + "\n", + "current_dir = %pwd\n", + "working_dir = current_dir.replace(\"scratch_code\", \"\") \n", + "%cd $working_dir" + ] + }, + { + "cell_type": "markdown", + "id": "f1c108fc-d739-471d-9c64-c08151845f06", + "metadata": {}, + "source": [ + "# one layer mixtral model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf8eee59-295e-41f4-8c09-d2177b410ddc", + "metadata": {}, + "outputs": [], + "source": [ + "import pyconfig\n", + "from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n", + "\n", + "pyconfig.initialize(\n", + " [None, \"configs/base.yml\"],\n", + " base_emb_dim=4096,\n", + " base_num_query_heads=32,\n", + " base_num_kv_heads=8,\n", + " base_mlp_dim=14336,\n", + " base_num_decoder_layers=1, # 1 layer for simplicity\n", + " head_dim=128,\n", + " mlp_activations=[\"silu\",\"linear\"],\n", + " vocab_size=32000,\n", + " enable_dropout=False,\n", + " logits_via_embedding=False,\n", + " normalization_layer_epsilon=1.0e-5,\n", + " num_experts=8,\n", + " num_experts_per_tok=2,\n", + " rope_max_timescale=1_000_000,\n", + " decoder_block=\"mistral\",\n", + " run_name=\"moe_test\",\n", + " enable_checkpointing=False,\n", + " dtype=\"bfloat16\",\n", + " weight_dtype=\"bfloat16\",\n", + " megablox=True, # or False\n", + " max_target_length=4,\n", + " max_prefill_predict_length=3,\n", + " per_device_batch_size=1,\n", + " capacity_factor=-1,\n", + " scan_layers=False,\n", + ")\n", + "config_maxtext = pyconfig.config\n", + "\n", + "config_hf = MixtralConfig(\n", + " vocab_size=config_maxtext.vocab_size,\n", + " hidden_size=config_maxtext.emb_dim,\n", + " intermediate_size=config_maxtext.mlp_dim,\n", + " num_hidden_layers=config_maxtext.num_decoder_layers, \n", + " num_attention_heads=config_maxtext.base_num_query_heads,\n", + " num_key_value_heads=config_maxtext.num_kv_heads,\n", + " rms_norm_eps=config_maxtext.normalization_layer_epsilon,\n", + " rope_theta=config_maxtext.rope_max_timescale,\n", + " attention_dropout=0.0,\n", + " num_experts_per_tok=config_maxtext.num_experts_per_tok,\n", + " num_local_experts=config_maxtext.num_experts,\n", + " tie_word_embeddings=config_maxtext.logits_via_embedding,\n", + " output_router_logits=False,\n", + " router_aux_loss_coef=0.001,\n", + " router_jitter_noise=0.0,\n", + " torch_dtype=\"bfloat16\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c94c857a-2efd-48f3-9669-aef926329cbd", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, set_seed\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from layers.models import Transformer\n", + "import max_utils\n", + "from jax.sharding import Mesh\n", + "\n", + "# ensure the same model initialization\n", + "set_seed(0)\n", + "\n", + "model_hf = AutoModelForCausalLM.from_config(config_hf)\n", + "\n", + "devices_array = max_utils.create_device_mesh(config_maxtext)\n", + "mesh = Mesh(devices_array, config_maxtext.mesh_axes)\n", + "prng_key = jax.random.PRNGKey(1234)\n", + "model_maxtext = Transformer(config=config_maxtext, mesh=mesh, quant=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "707df022-ec37-44b3-b203-5f938151c6ca", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "input_np = {\n", + " 'inputs': np.random.randint(0, config_maxtext.vocab_size, size=(int(config_maxtext.per_device_batch_size), config_maxtext.max_target_length)),\n", + " 'inputs_position': np.tile(np.arange(config_maxtext.max_target_length), (int(config_maxtext.per_device_batch_size), 1)),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "baca50fb-28f2-48b1-b4f5-0145ac6cfe38", + "metadata": {}, + "outputs": [], + "source": [ + "state_maxtext = model_maxtext.init({'params': prng_key, 'dropout': prng_key, 'aqt': prng_key},\n", + " jnp.array(input_np['inputs']),\n", + " jnp.array(input_np['inputs_position']),\n", + " enable_dropout=config_maxtext.enable_dropout,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74e8353b-b87a-4c5e-9a7c-138052249250", + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "from flax import linen as nn\n", + "\n", + "state_map = {\n", + " \"['params']['decoder']['decoder_norm']['scale'].value\": (\"model.norm.weight\", lambda x: x), \n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['gate']['kernel'].value\": (\"model.layers.0.block_sparse_moe.gate.weight\", lambda x: x.T),\n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_0'].value\": (\"model.layers.0.block_sparse_moe.experts..w1.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_1'].value\": (\"model.layers.0.block_sparse_moe.experts..w3.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n", + " \"['params']['decoder']['layers_0']['MoeBlock_0']['wo'].value\": (\"model.layers.0.block_sparse_moe.experts..w2.weight\", lambda *x: torch.stack(*x, dim=0).transpose(1,2)),\n", + " \"['params']['decoder']['layers_0']['post_self_attention_layer_norm']['scale'].value\": (\"model.layers.0.post_attention_layernorm.weight\", lambda x: x),\n", + " \"['params']['decoder']['layers_0']['pre_self_attention_layer_norm']['scale'].value\": (\"model.layers.0.input_layernorm.weight\", lambda x:x),\n", + " \"['params']['decoder']['layers_0']['self_attention']['key']['kernel'].value\": (\"model.layers.0.self_attn.k_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),\n", + " \"['params']['decoder']['layers_0']['self_attention']['out']['kernel'].value\": (\"model.layers.0.self_attn.o_proj.weight\", lambda x:x.T.reshape(config_hf.num_attention_heads, config_maxtext.head_dim, config_hf.hidden_size)),\n", + " \"['params']['decoder']['layers_0']['self_attention']['query']['kernel'].value\": (\"model.layers.0.self_attn.q_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_attention_heads, config_maxtext.head_dim) / np.sqrt(config_maxtext.head_dim)),\n", + " \"['params']['decoder']['layers_0']['self_attention']['value']['kernel'].value\": (\"model.layers.0.self_attn.v_proj.weight\", lambda x:x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim)),\n", + " \"['params']['decoder']['logits_dense']['kernel'].value\": (\"lm_head.weight\", lambda x:x.T),\n", + " \"['params']['token_embedder']['embedding'].value\": (\"model.embed_tokens.weight\", lambda x:x),\n", + " }\n", + "\n", + "state_hf = model_hf.state_dict()\n", + "def map_fn(key_path, value):\n", + " key_path_str = jax.tree_util.keystr(key_path)\n", + " torch_key, transform_fn = state_map[key_path_str]\n", + " if \"\" in torch_key:\n", + " torch_tensors = [state_hf[torch_key.replace(\"\", str(i))] for i in range(config_hf.num_local_experts)]\n", + " else:\n", + " torch_tensors = state_hf[torch_key]\n", + " \n", + " torch_tensors = transform_fn(torch_tensors)\n", + "\n", + " assert value.shape == torch_tensors.shape, f\"{key_path_str}, {value.shape}, {torch_tensors.shape}\"\n", + " new_value = jnp.array(torch_tensors.to(torch.float32).numpy(), dtype=value.dtype)\n", + " if isinstance(value, nn.LogicallyPartitioned):\n", + " new_value = value.replace_boxed(new_value)\n", + " return new_value\n", + "\n", + "loaded_state_maxtext = jax.tree_util.tree_map_with_path(map_fn, state_maxtext)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1f88708-c3a6-4b95-bc51-94adfebdf2aa", + "metadata": {}, + "outputs": [], + "source": [ + "logits_hf = model_hf(torch.from_numpy(input_np['inputs'])).logits.detach()\n", + "\n", + "logits_maxtext = model_maxtext.apply(\n", + " loaded_state_maxtext,\n", + " input_np['inputs'],\n", + " input_np['inputs_position'],\n", + " enable_dropout=False,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1207375a-b92c-4a8c-975a-21f2f027d91e", + "metadata": {}, + "outputs": [], + "source": [ + "# currently, pass the following tests in both \"megablox=True\" & \"megablox=False capacity_factor=-1\"\n", + "\n", + "np.testing.assert_allclose(np.array(logits_maxtext), logits_hf.numpy(), rtol=1e-1, atol=1e-1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index f04ebe1904..2cf47169ef 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -118,7 +118,7 @@ def get_structured_data(self, dtype): return lnx, decoder_segment_ids, decoder_positions - @pytest.mark.tpu + @pytest.mark.tpu_only def test_autoregression(self): prefill_length = self.cfg.max_prefill_predict_length decode_total_length = self.cfg.max_target_length @@ -174,11 +174,11 @@ def test_autoregression(self): self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_model_mode_prefill_dtype_float32(self): self._test_model_mode_prefill_dtype(jnp.float32) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_model_mode_prefill_dtype_bfloat16(self): self._test_model_mode_prefill_dtype(jnp.bfloat16) @@ -224,15 +224,15 @@ def _test_model_mode_prefill_dtype(self, dtype): self.assertEqual(dtype, mha_prefill.dtype) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tpu_kernel_attention_mha(self): self.tpu_kernel_attention_helper(self.num_kv_heads) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tpu_kernel_attention_gqa(self): self.tpu_kernel_attention_helper(self.num_kv_heads // 2) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tpu_kernel_attention_mqa(self): self.tpu_kernel_attention_helper(1) @@ -309,7 +309,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_dot_product_cache_axis_order(self): all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))] for axis_order in random.choices(all_axis_orders, k=4): @@ -423,7 +423,7 @@ def _dot_product_attention( jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_dot_product_reshape_q(self): for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: self._dot_product_attention_reshape_q( diff --git a/MaxText/tests/decode_tests.py b/MaxText/tests/decode_tests.py new file mode 100644 index 0000000000..c86f47e6aa --- /dev/null +++ b/MaxText/tests/decode_tests.py @@ -0,0 +1,98 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Tests for decode with various configs""" +import os +import unittest +import pytest +from decode import main as decode_main +from absl.testing import absltest + + +class DecodeTests(unittest.TestCase): + """Tests decode with various configs""" + + CONFIGS = { + "base": [ # tests decode + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "ici_tensor_parallelism=4", + "max_target_length=128", + "per_device_batch_size=1", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "int8": [ # tests decode with int8 quantization + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "ici_tensor_parallelism=4", + "max_target_length=128", + "per_device_batch_size=1", + "quantization=int8", + "quantize_kvcache=True", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "pdb_lt_1": [ # tests decode with per_device_batch_size < 1 + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "ici_tensor_parallelism=4", + "max_target_length=128", + "per_device_batch_size=.25", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + } + + @pytest.mark.tpu_only + def test_tpu_base(self): + decode_main(DecodeTests.CONFIGS["base"]) + + @pytest.mark.gpu_only + def test_gpu_base(self): + decode_main(DecodeTests.CONFIGS["base"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_int8(self): + decode_main(DecodeTests.CONFIGS["int8"]) + + @pytest.mark.gpu_only + def test_gpu_int8(self): + decode_main(DecodeTests.CONFIGS["int8"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_pdb_lt_1(self): + decode_main(DecodeTests.CONFIGS["pdb_lt_1"]) + + @pytest.mark.gpu_only + def test_gpu_pdb_lt_1(self): + decode_main(DecodeTests.CONFIGS["pdb_lt_1"] + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index b1f0bed521..fea40b9e09 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -85,7 +85,7 @@ def setUp(self): } self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_logits_numerically(self): # ground truth values are calculated from paxml after loading above model_vars # note we expect all xents are the same except the padding one since: @@ -108,4 +108,7 @@ def test_logits_numerically(self): # Mask out paddings at the end of each example. per_example_xent = per_example_xent * (self.example_batch["targets_segmentation"] != 0) - self.assertTrue(jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03)) + self.assertTrue( + jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03), + msg=f"per_example_xent:\n{per_example_xent}\n\nper_example_xent_truth:\n{per_example_xent_truth}", + ) diff --git a/MaxText/tests/gradient_accumulation_test.py b/MaxText/tests/gradient_accumulation_test.py index e2730fe971..29c0ab0876 100644 --- a/MaxText/tests/gradient_accumulation_test.py +++ b/MaxText/tests/gradient_accumulation_test.py @@ -28,7 +28,7 @@ def generate_random_string(length=10): class GradientAccumulationTest(unittest.TestCase): - @pytest.mark.tpu + @pytest.mark.tpu_only def test_grad_accumulate_same_loss(self): random_suffix = generate_random_string() run_accumulate_metrics_file = f"/tmp/runner_grad_accumulate_{random_suffix}.txt" diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index c28de3dcc6..43ceb82a15 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -23,7 +23,7 @@ class Inference_Microbenchmark(unittest.TestCase): - @pytest.mark.tpu + @pytest.mark.tpu_only def test(self): pyconfig.initialize( [ diff --git a/MaxText/tests/integration_tests/checkpoint_compatibility_test.py b/MaxText/tests/integration_tests/checkpoint_compatibility_test.py new file mode 100644 index 0000000000..0470d70938 --- /dev/null +++ b/MaxText/tests/integration_tests/checkpoint_compatibility_test.py @@ -0,0 +1,48 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion tests for test_checkpointing.sh""" +from datetime import datetime +import subprocess +import pytest + + +def run_checkpoint_compatibility(attention_type): + """Tests checkpoint compatibility.""" + + run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + command = [ + "bash", + "end_to_end/test_checkpoint_compatibility.sh", + f"runner_{run_date}", # run_name + r"gs://runner-maxtext-logs", # output_path + r"gs://maxtext-dataset", # dataset_path + attention_type, + ] + + subprocess.run(command, check=True, cwd="..") + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +def test_autoselected_attention(): + run_checkpoint_compatibility("autoselected") + + +@pytest.mark.integration_test +@pytest.mark.gpu_only +def test_with_dot_product(): + run_checkpoint_compatibility("dot_product") diff --git a/MaxText/tests/integration_tests/checkpointing_test.py b/MaxText/tests/integration_tests/checkpointing_test.py new file mode 100644 index 0000000000..01db18cba2 --- /dev/null +++ b/MaxText/tests/integration_tests/checkpointing_test.py @@ -0,0 +1,51 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion tests for test_checkpointing.sh""" +from datetime import datetime +import subprocess +import pytest + + +def run_checkpointing(attention_type): + """Tests grain checkpoint determinism.""" + + run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + + command = [ + "bash", + "end_to_end/test_checkpointing.sh", + f"runner_{run_date}", # run_name + r"gs://runner-maxtext-logs", # output_path + r"gs://maxtext-dataset", # dataset_path + "False", # collect_stack_trace + "grain", # dataset_type + attention_type, + ] + + subprocess.run(command, check=True, cwd="..") + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +def test_autoselected_attention(): + run_checkpointing("autoselected") + + +@pytest.mark.integration_test +@pytest.mark.gpu_only +def test_with_dot_product(): + run_checkpointing("dot_product") diff --git a/MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py b/MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py new file mode 100644 index 0000000000..6afb389c7b --- /dev/null +++ b/MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py @@ -0,0 +1,53 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion tests for test_generate_param_only_checkpoint.sh""" +from datetime import datetime +import subprocess +import pytest + + +def run_generate_param_only_checkpoint(attention_type, quantization): + """Tests generating a parameter-only checkpoint.""" + + run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + # fmt: off + command = [ + "bash", + "end_to_end/test_generate_param_only_checkpoint.sh", + "-r", f"runner_{run_date}", + "-o", r"gs://runner-maxtext-logs", + "-d", r"gs://maxtext-dataset", + "-i", "4", + "-a", attention_type, + "-q", quantization, + ] + + subprocess.run(command, check=True, cwd="..") + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +@pytest.mark.parametrize("quantization", [(""), ("int8")]) +def test_autoselected_attention(quantization): + run_generate_param_only_checkpoint("autoselected", quantization) + + +@pytest.mark.integration_test +@pytest.mark.gpu_only +@pytest.mark.parametrize("quantization", [(""), ("int8")]) +def test_with_dot_product(quantization): + run_generate_param_only_checkpoint("dot_product", quantization) diff --git a/MaxText/tests/integration_tests/shmap_collective_matmul_test.py b/MaxText/tests/integration_tests/shmap_collective_matmul_test.py new file mode 100644 index 0000000000..55658e414b --- /dev/null +++ b/MaxText/tests/integration_tests/shmap_collective_matmul_test.py @@ -0,0 +1,32 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion test for pedagogical_examples/shmap_collective_matmul.py""" +import subprocess +import pytest + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +def test_shmap_collective_matmul_example(): + """Validate Pedagogical Example, Shmap_collective_matmul.""" + + command = [ + "python3", + "pedagogical_examples/shmap_collective_matmul.py", + ] + + subprocess.run(command, check=True, cwd="..") diff --git a/MaxText/tests/kernels_test.py b/MaxText/tests/kernels_test.py index 5ec2d1c17d..6313aa884a 100644 --- a/MaxText/tests/kernels_test.py +++ b/MaxText/tests/kernels_test.py @@ -38,7 +38,7 @@ class RaggedAttentionTest(unittest.TestCase): key = jax.random.key(0) k1, k2, k3 = jax.random.split(key, 3) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_ragged_mqa(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.head_dim), dtype=self.dtype) k = jax.random.normal(self.k2, (self.batch_size, self.max_target_length, self.head_dim), dtype=self.dtype) @@ -56,7 +56,7 @@ def test_ragged_mqa(self): msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2", ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_ragged_mha(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype) k = jax.random.normal( @@ -79,7 +79,7 @@ def test_ragged_mha(self): msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2", ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_ragged_gqa(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype) k = jax.random.normal( diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index 9791af93ba..ed1eecdb69 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -105,7 +105,7 @@ def test_logits_dtype_with_cast_to_fp32(self): def test_logits_dtype_without_cast(self): self._test_logits_cast_driver(cast_logits_to_fp32=False, expected_dtype=jnp.bfloat16) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_train_vs_prefill_and_autoregress(self): PREFILL_RANGE = MAX_PREFILL_PREDICT_LENGTH diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index ba289c040d..297d753708 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -62,7 +62,7 @@ def setUp(self): dataset = dataset.batch(batch_size) self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_batch_sharded_data_pipeline(self): first_batch = next(self.multihost_gen) sec_batch = next(self.multihost_gen) diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 193c677fa9..0e1e18d251 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -150,7 +150,7 @@ def regular_sequential_layers_dummy_loss( dummy_targets, ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_circular_minimum_microbatches_same_output_and_grad(self): # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches pyconfig.initialize( @@ -167,7 +167,7 @@ def test_circular_minimum_microbatches_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_circular_extra_microbatches_same_output_and_grad(self): # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches pyconfig.initialize( @@ -184,7 +184,7 @@ def test_circular_extra_microbatches_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_non_circular_same_output_and_grad(self): # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches pyconfig.initialize( @@ -201,7 +201,7 @@ def test_non_circular_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_full_train_circular(self): # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches train_main( @@ -231,7 +231,7 @@ def test_full_train_circular(self): ] ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_delay_activation_forwarding_same_output_and_grad(self): # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches pyconfig.initialize( @@ -249,7 +249,7 @@ def test_delay_activation_forwarding_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_full_train_non_circular(self): # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches train_main( diff --git a/MaxText/tests/profiler_test.py b/MaxText/tests/profiler_test.py index 04ddeff5ce..6eb0568f34 100644 --- a/MaxText/tests/profiler_test.py +++ b/MaxText/tests/profiler_test.py @@ -14,66 +14,86 @@ limitations under the License. """ -"""Profiler tests for TPUs.""" -import glob -import json -import os +"""Profiler tests.""" +import sys import unittest +import pytest -from tensorboard_plugin_profile.convert import raw_to_tool_data - - -class TpuJAXTest(unittest.TestCase): - """Test for profile collected with JAX.""" - - def _get_session_snapshot(self): - """Gets a session snapshot of current session. assume only one session.""" - profile_plugin_root = "tensorboard/plugins/profile" - # The session exists under a director whose name is time-dependent. - profile_session_glob = os.path.join(profile_plugin_root, "*", "*.xplane.pb") - return glob.glob(profile_session_glob) - - def test_xplane_is_present(self): - files = self._get_session_snapshot() - self.assertEqual(len(files), 1) - - def test_overview_page(self): - xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "overview_page^", {}) - result = json.loads(result) - run_environment = result[2] - self.assertEqual(run_environment["p"]["host_count"], "1") - self.assertRegex(run_environment["p"]["device_type"], "TPU.*") - - def test_op_profile(self): - xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "op_profile^", {}) - result = json.loads(result) - self.assertIn("byCategory", result) - self.assertIn("metrics", result["byCategory"]) - overall_metrics = result["byCategory"]["metrics"] - self.assertIn("flops", overall_metrics) - self.assertIn("bandwidthUtils", overall_metrics) - self.assertGreater(overall_metrics["flops"], 0) - - def test_device_trace_contains_threads(self): - xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "trace_viewer^", {}) - result = json.loads(result) - thread_names = [] - for event in result["traceEvents"]: - if "name" in event and event["name"] == "thread_name": - thread_names.append((event["args"]["name"])) - expected_threads = [ - "Framework Name Scope", - "Framework Ops", - "XLA Modules", - "XLA Ops", - "XLA TraceMe", - "Steps", - ] - # Ensure that thread_names contains at least all expected threads. - self.assertEqual(set(expected_threads) - set(thread_names), set()) +import profiler +import pyconfig + + +class ProfilerTest(unittest.TestCase): + """Test for profiler.""" + + # These periodic proilfer tests can run on any platform (cpu, gpu or tpu) + @pytest.mark.tpu_only + def test_periodic_profiler_third_period_starts(self): + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="test_periodic_profiler_starts_after_regular_profile", + profiler="xplane", + skip_first_n_steps_for_profiler=7, + profiler_steps=4, + profile_periodically_period=5, + ) + config = pyconfig.config + prof = profiler.Profiler(config, offset_step=2) + + step = 24 # 3 * 5 + 7 + 2: 3 periods of 5 after skipping initial 7 skip + 2 offset. + assert prof.should_activate_periodic_profile(step) + + @pytest.mark.tpu_only + def test_periodic_profiler_not_start_middle_period(self): + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="test_periodic_profiler_starts_after_regular_profile", + profiler="xplane", + skip_first_n_steps_for_profiler=7, + profiler_steps=4, + profile_periodically_period=5, + ) + config = pyconfig.config + prof = profiler.Profiler(config, offset_step=2) + + step = 25 # This corresponds to the middle of period 3 which started at step 24. + assert not prof.should_activate_periodic_profile(step) + + @pytest.mark.tpu_only + def test_periodic_profiler_third_period_ends(self): + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="test_periodic_profiler_starts_after_regular_profile", + profiler="xplane", + skip_first_n_steps_for_profiler=7, + profiler_steps=4, + profile_periodically_period=5, + ) + config = pyconfig.config + prof = profiler.Profiler(config, offset_step=2) + + step = 27 # 3 * 5 + 4 + 7 + 2: 3 periods of 5, profile takes 4 steps + skipping initial 7 skip + 2 offset + assert prof.should_deactivate_periodic_profile(step) + + @pytest.mark.tpu_only + def test_periodic_profiler_third_period_middle_not_end(self): + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + enable_checkpointing=False, + run_name="test_periodic_profiler_starts_after_regular_profile", + profiler="xplane", + skip_first_n_steps_for_profiler=7, + profiler_steps=4, + profile_periodically_period=5, + ) + config = pyconfig.config + prof = profiler.Profiler(config, offset_step=2) + + step = 28 # Corresponds to 1 after the third period ended. + assert not prof.should_deactivate_periodic_profile(step) if __name__ == "__main__": diff --git a/MaxText/tests/quantizations_test.py b/MaxText/tests/quantizations_test.py index 2218c34496..c399855bdf 100644 --- a/MaxText/tests/quantizations_test.py +++ b/MaxText/tests/quantizations_test.py @@ -27,6 +27,7 @@ from aqt.jax.v2 import calibration _QUERY_REGEX = ".*/query" +_VALUE_REGEX = ".*/value" class QuantTestModule(nn.Module): @@ -147,6 +148,11 @@ def test_mixed_precision_config_subchannel(self): self.assertEqual(quant_cfg.fwd.dg_quantizer.rhs.numerics.bits, 4) self.assertEqual(tile_size, 128) + quant_cfg, tile_size = quant.quant_dg[_VALUE_REGEX] + self.assertEqual(quant_cfg.fwd.dg_quantizer.lhs.numerics.bits, 8) + self.assertEqual(quant_cfg.fwd.dg_quantizer.rhs.numerics.bits, 4) + self.assertEqual(tile_size, -1) + def test_remove_quantized_params(self): _params = { "decoder": { diff --git a/MaxText/tests/simple_decoder_layer_test.py b/MaxText/tests/simple_decoder_layer_test.py index afa2d0aeb4..ba6fa7c3ca 100644 --- a/MaxText/tests/simple_decoder_layer_test.py +++ b/MaxText/tests/simple_decoder_layer_test.py @@ -18,7 +18,7 @@ class SimpleDecoderLayerTest(unittest.TestCase): - @pytest.mark.tpu + @pytest.mark.tpu_only def test_simple_decoder_layer(self): train_main( [ @@ -34,7 +34,7 @@ def test_simple_decoder_layer(self): ] ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_mlp_decoder_layer(self): train_main( [ diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index 1bd774946e..652b652e17 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -34,7 +34,7 @@ def _get_random_test_name(self, test_name): random_run_name = test_name + date_time + random_string return random_run_name - @pytest.mark.tpu + @pytest.mark.tpu_only def test_standalone_dataloader(self): random_run_name = self._get_random_test_name("standalone_dataloader") sdl_main( @@ -50,7 +50,7 @@ def test_standalone_dataloader(self): ) ) # need to pass relative path to tokenizer - @pytest.mark.tpu + @pytest.mark.tpu_only def test_standalone_checkpointer(self): random_run_name = self._get_random_test_name("standalone_checkpointer") # checkpoint at 50 diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index c5222f0dea..a64888e428 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -58,12 +58,12 @@ def tearDownClass(cls): os.remove(cls.tokenizer_path) @pytest.mark.skip(reason="mohitkhatwani@ will fix this") - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tokenize(self): text = "This is a test" self.assertTrue(np.array_equal(self.source_tokenizer.encode(text).numpy(), self.test_tokenizer.encode(text).numpy())) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_detokenize(self): tokens = [66, 12, 10, 698] self.assertEqual(np.asarray(self.source_tokenizer.decode(tokens)), np.asarray(self.test_tokenizer.decode(tokens))) @@ -86,13 +86,13 @@ def setUpClass(cls): train_ds_builder = tfds.builder(dataset_name) cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tokenize(self): text = "This is a test" tokens = [2028, 374, 264, 1296] self.assertTrue(np.array_equal(self.source_tokenizer.encode(text), tokens)) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_detokenize(self): tokens = [2028, 374, 264, 1296] text = "This is a test" diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 87977079c7..880a7a30c2 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -24,7 +24,7 @@ class TrainCompile(unittest.TestCase): """Tests for the Ahead of Time Compilation functionality, train_compile.py""" - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v4(self): compiled_trainstep_file = "/tmp/test_compiled_v4.pickle" train_compile_main( @@ -40,7 +40,7 @@ def test_save_compiled_v4(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v5e(self): compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle" train_compile_main( @@ -79,7 +79,7 @@ def test_minimal_offloaded_v5e(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v5p_two_slices(self): compiled_trainstep_file = "/tmp/test_compiled_v5p_two_slices.pickle" train_compile_main( @@ -97,7 +97,7 @@ def test_save_compiled_v5p_two_slices(self): # TODO (b/374764692) : Enable when v6e AOT test when stable Jax supports v6e AOT. @pytest.mark.skip(reason="Enable when downstream v6e AOT support reaches stable Jax.") - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v6e(self): compiled_trainstep_file = "/tmp/test_compiled_v6e.pickle" train_compile_main( @@ -113,7 +113,7 @@ def test_save_compiled_v6e(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_sequence_parallelism(self): compiled_trainstep_file = "/tmp/test_compiled.pickle" train_compile_main( @@ -131,7 +131,7 @@ def test_sequence_parallelism(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_save_dot_except_mlpwi(self): compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlpwi.pickle" train_compile_main( @@ -153,7 +153,7 @@ def test_remat_save_dot_except_mlpwi(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_save_dot_except_mlp(self): compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlp.pickle" train_compile_main( @@ -175,7 +175,7 @@ def test_remat_save_dot_except_mlp(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_save_qkv_proj(self): compiled_trainstep_file = "/tmp/test_remat_save_qkv_proj.pickle" train_compile_main( @@ -197,7 +197,7 @@ def test_remat_save_qkv_proj(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_full(self): compiled_trainstep_file = "/tmp/test_remat_full.pickle" train_compile_main( @@ -219,7 +219,7 @@ def test_remat_full(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_custom_64x4_mesh(self): compiled_trainstep_file = "/tmp/test_custom_64x4_mesh.pickle" train_compile_main( @@ -241,7 +241,7 @@ def test_custom_64x4_mesh(self): # TODO (b/376470419) : Enable when AOT test work with host offloading. @pytest.mark.skip(reason="Enable when AOT test work with host offloading.") - @pytest.mark.tpu + @pytest.mark.tpu_only def test_llama3_1_70b_opt_offload(self): compiled_trainstep_file = "/tmp/test_llama3_1_70b_opt_offload.pickle" train_compile_main( @@ -259,7 +259,7 @@ def test_llama3_1_70b_opt_offload(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_custom_32x8_mesh(self): compiled_trainstep_file = "/tmp/test_custom_32x8_mesh.pickle" train_compile_main( diff --git a/MaxText/tests/train_tests.py b/MaxText/tests/train_tests.py new file mode 100644 index 0000000000..d09e01d05a --- /dev/null +++ b/MaxText/tests/train_tests.py @@ -0,0 +1,184 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Tests for train.py with various configs""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class TrainTests(unittest.TestCase): + """Tests train.py with various configs""" + + CONFIGS = { + "base": [ # short test for train.py with TFDS c4 + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "synthetic": [ # tests base config with synthtic dataset + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "dataset_type=synthetic", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "pdb_lt_1": [ # tests base config with per_device_batch_size < 1 + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "per_device_batch_size=0.25", + "ici_tensor_parallelism=4", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "int8": [ # tests base config with int8 + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "quantization=int8", + "steps=2", + "enable_checkpointing=False", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "fp8": [ # tests base config with fp8 + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "quantization=fp8", + "steps=2", + "enable_checkpointing=False", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "dropout": [ # tests base config with dropout + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "max_target_length=128", + "per_device_batch_size=1", + "dropout_rate=0.02", + r"tokenizer_path=../assets/tokenizer.llama2", + ], + "hf_input_pipeline": [ # test for train.py with TFDS c4, using HF input pipeline + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + "steps=2", + "enable_checkpointing=False", + "dataset_type=hf", + "hf_path=parquet", + r"hf_train_files=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet", + r"tokenizer_path=google-t5/t5-large", + ], + } + + @pytest.mark.tpu_only + def test_tpu_base(self): + train_main(TrainTests.CONFIGS["base"]) + + @pytest.mark.gpu_only + def test_gpu_base(self): + train_main(TrainTests.CONFIGS["base"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_synthetic(self): + train_main(TrainTests.CONFIGS["synthetic"]) + + @pytest.mark.gpu_only + def test_gpu_synthetic(self): + train_main(TrainTests.CONFIGS["synthetic"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_pdb_lt_1(self): + train_main(TrainTests.CONFIGS["pdb_lt_1"]) + + @pytest.mark.gpu_only + def test_gpu_pdb_lt_1(self): + train_main(TrainTests.CONFIGS["pdb_lt_1"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_int8(self): + train_main(TrainTests.CONFIGS["int8"]) + + @pytest.mark.gpu_only + def test_gpu_int8(self): + train_main(TrainTests.CONFIGS["int8"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_fp8(self): + train_main(TrainTests.CONFIGS["fp8"]) + + @pytest.mark.gpu_only + def test_gpu_fp8(self): + train_main(TrainTests.CONFIGS["fp8"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_dropout(self): + train_main(TrainTests.CONFIGS["dropout"]) + + @pytest.mark.gpu_only + def test_gpu_dropout(self): + train_main(TrainTests.CONFIGS["dropout"] + ["attention=dot_product"]) + + @pytest.mark.tpu_only + def test_tpu_hf_input_pipeline(self): + train_main(TrainTests.CONFIGS["hf_input_pipeline"]) + + @pytest.mark.gpu_only + def test_gpu_hf_input_pipeline(self): + train_main(TrainTests.CONFIGS["hf_input_pipeline"] + ["attention=dot_product"]) + + @pytest.mark.gpu_only + def test_gpu_cudnn_flash_te(self): + cudnn_flash_te = [ # tests base config on GPU with flash attention""" + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "attention=cudnn_flash_te", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + train_main(cudnn_flash_te) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/train.py b/MaxText/train.py index 4443b10dde..b52e93ee32 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -648,16 +648,23 @@ def setup_mesh_and_model(config): tx = optimizers.get_optimizer(config, learning_rate_schedule) logger = checkpointing.setup_checkpoint_logger(config) if config.enable_emergency_checkpoint: - abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) - checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( - config.local_checkpoint_directory, - config.checkpoint_dir, - mesh, - abstract_state, - config.local_checkpoint_period, - config.checkpoint_period, - logger, - ) + if config.use_replicator_service: + checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( + config.local_checkpoint_directory, + config.local_checkpoint_period, + mesh, + ) + else: + abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( + config.local_checkpoint_directory, + config.checkpoint_dir, + mesh, + abstract_state, + config.local_checkpoint_period, + config.checkpoint_period, + logger, + ) else: # TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend use_ocdbt = config.checkpoint_storage_use_ocdbt @@ -843,19 +850,19 @@ def train_loop(config, state=None): running_gcs_metrics = [] if config.gcs_metrics else None start_step = get_first_step(state) # this is the start_step for training - first_profiling_step = start_step + config.skip_first_n_steps_for_profiler + prof = profiler.Profiler(config, offset_step=start_step) + first_profiling_step = prof.start_initial_profile_step if config.profiler != "" and first_profiling_step >= config.steps: raise ValueError("Profiling requested but initial profiling step set past training final step") - last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1) + last_profiling_step = prof.finished_initial_profile_step example_batch = None last_step_completion = datetime.datetime.now() - prof = profiler.Profiler(config) + for step in np.arange(start_step, config.steps): - if step == first_profiling_step: - if config.profile_cleanly: - jax.block_until_ready(state) # Block until previous state finishes to start profile cleanly - prof.activate() + if step == first_profiling_step or prof.should_activate_periodic_profile(step): + optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" + prof.activate(blocking_object=state, optional_postfix=optional_postfix) with jax.profiler.StepTraceAnnotation("train", step_num=step): record_goodput(recorder, config, recorder.record_data_loading_start_time if recorder else None) @@ -877,7 +884,7 @@ def train_loop(config, state=None): if checkpoint_manager is not None: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config): - max_logging.log(f"saved a checkpoint at step {step}") + checkpointing.print_save_message(step, config.async_checkpointing) # Upon preemption, exit when and only when all ongoing saves are complete. if checkpoint_manager.reached_preemption(step): @@ -886,6 +893,16 @@ def train_loop(config, state=None): write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config) + if config.dump_hlo and step == start_step: + jax.block_until_ready(state) # Ensure compilation has finished. + max_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: assert eval_data_iterator cumulative_eval_metrics = { @@ -930,10 +947,11 @@ def train_loop(config, state=None): prof.deactivate() break - if step == last_profiling_step: - if config.profile_cleanly: - jax.block_until_ready(state) # Block until current state finishes to end profile cleanly - prof.deactivate() + if step == last_profiling_step or prof.should_deactivate_periodic_profile(step): + prof.deactivate(blocking_object=state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") if checkpoint_manager is not None: checkpoint_manager.wait_until_finished() @@ -941,6 +959,15 @@ def train_loop(config, state=None): max_utils.close_summary_writer(writer) record_goodput(recorder, config, recorder.record_job_end_time if recorder else None) clear_buffered_metrics() + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + compiled = p_train_step.lower(state, example_batch, nextrng).compile() + compiled_stats = compiled.memory_analysis() + max_logging.log( + f"Output size: {compiled_stats.output_size_in_bytes}, " + f"temp size: {compiled_stats.temp_size_in_bytes}, " + f"argument size: {compiled_stats.argument_size_in_bytes}, " + f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes." + ) return state @@ -953,8 +980,8 @@ def main(argv: Sequence[str]) -> None: if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" pyconfig.initialize(argv) - max_utils.print_system_information() config = pyconfig.config + max_utils.print_system_information() validate_train_config(config) os.environ["TFDS_DATA_DIR"] = config.dataset_path vertex_tensorboard_manager = VertexTensorboardManager() diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index a6dc4682fc..08a33462f9 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -183,6 +183,16 @@ def main(argv: Sequence[str]) -> None: print(f"Cost analysis: {compiled.cost_analysis()}") print(f"Memory analysis: {compiled.memory_analysis()}") + # Dump HLO if requested + if config.dump_hlo: + max_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + if __name__ == "__main__": app.run(main) diff --git a/README.md b/README.md index e6edd0ffac..ba5ae950ce 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ --> -[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxtext/actions/workflows/UnitTests.yml) +[![Unit Tests](https://github.com/google/maxtext/actions/workflows/RunTests.yml/badge.svg)](https://github.com/google/maxtext/actions/workflows/RunTests.yml) # Overview @@ -49,7 +49,7 @@ Some extra helpful guides: * [Llama2](https://llama.meta.com/llama2/): a family of open-weights Large Language Model (LLM) by Meta. You can run decode and finetuning using [these instructions](getting_started/Run_Llama2.md). * [Mixtral](https://mistral.ai/news/mixtral-of-experts/): a family of open-weights sparse mixture-of-experts (MoE) model by Mistral AI. You can run decode and finetuning using [these instructions](end_to_end/tpu/mixtral/Run_Mixtral.md) -In addition to the getting started guides, there are always other MaxText capabilities that are being constantly being added! The full suite of end-to-end tests is in [end_to_end](end_to_end). We run them with a nightly cadence. They can be a good source for understanding MaxText Alternatively you can see the continuous [unit tests](.github/workflows/UnitTests.yml) which are run almost continuously. +In addition to the getting started guides, there are always other MaxText capabilities that are being constantly being added! The full suite of end-to-end tests is in [end_to_end](end_to_end). We run them with a nightly cadence. They can be a good source for understanding MaxText Alternatively you can see the continuous [unit tests](.github/workflows/RunTests.yml) which are run almost continuously. # Runtime Performance Results diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 909c362767..8e629fe555 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -94,6 +94,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'gemma2_9b_8192', 'gemma2_27b_8192', 'llama3_1_70b_129024', + 'llama3_1_8b_8192', + 'llama3_1_70b_8192', ], default='llama2_70b_4096', help=( diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index b3606ec8f0..952d803cab 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -412,6 +412,86 @@ class MaxTextModel: ), ) +llama3_1_8b_8192 = MaxTextModel( + model_name="llama3_1-8b-8192", + model_type="llama3.1-8b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), +) + +llama3_1_70b_8192 = MaxTextModel( + model_name="llama3_1-70b-8192", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), +) + llama3_1_70b_129024 = MaxTextModel( model_name="llama3_1-70b-129024", model_type="llama3.1-70b", @@ -619,6 +699,8 @@ class MaxTextModel: llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet llama3_1_405b_8192_fsdp_dcn, + llama3_1_8b_8192, + llama3_1_70b_8192, llama3_1_70b_129024, mixtral_8x7b_dropped, mixtral_8x7b_dropped_int8, diff --git a/benchmarks/xla_flags_library.py b/benchmarks/xla_flags_library.py index 705e838d1c..d18b2a178a 100644 --- a/benchmarks/xla_flags_library.py +++ b/benchmarks/xla_flags_library.py @@ -66,6 +66,16 @@ # " --xla_tpu_enable_offloading_scatter_to_sparsecore=true" ) +ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE = ( + " --xla_sc_disable_megacore_partitioning=true" + " --xla_tpu_enable_all_reduce_offload_tracing=true" + " --xla_tpu_use_tc_device_shape_on_sc=true" + " --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true" + " --xla_sc_enable_instruction_fusion=false" + " --xla_sc_disjoint_spmem=false" + " --2a886c8_chip_config_name=megachip_tccontrol" +) + # Better memory layout for all-reduce LAYOUT_FOR_ALL_REDUCE_SCATTER = ( " --xla_tpu_use_minor_sharding_for_major_trivial_input=true" diff --git a/constraints_gpu.txt b/constraints_gpu.txt index b45174399b..be61774edb 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -77,10 +77,10 @@ importlib_metadata==8.4.0 importlib_resources==6.4.5 iniconfig==2.0.0 isort==5.13.2 -jax==0.4.35 -jax-cuda12-pjrt==0.4.35 -jax-cuda12-plugin==0.4.35 -jaxlib==0.4.34 +jax==0.4.38 +jax-cuda12-pjrt==0.4.38 +jax-cuda12-plugin==0.4.38 +jaxlib==0.4.38 jaxtyping==0.2.34 Jinja2==3.1.4 jsonlines==4.0.0 @@ -108,16 +108,16 @@ networkx==3.4.2 ninja==1.11.1.1 nodeenv==1.9.1 numpy==1.26.4 -nvidia-cublas-cu12==12.6.3.3 +nvidia-cublas-cu12==12.6.4.1 nvidia-cuda-cupti-cu12==12.6.80 -nvidia-cuda-nvcc-cu12==12.6.77 +nvidia-cuda-nvcc-cu12==12.6.85 nvidia-cuda-runtime-cu12==12.6.77 -nvidia-cudnn-cu12==9.5.0.50 +nvidia-cudnn-cu12==9.6.0.74 nvidia-cufft-cu12==11.3.0.4 nvidia-cusolver-cu12==11.7.1.2 nvidia-cusparse-cu12==12.5.4.2 nvidia-nccl-cu12==2.23.4 -nvidia-nvjitlink-cu12==12.6.77 +nvidia-nvjitlink-cu12==12.6.85 oauthlib==3.2.2 opentelemetry-api==1.27.0 opt_einsum==3.4.0 @@ -196,7 +196,7 @@ tomli==2.0.2 tomlkit==0.13.2 toolz==1.0.0 tqdm==4.66.5 -transformer-engine==1.5.0+297459b +transformer-engine==1.13.0+e5edd6c transformers==4.46.0 typeguard==2.13.3 typing_extensions==4.12.2 diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index ce3a5d3c39..24c245cdef 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -31,6 +31,7 @@ set -e export LOCAL_IMAGE_NAME=maxtext_base_image +echo "Building to $LOCAL_IMAGE_NAME" # Use Docker BuildKit so we can cache pip packages. export DOCKER_BUILDKIT=1 @@ -90,7 +91,7 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then build_stable_stack else if [[ ${MODE} == "pinned" ]]; then - export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-05-07 + export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-12-04 else export BASEIMAGE=ghcr.io/nvidia/jax:base fi diff --git a/getting_started/First_run.md b/getting_started/First_run.md index ad34ce64e5..b1d9a99713 100644 --- a/getting_started/First_run.md +++ b/getting_started/First_run.md @@ -13,7 +13,7 @@ We recommend starting with a single host first and then moving to multihost. Local development is a convenient way to run MaxText on a single host. It doesn't scale to multiple hosts. -1. [Create and SSH to the single-host VM of your choice.](https://cloud.google.com/tpu/docs/users-guide-tpu-vm#creating_a_cloud_tpu_vm_with_gcloud) We recommend a `v4-8`. +1. [Create and SSH to the single-host VM of your choice.](https://cloud.google.com/tpu/docs/users-guide-tpu-vm#creating_a_cloud_tpu_vm_with_gcloud). You can use any available single host TPU, such as `v5litepod-8`, `v5p-8` or `v4-8`. 2. Clone MaxText onto that TPUVM. 3. Within the root directory of that `git` repo, install dependencies and pre-commit hook by running: ``` diff --git a/inference/maxengine_server/Dockerfile b/inference/maxengine_server/Dockerfile new file mode 100644 index 0000000000..b94ea499a1 --- /dev/null +++ b/inference/maxengine_server/Dockerfile @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Ubuntu:22.04 +# Use Ubuntu 22.04 from Docker Hub. +# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04 +FROM ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV MAXTEXT_VERSION=main + +RUN apt -y update && apt install -y --no-install-recommends \ + ca-certificates \ + git \ + python3.10 \ + python3-pip + +RUN update-alternatives --install \ + /usr/bin/python3 python3 /usr/bin/python3.10 1 + +RUN git clone https://github.com/AI-Hypercomputer/maxtext.git && \ +git clone https://github.com/AI-Hypercomputer/JetStream.git + +RUN cd maxtext/ && \ +git checkout ${MAXTEXT_VERSION} && \ +bash setup.sh + +RUN cd /JetStream && \ +pip install -e . + +COPY maxengine_server_entrypoint.sh /usr/bin/ + +RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh + +ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"] \ No newline at end of file diff --git a/inference/maxengine_server/README.md b/inference/maxengine_server/README.md new file mode 100644 index 0000000000..40fad64b51 --- /dev/null +++ b/inference/maxengine_server/README.md @@ -0,0 +1,14 @@ +## Build and upload Maxengine Server image + +These instructions are to build the Maxengine Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the MaxText framework. + +``` +docker build -t maxengine-server . +docker tag maxengine-server us-docker.pkg.dev/${PROJECT_ID}/jetstream/maxengine-server:latest +docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/maxengine-server:latest +``` + +If you would like to change the version of MaxText the image is built off of, change the `MAXTEXT_VERSION` environment variable: +``` +ENV MAXTEXT_VERSION= +``` \ No newline at end of file diff --git a/inference/maxengine_server/maxengine_server_entrypoint.sh b/inference/maxengine_server/maxengine_server_entrypoint.sh new file mode 100644 index 0000000000..726926ce05 --- /dev/null +++ b/inference/maxengine_server/maxengine_server_entrypoint.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cd /maxtext +python3 MaxText/maxengine_server.py \ +MaxText/configs/base.yml $@ \ No newline at end of file diff --git a/maxtext_gpu_dependencies.Dockerfile b/maxtext_gpu_dependencies.Dockerfile index 47cd646eb7..b09cf5208f 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/maxtext_gpu_dependencies.Dockerfile @@ -22,7 +22,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}" # Upgrade libcusprase to work with Jax -RUN apt-get update && apt-get install -y libcusparse-12-3 +RUN apt-get update && apt-get install -y libcusparse-12-6 ARG MODE ENV ENV_MODE=$MODE @@ -38,12 +38,13 @@ RUN mkdir -p /deps # Set the working directory in the container WORKDIR /deps -# Copy all files from local workspace into docker container -COPY . . -RUN ls . +# Copy setup files and dependency files separately for better caching +COPY setup.sh ./ +COPY constraints_gpu.txt requirements.txt requirements_with_jax_stable_stack.txt ./ +# Install dependencies - these steps are cached unless the copied files change RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} - -WORKDIR /deps +# Now copy the remaining code (source files that may change frequently) +COPY . . diff --git a/maxtext_runner.Dockerfile b/maxtext_runner.Dockerfile index 665da004c3..6c3a3f0d85 100644 --- a/maxtext_runner.Dockerfile +++ b/maxtext_runner.Dockerfile @@ -1,3 +1,5 @@ +# syntax=docker.io/docker/dockerfile:1.7-labs + ARG BASEIMAGE=maxtext_base_image FROM $BASEIMAGE @@ -6,7 +8,11 @@ FROM $BASEIMAGE # Set the working directory in the container WORKDIR /deps -# Copy all files from local workspace into docker container -COPY . . +# Copy assets separately +COPY assets/ . +COPY MaxText/test_assets/ MaxText/. + +# Copy all files except assets from local workspace into docker container +COPY --exclude=assets --exclude=MaxText/test_assets . . WORKDIR /deps diff --git a/maxtext_transformerengine_builder.Dockerfile b/maxtext_transformerengine_builder.Dockerfile deleted file mode 100644 index 22a66e960a..0000000000 --- a/maxtext_transformerengine_builder.Dockerfile +++ /dev/null @@ -1,11 +0,0 @@ -FROM ghcr.io/nvidia/jax:base - -WORKDIR /root -ENV NVTE_FRAMEWORK=jax - - -RUN git clone https://github.com/NVIDIA/TransformerEngine -WORKDIR /root/TransformerEngine -RUN git checkout 297459bd08e1b791ca7a2872cfa8582220477782 -RUN git submodule update --init --recursive -RUN python setup.py bdist_wheel diff --git a/requirements.txt b/requirements.txt index 269170c1d2..19aa0b8e80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ jax>=0.4.30 jaxlib>=0.4.30 -orbax-checkpoint>=0.5.12,<0.10.3 +orbax-checkpoint>=0.5.12 absl-py array-record aqtp diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index 9c68d62769..a2d24c68ba 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -1,6 +1,7 @@ # Requirements for Building the MaxText Docker Image # These requirements are additional to the dependencies present in the JAX SS base image. absl-py +aqtp==0.8.2 datasets pylint pytest diff --git a/setup.sh b/setup.sh index 4b4ba090eb..9cd898a325 100644 --- a/setup.sh +++ b/setup.sh @@ -31,7 +31,7 @@ export DEBIAN_FRONTEND=noninteractive export NEEDRESTART_SUSPEND=1 export NEEDRESTART_MODE=l - +apt-get update && apt-get install -y sudo (sudo bash || bash) <<'EOF' apt update && \ apt install -y numactl lsb-release gnupg curl net-tools iproute2 procps lsof git ethtool && \ @@ -90,9 +90,9 @@ run_name_folder_path=$(pwd) # Install dependencies from requirements.txt cd $run_name_folder_path && pip install --upgrade pip if [[ "$MODE" == "pinned" ]]; then - pip3 install -U -r requirements.txt -c constraints_gpu.txt + pip3 install --no-cache-dir -U -r requirements.txt -c constraints_gpu.txt else - pip3 install -U -r requirements.txt + pip3 install --no-cache-dir -U -r requirements.txt fi # Uninstall existing jax, jaxlib and libtpu-nightly @@ -110,11 +110,10 @@ if [[ "$MODE" == "pinned" ]]; then echo "pinned mode is supported for GPU builds only." exit 1 fi - echo "Installing pinned jax, jaxlib for NVIDIA gpu." + echo "Installing Jax and Transformer Engine." pip3 install "jax[cuda12]" -c constraints_gpu.txt - pip3 install "transformer-engine==1.5.0+297459b" \ - --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ \ - -c constraints_gpu.txt + pip install transformer-engine[jax]==1.13.0 + elif [[ "$MODE" == "stable" || ! -v MODE ]]; then # Stable mode if [[ $DEVICE == "tpu" ]]; then