-
Notifications
You must be signed in to change notification settings - Fork 457
Active sampling simplified #1143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| raw_queries: list[str] | None | ||
| decoded_responses: list[str] | None | ||
| indices: list[int] | None | ||
| scores: list[float] | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Batch Slicing: Incomplete Data Causes Errors
The __getitem__ method in the Batch dataclass doesn't include the newly added decoded_responses and scores fields when creating sliced/indexed batches. This causes a TypeError because the Batch constructor requires all fields but slicing operations only pass the original fields, omitting the new ones.
move weight sync directly after update episode now refers to "training episode", not "generation episode" as previously
becomes the same as reward when num_responses_per_prompt is 1 just because cursor keeps complaining
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant refactoring to simplify active sampling. The reward calculation is now integrated into accumulate_inference_batches, which streamlines the data preparation thread. Active sampling is also improved to continuously sample until a full batch is collected. My review focuses on improving code clarity and maintainability. I've suggested renaming a confusing variable and refactoring a function call to make the data flow more consistent. Overall, the changes are a good improvement to the codebase.
| scores, reward_metrics = asyncio.run( | ||
| reward_fn( | ||
| result.responses, | ||
| decoded_responses, | ||
| # note that you only need ground_truths and datasets for the reward model | ||
| Batch( | ||
| queries=None, | ||
| ground_truths=k_ground_truths, | ||
| datasets=k_datasets, | ||
| raw_queries=None, | ||
| decoded_responses=None, | ||
| indices=None, | ||
| scores=None, | ||
| ), | ||
| result.finish_reasons, | ||
| result.request_info, | ||
| k_raw_queries, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way reward_fn is called with a partially constructed Batch object and k_raw_queries as a separate argument is a bit inconsistent. The Batch dataclass has a raw_queries field, which is set to None here, while the actual queries are passed separately.
For better code clarity and data flow consistency, I suggest including k_raw_queries in the Batch object. This would look like:
scores, reward_metrics = asyncio.run(
reward_fn(
result.responses,
decoded_responses,
Batch(
queries=None,
ground_truths=k_ground_truths,
datasets=k_datasets,
raw_queries=k_raw_queries,
decoded_responses=None,
indices=None,
scores=None,
),
result.finish_reasons,
result.request_info,
)
)This would require a small change in make_reward_fn to use batch.raw_queries instead of the separate queries argument, and you could then remove the queries argument from reward_fn's signature.
makes grpo and ppo reward functions the same
open_instruct/grpo_fast.py
Outdated
| ) | ||
|
|
||
| # Filter out zero std prompts | ||
| if filter_zero_std_samples and np.array(scores).std() == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just do np.std(scores) without the array!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
open_instruct/grpo_fast.py
Outdated
| eval_result.request_info, | ||
| ) | ||
| ) | ||
| # eval_decoded_responses = tokenizer.batch_decode(eval_result.responses, skip_special_tokens=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why comment these out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be deleted not commented out, it was deleted in a later commit, we're doing all of this stuff in accumulate_inference_batches
open_instruct/grpo_fast.py
Outdated
| decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) | ||
|
|
||
| k_queries = [query for _ in range(generation_config.n)] | ||
| k_ground_truths = [ground_truth for _ in range(generation_config.n)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use the repeat_each function here? like we do in packing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
open_instruct/grpo_fast.py
Outdated
| k_datasets = [dataset for _ in range(generation_config.n)] | ||
| k_raw_queries = [raw_query for _ in range(generation_config.n)] | ||
|
|
||
| # with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
Co-authored-by: Finbarr Timbers <finbarrtimbers@gmail.com>
Co-authored-by: Finbarr Timbers <finbarrtimbers@gmail.com>
all integration tests passing:
true dynamic sampling aka active sampling: we can filter out 0 std prompts and continuously sample until we have a full batch_size number of non-zero-std prompts
major changes:
accumulate_inference_batchesand simplifies the workflow while being better asyncmax_retriesminor changes
noopand the reward logging adebugso that it doesn't clutter the logs.Batchdataclass to include all the fields necessary for a training batch. Can alternatively change to have a new dataclass that's just exactly the fields we need for a training batch and change currentBatchto be aGenerationBatchepisodenow refers to atraining episode(number of samples we train on) instead of ageneration episode(number of samples we've generated) since everything is running completely async. We can move it back to a good approx of ageneration episodebut different runs will no longer have synced episode numbers for steps / updatesreward_fndoes not takeBatchanymore, which didn't make much sense as we only usedground_truthsanddatasets. Instead we matchppo.pyin explicitly passingground_truthsanddatasetsNote
Adds active sampling with per-sample reward computed during accumulation, extends Batch and metrics, updates training/eval flow, tests, and scripts.
Args.active_sampling(replacesfill_completions) to keep sampling until batches have non-zero-std rewards; assertasync_steps > 1.accumulate_inference_batches; now returns(GenerationResult, Batch, reward_metrics, BatchStatistics).BatchStatistics.model_utils.Batchwithdecoded_responsesandscores.BatchStatisticsdataclass for prompt/response lengths and filtering stats.apply_verifiable_rewardandreward_fnsignatures to acceptground_truthsanddatasetsinstead ofBatch.utils.combine_reward_metricsto aggregate per-prompt reward metrics.Args.max_possible_scorefor solved/unsolved stats.time/rewardand updated "real/unsolved batch size" ratios.eval_batch.scoresandeval_batch.decoded_responsesdirectly.num_filtered_prompts; adjusts episode counting and packing outputs accordingly.load_data_from_packing_threadnow returnsnum_filtered_prompts.--active_samplingand set--async_stepsin debug/integration scripts; minor parameter tweaks.Written by Cursor Bugbot for commit 9c5ba65. This will update automatically on new commits. Configure here.