Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization of QDax repertoires #353

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ jobs:
pytest tests/archives tests/emitters tests/schedulers
- name: Install extras deps
run: pip install -r pinned_reqs/extras_visualize.txt
- name: Test extras
- name: Test visualize extra
run: pytest tests/visualize
- name: Install QDax
run: pip install qdax
- name: Test visualize extra for QDax
run: pytest tests/visualize_qdax
coverage:
runs-on: ubuntu-latest
steps:
Expand All @@ -87,7 +91,11 @@ jobs:
- name: Test coverage
env:
NUMBA_DISABLE_JIT: 1
run: pytest tests
# Exclude `visualize_qdax` since we don't install QDax here. We also
# exclude `tests` since we don't want the base directory here.
run:
pytest $(find tests -maxdepth 1 -type d -not -name 'tests' -not -name
'visualize_qdax')
benchmarks:
runs-on: ubuntu-latest
steps:
Expand Down
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### API

- Drop Python 3.7 support and upgrade dependencies (#350)
- Add visualization of QDax repertoires (#353)

#### Documentation

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,5 @@
"python": ("https://docs.python.org/3/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"qdax": ("https://qdax.readthedocs.io/en/latest/", None),
}
50 changes: 50 additions & 0 deletions ribs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from matplotlib.cm import ScalarMappable
from scipy.spatial import Voronoi # pylint: disable=no-name-in-module

from ribs.archives import CVTArchive

# Matplotlib functions tend to have a ton of args.
# pylint: disable = too-many-arguments

Expand Down Expand Up @@ -782,3 +784,51 @@ def parallel_axes_plot(archive,
ax=host_ax,
pad=cbar_pad,
orientation=cbar_orientation)


def qdax_repertoire_heatmap(
repertoire,
ranges,
*args,
**kwargs,
):
# pylint: disable = line-too-long
"""Plots a heatmap of a QDax MapElitesRepertoire.

Internally, this function converts a
:class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into
a :class:`~ribs.archives.CVTArchive` and plots it with
:meth:`cvt_archive_heatmap`.

Args:
repertoire (qdax.core.containers.mapelites_repertoire.MapElitesRepertoire):
A MAP-Elites repertoire output by an algorithm in QDax.
ranges (array-like of (float, float)): Upper and lower bound of each
dimension of the measure space, e.g. ``[(-1, 1), (-2, 2)]``
indicates the first dimension should have bounds :math:`[-1,1]`
(inclusive), and the second dimension should have bounds
:math:`[-2,2]` (inclusive).
*args: Positional arguments to pass to :meth:`cvt_archive_heatmap`.
**kwargs: Keyword arguments to pass to :meth:`cvt_archive_heatmap`.
"""
# pylint: enable = line-too-long

# Construct a CVTArchive. We set solution_dim to 0 since we are only
# plotting and do not need to have the solutions available.
cvt_archive = CVTArchive(
solution_dim=0,
cells=repertoire.centroids.shape[0],
ranges=ranges,
custom_centroids=repertoire.centroids,
)

# Add everything to the CVTArchive.
occupied = repertoire.fitnesses != -np.inf
cvt_archive.add(
np.empty((occupied.sum(), 0)),
repertoire.fitnesses[occupied],
repertoire.descriptors[occupied],
)

# Plot the archive.
cvt_archive_heatmap(cvt_archive, *args, **kwargs)
7 changes: 6 additions & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Tests

This directory contains tests and micro-benchmarks for ribs. The tests mirror
This directory contains tests and micro-benchmarks for pyribs. The tests mirror
the directory structure of `ribs`. To run these tests, install the dev
dependencies for ribs with `pip install ribs[dev]` or `pip install -e .[dev]`
(from the root directory of the repo).

For information on running tests, see [CONTRIBUTING.md](../CONTRIBUTING.md).

## Visualization Tests

We divide the visualization tests into `visualize` and `visualize_qdax`, where
`visualize_qdax` tests visualizations of QDax components.

## Additional Tests

This directory also contains:
Expand Down
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions tests/visualize_qdax/visualize_qdax_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Tests for ribs.visualize that use qdax.

Instructions are identical as in visualize_test.py, but images are stored in
tests/visualize_qdax_test/baseline_images/visualize_qdax_test instead.
"""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pytest
from matplotlib.testing.decorators import image_comparison
from qdax.core.containers.mapelites_repertoire import (MapElitesRepertoire,
compute_cvt_centroids)

from ribs.visualize import qdax_repertoire_heatmap


@pytest.fixture(autouse=True)
def clean_matplotlib():
"""Cleans up matplotlib figures before and after each test."""
# Before the test.
plt.close("all")

yield

# After the test.
plt.close("all")


@image_comparison(baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"])
def test_qdax_repertoire_heatmap():
plt.figure(figsize=(8, 6))

# Compute the CVT centroids.
centroids, _ = compute_cvt_centroids(
num_descriptors=2,
num_init_cvt_samples=1000,
num_centroids=100,
minval=-1,
maxval=1,
random_key=jax.random.PRNGKey(42),
)

# Create initial population.
init_pop_x, init_pop_y = jnp.meshgrid(jnp.linspace(-1, 1, 50),
jnp.linspace(-1, 1, 50))
init_pop = jnp.stack((init_pop_x.flatten(), init_pop_y.flatten()), axis=1)

# Create repertoire with the initial population inserted.
repertoire = MapElitesRepertoire.init(
genotypes=init_pop,
# Negative sphere function.
fitnesses=-jnp.sum(jnp.square(init_pop), axis=1),
descriptors=init_pop,
centroids=centroids,
)

# Plot heatmap.
qdax_repertoire_heatmap(repertoire, ranges=[(-1, 1), (-1, 1)])
Loading