Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
- Always run the linter and make sure the tests pass before finishing a task.
- Prefer running single tests, not the whole suite, when developing.
- To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes.
- Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`.
- Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo.sh`.
- Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`.
54 changes: 27 additions & 27 deletions docs/algorithms/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@ GRPO is an online RL method used in [DeepSeek R1 paper](https://arxiv.org/abs/25

## Implemented Variants

- `grpo_fast.py` is a faster variant using [packing techniques](https://huggingface.co/blog/sirluk/llm-sequence-packing).
- `grpo.py` is a faster variant using [packing techniques](https://huggingface.co/blog/sirluk/llm-sequence-packing).
- `grpo_vllm_thread_ray_gtrl.py` is a more vanilla GRPO implementation, using vLLM and Ray.



## `grpo_fast.py`
## `grpo.py`

This implementation has the following features:

- Uses packing techniques to speed up the training process, inspired by [Open-Reasoner-Zero/Open-Reasoner-Zero](https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero)
- Uses a thread-based approach to parallelize the training and inference processes, based on [Asynchronous RLHF](https://arxiv.org/abs/2410.18252).
- Uses a data preparation thread to prepare the data for the training process.

In simpler tasks, we see 2x faster training, and even 10x faster for more complex tasks. With `grpo_fast.py`, we can run crank up `number_samples_per_prompt` and train on really large batch sizes.
In simpler tasks, we see 2x faster training, and even 10x faster for more complex tasks. With `grpo.py`, we can run crank up `number_samples_per_prompt` and train on really large batch sizes.

It implements additional optimizations:

* `grpo_fast.py` also implements an optimization to skip zero gradient batches. If we solve a prompt 100% correct or 0% correct, the std of the group is 0. So `adv = (score - score.mean()) / (score.std + 1e-5) = 0 / 1e-5 = 0`, causing 0 gradients. `grpo_fast.py` will skip these batches before packing the sequences.
* `grpo.py` also implements an optimization to skip zero gradient batches. If we solve a prompt 100% correct or 0% correct, the std of the group is 0. So `adv = (score - score.mean()) / (score.std + 1e-5) = 0 / 1e-5 = 0`, causing 0 gradients. `grpo.py` will skip these batches before packing the sequences.

![](grpo/grpo_fast_gradient.png)
![](grpo/grpo_gradient.png)

Figure taken from [this discord thread by @the_real_jrb](https://discord.com/channels/1179127597926469703/1208183216843005962/1357712190957682839)

* `grpo_fast.py` only applies the verification reward if the format reward is enabled (via `--additive_format_reward False` by default). See ([allenai/open-instruct/pull/659](https://github.com/allenai/open-instruct/pull/659)). A direct additive format reward is undesirable. In GRPO, the scale of the rewards is not relevant due to group normalization. For example, a group of [0, 0, 0, 0, 10], [0, 0, 0, 0, 11], [0, 0, 0, 0, 1] reward will have the same advantage.
* `grpo.py` only applies the verification reward if the format reward is enabled (via `--additive_format_reward False` by default). See ([allenai/open-instruct/pull/659](https://github.com/allenai/open-instruct/pull/659)). A direct additive format reward is undesirable. In GRPO, the scale of the rewards is not relevant due to group normalization. For example, a group of [0, 0, 0, 0, 10], [0, 0, 0, 0, 11], [0, 0, 0, 0, 1] reward will have the same advantage.

Now imagine there are cases where the model generates a really long response (8k) gen length, but only get the format reward right, GRPO will push up the probs for this long response even though the response is not really correct. As a result, when using the format reward directly, we see the response length of unsolved prompts to fluctuate significantly, causing stability issues.

Expand All @@ -41,26 +41,26 @@ You can run the script in a single GPU mode to debug the training process.

```bash
# single GPU
bash scripts/train/debug/grpo_fast.sh
bash scripts/train/debug/grpo.sh
# 3 GPU: 2 for training, 1 for inference (a more realistic setting for async training)
bash scripts/train/debug/grpo_fast_3_gpu.sh
bash scripts/train/debug/grpo_3_gpu.sh
```

### Reproduce `allenai/Llama-3.1-Tulu-3.1-8B` (1 Nodes)

You can reproduce our `allenai/Llama-3.1-Tulu-3.1-8B` model by running the following command:

```bash
bash scripts/train/tulu3/grpo_fast_8b_single_node.sh
bash scripts/train/tulu3/grpo_8b_single_node.sh
```

???+ info

Here the `grpo_fast.py` actually use 6 GPUs for training and 2 GPUs for inference, so it's using less hardware but runs faster than `grpo_vllm_thread_ray_gtrl.py` which uses 2 nodes (12 GPUs for training and 4 GPUs for inference).
Here the `grpo.py` actually use 6 GPUs for training and 2 GPUs for inference, so it's using less hardware but runs faster than `grpo_vllm_thread_ray_gtrl.py` which uses 2 nodes (12 GPUs for training and 4 GPUs for inference).


![grpo_tulu3_8b](grpo/tulu3.1_8b_grpo_fast.png)
![grpo_tulu3_8b_time](grpo/tulu3.1_8b_grpo_fast-time.png)
![grpo_tulu3_8b](grpo/tulu3.1_8b_grpo.png)
![grpo_tulu3_8b_time](grpo/tulu3.1_8b_grpo-time.png)

??? note "👉 Tracked WandB Experiments (Click to expand)"

Expand All @@ -70,13 +70,13 @@ bash scripts/train/tulu3/grpo_fast_8b_single_node.sh

Below are some learning curves for the evaluation metrics during training. Basically, ifeval, gsm8k, and math:flex all go up.

![grpo_plot](grpo/tulu3.1_8b_grpo_fast_eval_curve.png)
![grpo_plot](grpo/tulu3.1_8b_grpo_eval_curve.png)

???+ info

Based on our internal evaluation, the GRPO model is roughly on par with the original `allenai/Llama-3.1-Tulu-3.1-8B` model, though there are some slight differences. Note that your results may vary slightly due to the random seeds used in the training.

![grpo_plot](grpo/tulu3.1_8b_grpo_fast_eval.png)
![grpo_plot](grpo/tulu3.1_8b_grpo_eval.png)


???+ info
Expand All @@ -89,12 +89,12 @@ bash scripts/train/tulu3/grpo_fast_8b_single_node.sh
We have

```bash
bash scripts/train/qwen/grpo_fast_7b.sh
bash scripts/train/qwen/grpo_7b.sh
```


![grpo_qwen2.5_7B_works](grpo/qwen2.5_7b_grpo_fast_zero.png)
![grpo_qwen2.5_7B_works_time](grpo/qwen2.5_7b_grpo_fast_zero-time.png)
![grpo_qwen2.5_7B_works](grpo/qwen2.5_7b_grpo_zero.png)
![grpo_qwen2.5_7B_works_time](grpo/qwen2.5_7b_grpo_zero-time.png)


??? note "👉 Tracked WandB Experiments (Click to expand)"
Expand All @@ -106,7 +106,7 @@ bash scripts/train/qwen/grpo_fast_7b.sh

Below are some learning curves for the evaluation metrics during training. Basically, ifeval, gsm8k, and math:flex all go up.

![grpo_plot](grpo/qwen2.5_7b_grpo_fast_zero_eval_curve.png)
![grpo_plot](grpo/qwen2.5_7b_grpo_zero_eval_curve.png)

???+ info

Expand All @@ -120,12 +120,12 @@ bash scripts/train/qwen/grpo_fast_7b.sh
We have

```bash
bash scripts/train/olmo2/grpo_fast_7b_zero.sh
bash scripts/train/olmo2/grpo_7b_zero.sh
```


![grpo_olmo2_7b_zero](grpo/olmo2_7b_grpo_fast_zero.png)
![grpo_olmo2_7b_zero_time](grpo/olmo2_7b_grpo_fast_zero-time.png)
![grpo_olmo2_7b_zero](grpo/olmo2_7b_grpo_zero.png)
![grpo_olmo2_7b_zero_time](grpo/olmo2_7b_grpo_zero-time.png)

??? note "👉 Tracked WandB Experiments (Click to expand)"

Expand All @@ -135,7 +135,7 @@ bash scripts/train/olmo2/grpo_fast_7b_zero.sh

Below are some learning curves for the evaluation metrics during training. Basically, ifeval, gsm8k, and math:flex all go up.

![grpo_plot](grpo/olmo2_7b_grpo_fast_zero_eval_curve.png)
![grpo_plot](grpo/olmo2_7b_grpo_zero_eval_curve.png)


???+ info
Expand All @@ -148,12 +148,12 @@ bash scripts/train/olmo2/grpo_fast_7b_zero.sh
We have

```bash
bash scripts/train/olmo2/grpo_fast_13b_zero.sh
bash scripts/train/olmo2/grpo_13b_zero.sh
```


![grpo_olmo2_13b_zero](grpo/olmo2_13b_grpo_fast_zero.png)
![grpo_olmo2_13b_zero_time](grpo/olmo2_13b_grpo_fast_zero-time.png)
![grpo_olmo2_13b_zero](grpo/olmo2_13b_grpo_zero.png)
![grpo_olmo2_13b_zero_time](grpo/olmo2_13b_grpo_zero-time.png)

??? note "👉 Tracked WandB Experiments (Click to expand)"

Expand All @@ -163,7 +163,7 @@ bash scripts/train/olmo2/grpo_fast_13b_zero.sh

Below are some learning curves for the evaluation metrics during training. Basically, ifeval, gsm8k, and math:flex all go up.

![grpo_plot](grpo/olmo2_13b_grpo_fast_zero_eval_curve.png)
![grpo_plot](grpo/olmo2_13b_grpo_zero_eval_curve.png)


???+ info
Expand All @@ -175,7 +175,7 @@ bash scripts/train/olmo2/grpo_fast_13b_zero.sh

### Training Metrics

See the Training Metrics for `grpo_vllm_thread_ray_gtrl.py` below for general metrics. `grpo_fast.py` includes the following additional metrics:
See the Training Metrics for `grpo_vllm_thread_ray_gtrl.py` below for general metrics. `grpo.py` includes the following additional metrics:


* `other/real_batch_size_ratio`: In GRPO, as we train we actually get smaller and smaller batch sizes. This is because if we solve a prompt 100% correct or 0% correct, the std of the group is 0. So `adv = (score - score.mean()) / (score.std + 1e-5) = 0 / 1e-5 = 0`, causing 0 gradients. This metric is the ratio of the samples that have gradients vs the total number of samples,
Expand Down
Binary file not shown.
Binary file removed docs/algorithms/grpo/qwen2.5_7b_grpo_fast_zero.png
Binary file not shown.
Binary file not shown.
Binary file removed docs/algorithms/grpo/qwen2.5_7b_grpo_zero-time.png
Binary file not shown.
Binary file removed docs/algorithms/grpo/qwen2.5_7b_grpo_zero.png
Binary file not shown.
Binary file not shown.
Binary file modified docs/algorithms/grpo/tulu3.1_8b_grpo-time.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/algorithms/grpo/tulu3.1_8b_grpo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/algorithms/grpo/tulu3.1_8b_grpo_fast-time.png
Binary file not shown.
Binary file removed docs/algorithms/grpo/tulu3.1_8b_grpo_fast.png
Binary file not shown.
Binary file removed docs/algorithms/grpo/tulu3.1_8b_grpo_fast_eval.png
Binary file not shown.
Binary file not shown.
6 changes: 3 additions & 3 deletions docs/get_started/ai2_internal_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ When submitting to the `ai2/augusta` cluster, mason will try to read your model
The [/scripts/train](/scripts/train) directory contains many examples on how to launch jobs with mason.py. Sometimes the commands can get long and hard to manage, so we wrote a script called [update_command_args.py](/update_command_args.py) that can be used to add or update arguments in a shell script. For example,

```bash
python update_command_args.py scripts/train/tulu3/grpo_fast_8b.sh \
python update_command_args.py scripts/train/tulu3/grpo_8b.sh \
--cluster ai2/augusta \
--priority normal \
--image costah/open_instruct_dev0320_11 --non_stop_penalty False | uv run bash
Expand All @@ -118,8 +118,8 @@ As another example, you can run something like this for a learning rate search:

