Skip to content

Conversation

@mnoukhov
Copy link
Contributor

@mnoukhov mnoukhov commented Nov 4, 2025

all integration tests passing:

true dynamic sampling aka active sampling: we can filter out 0 std prompts and continuously sample until we have a full batch_size number of non-zero-std prompts

major changes:

  • moves reward function into accumulate_inference_batches and simplifies the workflow while being better async
  • active sampling now always waits to collect the exact batch size, no max_retries

minor changes

  • we're now calling reward for every sample so I've made the Timer logging it a noop and the reward logging a debug so that it doesn't clutter the logs.
  • extended the Batch dataclass to include all the fields necessary for a training batch. Can alternatively change to have a new dataclass that's just exactly the fields we need for a training batch and change current Batch to be a GenerationBatch
  • episode now refers to a training episode (number of samples we train on) instead of a generation episode (number of samples we've generated) since everything is running completely async. We can move it back to a good approx of a generation episode but different runs will no longer have synced episode numbers for steps / updates
  • reward_fn does not take Batch anymore, which didn't make much sense as we only used ground_truths and datasets. Instead we match ppo.py in explicitly passing ground_truths and datasets

Note

Adds active sampling with per-sample reward computed during accumulation, extends Batch and metrics, updates training/eval flow, tests, and scripts.

  • Core RL changes:
    • Introduce Args.active_sampling (replaces fill_completions) to keep sampling until batches have non-zero-std rewards; assert async_steps > 1.
    • Move reward computation into accumulate_inference_batches; now returns (GenerationResult, Batch, reward_metrics, BatchStatistics).
    • Compute and filter zero-std prompts during accumulation; track filtered counts in BatchStatistics.
    • Update truncated completion masking to filter responses and align all related tensors in-place.
  • Data structures & APIs:
    • Extend model_utils.Batch with decoded_responses and scores.
    • Add BatchStatistics dataclass for prompt/response lengths and filtering stats.
    • Change apply_verifiable_reward and reward_fn signatures to accept ground_truths and datasets instead of Batch.
    • Add utils.combine_reward_metrics to aggregate per-prompt reward metrics.
    • Compute Args.max_possible_score for solved/unsolved stats.
  • Metrics & logging:
    • Log aggregated reward metrics and batch filtering stats; include time/reward and updated "real/unsolved batch size" ratios.
    • Evaluation uses eval_batch.scores and eval_batch.decoded_responses directly.
  • Training loop:
    • Main thread refills prompts based on num_filtered_prompts; adjusts episode counting and packing outputs accordingly.
    • load_data_from_packing_thread now returns num_filtered_prompts.
  • Tests:
    • Update tests to provide tokenizer/reward_fn and to validate new accumulation outputs and batch sizing.
  • Scripts:
    • Enable --active_sampling and set --async_steps in debug/integration scripts; minor parameter tweaks.

Written by Cursor Bugbot for commit 9c5ba65. This will update automatically on new commits. Configure here.

@mnoukhov mnoukhov marked this pull request as ready for review November 7, 2025 17:52
raw_queries: list[str] | None
decoded_responses: list[str] | None
indices: list[int] | None
scores: list[float] | None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Batch Slicing: Incomplete Data Causes Errors

The __getitem__ method in the Batch dataclass doesn't include the newly added decoded_responses and scores fields when creating sliced/indexed batches. This causes a TypeError because the Batch constructor requires all fields but slicing operations only pass the original fields, omitting the new ones.

Fix in Cursor Fix in Web

move weight sync directly after update
episode now refers to "training episode", not "generation episode" as
previously
becomes the same as reward when num_responses_per_prompt is 1
just because cursor keeps complaining
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant refactoring to simplify active sampling. The reward calculation is now integrated into accumulate_inference_batches, which streamlines the data preparation thread. Active sampling is also improved to continuously sample until a full batch is collected. My review focuses on improving code clarity and maintainability. I've suggested renaming a confusing variable and refactoring a function call to make the data flow more consistent. Overall, the changes are a good improvement to the codebase.

Comment on lines 1711 to 1729
scores, reward_metrics = asyncio.run(
reward_fn(
result.responses,
decoded_responses,
# note that you only need ground_truths and datasets for the reward model
Batch(
queries=None,
ground_truths=k_ground_truths,
datasets=k_datasets,
raw_queries=None,
decoded_responses=None,
indices=None,
scores=None,
),
result.finish_reasons,
result.request_info,
k_raw_queries,
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The way reward_fn is called with a partially constructed Batch object and k_raw_queries as a separate argument is a bit inconsistent. The Batch dataclass has a raw_queries field, which is set to None here, while the actual queries are passed separately.

For better code clarity and data flow consistency, I suggest including k_raw_queries in the Batch object. This would look like:

scores, reward_metrics = asyncio.run(
    reward_fn(
        result.responses,
        decoded_responses,
        Batch(
            queries=None,
            ground_truths=k_ground_truths,
            datasets=k_datasets,
            raw_queries=k_raw_queries,
            decoded_responses=None,
            indices=None,
            scores=None,
        ),
        result.finish_reasons,
        result.request_info,
    )
)

This would require a small change in make_reward_fn to use batch.raw_queries instead of the separate queries argument, and you could then remove the queries argument from reward_fn's signature.

makes grpo and ppo reward functions the same
we now return k repeats of a prompts, not just 1 in the batch
)

# Filter out zero std prompts
if filter_zero_std_samples and np.array(scores).std() == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just do np.std(scores) without the array!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

eval_result.request_info,
)
)
# eval_decoded_responses = tokenizer.batch_decode(eval_result.responses, skip_special_tokens=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why comment these out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be deleted not commented out, it was deleted in a later commit, we're doing all of this stuff in accumulate_inference_batches

decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)

k_queries = [query for _ in range(generation_config.n)]
k_ground_truths = [ground_truth for _ in range(generation_config.n)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the repeat_each function here? like we do in packing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

k_datasets = [dataset for _ in range(generation_config.n)]
k_raw_queries = [raw_query for _ in range(generation_config.n)]

# with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

mnoukhov and others added 5 commits November 7, 2025 16:16
Co-authored-by: Finbarr Timbers <finbarrtimbers@gmail.com>
Co-authored-by: Finbarr Timbers <finbarrtimbers@gmail.com>
@finbarrtimbers finbarrtimbers self-requested a review November 7, 2025 21:32
@mnoukhov mnoukhov added this pull request to the merge queue Nov 9, 2025
Merged via the queue into main with commit da87d77 Nov 9, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants