Skip to content

Commit

Permalink
Add llama2 configs
Browse files Browse the repository at this point in the history
  • Loading branch information
michelle-yooh committed Apr 5, 2024
1 parent 264c823 commit b4c4139
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 0 deletions.
43 changes: 43 additions & 0 deletions MaxText/configs/a3/llama_2_7b/16vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash

echo "Running 16vm.sh"
# This is the config for llama-2 7B model
# for 16 VMs of GPUs using XPK

export CLUSTER_NAME=maxtext-a3-20n
export WORKLOAD_NAME=llama-2-16vm-$(date +%m-%d-%H-%M)
export LOCAL_IMAGE_NAME=yooh/maxtext-tcpx
export DEVICE_TYPE=h100-80gb-8

# Build and upload image
bash docker_build_dependency_image.sh DEVICE=gpu
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${LOCAL_IMAGE_NAME}

# Write XLA flags as env file
cat << EOF > xpk/env_16.txt
XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_FUSED_ATTN=1
NCCL_DEBUG=VERSION
# TF_CPP_VMODULE=profile_guided_latency_estimator=10
# TF_CPP_MIN_LOG_LEVEL=0
# TF_CPP_MAX_LOG_LEVEL=100
XLA_FLAGS=--xla_dump_to=gs://runner-maxtext-logs/yooh/llama2-70b-$(date +%Y-%m-%d-%H-%M)/HLO_dumps/ \
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true \
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0 \
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728 \
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization
EOF

# 16 nodes
python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} --workload ${WORKLOAD_NAME} \
--docker-image=${LOCAL_IMAGE_NAME} --device-type=${DEVICE_TYPE} --num-slices=16 --env-file=xpk/env_16.txt \
--command "python MaxText/train.py MaxText/configs/base.yml hardware=gpu \
steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true"
41 changes: 41 additions & 0 deletions MaxText/configs/a3/llama_2_7b/1vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash

echo "Running 1vm.sh"
# This is the config for llama-2 7B model
# for 1 VM of GPUs using XPK

export CLUSTER_NAME=maxtext-a3-20n
export WORKLOAD_NAME=llama-2-16vm-$(date +%m-%d-%H-%M)
export LOCAL_IMAGE_NAME=yooh/maxtext-tcpx
export DEVICE_TYPE=h100-80gb-8

# Build and upload image
bash docker_build_dependency_image.sh DEVICE=gpu
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${LOCAL_IMAGE_NAME}

# Write XLA flags as env file
cat << EOF > xpk/env_1.txt
XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_FUSED_ATTN=1
NCCL_DEBUG=VERSION
XLA_FLAGS=--xla_dump_to=gs://runner-maxtext-logs/yooh/llama2-7b-$(date +%Y-%m-%d-%H-%M)/HLO_dumps/ \
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true \
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions\
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 \
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization
EOF


# 1 node, DATA_DP=1, ICI_FSDP=8
python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} --workload ${WORKLOAD_NAME} \
--docker-image=${LOCAL_IMAGE_NAME} --device-type=${DEVICE_TYPE} --num-slices=1 --env-file=xpk/env_1.txt \
--command "python MaxText/train.py MaxText/configs/base.yml hardware=gpu\
steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true "
42 changes: 42 additions & 0 deletions MaxText/configs/a3/llama_2_7b/2vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

echo "Running 2vm.sh"
# This is the config for llama-2 7B model
# for 2 VMs of GPUs using XPK

export CLUSTER_NAME=maxtext-a3-20n
export WORKLOAD_NAME=llama-2-16vm-$(date +%m-%d-%H-%M)
export LOCAL_IMAGE_NAME=yooh/maxtext-tcpx
export DEVICE_TYPE=h100-80gb-8

# Build and upload image
bash docker_build_dependency_image.sh DEVICE=gpu
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${LOCAL_IMAGE_NAME}

