Bridging numerical relativity and automatic differentiation using JAX
autograv is a Python library that uses JAX and automatic differentiation to compute various tensors and quantities from Einstein's general theory of relativity. Given a metric function, it can calculate Christoffel symbols, curvature tensors, and solve the Einstein field equations with high numerical precision.
- Automatic Differentiation: Uses JAX's
jax.jacfwdfor forward-mode automatic differentiation to compute derivatives of metric tensors with exact numerical precision - Tensor Calculus: Leverages
jax.numpy.einsumfor efficient Einstein summation notation operations - High Precision: Configured to use 64-bit floating point arithmetic for maximum accuracy
- Pure Functions: All computations are functional and composable
Given a metric tensor function, autograv can compute:
- Christoffel symbols (affine connection coefficients)
- Torsion tensor (verification that connection is symmetric)
- Riemann curvature tensor (intrinsic curvature of spacetime)
- Ricci tensor and Ricci scalar (curvature related to volume change)
- Einstein tensor (left-hand side of Einstein field equations)
- Stress-energy-momentum tensor (mass-energy content)
- Kretschmann invariant (scalar curvature for detecting singularities)
# Using uv
uv pip install autograv
# Or using pip
pip install autogravimport jax.numpy as jnp
from autograv import (
spherical_polar_metric,
christoffel_symbols,
riemann_tensor,
einstein_tensor,
)
# Define coordinates
coordinates = jnp.array([5, jnp.pi/3, jnp.pi/2], dtype=jnp.float64)
# Compute Christoffel symbols for the 2-sphere
christoffels = christoffel_symbols(coordinates, spherical_polar_metric)
print(christoffels)
# Compute Riemann tensor
riemann = riemann_tensor(coordinates, spherical_polar_metric)
print(riemann)The examples/ directory contains complete examples:
sphere_example.py: Computing quantities for a 2-sphere metricschwarzschild_example.py: Computing quantities for the Schwarzschild black hole metric
Run them with:
uv run python examples/sphere_example.py
uv run python examples/schwarzschild_example.pyTraditional approaches to computing derivatives in physics use either:
- Symbolic differentiation: Exact but computationally expensive
- Numerical differentiation: Fast but prone to floating-point errors
Automatic differentiation (autodiff) combines the best of both worlds by:
- Tracing computational operations to build a directed acyclic graph (DAG)
- Computing gradients via the chain rule by traversing the graph
- Achieving exact numerical precision at machine precision limits
JAX provides:
jax.jacfwd: Forward-mode autodiff for computing Jacobiansjax.numpy.einsum: Efficient Einstein summation for tensor operations- NumPy-compatible API with GPU/TPU acceleration support
Given a metric tensor g_ij, the Christoffel symbols are:
Γ^j_kl = (1/2) g^jm (∂g_mk/∂x^l + ∂g_lm/∂x^k - ∂g_kl/∂x^m)
In code:
def christoffel_symbols(coordinates, metric):
g = metric(coordinates)
g_inv = jnp.linalg.inv(g)
jacobian = jax.jacfwd(metric)(coordinates) # Automatic differentiation!
return 0.5 * jnp.einsum('jm, klm -> jkl', g_inv,
jnp.einsum('klm -> mkl', jacobian) +
jnp.einsum('klm -> lmk', jacobian) - jacobian)minkowski_metric(coordinates): Flat spacetime metricspherical_polar_metric(coordinates): 2-sphere metric in (r, θ, φ)
christoffel_symbols(coordinates, metric): Affine connection coefficientstorsion_tensor(coordinates, metric): Antisymmetric part of connectionriemann_tensor(coordinates, metric): Curvature tensorricci_tensor(coordinates, metric): Trace of Riemann tensorricci_scalar(coordinates, metric): Scalar curvaturekretschmann_invariant(coordinates, metric): Curvature invarianteinstein_tensor(coordinates, metric): G_ij = R_ij - (1/2)g_ij Rstress_energy_momentum_tensor(coordinates, metric): T_ij from Einstein equations
close_to_zero(func): Decorator to suppress near-zero numerical noiseTOLERANCE: Threshold for zero suppression (default: 1e-8)
- Python 3.11+
- JAX (CPU-only on Windows, GPU/TPU support on Linux/macOS)
- NumPy
- Add more standard metrics (Kerr, Kerr-Newman, FRW, etc.)
- Implement Weyl tensor and Weyl invariant
- Support for JIT compilation with
@jax.jit - GPU/TPU acceleration examples
- Integration with differential equation solvers
- Visualization tools for curvature
@software{autograv,
author = {Kataru, Baalateja},
title = {autograv: Numerical Relativity with Automatic Differentiation},
year = {2026},
publisher = {GitHub},
url = {https://github.com/planckeon/autograv},
note = {JAX-based numerical relativity framework}
}See also: Medium Article
MIT
Based on concepts from the blog post "Bridging numerical relativity and automatic differentiation using JAX". This project demonstrates the synergy between modern machine learning tools and classical physics computations.