Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pallas GPU decode attention in Maxtext inference #1066

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

tohaowu
Copy link
Collaborator

@tohaowu tohaowu commented Nov 26, 2024

Description

This PR enables the use of Pallas GPU attention from jax.experimental.pallas.ops.gpu in Maxtext GPU inference.

Previously, Maxtext GPU decode did not support flash attention, explicitly raising a ValueError when attention_kernel was set to cudnn_flash_te in the decoder. This PR adds pallas_gpu, allowing the use of Pallas GPU attention for improved efficiency and speed.

This change is crucial for leveraging the performance benefits of Pallas GPU attention in Maxtext GPU inference. It allows for faster and more efficient decoding.

Specific implementation details:

  • Imported the necessary modules from jax.experimental.pallas.ops.gpu.
  • Replaced the existing attention calculation in decode with calls pallas_decode_attention.gqa
  • Adjusted parameters within these functions to optimize performance.
  • Ensured compatibility with the decoder's attention calculations.

Shortcomings and future improvements:

  • Further performance tuning might be required for specific use cases.
  • Explore the possibility of integrating other optimized attention kernels from jax.experimental.pallas.ops.gpu.

FIXES: b/366477266, b/375269239

Tests

Instructions to reproduce:

  • Run the decoder with the modified code using the following command:
    XLA_FLAGS="--xla_dump_to=/llama-2-1vm-2024-11-14-05-28/HLO_dumps/ --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=true --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 TF_FORCE_GPU_ALLOW_GROWTH=true CUDA_VISIBLE_DEVICES=0 python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=gs://tohaowu/maxtext/direct_generate_param_only_checkpoint_2024-11-14-05-21/checkpoints/0/items run_name=runner_decode_finetuned_2024-11-14-05-21 base_output_directory=gs://tohaowu/maxtext per_device_batch_size=1 model_name=llama2-7b ici_autoregressive_parallelism=1 max_prefill_predict_length=1024 max_target_length=2048 attention=pallas_gpu scan_layers=false hardware=gpu async_checkpointing=false

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants