|
37 | 37 | "outputs": [], |
38 | 38 | "source": [ |
39 | 39 | "# Clone MaxText repository\n", |
40 | | - "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", |
41 | | - "%cd maxtext" |
| 40 | + "!git clone https://github.com/AI-Hypercomputer/maxtext\n", |
| 41 | + "%cd maxtext/src" |
42 | 42 | ] |
43 | 43 | }, |
44 | 44 | { |
|
47 | 47 | "metadata": {}, |
48 | 48 | "outputs": [], |
49 | 49 | "source": [ |
50 | | - "# Install dependencies\n", |
51 | | - "!chmod +x setup.sh\n", |
52 | | - "!./setup.sh\n", |
| 50 | + "!bash tools/setup/setup.sh\n", |
| 51 | + "%pip uninstall -y jax jaxlib libtpu\n", |
53 | 52 | "\n", |
54 | | - "# Install GRPO-specific dependencies\n", |
55 | | - "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n", |
| 53 | + "%pip install aiohttp==3.12.15\n", |
| 54 | + "\n", |
| 55 | + "# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.\n", |
| 56 | + "%pip install keyring keyrings.google-artifactregistry-auth\n", |
| 57 | + "\n", |
| 58 | + "# Install vLLM for Jax and TPUs from the artifact registry\n", |
| 59 | + "!VLLM_TARGET_DEVICE=\"tpu\" pip install --no-cache-dir --pre \\\n", |
| 60 | + " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", |
| 61 | + " --extra-index-url https://pypi.org/simple/ \\\n", |
| 62 | + " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", |
| 63 | + " --extra-index-url https://download.pytorch.org/whl/nightly/cpu \\\n", |
| 64 | + " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", |
| 65 | + " --find-links https://storage.googleapis.com/libtpu-wheels/index.html \\\n", |
| 66 | + " --find-links https://storage.googleapis.com/libtpu-releases/index.html \\\n", |
| 67 | + " --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \\\n", |
| 68 | + " --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \\\n", |
| 69 | + " vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu\n", |
| 70 | + "\n", |
| 71 | + "# Install tpu-commons from the artifact registry\n", |
| 72 | + "%pip install --no-cache-dir --pre \\\n", |
| 73 | + " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", |
| 74 | + " --extra-index-url https://pypi.org/simple/ \\\n", |
| 75 | + " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", |
| 76 | + " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", |
| 77 | + " tpu-commons==0.1.2\n", |
| 78 | + "\n", |
| 79 | + "%pip install numba==0.61.2" |
| 80 | + ] |
| 81 | + }, |
| 82 | + { |
| 83 | + "cell_type": "code", |
| 84 | + "execution_count": null, |
| 85 | + "metadata": {}, |
| 86 | + "outputs": [], |
| 87 | + "source": [ |
56 | 88 | "\n", |
57 | | - "# Install additional requirements\n", |
58 | | - "%pip install --force-reinstall numpy==2.1.2\n", |
59 | 89 | "%pip install nest_asyncio\n", |
60 | 90 | "\n", |
61 | 91 | "import nest_asyncio\n", |
62 | | - "nest_asyncio.apply() # Fix for Colab event loop" |
| 92 | + "nest_asyncio.apply() # Fix for Colab event loop\n", |
| 93 | + "\n", |
| 94 | + "%cd maxtext/src/\n", |
| 95 | + "\n", |
| 96 | + "#Fix nnx problems\n", |
| 97 | + "!pip uninstall flax \n", |
| 98 | + "!pip uninstall qwix\n", |
| 99 | + "!pip install flax \n", |
| 100 | + "!pip install qwix" |
63 | 101 | ] |
64 | 102 | }, |
65 | 103 | { |
|
97 | 135 | "source": [ |
98 | 136 | "# Configuration for GRPO training\n", |
99 | 137 | "import os\n", |
| 138 | + "import MaxText\n", |
100 | 139 | "\n", |
101 | 140 | "# Set up paths (adjust if needed)\n", |
102 | | - "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n", |
103 | | - "\n", |
| 141 | + "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", |
| 142 | + "RUN_NAME=\"grpo_test\"\n", |
104 | 143 | "# Hardcoded defaults for Llama3.1-8B\n", |
105 | 144 | "MODEL_NAME = \"llama3.1-8b\"\n", |
106 | 145 | "HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n", |
107 | | - "CHAT_TEMPLATE_PATH = \"src/MaxText/examples/chat_templates/gsm8k_rl.json\"\n", |
| 146 | + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n", |
| 147 | + "LOSS_ALGO=\"gspo-token\"\n", |
108 | 148 | "\n", |
109 | 149 | "# Required: Set these before running\n", |
110 | | - "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\" # Update this!\n", |
111 | | - "OUTPUT_DIRECTORY = \"/tmp/grpo_output\" # Update this!\n", |
112 | | - "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\") # Set HF_TOKEN environment variable\n", |
| 150 | + "MODEL_CHECKPOINT_PATH = \"\" # Update this!\n", |
| 151 | + "OUTPUT_DIRECTORY = \"/tmp/gpo_output\" # Update this!\n", |
| 152 | + "HF_TOKEN = \"\" # Set HF_TOKEN environment variable\n", |
113 | 153 | "\n", |
114 | 154 | "# Optional: Override training parameters\n", |
115 | 155 | "STEPS = 10 # Reduced for demo purposes\n", |
|
118 | 158 | "NUM_GENERATIONS = 2\n", |
119 | 159 | "GRPO_BETA = 0.08\n", |
120 | 160 | "GRPO_EPSILON = 0.2\n", |
121 | | - "CHIPS_PER_VM = 4\n", |
| 161 | + "CHIPS_PER_VM = 1\n", |
122 | 162 | "\n", |
123 | 163 | "print(f\"📁 MaxText Home: {MAXTEXT_REPO_ROOT}\")\n", |
124 | 164 | "print(f\"🤖 Model: {MODEL_NAME}\")\n", |
125 | 165 | "print(f\"📦 Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", |
126 | 166 | "print(f\"💾 Output: {OUTPUT_DIRECTORY}\")\n", |
127 | 167 | "print(f\"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n", |
128 | | - "print(f\"📊 Steps: {STEPS}\")" |
| 168 | + "print(f\"📊 Steps: {STEPS}\")\n", |
| 169 | + "print(f\"Loss Algorithm : {LOSS_ALGO}\")" |
129 | 170 | ] |
130 | 171 | }, |
131 | 172 | { |
|
140 | 181 | "from pathlib import Path\n", |
141 | 182 | "\n", |
142 | 183 | "# Add MaxText to Python path\n", |
143 | | - "maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n", |
| 184 | + "maxtext_path = Path(MAXTEXT_REPO_ROOT) \n", |
144 | 185 | "sys.path.insert(0, str(maxtext_path))\n", |
145 | 186 | "\n", |
146 | 187 | "from MaxText import pyconfig, max_utils\n", |
|
163 | 204 | "print(f\"📁 MaxText path: {maxtext_path}\")" |
164 | 205 | ] |
165 | 206 | }, |
| 207 | + { |
| 208 | + "cell_type": "code", |
| 209 | + "execution_count": null, |
| 210 | + "metadata": {}, |
| 211 | + "outputs": [], |
| 212 | + "source": [ |
| 213 | + "# Build configuration for GRPO training\n", |
| 214 | + "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n", |
| 215 | + "\n", |
| 216 | + "# Verify chat template exists\n", |
| 217 | + "if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):\n", |
| 218 | + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", |
| 219 | + "\n", |
| 220 | + "# Build argv list for pyconfig.initialize()\n", |
| 221 | + "config_argv = [\n", |
| 222 | + " \"\", # argv[0] placeholder\n", |
| 223 | + " config_file,\n", |
| 224 | + " f\"model_name={MODEL_NAME}\",\n", |
| 225 | + " f\"tokenizer_path={HF_REPO_ID}\",\n", |
| 226 | + " f\"run_name={RUN_NAME}\",\n", |
| 227 | + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", |
| 228 | + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", |
| 229 | + " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", |
| 230 | + " f\"hf_access_token={HF_TOKEN}\",\n", |
| 231 | + " f\"steps={STEPS}\",\n", |
| 232 | + " f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n", |
| 233 | + " f\"learning_rate={LEARNING_RATE}\",\n", |
| 234 | + " f\"num_generations={NUM_GENERATIONS}\",\n", |
| 235 | + " f\"grpo_beta={GRPO_BETA}\",\n", |
| 236 | + " f\"grpo_epsilon={GRPO_EPSILON}\",\n", |
| 237 | + " f\"chips_per_vm={CHIPS_PER_VM}\",\n", |
| 238 | + " f\"loss_algo={LOSS_ALGO}\"\n", |
| 239 | + "]\n", |
| 240 | + "\n", |
| 241 | + "# Initialize configuration\n", |
| 242 | + "print(f\"🔧 Initializing configuration from: {config_file}\")\n", |
| 243 | + "config = pyconfig.initialize(config_argv)\n", |
| 244 | + "max_utils.print_system_information()\n", |
| 245 | + "\n", |
| 246 | + "print(\"\\n✅ Configuration initialized successfully\")\n", |
| 247 | + "print(f\"📊 Training steps: {config.steps}\")\n", |
| 248 | + "print(f\"📁 Output directory: {config.base_output_directory}\")\n", |
| 249 | + "print(f\"🤖 Model: {config.model_name}\")" |
| 250 | + ] |
| 251 | + }, |
166 | 252 | { |
167 | 253 | "cell_type": "code", |
168 | 254 | "execution_count": null, |
|
214 | 300 | "metadata": {}, |
215 | 301 | "outputs": [], |
216 | 302 | "source": [ |
217 | | - "# Execute GRPO training\n", |
| 303 | + "# Execute GRPO/GSPO training\n", |
218 | 304 | "print(\"\\n\" + \"=\"*80)\n", |
219 | | - "print(\"🚀 Starting GRPO Training...\")\n", |
| 305 | + "print(\"🚀 Starting Training...\")\n", |
220 | 306 | "print(\"=\"*80)\n", |
221 | | - "\n", |
| 307 | + "print(1)\n", |
222 | 308 | "try:\n", |
223 | 309 | " # Call the rl_train function (it handles everything internally)\n", |
224 | 310 | " rl_train(config)\n", |
225 | 311 | " \n", |
226 | 312 | " print(\"\\n\" + \"=\"*80)\n", |
227 | | - " print(\"✅ GRPO Training Completed Successfully!\")\n", |
| 313 | + " print(\"✅ Training Completed Successfully!\")\n", |
228 | 314 | " print(\"=\"*80)\n", |
229 | 315 | " print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n", |
230 | 316 | " print(f\"📊 TensorBoard logs: {config.tensorboard_dir}\")\n", |
231 | 317 | " print(f\"🎯 Model ready for inference!\")\n", |
232 | 318 | " \n", |
233 | 319 | "except Exception as e:\n", |
234 | 320 | " print(\"\\n\" + \"=\"*80)\n", |
235 | | - " print(\"❌ GRPO Training Failed!\")\n", |
| 321 | + " print(\"❌Training Failed!\")\n", |
236 | 322 | " print(\"=\"*80)\n", |
237 | 323 | " print(f\"Error: {str(e)}\")\n", |
238 | 324 | " import traceback\n", |
|
0 commit comments