Add Pallas GPU decode attention in Maxtext inference #1066
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR enables the use of Pallas GPU attention from
jax.experimental.pallas.ops.gpu
in Maxtext GPU inference.Previously, Maxtext GPU decode did not support flash attention, explicitly raising a ValueError when
attention_kernel
was set tocudnn_flash_te
in the decoder. This PR addspallas_gpu
, allowing the use of Pallas GPU attention for improved efficiency and speed.This change is crucial for leveraging the performance benefits of Pallas GPU attention in Maxtext GPU inference. It allows for faster and more efficient decoding.
Specific implementation details:
jax.experimental.pallas.ops.gpu
.pallas_decode_attention.gqa
Shortcomings and future improvements:
jax.experimental.pallas.ops.gpu
.FIXES: b/366477266, b/375269239
Tests
Instructions to reproduce:
XLA_FLAGS="--xla_dump_to=/llama-2-1vm-2024-11-14-05-28/HLO_dumps/ --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=true --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 TF_FORCE_GPU_ALLOW_GROWTH=true CUDA_VISIBLE_DEVICES=0 python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=gs://tohaowu/maxtext/direct_generate_param_only_checkpoint_2024-11-14-05-21/checkpoints/0/items run_name=runner_decode_finetuned_2024-11-14-05-21 base_output_directory=gs://tohaowu/maxtext per_device_batch_size=1 model_name=llama2-7b ici_autoregressive_parallelism=1 max_prefill_predict_length=1024 max_target_length=2048 attention=pallas_gpu scan_layers=false hardware=gpu async_checkpointing=false
Checklist
Before submitting this PR, please make sure (put X in square brackets):