-
Couldn't load subscription status.
- Fork 306
Extended fix OOM Issue #21634 on KerasHub side #2441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @praveenhosdrug123, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses critical Out-Of-Memory issues encountered during the initialization of very large language models (7B+ parameters) on resource-constrained 8GB TPU devices within Keras Hub. The core solution involves integrating distributed initializers for token embeddings across various model backbones (Gemma, Llama, Qwen families). By sharding these large weight arrays across available TPU devices at the point of creation, the PR effectively mitigates memory pressure and enables the successful loading of models that previously failed due to excessive memory consumption during their initial setup. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces distributed initializers to resolve Out-Of-Memory (OOM) errors on TPUs when initializing large models. The core of the change is a new dist_initializer.py utility that provides sharded weight initialization for JAX backend, which is then applied to the token embedding layers of Gemma, Llama, and Qwen models. The implementation includes a fallback to standard initializers when not in a distributed context, ensuring broader compatibility. The changes are well-structured and include comprehensive tests for the new initializers. My review focuses on improving the maintainability and robustness of the new distributed initializer module.
| # This may need to be updated if models with different layout maps are used. | ||
| # The failsafe is to just use non-distributed initializers if no layout is found | ||
|
|
||
| TOKEN_EMBEDDING_PATH = "token_embedding/embeddings" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded TOKEN_EMBEDDING_PATH is fragile. If a new model uses a different name for the token embedding layer or its weights, the distributed initialization will silently fall back to the default, non-distributed one. This could lead to unexpected OOM errors for large models, which this PR aims to fix.
To make this more robust and easier to debug, consider adding a warning when a distribution strategy is active but the embedding path is not found in the layout map. This would alert developers if a new model is not correctly configured for distributed initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded TOKEN_EMBEDDING_PATH is fragile. If a new model uses a different name for the token embedding layer or its weights, the distributed initialization will silently fall back to the default, non-distributed one. This could lead to unexpected OOM errors for large models, which this PR aims to fix.
-- A warning has been added to the subsequent function being called through the distributed initialization.
To make this more robust and easier to debug, consider adding a warning when a distribution strategy is active but the embedding path is not found in the layout map. This would alert developers if a new model is not correctly configured for distributed initialization. -- This is a good edge case for warning, i will include this in the next push. TODO : praveenhosrug123
|
@praveenhosdrug123 thanks for the PR, i understand you are sharding the tensor at the build time itself so that instead of creating a large tensor you now created sharded tensor on respective device directly. |
|
@amitsrivastava78 -- Thanks for the detailed question and the collab example! The PR fixes the host OOM issue during the build phase. Current Failure PointThe model OOMs during the build phase of the Embedding layer, Expected flow: Actual flow with large models: Why This is Difficult to Reproduce on Single-RAM SystemsTesting on a single large RAM (even when logically divided into shards)
Root Cause & ApproachThe root cause behind the memory spike is somewhat speculative. I've documented Evidence: TPU Test ResultsError logs from TPU testing show the failure occurs specifically during build File "/.../keras_hub/src/layers/modeling/reversible_embedding.py", line 103, in build Before changes: After changes: Memory AnalysisA quick calculation reaffirms the memory profiling behavior observed: From successful run (with distributed initialization):
From error log (without distributed initialization):
Conclusion: The ~1.3-1.4GB overhead matches across both scenarios, When using sharding at token embedding intializer array, the memory pressure alleviates and the issue resolves. Happy to run additional experiments or profiling if it would help! |
|
This PR depends on keras-team/keras#21755 (Keras repo) being merged first. Suggestion: Would it make sense for you to review #21755 on the Keras side, and have another team member review this KerasHub PR? That way both can progress in parallel. Let me know if you need any clarification on the approach or additional profiling! |
|
Thank you for the investigation. This is indeed an issue. Somebody on the team is working on a fix that's generally applicable to all variables so that you don't have to explicitly use the fix that you provided here. |
|
@hertschuh - Thanks for the feedback and for taking the time to review the document. I want to clarify the technical issue: Thank you for the context on the general solution. A few follow-up questions to help me understand the timeline:
The reason I ask: users are blocked on this today for 7B+ models on 8GB TPU devices.
Let me know if that's feasible. |
Summary
Applies distributed initialization fix to model presets to resolve OOM errors during initialization of 7B+ parameter models on 8GB TPU devices. This PR updates Keras Hub model backbone and adds a dist_initializer.
Issue
Token embedding initialization creates large arrays at time of creation, placing all weights on a single device.
Combined with forward passes during backbone initialization, this causes a 2X to 3X memory spike and triggers OOM on TPUs with limited HBM.
Solution
Updates model backbone (Gemma, Llama, Qwen) to use dist_initializer for token embeddings, sharding weights across TPU devices during instantiation. Validated on 8-device TPU: models that previously OOM'd during backbone initialization now load successfully.
Reference
For memory profiling analysis, cache locality theory, validation logs and alternative solutions considered, refer to: Doc
Related PR: keras-team/keras#21755
Checklist