JAX-PT is a rewrite of the FAST-PT codebase to be compatible with JAX's autodifferentiation and JIT compilation tools. This code can be integrated into full JAX data computation pipelines or used on its own. When compiled, the main JAX-PT functions (same as Fast-PT) can see a 5-100x speed increase on Fast-PT 4.0. (depending on the function) For more in depth examples on the features of functionality of Jax-PT, please see examples.
FAST-PT is a code to calculate quantities in cosmological perturbation theory at 1-loop (including, e.g., corrections to the matter power spectrum). The code utilizes Fourier methods combined with analytic expressions to reduce the computation time to scale as N log N, where N is the number of grid points in the input linear power spectrum.
pip install jax-ptpip install jax-pt[dev]JAX-PT allows for you to specify a device to run your computations on. During init pass 'cpu', 'gpu', or any other jax.Device to the device kwarg:
import jax
import jax.numpy as jnp
from jaxpt import JAXPT
# Check available devices
print("Available devices:", jax.devices())
k = jnp.logspace(-3, 1, 1000)
# Create JAXPT instance (defaults to CPU)
jpt = JAXPT(k, warmup="moderate")
# Specify to use GPU
jpt = JAXPT(k, warmup="moderate", device="gpu")
# Add a different jax Device
devices = jax.devices()
jpt = JAXPT(k, warmup="moderate", device=devices[0]) # or any index from devices listPlease remember to install the correct jax CUDA libraries for your CUDA version. For example:
pip install jax[cuda12]Note: Jax currently only supports GPU acceleration for Linux systems, please see the Jax documentation for the latest updates.