Skip to content

Conversation

@praveenhosdrug123
Copy link

@praveenhosdrug123 praveenhosdrug123 commented Oct 17, 2025

Summary

Applies distributed initialization fix to model presets to resolve OOM errors during initialization of 7B+ parameter models on 8GB TPU devices. This PR updates Keras Hub model backbone and adds a dist_initializer.

Issue

Token embedding initialization creates large arrays at time of creation, placing all weights on a single device.
Combined with forward passes during backbone initialization, this causes a 2X to 3X memory spike and triggers OOM on TPUs with limited HBM.

Solution

Updates model backbone (Gemma, Llama, Qwen) to use dist_initializer for token embeddings, sharding weights across TPU devices during instantiation. Validated on 8-device TPU: models that previously OOM'd during backbone initialization now load successfully.

Reference

For memory profiling analysis, cache locality theory, validation logs and alternative solutions considered, refer to: Doc

Related PR: keras-team/keras#21755

Checklist

  • [o] I have added all the necessary unit tests for my change. -- I have added smoke tests since the exact test is difficult to replicate on CPU/GPU runtime. But i have validated the fix using TPUs and the validation logs are inside the document link.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Oct 17, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @praveenhosdrug123, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses critical Out-Of-Memory issues encountered during the initialization of very large language models (7B+ parameters) on resource-constrained 8GB TPU devices within Keras Hub. The core solution involves integrating distributed initializers for token embeddings across various model backbones (Gemma, Llama, Qwen families). By sharding these large weight arrays across available TPU devices at the point of creation, the PR effectively mitigates memory pressure and enables the successful loading of models that previously failed due to excessive memory consumption during their initial setup.

Highlights

  • OOM Issue Resolution: This PR resolves Out-Of-Memory (OOM) errors that occur when initializing large language models (7B+ parameters) on 8GB TPU devices by implementing a distributed initialization strategy.
  • Distributed Initializers: A new utility file dist_initializer.py introduces distributed versions of standard Keras initializers (e.g., DistributedVarianceScaling, DistributedRandomNormal). These initializers automatically shard token embedding weights across TPU devices during instantiation, preventing memory spikes.
  • Model Backbone Updates: The Gemma, Gemma3, Llama, Qwen, and Qwen3 model backbones have been updated to utilize these new distributed initializers for their token embeddings, ensuring that large embedding tables are sharded from the start.
  • Memory Cleanup Enhancement: The jax_memory_cleanup utility function has been modified to unconditionally delete all JAX layer weights (sharded or unsharded) during cleanup, further optimizing memory usage during model loading.
  • Validation and Testing: The fix has been validated on an 8-device TPU, successfully loading models that previously OOM'd. New smoke tests and comprehensive unit tests for the distributed initializers have been added.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces distributed initializers to resolve Out-Of-Memory (OOM) errors on TPUs when initializing large models. The core of the change is a new dist_initializer.py utility that provides sharded weight initialization for JAX backend, which is then applied to the token embedding layers of Gemma, Llama, and Qwen models. The implementation includes a fallback to standard initializers when not in a distributed context, ensuring broader compatibility. The changes are well-structured and include comprehensive tests for the new initializers. My review focuses on improving the maintainability and robustness of the new distributed initializer module.

# This may need to be updated if models with different layout maps are used.
# The failsafe is to just use non-distributed initializers if no layout is found

TOKEN_EMBEDDING_PATH = "token_embedding/embeddings"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The hardcoded TOKEN_EMBEDDING_PATH is fragile. If a new model uses a different name for the token embedding layer or its weights, the distributed initialization will silently fall back to the default, non-distributed one. This could lead to unexpected OOM errors for large models, which this PR aims to fix.

To make this more robust and easier to debug, consider adding a warning when a distribution strategy is active but the embedding path is not found in the layout map. This would alert developers if a new model is not correctly configured for distributed initialization.

Copy link
Author

@praveenhosdrug123 praveenhosdrug123 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded TOKEN_EMBEDDING_PATH is fragile. If a new model uses a different name for the token embedding layer or its weights, the distributed initialization will silently fall back to the default, non-distributed one. This could lead to unexpected OOM errors for large models, which this PR aims to fix.
-- A warning has been added to the subsequent function being called through the distributed initialization.

To make this more robust and easier to debug, consider adding a warning when a distribution strategy is active but the embedding path is not found in the layout map. This would alert developers if a new model is not correctly configured for distributed initialization. -- This is a good edge case for warning, i will include this in the next push. TODO : praveenhosrug123

