Skip to content

Conversation

@Mayankvlog
Copy link

Resolved issue #21105 where keras.ops.custom_gradient incorrectly captured Variable objects instead of their values when using the JAX backend. This caused gradient computation failures in custom quantization layers. The problem occurred when passing keras.Variable instances directly to functions decorated with @ops.custom_gradient, where the gradient function (roundpass_grad) captured the variable object rather than its underlying value.

Root Cause: JAX backend's custom gradient implementation didn't properly extract values from Variable objects before passing them to gradient functions.

Solution: Modified keras/src/backend/jax/core.py to automatically extract .value from Variable instances within the custom gradient decorator, eliminating the need for manual .value calls in user code. Added comprehensive tests in keras/src/ops/core_test.py::CoreOpsCorrectnessTest::test_custom_gradient_with_variable to prevent regression. This fix ensures seamless Variable handling across all backends without requiring workarounds.

- torch-xla is not available for Windows platform
- Manually installed tensorflow-cpu, torch, jax, and flax
- Fixed protobuf version conflicts (downgraded to <6.0.0)
- Tests now run successfully without ModuleNotFoundError
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Mayankvlog, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the Keras JAX backend where keras.ops.custom_gradient was failing to properly handle Variable objects, causing gradient calculation errors. The core solution involves enhancing the custom_gradient decorator in keras/src/backend/jax/core.py to automatically extract the underlying tensor value from Variable instances, thus ensuring correct gradient propagation without requiring users to manually access .value. This change improves the consistency and usability of Keras with the JAX backend. Additionally, the PR includes new documentation and tests for this fix, as well as new files documenting and testing known jit_compile limitations within the Torch backend for EfficientNetV2 models.

