Barycenter preserving measure reduction in JAX
A JAX implementation of the recombination (barycenter preserving discrete-measure reduction) algorithms.
Requires Python 3.10+
pip install baryximport jax
import jax.random as jr
from baryx import recombine
# Required for good results with the default solvers.
jax.config.update("jax_enable_x64", True)
# Nodes and weights define a discrete measure $\mu$.
n, m = 1024, 32
key = jr.key(0)
nodes_key, weights_key = jr.split(key)
nodes = jr.normal(nodes_key, (n, m))
weights = jr.uniform(weights_key, (n,))
# Test functions $f$ used to define the push-forward $f_{\#}\mu$.
def f(x):
return x
# Apply the test functions $f$ to yield the push-forward $f_{\#}\mu$.
pushed_forward_nodes = f(nodes)
# Solve the recombination problem for the push-forward $f_{\#}{\mu}$.
solution = recombine(pushed_forward_nodes, weights)
err = solution.error(norm=2)
print(err)
# Construct the recombined measure $\hat{\mu}$.
recombined_nodes = nodes[solution.indices]
recombined_weights = solution.weightsRecombination refers to both a discrete measure reduction problem and the algorithm(s) used to solve it.
Given a discrete measure
where
A linear programming interpretation:
The recombination problem can also be formulated as finding a basic feasible solution
In this interpretation, a given (not necessarily basic) feasible solution
The process of applying a numerical algorithm which solves the recombination problem, as defined above, may be referred to as recombination, e.g., "we perform recombination to obtain a reduced measure...".
This library implements the deterministic 1-Tree and m-Tree measure reduction algorithms introduced by Litterer and Lyons 2012 and Tchernychova 2016 respectively, in addition to providing a new generalization referred to as solver keyword argument of recombine(nodes, weights, solver=...) to the following:
-
1-Tree:
Caratheodory(...)orTreeReduce(Caratheodory(...), tree_reduction_factor=n/(m+1)), -
m-Tree reduction:
TreeReduce(Caratheodory(...), tree_reduction_factor=2), -
$\alpha$ -Tree reduction (default):TreeReduce(Caratheodory(...), tree_reduction_factor=1+alpha/(m+1)).
Notice that 1-Tree and m-Tree reduction are special cases of
The randomized and hybrid algorithms introduced by Cossentino et al 2020 are not yet implemented in this library; a JAX-free implementation can be found here. Please open an issue if there is another recombination-like algorithm you would like to see implemented in baryx.
Baryx loosely forks and extends the recombination algorithms implemented in the Coreax coreset library (licensed under Apache-2.0). Baryx provides a variety of algorithmic, numerical, and quality-of-life improvements in addition to having a simpler user API, and fewer dependencies.
Some other recombination related works that you may find interesting:
Packages
- PyRecombine - the original (CPU only) CPP implementation with python bindings.
- RoughPy - a toolbox for working with streaming data as rough paths in Python.