```bash
for lr in 1e-6 1e-5 1e-4; do
python update_command_args.py scripts/train/tulu3/grpo_fast_8b.sh \
--exp_name grpo_fast_8b_lr_${lr} \
python update_command_args.py scripts/train/tulu3/grpo_8b.sh \
--exp_name grpo_8b_lr_${lr} \
--learning_rate $lr \
--image costah/open_instruct_dev0320_11 --non_stop_penalty False | uv run bash
done
Expand Down
4 changes: 2 additions & 2 deletions mason.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
OPEN_INSTRUCT_COMMANDS = [
"open_instruct/finetune.py",
"open_instruct/dpo_tune_cache.py",
"open_instruct/grpo_fast.py",
"open_instruct/grpo.py",
"open_instruct/ppo.py",
"open_instruct/grpo_vllm_thread_ray_gtrl.py",
"open_instruct/ppo_vllm_thread_ray_gtrl.py",
"open_instruct/reward_modeling.py",
]

OPEN_INSTRUCT_RESUMABLES = ["open_instruct/grpo_fast.py"]
OPEN_INSTRUCT_RESUMABLES = ["open_instruct/grpo.py"]


# ----------------------------------------------------------------------
Expand Down
24 changes: 11 additions & 13 deletions open_instruct/benchmark_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
"""
Benchmark script for testing vLLM generator performance.

This script loads datasets in the same way as grpo_fast.py, sets up a generator
like in test_grpo_fast.py, and streams results to/from the generator to measure
This script loads datasets in the same way as grpo.py, sets up a generator
like in test_grpo.py, and streams results to/from the generator to measure
performance.
"""

Expand All @@ -27,7 +27,7 @@
import vllm
from ray.util import queue as ray_queue

from open_instruct import dataset_transformation, grpo_fast, logger_utils, model_utils, utils, vllm_utils
from open_instruct import dataset_transformation, grpo, logger_utils, model_utils, utils, vllm_utils
from open_instruct.actor_manager import ActorManager
from open_instruct.queue_types import PromptRequest

Expand Down Expand Up @@ -116,7 +116,7 @@ def get_git_commit() -> str:


def save_benchmark_results_to_csv(
results: list[dict[str, Any]], total_time: float, args: grpo_fast.Args, model_config: model_utils.ModelConfig
results: list[dict[str, Any]], total_time: float, args: grpo.Args, model_config: model_utils.ModelConfig
) -> None:
"""Save benchmark results to CSV file."""
git_commit = get_git_commit()
Expand Down Expand Up @@ -199,8 +199,8 @@ def free_all_gpu_memory(device: int | str = 0) -> None:
logger.info(f"[GPU {dev.index}] {free / gib:.2f} GiB free of {total / gib:.2f} GiB after cleanup")


def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation.TokenizerConfig) -> datasets.Dataset:
"""Set up the dataset using the same pipeline as grpo_fast.py."""
def setup_dataset(args: grpo.Args, tokenizer_config: dataset_transformation.TokenizerConfig) -> datasets.Dataset:
"""Set up the dataset using the same pipeline as grpo.py."""
logger.info("Loading and processing dataset...")

# Transform function arguments
Expand Down Expand Up @@ -229,7 +229,7 @@ def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation


def setup_vllm_engines(
args: grpo_fast.Args,
args: grpo.Args,
tokenizer_config: dataset_transformation.TokenizerConfig,
model_config: model_utils.ModelConfig,
max_model_len: int,
Expand Down Expand Up @@ -274,7 +274,7 @@ def setup_vllm_engines(


def simulate_weight_sync(
actor_manager: ray.actor.ActorHandle, vllm_engines: list[ray.actor.ActorHandle], args: grpo_fast.Args
actor_manager: ray.actor.ActorHandle, vllm_engines: list[ray.actor.ActorHandle], args: grpo.Args
) -> float:
"""Simulate weight sync by pausing all actors.

Expand Down Expand Up @@ -348,7 +348,7 @@ def run_benchmark(
param_prompt_Q: ray_queue.Queue,
inference_results_Q: ray_queue.Queue,
actor_manager: ray.actor.ActorHandle,
args: grpo_fast.Args,
args: grpo.Args,
model_config: model_utils.ModelConfig,
timestamp: int,
num_batches: int = 5,
Expand Down Expand Up @@ -565,7 +565,7 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
def print_summary(
results: list[dict[str, Any]],
total_time: float,
args: grpo_fast.Args,
args: grpo.Args,
model_config: model_utils.ModelConfig,
model_dims: utils.ModelDims,
) -> None:
Expand Down Expand Up @@ -642,9 +642,7 @@ def cleanup(vllm_engines: list[ray.actor.ActorHandle], actor_manager: ray.actor.
def main() -> None:
"""Main benchmark function."""
# Parse arguments using ArgumentParserPlus
parser = utils.ArgumentParserPlus(
(grpo_fast.Args, dataset_transformation.TokenizerConfig, model_utils.ModelConfig)
)
parser = utils.ArgumentParserPlus((grpo.Args, dataset_transformation.TokenizerConfig, model_utils.ModelConfig))

args, tokenizer_config, model_config = parser.parse_args_into_dataclasses()

Expand Down
File renamed without changes.
Loading