Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Makes test tolerances stricter for TE/JAX tests

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Remove relaxed tolerance overrides from test_layer.py tests
  • Adjust default dtype tolerances to be stricter

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 3 commits October 27, 2025 08:59
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L2 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR tightens numerical validation for JAX tests by removing relaxed tolerance overrides and adjusting default epsilon calculations. The changes remove explicit rtol=1e-4, atol=1e-3 overrides from FP8-specific tests in test_layer.py, making them use the default tolerances from utils.py. Simultaneously, utils.py now calculates eps_relaxed directly from machine epsilon (finfo.eps) instead of the previous formula (math.pow(finfo.eps, 2/3)), which results in stricter tolerances across all non-NVFP4 dtypes. This change integrates with the existing dtype_tols() function that all test methods already use, requiring no structural modifications beyond parameter removal. The stricter tolerances ensure FP8 quantization and other numerical operations maintain higher accuracy standards, reflecting confidence in kernel maturity.

Important Files Changed

Filename Score Overview
tests/jax/test_layer.py 5/5 Removed explicit tolerance overrides (rtol/atol) from FP8 forward/backward test methods, deferring to stricter defaults
tests/jax/utils.py 4/5 Changed eps_relaxed calculation from math.pow(finfo.eps, 2/3) to finfo.eps and added explicit float() casts for type safety

Confidence score: 4/5

  • This PR is safe to merge with minimal risk, as it only tightens test tolerances without changing functional code
  • Score reflects the possibility that stricter tolerances could expose legitimate numerical precision issues in existing FP8 kernels that were previously masked by relaxed tolerances, though this is more a validation concern than a code defect
  • Pay closer attention to CI test results to ensure all existing tests pass with the new stricter tolerances, particularly FP8-related tests that previously used relaxed values

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant