Skip to content

Conversation

@mattjj
Copy link

@mattjj mattjj commented Oct 24, 2025

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mattjj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request updates a test in keras/src/quantizers/quantizers_test.py to correctly handle gradient extraction when the jax.config.jax_vjp3 flag is enabled in JAX. This ensures that the test remains compatible with newer JAX VJP implementations, preventing potential failures related to how gradients are structured in the f_vjp object.

Highlights

  • JAX VJP3 Configuration: The test now checks for the jax.config.jax_vjp3 flag to correctly extract gradients from f_vjp.opaque_residuals when this new JAX VJP (vector-Jacobian product) implementation is enabled.
  • Gradient Extraction Logic: Modified the gradient extraction logic in quantizers_test.py to adapt to changes in JAX's VJP behavior, specifically when jax_vjp3 is active, ensuring compatibility with updated JAX internals.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates a test in keras/src/quantizers/quantizers_test.py to correctly handle gradient extraction when jax.config.jax_vjp3 is enabled. The logic is sound and correctly adapts to the new JAX configuration. I have one minor suggestion to improve code style by fixing an inconsistent indentation.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
@codecov-commenter
Copy link

codecov-commenter commented Oct 24, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 78.05%. Comparing base (18f79d6) to head (9fd4c0f).
⚠️ Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21776      +/-   ##
==========================================
- Coverage   82.64%   78.05%   -4.59%     
==========================================
  Files         577      577              
  Lines       59254    59249       -5     
  Branches     9292     9291       -1     
==========================================
- Hits        48968    46248    -2720     
- Misses       7903    10724    +2821     
+ Partials     2383     2277     -106     
Flag Coverage Δ
keras 77.91% <ø> (-4.56%) ⬇️
keras-jax ?
keras-numpy 57.57% <ø> (-0.01%) ⬇️
keras-openvino 34.31% <ø> (+<0.01%) ⬆️
keras-tensorflow 64.11% <ø> (+<0.01%) ⬆️
keras-torch 63.65% <ø> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

# `f_vjp.args[0].args[0][0]`. Otherwise, they are at
# `f_vjp.args[0].args[0][1]`.
if sys.version_info >= (3, 10):
if jax.config.jax_vjp3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we use a pretty old version of JAX on CPU (0.5.0, I spent hours trying to find a combination that plays nice with both TensorFlow and Torch and failed).
Right now the test is failing with

FAILED keras/src/quantizers/quantizers_test.py::QuantizersTest::test_fake_quant_with_min_max_vars_wide_8bits_multi_channel - AttributeError: 'Config' object has no attribute 'jax_vjp3'

I guess it needs to be changed to:
if getattr(jax.config, "jax_vjp3", False):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants