Skip to content

Commit 051eb1f

Browse files
committed
Improve SFT documentation
1 parent e3ddb1a commit 051eb1f

File tree

5 files changed

+144
-70
lines changed

5 files changed

+144
-70
lines changed

docs/index.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ Our goal is to provide a variety of models (dimension “a”) and techniques (d
5454

5555
Check out these getting started guides:
5656

57-
* [SFT](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/sft.md) (Supervised Fine Tuning)
58-
* [GRPO](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/grpo.md) (Group Relative Policy Optimization)
57+
* Supervised Fine Tuning (SFT)
58+
* [SFT on Single-Host TPUs](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/sft.md)
59+
* [SFT on Multi-Host TPUs](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/sft_on_multi_host.md)
60+
* Group Relative Policy Optimization (GRPO)
61+
* [GRPO on Single-Host TPUs](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/grpo.md)
62+
* [GRPO on Multi-Host TPUs](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/grpo_with_pathways.md)
5963

6064
### Model library
6165

docs/tutorials.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ You can also find other examples in the [MaxText repository](https://github.com/
4141
tutorials/first_run.md
4242
tutorials/pretraining.md
4343
tutorials/full_finetuning.md
44+
tutorials/how_to_run_colabs.md
4445
tutorials/grpo.md
4546
tutorials/sft.md
4647
tutorials/sft_on_multi_host.md

src/MaxText/examples/README_how_to_run_examples.md renamed to docs/tutorials/how_to_run_colabs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla
122122

123123
### Supervised Fine-Tuning (SFT)
124124

125-
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B with Hugging Face ultrachat_200k dataset
126-
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B with Hugging Face ultrachat_200k dataset
125+
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)
126+
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
127127

128128
### GRPO Training
129129

docs/tutorials/sft.md

Lines changed: 70 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,81 +14,99 @@
1414
limitations under the License.
1515
-->
1616

17-
# Try SFT
17+
# Supervised Fine-Tuning (SFT) on Single-Host TPUs
1818
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
1919

20-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B model on the [HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset using SFT. If you wish to use a different dataset, you can [update the dataset configurations](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft.yml).
20+
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.
2121

2222
We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT.
2323

2424
In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!
2525

26-
## Setup virtual environment
27-
28-
### Create a Python3.12 virtual environment
26+
## Install dependencies
2927
```sh
30-
bash tools/setup/setup.sh
28+
# 1. Clone the repository
29+
git clone https://github.com/AI-Hypercomputer/maxtext.git
30+
cd maxtext
31+
32+
# 2. Create virtual environment
33+
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
34+
pip install uv
35+
uv venv --python 3.12 --seed $VENV_NAME
36+
source $VENV_NAME/bin/activate
37+
38+
# 3. Install dependencies in editable mode
39+
uv pip install -e .[tpu] --resolution=lowest
40+
install_maxtext_github_deps
3141
```
3242

33-
### Activate virtual environment
34-
```
35-
# Replace with your virtual environment name if not using this default name
36-
venv_name="maxtext_venv"
37-
source ~/$venv_name/bin/activate
38-
```
39-
40-
### Install MaxText dependencies
41-
```
42-
bash tools/setup/setup.sh
43+
## Setup environment variables
44+
Set the following environment variables before running SFT.
45+
```sh
46+
# -- Model configuration --
47+
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b'
48+
export PRE_TRAINED_MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct'
49+
export HF_TOKEN=<Hugging Face access token>
50+
51+
# -- MaxText configuration --
52+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
53+
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
54+
export STEPS=<number of fine-tuning steps to run> # e.g., 1000
55+
export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1
56+
57+
# -- Dataset configuration --
58+
export DATASET_NAME=<Hugging Face dataset name> # e.g., HuggingFaceH4/ultrachat_200k
59+
export TRAIN_SPLIT=<data split for train> # e.g., train_sft
60+
export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['messages']
4361
```
4462

45-
## Run SFT
46-
There are two scenarios supported for running SFT:
47-
1. **Run SFT on Hugging Face checkpoint**
48-
Download the checkpoint directly from Hugging Face and fine-tune it using SFT.
49-
50-
2. **Run SFT on MaxText checkpoint**
51-
Use a checkpoint generated by MaxText and fine-tune it using SFT.
52-
53-
Choose the scenario that matches your workflow and follow the corresponding instructions below.
63+
## Get your model checkpoint
64+
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
5465

55-
### Run SFT on Hugging Face checkpoint
56-
* The script will first convert a Hugging Face checkpoint to a MaxText checkpoint.
57-
* It then runs SFT on this converted checkpoint.
66+
### Option 1: Using an existing MaxText checkpoint
67+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
5868

59-
#### Setup environment variables
69+
```sh
70+
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
6071
```
61-
export HF_TOKEN=<Hugging Face access token>
6272

63-
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs>
73+
### Option 2: Converting a Hugging Face checkpoint
74+
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
6475

65-
export STEPS=<number of fine-tuning steps to run>
76+
1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.
6677

67-
export PER_DEVICE_BATCH_SIZE=1
78+
```sh
79+
export PRE_TRAINED_MODEL_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint/0/items
6880
```
6981

70-
Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh):
71-
```
72-
bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh
73-
```
82+
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
7483

75-
### Run SFT on MaxText checkpoint
76-
* The script directly runs SFT on MaxText checkpoint.
84+
```sh
85+
pip install torch # Ensure torch is installed for the conversion script
7786

78-
#### Setup environment variables
87+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
88+
model_name=${PRE_TRAINED_MODEL} \
89+
hf_access_token=${HF_TOKEN} \
90+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint \
91+
scan_layers=True
7992
```
80-
export HF_TOKEN=<Hugging Face access token>
81-
82-
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs>
83-
84-
export STEPS=<number of fine-tuning steps to run>
8593

86-
export PER_DEVICE_BATCH_SIZE=1
94+
## Run SFT on Hugging Face Dataset
95+
Now you are ready to run SFT using the following command:
8796

88-
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint>
89-
```
90-
91-
Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh):
92-
```
93-
bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh
97+
```sh
98+
python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \
99+
run_name=${RUN_NAME} \
100+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
101+
model_name=${PRE_TRAINED_MODEL} \
102+
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \
103+
hf_access_token=${HF_TOKEN} \
104+
tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \
105+
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
106+
steps=${STEPS} \
107+
hf_path=${DATASET_NAME} \
108+
train_split=${TRAIN_SPLIT} \
109+
train_data_columns=${TRAIN_DATA_COLUMNS} \
110+
profiler=xplane
94111
```
112+
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.

docs/tutorials/sft_on_multi_host.md

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@
1414
# limitations under the License.
1515
-->
1616

17-
# Supervised Fine-Tuning (SFT) with Deepseek-V3 model
18-
This guide provides step by step instructions to run SFT with Deepseek-V3 model on TPU v6e-256. Deepseek-V3 is a Mixture-of-Experts (MoE) language model with 671B parameters.
17+
# Supervised Fine-Tuning (SFT) on Multi-Host TPUs
18+
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
19+
20+
This tutorial demonstrates step-by-step instructions for setting up the multi-host TPU environment and then training the model on the Hugging Face dataset using SFT. In this tutorial we use a multi-host TPU such as `v6e-256`.
21+
22+
We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT.
23+
24+
Let's get started!
1925

2026
## 1. Build and upload MaxText Docker image
2127
This section guides you through cloning the MaxText repository, building MaxText Docker image with dependencies, and uploading the docker image to your project's Artifact Registry.
@@ -28,7 +34,7 @@ cd maxtext
2834

2935
### 1.2. Build MaxText Docker image
3036
```bash
31-
bash dependencies/scripts/docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
37+
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
3238
```
3339
This creates a local Docker image named `maxtext_base_image`.
3440

@@ -44,7 +50,7 @@ The `docker_upload_runner.sh` script uploads your Docker image to Artifact Regis
4450
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation-via-pip).
4551

4652
## 3. Create GKE cluster
47-
If you don't already have a GKE cluster with a `v6e-256` TPU slice available, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create).
53+
If you don't already have a GKE cluster, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create).
4854

4955
## 4. Environment configuration
5056
```bash
@@ -54,20 +60,63 @@ export CLUSTER_NAME=<Name of GKE Cluster>
5460
export ZONE=<GKE Cluster Zone>
5561

5662
# -- Workload Configuration --
57-
export WORKLOAD_NAME="sft-$(date +%Y-%m-%d-%H-%M-%S)" # Or your desired workload name
58-
export TPU_TYPE=v6e-256
63+
export WORKLOAD_NAME=<Name of Workload> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
64+
export TPU_TYPE=<TPU Type> # e.g., v6e-256
5965
export TPU_SLICE=1
6066
export DOCKER_IMAGE="gcr.io/${PROJECT}/${DOCKER_IMAGE_NAME}"
6167

6268
# -- MaxText Configuration --
63-
export OUTPUT_PATH=<GCS Bucket Path for output/logs>
64-
export STEPS=100 # Number of fine-tuning steps to run
65-
export HF_TOKEN=<Hugging Face access token>
66-
export MODEL_CHECKPOINT_PATH=<GCS path to model checkpoint>
69+
export OUTPUT_PATH=<GCS Path for Output/Logs> # e.g., gs://my-bucket/my-output-directory
70+
export STEPS=<Fine-Tuning Steps> # e.g., 1000
71+
export HF_TOKEN=<Hugging Face Access Token>
72+
73+
# -- Model Configuration --
74+
export MODEL_NAME=<Model Name> # e.g., deepseek3-671b
75+
export TOKENIZER_PATH=<Model Tokenizer> # e.g., deepseek-ai/DeepSeek-V3
76+
77+
# -- Dataset configuration --
78+
export DATASET_NAME=<Hugging Face Dataset Name> # e.g., HuggingFaceH4/ultrachat_200k
79+
export TRAIN_SPLIT=<Data Split for Train> # e.g., train_sft
80+
export TRAIN_DATA_COLUMNS=<Data Columns to Train on> # e.g., ['messages']
81+
```
82+
83+
## 5. Get MaxText model checkpoint
84+
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
85+
86+
### Option 1: Using an existing MaxText checkpoint
87+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
88+
89+
```bash
90+
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
6791
```
6892

69-
## 5. Submit workload on GKE cluster
70-
This section provides the command to run SFT with Deepseek-v3 model on a v6e-256 GKE cluster.
93+
### Option 2: Converting a Hugging Face checkpoint
94+
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
95+
96+
1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.
97+
98+
```bash
99+
export MODEL_CHECKPOINT_PATH=${OUTPUT_PATH}/${WORKLOAD_NAME}/maxtext-checkpoint/0/items
100+
```
101+
102+
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
103+
104+
```bash
105+
xpk workload create \
106+
--cluster=${CLUSTER_NAME} \
107+
--project=${PROJECT} \
108+
--zone=${ZONE} \
109+
--docker-image=${DOCKER_IMAGE} \
110+
--workload=ckpt-${WORKLOAD_NAME} \
111+
--tpu-type=${TPU_TYPE} \
112+
--num-slices=${TPU_SLICE} \
113+
--command "python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml model_name=$MODEL_NAME hf_access_token=$HF_TOKEN base_output_directory=$OUTPUT_PATH/$WORKLOAD_NAME/maxtext-checkpoint scan_layers=True"
114+
```
115+
116+
## 6. Submit workload on GKE cluster
117+
This section provides the command to run SFT on a GKE cluster.
118+
119+
### 6.1. SFT with Multi-Controller JAX (McJAX)
71120
```bash
72121
xpk workload create \
73122
--cluster=${CLUSTER_NAME} \
@@ -77,7 +126,9 @@ xpk workload create \
77126
--workload=${WORKLOAD_NAME} \
78127
--tpu-type=${TPU_TYPE} \
79128
--num-slices=${TPU_SLICE} \
80-
--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=deepseek3-671b load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=deepseek-ai/DeepSeek-V3 per_device_batch_size=1 steps=$STEPS profiler=xplane megablox=False sparse_matmul=False ici_expert_parallelism=16 ici_fsdp_parallelism=16 weight_dtype=bfloat16 dtype=bfloat16 remat_policy=full decoder_layer_input=offload sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048 opt_type=sgd attention=flash capacity_factor=1.0 max_target_length=2048"
129+
--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS"
81130
```
82-
Once the fine-tuning is completed, you can access your model checkpoint at `${OUTPUT_PATH}/${WORKLOAD_NAME}/checkpoints/${STEPS}/model_params`.
131+
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
83132

133+
### 6.2. SFT with Pathways
134+
Pathways support is coming soon.

0 commit comments

Comments
 (0)