# Write XLA flags as env file
cat << EOF > xpk/env_2.txt
XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_FUSED_ATTN=1
NCCL_DEBUG=VERSION
XLA_FLAGS=--xla_dump_to=gs://runner-maxtext-logs/yooh/llama2-7b-$(date +%Y-%m-%d-%H-%M)/HLO_dumps/ \
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true \
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0 \
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_all_reduce_combine_threshold_bytes=67108864 --xla_gpu_all_gather_combine_threshold_bytes=134217728 \
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization
EOF


# 2 nodes
python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} --workload ${WORKLOAD_NAME} \
--docker-image=${LOCAL_IMAGE_NAME} --device-type=${DEVICE_TYPE} --num-slices=2 --env-file=xpk/env_2.txt \
--command "python MaxText/train.py MaxText/configs/base.yml hardware=gpu \
steps=30 dcn_data_parallelism=2 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true"

40 changes: 40 additions & 0 deletions MaxText/configs/a3/llama_2_7b/4vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

echo "Running 4vm.sh"
# This is the config for llama-2 7B model
# for 4 VMs of GPUs using XPK

export CLUSTER_NAME=maxtext-a3-20n
export WORKLOAD_NAME=llama-2-16vm-$(date +%m-%d-%H-%M)
export LOCAL_IMAGE_NAME=yooh/maxtext-tcpx
export DEVICE_TYPE=h100-80gb-8

# Build and upload image
bash docker_build_dependency_image.sh DEVICE=gpu
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${LOCAL_IMAGE_NAME}

# Write XLA flags as env file
cat << EOF > xpk/env_4.txt
XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_FUSED_ATTN=1
NCCL_DEBUG=VERSION
XLA_FLAGS=--xla_dump_to=gs://runner-maxtext-logs/yooh/llama2-7b-$(date +%Y-%m-%d-%H-%M)/HLO_dumps/ \
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true \
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions \
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_all_reduce_combine_threshold_bytes=536870912 --xla_gpu_all_gather_combine_threshold_bytes=134217728 \
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization
EOF

# 4 nodes
python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} --workload ${WORKLOAD_NAME} \
--docker-image=${LOCAL_IMAGE_NAME} --device-type=${DEVICE_TYPE} --num-slices=4 --env-file=xpk/env_4.txt \
--command "python MaxText/train.py MaxText/configs/base.yml hardware=gpu \
steps=30 dcn_data_parallelism=4 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true"
40 changes: 40 additions & 0 deletions MaxText/configs/a3/llama_2_7b/8vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

echo "Running 8vm.sh"
# This is the config for llama-2 7B model
# for 8 VMs of GPUs using XPK

export CLUSTER_NAME=maxtext-a3-20n
export WORKLOAD_NAME=llama-2-16vm-$(date +%m-%d-%H-%M)
export LOCAL_IMAGE_NAME=yooh/maxtext-tcpx
export DEVICE_TYPE=h100-80gb-8

# Build and upload image
bash docker_build_dependency_image.sh DEVICE=gpu
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${LOCAL_IMAGE_NAME}

# Write XLA flags as env file
cat << EOF > xpk/env_8.txt
XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_FUSED_ATTN=1
NCCL_DEBUG=VERSION
XLA_FLAGS=--xla_dump_to=gs://runner-maxtext-logs/yooh/llama2-70b-$(date +%Y-%m-%d-%H-%M)/HLO_dumps/ \
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true \
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0 \
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728 \
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization
EOF

# 8 nodes
python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} --workload ${WORKLOAD_NAME} \
--docker-image=${LOCAL_IMAGE_NAME} --device-type=${DEVICE_TYPE} --num-slices=8 --env-file=xpk/env_8.txt \
--command "python MaxText/train.py MaxText/configs/base.yml hardware=gpu \
steps=30 dcn_data_parallelism=8 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true"

0 comments on commit b4c4139

Please sign in to comment.