-
Notifications
You must be signed in to change notification settings - Fork 159
feat: add capability to update weights inflight during generation #1381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…for generation Signed-off-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
📝 WalkthroughWalkthroughThe changes implement in-flight weight updates for async GRPO, add defensive error handling in the TensorBoard logger, and pass checkpointing configuration to checkpoint saving operations. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes The changes span multiple files with a mix of logic additions (conditional wait behavior in async utilities), parameter passing (checkpointing config), and error handling (logger improvements). While the individual changes are straightforward, they affect distinct concerns requiring separate reasoning for each modification, and the conditional logic in async utilities requires understanding the interaction between vLLM engine configurations and weight update flows. Pre-merge checks and finishing touches✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
nemo_rl/utils/logger.py (1)
136-147
: Good defensive error handling for TensorBoard logging.The changes correctly filter out non-scalar metrics and add exception handling to prevent logging failures from disrupting the training pipeline. This aligns with the robustness improvements mentioned in the PR objectives.
The static analysis tool flags the blind
Exception
catch at line 145. While defensive logging is appropriate, you could be more specific:try: self.writer.add_scalar(name, value, step) - except Exception as e: + except (ValueError, TypeError, RuntimeError) as e: print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}") continueThis catches the most common TensorBoard logging errors while avoiding masking unexpected issues.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/configs/grpo_math_1B.yaml
(1 hunks)nemo_rl/algorithms/async_utils.py
(2 hunks)nemo_rl/algorithms/grpo.py
(1 hunks)nemo_rl/utils/logger.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/algorithms/async_utils.py
nemo_rl/algorithms/grpo.py
nemo_rl/utils/logger.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py
: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/async_utils.py
nemo_rl/algorithms/grpo.py
nemo_rl/utils/logger.py
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml
: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/grpo_math_1B.yaml
🪛 Ruff (0.14.0)
nemo_rl/utils/logger.py
145-145: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (4)
nemo_rl/algorithms/grpo.py (1)
1746-1746
: LGTM! Consistent configuration propagation.Passing the checkpointing configuration to
save_checkpoint
aligns with the config-driven design principle stated in the coding guidelines ("YAML is the single source of truth for defaults"). This change makes the async GRPO path consistent with the synchronous path (line 943) and properly integrates checkpointing configuration into the persistence flow.examples/configs/grpo_math_1B.yaml (1)
20-20
: Well-documented configuration addition.The new
in_flight_weight_updates
option is properly documented with a descriptive comment and has a safe default value (false
). This follows the coding guidelines requiring that "exemplar configs under examples/configs/*.yaml must include documented defaults."As per coding guidelines.
nemo_rl/algorithms/async_utils.py (2)
526-573
: Clean implementation of in-flight weight update logic.The conditional waiting behavior is well-structured and clearly documented:
- The docstring explains the difference between async and non-async engines
- Safe config access with
.get()
and default values prevents errors- Clear print statements aid debugging and observability
- The logic correctly skips waiting when both
async_engine
andin_flight_weight_updates
are enabledThis implementation properly enables the throughput improvements mentioned in the PR objectives by allowing ongoing generations to continue during weight updates.
478-480
: Helpful observability improvement.The additional message clarifies the behavior of vLLM V1 async engine during weight updates, improving the user experience by explaining that active generation threads can continue executing. This aligns well with the in-flight weight update feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! IIUC, the throughput is 2-3x better, than the sync baseline. what is the difference with just regular async RL and waiting for the generations to finish on llama8b?
could you also update
Line 4 in dee3fd9
Line 82 in dee3fd9
max_trajectory_age_steps: int |
Hi @parthchadha , could we have assertion when |
What does this PR do ?
This PR adds the capability to do inflight weight updates which prevents stall in async RL pipeline and provides increased throughput.
Convergence plots for LLama8B, 4K seq len:
The plot below shows 3 type of runs:
Timing tokens per sec per GPU: (higher the better)
Sync at ~187
Async best at ~ 478
Total step time:
Sync at ~60s
Async best at ~ 21s
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
in_flight_weight_updates
configuration option to control async GRPO weight update behaviorBug Fixes
Chores