@amitsrivastava78
Copy link
Collaborator

@praveenhosdrug123 thanks for the PR, i understand you are sharding the tensor at the build time itself so that instead of creating a large tensor you now created sharded tensor on respective device directly.
but i am just curious about the existing flow
build() → add_weight() → assign() → _direct_assign() → jax.device_put(sharded)
so the host tensor is large one but eventually on the device/s it is creating sharded tensor and then loading to respective device, you can take an example of my colab below and explain is your PR fixing host or device memory issue or both
https://colab.research.google.com/gist/amitsrivastava78/f7f1990b154001e098c9d8cfe248e4ee/issuelargegemaamodel.ipynb

@praveenhosdrug123
Copy link
Author

praveenhosdrug123 commented Oct 18, 2025

@amitsrivastava78 -- Thanks for the detailed question and the collab example!

The PR fixes the host OOM issue during the build phase.

Current Failure Point

The model OOMs during the build phase of the Embedding layer,
specifically during the forward pass of text embedding - before any device
sharding logic is reached.

Expected flow:
build() → add_weight() → assign() → _direct_assign() → jax.device_put(sharded)

Actual flow with large models:
Failure happens at build() before the host array reaches the sharding logic
in _direct_assign().

Why This is Difficult to Reproduce on Single-RAM Systems

Testing on a single large RAM (even when logically divided into shards)
doesn't reproduce the real memory pressure because:

  • The entire array still materializes in one physical memory space
  • Memory is then distributed logically, but the materialization already happened
  • On real multi-device TPUs (8 devices × 8GB), each device has separate physical memory
  • The build and forward passes happen at host level, causing real OOM before any
    device sharding occurs

Root Cause & Approach

The root cause behind the memory spike is somewhat speculative. I've documented
the theory and entire approach in detail here.
Please have a look - if something doesn't sit right with you, happy to discuss
or investigate further.

Evidence: TPU Test Results

Error logs from TPU testing show the failure occurs specifically during build
of the ReversibleEmbedding layer:

File "/.../keras_hub/src/layers/modeling/reversible_embedding.py", line 103, in build
super().build(inputs_shape)

Before changes:
errors_during_forward_pass.txt

After changes:
using_distributed_initializers.txt

Memory Analysis

A quick calculation reaffirms the memory profiling behavior observed:

From successful run (with distributed initialization):

  • Individual device memory: 0.37GB per device after text embedding scaling
  • Total memory across 8 devices: 2.96GB
  • Applying observed peak memory factor (3x): 2.96GB × 3 = 8.88GB
  • Available memory per device: 7.48GB
  • Overhead: 8.88GB - 7.48GB = 1.4GB

From error log (without distributed initialization):

  • Attempted allocation: 2.93GB
  • Available memory: 1.63GB
  • Overhead: 2.93GB - 1.63GB = 1.3GB

Conclusion: The ~1.3-1.4GB overhead matches across both scenarios,
confirming the ~3x peak memory factor during initialization.

When using sharding at token embedding intializer array, the memory pressure alleviates and the issue resolves.

Happy to run additional experiments or profiling if it would help!

@praveenhosdrug123
Copy link
Author

Hi @amitsrivastava78,

This PR depends on keras-team/keras#21755 (Keras repo) being merged first.

Suggestion: Would it make sense for you to review #21755 on the Keras side, and have another team member review this KerasHub PR? That way both can progress in parallel.

Let me know if you need any clarification on the approach or additional profiling!

@hertschuh
Copy link
Contributor

@praveenhosdrug123

Thank you for the investigation. This is indeed an issue.

Somebody on the team is working on a fix that's generally applicable to all variables so that you don't have to explicitly use the fix that you provided here.

@praveenhosdrug123
Copy link
Author

@hertschuh - Thanks for the feedback and for taking the time to review the document.

I want to clarify the technical issue:
The OOM problem is about large contiguous memory allocation, not total parameter count. Token embeddings are the largest single array and exceed device memory during initialization, even when the full model would fit after sharding.

Thank you for the context on the general solution. A few follow-up questions to help me understand the timeline:

  1. What's the expected completion date for the general fix?
  2. Will it handle the edge cases mentioned in the document (interleaving, quantization, LoRA)?
  3. Will it detect which variables actually need distribution?

The reason I ask: users are blocked on this today for 7B+ models on 8GB TPU devices.
If the general fix is months out, would it make sense to:

  • Merge this targeted fix as a stopgap
  • Mark it deprecated once the general solution ships
  • Remove it in a future release

Let me know if that's feasible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Gemma Gemma model specific issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants