Skip to content

Commit 12ddad6

Browse files
committed
GRPO/GSPO
Signed-off-by: Vladimir Suvorov <suvorovv@google.com>
1 parent 8127aa3 commit 12ddad6

File tree

1 file changed

+110
-24
lines changed

1 file changed

+110
-24
lines changed

src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb

Lines changed: 110 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
"outputs": [],
3838
"source": [
3939
"# Clone MaxText repository\n",
40-
"!git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
41-
"%cd maxtext"
40+
"!git clone https://github.com/AI-Hypercomputer/maxtext\n",
41+
"%cd maxtext/src"
4242
]
4343
},
4444
{
@@ -47,19 +47,57 @@
4747
"metadata": {},
4848
"outputs": [],
4949
"source": [
50-
"# Install dependencies\n",
51-
"!chmod +x setup.sh\n",
52-
"!./setup.sh\n",
50+
"!bash tools/setup/setup.sh\n",
51+
"%pip uninstall -y jax jaxlib libtpu\n",
5352
"\n",
54-
"# Install GRPO-specific dependencies\n",
55-
"!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n",
53+
"%pip install aiohttp==3.12.15\n",
54+
"\n",
55+
"# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.\n",
56+
"%pip install keyring keyrings.google-artifactregistry-auth\n",
57+
"\n",
58+
"# Install vLLM for Jax and TPUs from the artifact registry\n",
59+
"!VLLM_TARGET_DEVICE=\"tpu\" pip install --no-cache-dir --pre \\\n",
60+
" --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n",
61+
" --extra-index-url https://pypi.org/simple/ \\\n",
62+
" --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n",
63+
" --extra-index-url https://download.pytorch.org/whl/nightly/cpu \\\n",
64+
" --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n",
65+
" --find-links https://storage.googleapis.com/libtpu-wheels/index.html \\\n",
66+
" --find-links https://storage.googleapis.com/libtpu-releases/index.html \\\n",
67+
" --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \\\n",
68+
" --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \\\n",
69+
" vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu\n",
70+
"\n",
71+
"# Install tpu-commons from the artifact registry\n",
72+
"%pip install --no-cache-dir --pre \\\n",
73+
" --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n",
74+
" --extra-index-url https://pypi.org/simple/ \\\n",
75+
" --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n",
76+
" --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n",
77+
" tpu-commons==0.1.2\n",
78+
"\n",
79+
"%pip install numba==0.61.2"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {},
86+
"outputs": [],
87+
"source": [
5688
"\n",
57-
"# Install additional requirements\n",
58-
"%pip install --force-reinstall numpy==2.1.2\n",
5989
"%pip install nest_asyncio\n",
6090
"\n",
6191
"import nest_asyncio\n",
62-
"nest_asyncio.apply() # Fix for Colab event loop"
92+
"nest_asyncio.apply() # Fix for Colab event loop\n",
93+
"\n",
94+
"%cd maxtext/src/\n",
95+
"\n",
96+
"#Fix nnx problems\n",
97+
"!pip uninstall flax \n",
98+
"!pip uninstall qwix\n",
99+
"!pip install flax \n",
100+
"!pip install qwix"
63101
]
64102
},
65103
{
@@ -97,19 +135,21 @@
97135
"source": [
98136
"# Configuration for GRPO training\n",
99137
"import os\n",
138+
"import MaxText\n",
100139
"\n",
101140
"# Set up paths (adjust if needed)\n",
102-
"MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n",
103-
"\n",
141+
"MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
142+
"RUN_NAME=\"grpo_test\"\n",
104143
"# Hardcoded defaults for Llama3.1-8B\n",
105144
"MODEL_NAME = \"llama3.1-8b\"\n",
106145
"HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
107-
"CHAT_TEMPLATE_PATH = \"src/MaxText/examples/chat_templates/gsm8k_rl.json\"\n",
146+
"CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n",
147+
"LOSS_ALGO=\"gspo-token\"\n",
108148
"\n",
109149
"# Required: Set these before running\n",
110-
"MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\" # Update this!\n",
111-
"OUTPUT_DIRECTORY = \"/tmp/grpo_output\" # Update this!\n",
112-
"HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\") # Set HF_TOKEN environment variable\n",
150+
"MODEL_CHECKPOINT_PATH = \"\" # Update this!\n",
151+
"OUTPUT_DIRECTORY = \"/tmp/gpo_output\" # Update this!\n",
152+
"HF_TOKEN = \"\" # Set HF_TOKEN environment variable\n",
113153
"\n",
114154
"# Optional: Override training parameters\n",
115155
"STEPS = 10 # Reduced for demo purposes\n",
@@ -118,14 +158,15 @@
118158
"NUM_GENERATIONS = 2\n",
119159
"GRPO_BETA = 0.08\n",
120160
"GRPO_EPSILON = 0.2\n",
121-
"CHIPS_PER_VM = 4\n",
161+
"CHIPS_PER_VM = 1\n",
122162
"\n",
123163
"print(f\"📁 MaxText Home: {MAXTEXT_REPO_ROOT}\")\n",
124164
"print(f\"🤖 Model: {MODEL_NAME}\")\n",
125165
"print(f\"📦 Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n",
126166
"print(f\"💾 Output: {OUTPUT_DIRECTORY}\")\n",
127167
"print(f\"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n",
128-
"print(f\"📊 Steps: {STEPS}\")"
168+
"print(f\"📊 Steps: {STEPS}\")\n",
169+
"print(f\"Loss Algorithm : {LOSS_ALGO}\")"
129170
]
130171
},
131172
{
@@ -140,7 +181,7 @@
140181
"from pathlib import Path\n",
141182
"\n",
142183
"# Add MaxText to Python path\n",
143-
"maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n",
184+
"maxtext_path = Path(MAXTEXT_REPO_ROOT) \n",
144185
"sys.path.insert(0, str(maxtext_path))\n",
145186
"\n",
146187
"from MaxText import pyconfig, max_utils\n",
@@ -163,6 +204,51 @@
163204
"print(f\"📁 MaxText path: {maxtext_path}\")"
164205
]
165206
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": null,
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"# Build configuration for GRPO training\n",
214+
"config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n",
215+
"\n",
216+
"# Verify chat template exists\n",
217+
"if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):\n",
218+
" raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
219+
"\n",
220+
"# Build argv list for pyconfig.initialize()\n",
221+
"config_argv = [\n",
222+
" \"\", # argv[0] placeholder\n",
223+
" config_file,\n",
224+
" f\"model_name={MODEL_NAME}\",\n",
225+
" f\"tokenizer_path={HF_REPO_ID}\",\n",
226+
" f\"run_name={RUN_NAME}\",\n",
227+
" f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
228+
" f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n",
229+
" f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
230+
" f\"hf_access_token={HF_TOKEN}\",\n",
231+
" f\"steps={STEPS}\",\n",
232+
" f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n",
233+
" f\"learning_rate={LEARNING_RATE}\",\n",
234+
" f\"num_generations={NUM_GENERATIONS}\",\n",
235+
" f\"grpo_beta={GRPO_BETA}\",\n",
236+
" f\"grpo_epsilon={GRPO_EPSILON}\",\n",
237+
" f\"chips_per_vm={CHIPS_PER_VM}\",\n",
238+
" f\"loss_algo={LOSS_ALGO}\"\n",
239+
"]\n",
240+
"\n",
241+
"# Initialize configuration\n",
242+
"print(f\"🔧 Initializing configuration from: {config_file}\")\n",
243+
"config = pyconfig.initialize(config_argv)\n",
244+
"max_utils.print_system_information()\n",
245+
"\n",
246+
"print(\"\\n✅ Configuration initialized successfully\")\n",
247+
"print(f\"📊 Training steps: {config.steps}\")\n",
248+
"print(f\"📁 Output directory: {config.base_output_directory}\")\n",
249+
"print(f\"🤖 Model: {config.model_name}\")"
250+
]
251+
},
166252
{
167253
"cell_type": "code",
168254
"execution_count": null,
@@ -214,25 +300,25 @@
214300
"metadata": {},
215301
"outputs": [],
216302
"source": [
217-
"# Execute GRPO training\n",
303+
"# Execute GRPO/GSPO training\n",
218304
"print(\"\\n\" + \"=\"*80)\n",
219-
"print(\"🚀 Starting GRPO Training...\")\n",
305+
"print(\"🚀 Starting Training...\")\n",
220306
"print(\"=\"*80)\n",
221-
"\n",
307+
"print(1)\n",
222308
"try:\n",
223309
" # Call the rl_train function (it handles everything internally)\n",
224310
" rl_train(config)\n",
225311
" \n",
226312
" print(\"\\n\" + \"=\"*80)\n",
227-
" print(\"GRPO Training Completed Successfully!\")\n",
313+
" print(\"✅ Training Completed Successfully!\")\n",
228314
" print(\"=\"*80)\n",
229315
" print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n",
230316
" print(f\"📊 TensorBoard logs: {config.tensorboard_dir}\")\n",
231317
" print(f\"🎯 Model ready for inference!\")\n",
232318
" \n",
233319
"except Exception as e:\n",
234320
" print(\"\\n\" + \"=\"*80)\n",
235-
" print(\" GRPO Training Failed!\")\n",
321+
" print(\"❌Training Failed!\")\n",
236322
" print(\"=\"*80)\n",
237323
" print(f\"Error: {str(e)}\")\n",
238324
" import traceback\n",

0 commit comments

Comments
 (0)