Skip to content

datasig-ac-uk/baryx

Repository files navigation

Baryx

Barycenter preserving measure reduction in JAX

A JAX implementation of the recombination (barycenter preserving discrete-measure reduction) algorithms.

Installation

Requires Python 3.10+

pip install baryx

Example

import jax
import jax.random as jr

from baryx import recombine

# Required for good results with the default solvers.
jax.config.update("jax_enable_x64", True)

# Nodes and weights define a discrete measure $\mu$.
n, m = 1024, 32
key = jr.key(0)
nodes_key, weights_key = jr.split(key)
nodes = jr.normal(nodes_key, (n, m))
weights = jr.uniform(weights_key, (n,))


# Test functions $f$ used to define the push-forward $f_{\#}\mu$.
def f(x):
    return x


# Apply the test functions $f$ to yield the push-forward $f_{\#}\mu$.
pushed_forward_nodes = f(nodes)

# Solve the recombination problem for the push-forward $f_{\#}{\mu}$.
solution = recombine(pushed_forward_nodes, weights)
err = solution.error(norm=2)
print(err)

# Construct the recombined measure $\hat{\mu}$.
recombined_nodes = nodes[solution.indices]
recombined_weights = solution.weights

What is recombination

Recombination refers to both a discrete measure reduction problem and the algorithm(s) used to solve it.

The recombination problem

Given a discrete measure $\mu$, whose support is a subset of some arbitrary space $X$ with cardinality $n$, and a set of test functions $f = \{f_1, \dots, f_m \mid f_i \colon X \to \mathbb{R}\}$, construct a discrete measure $\hat{\mu}$ whose support has cardinality at most $r+1$ and is a subset of the support of $\mu$ such that

$$\int_\Omega f(x)\mu(x) dx = \int_\Omega f(x)\hat{\mu}(x) dx,$$

where $f(x) = (f_1(x), f_2(x), \dots, f_m(x))$ is an embedding $X \to \mathbb{R}^m$. Note: The number of retained support points is at most $r + 1$, where $r$ is the rank of the embedded node matrix (i.e., the number of linearly independent test functions), satisfying $r \le m$.

A linear programming interpretation: The recombination problem can also be formulated as finding a basic feasible solution $\hat{w}$ to the following linear program $$\bar{A}^Tw = \bar{b},\ w \ge 0,$$ with trivial (zero) objective and constraints $\bar{A} = [1 \mid A]$ and $\bar{b} = [\sum w \mid b ]$, where $A = (f(x_1), \dots, f(x_n))^T$ is an $n \times m$ matrix representing the points for the push-forward measure $f_{\#}\mu$, and $b = \int_\Omega f(x)\mu(x)$ represents the integral to be preserved.

In this interpretation, a given (not necessarily basic) feasible solution $w$—the weights of the initial measure $\mu$—is recombined into a basic feasible solution $\hat{w}$ that satisfies the same linear constraints with minimal support.

The recombination algorithm(s)

The process of applying a numerical algorithm which solves the recombination problem, as defined above, may be referred to as recombination, e.g., "we perform recombination to obtain a reduced measure...".

This library implements the deterministic 1-Tree and m-Tree measure reduction algorithms introduced by Litterer and Lyons 2012 and Tchernychova 2016 respectively, in addition to providing a new generalization referred to as $\alpha$--Tree measure reduction. The tree reduction algorithms can be accessed by setting the solver keyword argument of recombine(nodes, weights, solver=...) to the following:

  • 1-Tree: Caratheodory(...) or TreeReduce(Caratheodory(...), tree_reduction_factor=n/(m+1)),
  • m-Tree reduction: TreeReduce(Caratheodory(...), tree_reduction_factor=2),
  • $\alpha$-Tree reduction (default): TreeReduce(Caratheodory(...), tree_reduction_factor=1+alpha/(m+1)).

Notice that 1-Tree and m-Tree reduction are special cases of $\alpha$-Tree, where $\alpha = n - (m + 1)$ and $\alpha = m + 1$ respectively.

The randomized and hybrid algorithms introduced by Cossentino et al 2020 are not yet implemented in this library; a JAX-free implementation can be found here. Please open an issue if there is another recombination-like algorithm you would like to see implemented in baryx.

Notice

Baryx loosely forks and extends the recombination algorithms implemented in the Coreax coreset library (licensed under Apache-2.0). Baryx provides a variety of algorithmic, numerical, and quality-of-life improvements in addition to having a simpler user API, and fewer dependencies.

See also

Some other recombination related works that you may find interesting:

Packages

  • PyRecombine - the original (CPU only) CPP implementation with python bindings.
  • RoughPy - a toolbox for working with streaming data as rough paths in Python.

About

Barycenter preserving measure reduction in JAX

Topics

Resources

License

Code of conduct

Stars

Watchers

Forks

Contributors

Languages