|
14 | 14 | limitations under the License. |
15 | 15 | --> |
16 | 16 |
|
17 | | -# Try SFT |
| 17 | +# Supervised Fine-Tuning (SFT) on Single-Host TPUs |
18 | 18 | Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks. |
19 | 19 |
|
20 | | -This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B model on the [HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset using SFT. If you wish to use a different dataset, you can [update the dataset configurations](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft.yml). |
| 20 | +This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT. |
21 | 21 |
|
22 | 22 | We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT. |
23 | 23 |
|
24 | 24 | In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started! |
25 | 25 |
|
26 | | -## Setup virtual environment |
27 | | - |
28 | | -### Create a Python3.12 virtual environment |
| 26 | +## Install dependencies |
29 | 27 | ```sh |
30 | | -bash tools/setup/setup.sh |
| 28 | +# 1. Clone the repository |
| 29 | +git clone https://github.com/AI-Hypercomputer/maxtext.git |
| 30 | +cd maxtext |
| 31 | + |
| 32 | +# 2. Create virtual environment |
| 33 | +export VENV_NAME=<your virtual env name> # e.g., maxtext_venv |
| 34 | +pip install uv |
| 35 | +uv venv --python 3.12 --seed $VENV_NAME |
| 36 | +source $VENV_NAME/bin/activate |
| 37 | + |
| 38 | +# 3. Install dependencies in editable mode |
| 39 | +uv pip install -e .[tpu] --resolution=lowest |
| 40 | +install_maxtext_github_deps |
31 | 41 | ``` |
32 | 42 |
|
33 | | -### Activate virtual environment |
34 | | -``` |
35 | | -# Replace with your virtual environment name if not using this default name |
36 | | -venv_name="maxtext_venv" |
37 | | -source ~/$venv_name/bin/activate |
38 | | -``` |
39 | | - |
40 | | -### Install MaxText dependencies |
41 | | -``` |
42 | | -bash tools/setup/setup.sh |
| 43 | +## Setup environment variables |
| 44 | +Set the following environment variables before running SFT. |
| 45 | +```sh |
| 46 | +# -- Model configuration -- |
| 47 | +export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b' |
| 48 | +export PRE_TRAINED_MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct' |
| 49 | +export HF_TOKEN=<Hugging Face access token> |
| 50 | + |
| 51 | +# -- MaxText configuration -- |
| 52 | +export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory |
| 53 | +export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S) |
| 54 | +export STEPS=<number of fine-tuning steps to run> # e.g., 1000 |
| 55 | +export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1 |
| 56 | + |
| 57 | +# -- Dataset configuration -- |
| 58 | +export DATASET_NAME=<Hugging Face dataset name> # e.g., HuggingFaceH4/ultrachat_200k |
| 59 | +export TRAIN_SPLIT=<data split for train> # e.g., train_sft |
| 60 | +export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['messages'] |
43 | 61 | ``` |
44 | 62 |
|
45 | | -## Run SFT |
46 | | -There are two scenarios supported for running SFT: |
47 | | -1. **Run SFT on Hugging Face checkpoint** |
48 | | - Download the checkpoint directly from Hugging Face and fine-tune it using SFT. |
49 | | - |
50 | | -2. **Run SFT on MaxText checkpoint** |
51 | | - Use a checkpoint generated by MaxText and fine-tune it using SFT. |
52 | | - |
53 | | -Choose the scenario that matches your workflow and follow the corresponding instructions below. |
| 63 | +## Get your model checkpoint |
| 64 | +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. |
54 | 65 |
|
55 | | -### Run SFT on Hugging Face checkpoint |
56 | | -* The script will first convert a Hugging Face checkpoint to a MaxText checkpoint. |
57 | | -* It then runs SFT on this converted checkpoint. |
| 66 | +### Option 1: Using an existing MaxText checkpoint |
| 67 | +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. |
58 | 68 |
|
59 | | -#### Setup environment variables |
| 69 | +```sh |
| 70 | +export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items |
60 | 71 | ``` |
61 | | - export HF_TOKEN=<Hugging Face access token> |
62 | 72 |
|
63 | | - export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> |
| 73 | +### Option 2: Converting a Hugging Face checkpoint |
| 74 | +If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible. |
64 | 75 |
|
65 | | - export STEPS=<number of fine-tuning steps to run> |
| 76 | +1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved. |
66 | 77 |
|
67 | | - export PER_DEVICE_BATCH_SIZE=1 |
| 78 | +```sh |
| 79 | +export PRE_TRAINED_MODEL_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint/0/items |
68 | 80 | ``` |
69 | 81 |
|
70 | | -Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh): |
71 | | -``` |
72 | | -bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh |
73 | | -``` |
| 82 | +2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py). |
74 | 83 |
|
75 | | -### Run SFT on MaxText checkpoint |
76 | | -* The script directly runs SFT on MaxText checkpoint. |
| 84 | +```sh |
| 85 | +pip install torch # Ensure torch is installed for the conversion script |
77 | 86 |
|
78 | | -#### Setup environment variables |
| 87 | +python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \ |
| 88 | + model_name=${PRE_TRAINED_MODEL} \ |
| 89 | + hf_access_token=${HF_TOKEN} \ |
| 90 | + base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint \ |
| 91 | + scan_layers=True |
79 | 92 | ``` |
80 | | - export HF_TOKEN=<Hugging Face access token> |
81 | | -
|
82 | | - export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> |
83 | | -
|
84 | | - export STEPS=<number of fine-tuning steps to run> |
85 | 93 |
|
86 | | - export PER_DEVICE_BATCH_SIZE=1 |
| 94 | +## Run SFT on Hugging Face Dataset |
| 95 | +Now you are ready to run SFT using the following command: |
87 | 96 |
|
88 | | - export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> |
89 | | -``` |
90 | | - |
91 | | -Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh): |
92 | | -``` |
93 | | -bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh |
| 97 | +```sh |
| 98 | +python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \ |
| 99 | + run_name=${RUN_NAME} \ |
| 100 | + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ |
| 101 | + model_name=${PRE_TRAINED_MODEL} \ |
| 102 | + load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \ |
| 103 | + hf_access_token=${HF_TOKEN} \ |
| 104 | + tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \ |
| 105 | + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ |
| 106 | + steps=${STEPS} \ |
| 107 | + hf_path=${DATASET_NAME} \ |
| 108 | + train_split=${TRAIN_SPLIT} \ |
| 109 | + train_data_columns=${TRAIN_DATA_COLUMNS} \ |
| 110 | + profiler=xplane |
94 | 111 | ``` |
| 112 | +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. |
0 commit comments