Skip to content

Convert wealth_dynamics lecture from numba to JAX#632

Closed
Copilot wants to merge 3 commits intomainfrom
copilot/fix-34715820-086e-4db5-bbf5-7af0855feae8
Closed

Convert wealth_dynamics lecture from numba to JAX#632
Copilot wants to merge 3 commits intomainfrom
copilot/fix-34715820-086e-4db5-bbf5-7af0855feae8

Conversation

Copy link
Contributor

Copilot AI commented Sep 26, 2025

This PR converts the wealth_dynamics lecture from numba to JAX, following the QuantEcon JAX style guidelines and addressing issue #[issue_number].

Overview

The wealth dynamics lecture previously used numba's jitclass and parallel processing for performance. This conversion modernizes the implementation to use JAX's functional programming paradigm, vectorization, and JIT compilation.

Key Changes

Architecture Conversion

  • WealthDynamics class: Converted from numba jitclass to JAX NamedTuple for immutability and better type safety
  • Parameter storage: Now uses computed properties instead of instance variables for derived quantities like z_mean, y_mean, etc.
  • Factory function: Added create_wealth_dynamics() to handle stability condition checking since NamedTuple doesn't support __post_init__

Function Modernization

  • update_states(): Converted to pure JAX function with functional random key handling
  • wealth_time_series(): Replaced iterative loop with JAX lax.scan for efficient time series generation
  • update_cross_section(): Replaced numba's prange parallelization with JAX vmap vectorization

Random Number Generation

  • Functional RNG: All functions now use JAX's functional random number generation with proper key splitting
  • Reproducibility: Explicit random keys ensure reproducible results across runs
  • Performance: JAX's RNG is optimized for vectorized operations

Performance & Compatibility

  • JIT Compilation: Uses @jax.jit with static_argnames for parameters that need to be compile-time constants
  • 64-bit Precision: Enabled via jax.config.update("jax_enable_x64", True) for financial calculations
  • QuantEcon Integration: Handles JAX→NumPy array conversion for qe.gini_coefficient() and qe.lorenz_curve()

Example Usage

The API remains largely the same, with the main difference being explicit random key management:

# Old numba version
wdy = WealthDynamics()
w = wealth_time_series(wdy, wdy.y_mean, 200)

# New JAX version  
wdy = create_wealth_dynamics()
key = jax.random.PRNGKey(42)
w = wealth_time_series(wdy, wdy.y_mean, 200, key)

Benefits

  • GPU Ready: Code can leverage GPU acceleration when available
  • Vectorization: More efficient batch processing of household simulations
  • Functional: Eliminates mutable state for better composability
  • Type Safety: Better type hints and compile-time error checking
  • Future-Proof: Ready for automatic differentiation if needed for optimization

Testing

All functionality has been comprehensively tested:

  • Basic model creation and property computation
  • Time series generation with proper random key handling
  • Cross-section simulation with vectorized household updates
  • Inequality measures (Lorenz curves, Gini coefficients)
  • Parameter variation studies
  • Stability condition enforcement
  • Performance benchmarks showing equivalent speed

Documentation

  • Added JAX admonition explaining key features and benefits
  • Updated installation requirements to include JAX
  • Removed reference to separate JAX version (this now IS the JAX version)
  • Updated performance commentary to reflect JAX capabilities
  • All code examples updated to use new JAX patterns

The conversion maintains full functional compatibility while providing the benefits of JAX's modern computational framework.

Original prompt

This section details on the original issue you should resolve

<issue_title>[wealth_dynamics] Update lecture to use JAX instead of numba</issue_title>
<issue_description>Please update lecture/wealth_dynamics lecture and conversation the code to make use of the JAX python package. This includes looking at code that has high complexity or long run times, and areas where numba is being used to accelerate code.

Our style guide has a number of guidelines we should follow when making these updates

  1. JAX code conversions
  2. JAX admonitions and sequence generation

Please only use JAX where it would be beneficial to do, solving a problem well suited to JAX. </issue_description>

Comments on the Issue (you are @copilot in this section)

Fixes #631

💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 2 commits September 26, 2025 02:48
Co-authored-by: mmcky <8263752+mmcky@users.noreply.github.com>
Co-authored-by: mmcky <8263752+mmcky@users.noreply.github.com>
Copilot AI changed the title [WIP] [wealth_dynamics] Update lecture to use JAX instead of numba Convert wealth_dynamics lecture from numba to JAX Sep 26, 2025
Copilot AI requested a review from mmcky September 26, 2025 02:55
@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-632--sunny-cactus-210e3e.netlify.app (85edd95)

📚 Changed Lecture Pages: wealth_dynamics

@mmcky mmcky added the example label Sep 29, 2025
@mmcky mmcky closed this Sep 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[wealth_dynamics] Update lecture to use JAX instead of numba

2 participants