Highlights

  • JAX Custom Gradient Fix: Resolved an issue where keras.ops.custom_gradient in the JAX backend incorrectly captured Variable objects instead of their values, leading to gradient computation failures (Issue custom_gradient not working with JAX backend #21105).
  • Automatic Variable Value Extraction: Modified keras/src/backend/jax/core.py to automatically extract the .value from Variable instances within the custom_gradient decorator, eliminating the need for manual workarounds in user code.
  • Comprehensive Testing: Added new tests in keras/src/ops/core_test.py and a dedicated test file test_custom_gradient_jax_variable.py to ensure the fix works as expected and prevents regressions.
  • New Documentation for JAX Fix: Introduced CUSTOM_GRADIENT_JAX_FIX.md to document the issue, root cause, solution, and impact of the JAX custom gradient fix.
  • Torch JIT Compile Limitations Documentation: Added TORCH_JIT_COMPILE_LIMITATIONS.md and keras/src/applications/efficientnet_v2_jit_test.py to address and test known limitations of jit_compile=True with the Torch backend, particularly for EfficientNetV2 models (Issue 'jit_compile=True' raises InternalTorchDynamoError with EfficientNetV2 models on torch backend. #21647).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a variable capture issue with custom_gradient in the JAX backend. The implementation is clean, and the added tests are thorough, ensuring the issue is resolved and preventing regressions.

However, the pull request also includes changes related to a jit_compile issue for the Torch backend (issue #21647), which is not reflected in the PR's title or description. Combining unrelated fixes can complicate the review process and make the commit history less clear. In the future, please consider creating separate pull requests for distinct issues to improve clarity and maintainability.

Comment on lines 517 to 528
def wrapper(*args, **kwargs):
# Convert Variable objects to their values
def _convert_arg(arg):
if isinstance(arg, Variable):
return arg.value
return arg

args = tree.map_structure(_convert_arg, args)
kwargs = tree.map_structure(_convert_arg, kwargs)
return fun(*args, **kwargs)

return jax.custom_gradient(fun=wrapper)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation redefines the _convert_arg function on every call to the decorated function, which is inefficient. This helper function can be defined once outside the wrapper to avoid this overhead. Additionally, renaming it to _convert_variable_to_value would make its purpose clearer, following the style guide's preference for descriptive names.1

Suggested change
def wrapper(*args, **kwargs):
# Convert Variable objects to their values
def _convert_arg(arg):
if isinstance(arg, Variable):
return arg.value
return arg
args = tree.map_structure(_convert_arg, args)
kwargs = tree.map_structure(_convert_arg, kwargs)
return fun(*args, **kwargs)
return jax.custom_gradient(fun=wrapper)
def _convert_variable_to_value(arg):
if isinstance(arg, Variable):
return arg.value
return arg
def wrapper(*args, **kwargs):
# Convert Variable objects to their values
args = tree.map_structure(_convert_variable_to_value, args)
kwargs = tree.map_structure(_convert_variable_to_value, kwargs)
return fun(*args, **kwargs)
return jax.custom_gradient(fun=wrapper)

Style Guide References

Footnotes

  1. Argument names should be intuitive and easy to remember, and their meaning should be clear from the name. Overly generic names should be avoided.

Mayankvlog and others added 2 commits October 26, 2025 14:56
…ng errors

- Fixed custom_gradient in JAX backend to extract Variable values automatically
- Improved code structure by moving helper function outside wrapper
- Fixed EfficientNetV2B2 import to use direct module import
- Fixed all Ruff linting errors (line length, unused imports/variables)
- Tests now pass without requiring manual .value access on Variables
@codecov-commenter
Copy link

codecov-commenter commented Oct 26, 2025

Codecov Report

❌ Patch coverage is 53.84615% with 42 lines in your changes missing coverage. Please review.
✅ Project coverage is 63.78%. Comparing base (eecd34f) to head (22a3bf1).

Files with missing lines Patch % Lines
keras/src/losses/lpips.py 62.50% 21 Missing and 6 partials ⚠️
keras/src/backend/jax/core.py 7.14% 13 Missing ⚠️
keras/src/applications/imagenet_utils.py 33.33% 1 Missing and 1 partial ⚠️

❗ There is a different number of reports uploaded between BASE (eecd34f) and HEAD (22a3bf1). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (eecd34f) HEAD (22a3bf1)
keras 5 2
keras-torch 1 0
keras-tensorflow 1 0
keras-jax 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #21783       +/-   ##
===========================================
- Coverage   82.63%   63.78%   -18.86%     
===========================================
  Files         577      578        +1     
  Lines       59316    59404       +88     
  Branches     9300     9313       +13     
===========================================
- Hits        49018    37893    -11125     
- Misses       7910    19081    +11171     
- Partials     2388     2430       +42     
Flag Coverage Δ
keras 61.67% <53.84%> (-20.79%) ⬇️
keras-jax ?
keras-numpy 57.66% <53.84%> (+0.10%) ⬆️
keras-openvino 34.38% <17.58%> (+0.09%) ⬆️
keras-tensorflow ?
keras-torch ?
keras.applications 83.44% <33.33%> (?)
keras.applications-numpy 22.74% <33.33%> (?)
keras.applications-openvino 22.74% <33.33%> (?)
keras.applications-tensorflow 83.44% <33.33%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Mayankvlog and others added 5 commits October 26, 2025 15:25
- Changed input size from 64x64 to 224x224 (minimum supported by EfficientNetV2)
- Fixed EfficientNetV2B0 import to use direct module path
- Resolves ValueError: Input size must be at least 32x32
- Resolves ImportError for EfficientNetV2B0
…input_shape validation

This commit addresses three issues that were causing CI failures:

1. Fixed JAX Backend custom_gradient with Variables (Issue keras-team#21105)
   - Problem: Variables passed to custom_gradient in JAX backend caused
     'TypeError: NoneType object is not callable'
   - Root cause: JAX copies Variables during tracing, causing both _value
     and _initializer to become None
   - Solution:
     * Modified custom_gradient wrapper to properly convert Variables to values
     * Added fallback in __jax_array__ to handle uninitialized Variables
   - Added test: test_custom_gradient_with_variable in keras/src/ops/core_test.py

2. Fixed obtain_input_shape validation for channels_first format
   - Problem: Confusing error when users provide input_shape in wrong format
     (e.g., (224,224,3) when (3,224,224) expected for channels_first)
   - Solution: Added validation to detect format mismatch with clear error message
   - Updated efficientnet_v2_jit_test.py to use correct channels_first format

3. Code format fixes
   - Fixed line length violations
   - Fixed import ordering
   - Removed unused imports

Files modified:
- keras/src/backend/jax/core.py
- keras/src/ops/core_test.py
- keras/src/applications/imagenet_utils.py
- keras/src/applications/efficientnet_v2_jit_test.py
- test_custom_gradient_jax_variable.py

All tests passing with JAX backend.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants