Skip to content

Commit 1200dd7

Browse files
committed
Add llama2 configs for GPU A3
1 parent f04ba76 commit 1200dd7

File tree

5 files changed

+215
-0
lines changed

5 files changed

+215
-0
lines changed

MaxText/configs/a3/llama_2_7b/16vm.sh

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
echo "Running 16vm.sh"
2+
# Example command to invoke this script
3+
# bash MaxText/configs/a3/llama_2_7b/16vm.sh
4+
#
5+
# Example command to invoke this script via XPK
6+
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
7+
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
8+
# --device-type ${DEVICE_TYPE} --num-slices 16 --priority=high \
9+
# --command "bash MaxText/configs/a3/llama_2_7b/16vm.sh"
10+
11+
# Stop execution if any command exits with error
12+
set -e
13+
14+
export OUTPUT_PATH="gs://maxtext-experiments-multipod"
15+
export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)"
16+
17+
# Set environment variables
18+
for ARGUMENT in "$@"; do
19+
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
20+
export "$KEY"="$VALUE"
21+
done
22+
23+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
24+
export CUDA_DEVICE_MAX_CONNECTIONS=1
25+
export NVTE_FUSED_ATTN=1
26+
export NCCL_DEBUG=VERSION
27+
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
28+
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
29+
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0
30+
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
31+
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
32+
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true
33+
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
34+
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
35+
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
36+
--xla_disable_hlo_passes=rematerialization"
37+
38+
# 16 nodes
39+
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
40+
steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=6 max_target_length=4096 model_name=llama2-7b \
41+
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
42+
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true

MaxText/configs/a3/llama_2_7b/1vm.sh

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
echo "Running 1vm.sh"
2+
# Example command to invoke this script
3+
# bash MaxText/configs/a3/llama_2_7b/1vm.sh
4+
#
5+
# Example command to invoke this script via XPK
6+
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
7+
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
8+
# --device-type ${DEVICE_TYPE} --num-slices 1 \
9+
# --command "bash MaxText/configs/a3/llama_2_7b/1vm.sh"
10+
11+
# Stop execution if any command exits with error
12+
set -e
13+
14+
export OUTPUT_PATH="gs://maxtext-experiments-multipod"
15+
export RUN_NAME="llama-2-1vm-$(date +%Y-%m-%d-%H-%M)"
16+
17+
# Set environment variables
18+
for ARGUMENT in "$@"; do
19+
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
20+
export "$KEY"="$VALUE"
21+
done
22+
23+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
24+
export CUDA_DEVICE_MAX_CONNECTIONS=1
25+
export NVTE_FUSED_ATTN=1
26+
export NCCL_DEBUG=VERSION
27+
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
28+
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true
29+
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions
30+
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
31+
--xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728
32+
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
33+
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
34+
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
35+
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
36+
--xla_disable_hlo_passes=rematerialization"
37+
38+
39+
# 1 node, DATA_DP=1, ICI_FSDP=8
40+
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu\
41+
steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
42+
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
43+
dataset_type=synthetic async_checkpointing=false base_output_directory=$OUTPUT_PATH enable_profiler=false

MaxText/configs/a3/llama_2_7b/2vm.sh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
echo "Running 2vm.sh"
2+
# Example command to invoke this script
3+
# bash MaxText/configs/a3/llama_2_7b/2vm.sh
4+
#
5+
# Example command to invoke this script via XPK
6+
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
7+
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
8+
# --device-type ${DEVICE_TYPE} --num-slices 2 \
9+
# --command "bash MaxText/configs/a3/llama_2_7b/2vm.sh"
10+
11+
# Stop execution if any command exits with error
12+
set -e
13+
14+
export OUTPUT_PATH="gs://maxtext-experiments-multipod"
15+
export RUN_NAME="llama-2-2vm-$(date +%Y-%m-%d-%H-%M)"
16+
17+
# Set environment variables
18+
for ARGUMENT in "$@"; do
19+
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
20+
export "$KEY"="$VALUE"
21+
done
22+
23+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
24+
export CUDA_DEVICE_MAX_CONNECTIONS=1
25+
export NVTE_FUSED_ATTN=1
26+
export NCCL_DEBUG=VERSION
27+
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
28+
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
29+
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0
30+
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
31+
--xla_gpu_all_reduce_combine_threshold_bytes=67108864 --xla_gpu_all_gather_combine_threshold_bytes=134217728
32+
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
33+
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
34+
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
35+
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
36+
--xla_disable_hlo_passes=rematerialization"
37+
38+
39+
# 2 nodes
40+
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
41+
steps=30 dcn_data_parallelism=2 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
42+
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
43+
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true
44+

MaxText/configs/a3/llama_2_7b/4vm.sh

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
echo "Running 4vm.sh"
2+
# Example command to invoke this script
3+
# bash MaxText/configs/a3/llama_2_7b/4vm.sh
4+
#
5+
# Example command to invoke this script via XPK
6+
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
7+
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
8+
# --device-type ${DEVICE_TYPE} --num-slices 4 \
9+
# --command "bash MaxText/configs/a3/llama_2_7b/4vm.sh"
10+
11+
# Stop execution if any command exits with error
12+
set -e
13+
14+
export OUTPUT_PATH="gs://maxtext-experiments-multipod"
15+
export RUN_NAME="llama-2-4vm-$(date +%Y-%m-%d-%H-%M)"
16+
17+
# Set environment variables
18+
for ARGUMENT in "$@"; do
19+
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
20+
export "$KEY"="$VALUE"
21+
done
22+
23+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
24+
export CUDA_DEVICE_MAX_CONNECTIONS=1
25+
export NVTE_FUSED_ATTN=1
26+
export NCCL_DEBUG=VERSION
27+
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
28+
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true
29+
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions
30+
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
31+
--xla_gpu_all_reduce_combine_threshold_bytes=536870912 --xla_gpu_all_gather_combine_threshold_bytes=134217728
32+
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
33+
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
34+
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
35+
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
36+
--xla_disable_hlo_passes=rematerialization"
37+
38+
# 4 nodes
39+
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
40+
steps=30 dcn_data_parallelism=4 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
41+
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
42+
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true
43+

MaxText/configs/a3/llama_2_7b/8vm.sh

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
echo "Running 8vm.sh"
2+
# Example command to invoke this script
3+
# bash MaxText/configs/a3/llama_2_7b/8vm.sh
4+
#
5+
# Example command to invoke this script via XPK
6+
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
7+
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
8+
# --device-type ${DEVICE_TYPE} --num-slices 8 \
9+
# --command "bash MaxText/configs/a3/llama_2_7b/8vm.sh"
10+
11+
# Stop execution if any command exits with error
12+
set -e
13+
14+
export OUTPUT_PATH="gs://maxtext-experiments-multipod"
15+
export RUN_NAME="llama-2-8vm-$(date +%Y-%m-%d-%H-%M)"
16+
17+
# Set environment variables
18+
for ARGUMENT in "$@"; do
19+
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
20+
export "$KEY"="$VALUE"
21+
done
22+
23+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
24+
export CUDA_DEVICE_MAX_CONNECTIONS=1
25+
export NVTE_FUSED_ATTN=1
26+
export NCCL_DEBUG=VERSION
27+
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
28+
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
29+
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0
30+
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
31+
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
32+
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
33+
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
34+
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
35+
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
36+
--xla_disable_hlo_passes=rematerialization"
37+
38+
# 8 nodes
39+
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
40+
steps=30 dcn_data_parallelism=8 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
41+
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
42+
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true
43+

0 commit comments

Comments
 (0)