linox is a Python package that provides a collection of linear operators for JAX, enabling efficient and flexible linear algebra operations with lazy evaluation. This package is designed as an JAX alternative to probnum.linops, but it is currently still under development having less and more instable features. It has no dependencies other than JAX and plum for multiple dispatch.
- Lazy Evaluation: All operators support lazy evaluation, allowing for efficient computation of complex linear transformations
- JAX Integration: Built on top of JAX, providing automatic differentiation, parallelization, JIT compilation, and GPU/TPU support
- Composable Operators: Operators can be combined to form complex linear transformations
Matrix: General matrix operatorIdentity: Identity matrix operatorDiagonal: Diagonal matrix operatorScalar: Scalar multiple of identityZero: Zero matrix operatorOnes: Matrix of ones operator
BlockMatrix: General block matrix operatorBlockMatrix2x2: 2x2 block matrix operatorBlockDiagonal: Block diagonal matrix operator
LowRank: General low rank operatorSymmetricLowRank: Symmetric low rank operatorIsotropicScalingPlusSymmetricLowRank: Isotropic scaling plus symmetric low rankPositiveDiagonalPlusSymmetricLowRank: Positive diagonal plus symmetric low rank
Kronecker: Kronecker product operatorPermutation: Permutation matrix operatorEigenD: Eigenvalue decomposition operator
- Automatic Differentiation: Compute gradients automatically through operator compositions
- JIT Compilation: Speed up computations with just-in-time compilation
- Vectorization: Efficient batch processing of linear operations via e.g.
jax.vmap - GPU/TPU Support: Run computations on accelerators without code changes
- Functional Programming: Pure functions enable better optimization and parallelization
import jax
import jax.numpy as jnp
from linox import Matrix, Diagonal, BlockMatrix
# Create operators
A = Matrix(jnp.array([[1, 2], [3, 4]], dtype=jnp.float32))
D = Diagonal(jnp.array([1, 2], dtype=jnp.float32))
# Compose operators
B = BlockMatrix([[A, D], [D, A]])
# Apply to vector
x = jnp.ones((4,), dtype=jnp.float32)
y = B @ x # Lazy evaluation
# Parallelize over batch of vectors
x_batched = jnp.ones((10, 4), dtype=jnp.float32)
y_batched = jax.vmap(B)(x_batched)