Skip to content

Commit

Permalink
Better spatial range (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Oct 17, 2024
1 parent 4cfc3fa commit 28aec8c
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 72 deletions.
33 changes: 18 additions & 15 deletions docs/source/examples/evo_boids.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,25 @@ speed/acceleration and reward values
Updates
-------

The observation function simply counts any neighbours, and
add aggregates position and velocity data from neighbours in-range
The observation function counts any neighbours, calculates
relative heading and position, then aggregates contributions
from neighbours in-range

.. testcode:: evo_boids

@partial(
esquilax.transforms.spatial,
n_bins=10,
i_range=0.1,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), 0.0, 0.0),
include_self=False,
)
def observe(_k: chex.PRNGKey, _params: Params, _a: Boid, b: Boid):
return 1, b.pos, b.speed, b.heading
def observe(_k: chex.PRNGKey, _params: Params, a: Boid, b: Boid):
dh = esquilax.utils.shortest_vector(
a.heading, b.heading, length=2 * jnp.pi
)
dx = esquilax.utils.shortest_vector(a.pos, b.pos)
return 1, dx, b.speed, dh

The next update then aggregates the observations into an observation
array to be passed to the steering neural network
Expand All @@ -106,20 +111,18 @@ array to be passed to the steering neural network
boid, n_nb, x_nb, s_nb, h_nb = observations

def obs_to_nbs():
_x_nb = x_nb / n_nb
_dx_nb = x_nb / n_nb
_s_nb = s_nb / n_nb
_h_nb = h_nb / n_nb

dx = esquilax.utils.shortest_vector(boid.pos, _x_nb)
d = jnp.sqrt(jnp.sum(dx * dx)) / 0.1
phi = jnp.arctan2(dx[1], dx[0]) + jnp.pi
d = jnp.sqrt(jnp.sum(_dx_nb * _dx_nb)) / 0.1
phi = jnp.arctan2(_dx_nb[1], _dx_nb[0]) + jnp.pi
d_phi = esquilax.utils.shortest_vector(
boid.heading, phi, 2 * jnp.pi
) / jnp.pi
dh = esquilax.utils.shortest_vector(
boid.heading, _h_nb, 2 * jnp.pi
) / jnp.pi
ds = (_s_nb - boid.speed) / (params.max_speed - params.min_speed)
dh = _h_nb / jnp.pi
ds = (_s_nb - boid.speed)
ds = ds / (params.max_speed - params.min_speed)

return jnp.array([d, d_phi, dh, ds])

Expand All @@ -129,7 +132,7 @@ array to be passed to the steering neural network
lambda: jnp.array([-1.0, 0.0, 0.0, 0.0]),
)

if a boid has neighbours this function then converts the observation
if a boid has neighbours, this function then converts the observation
to a vector (in polar co-ordinates) to the average position of the local flock,
and polar co-ordinates to the average heading of the local flock,
taking into account the heading of the boid. If there are no neighbours
Expand Down Expand Up @@ -184,7 +187,7 @@ to calculate reward contributions

@partial(
esquilax.transforms.spatial,
n_bins=5,
i_range=0.1,
reduction=jnp.add,
default=0.0,
include_self=False,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/hard_coded_boids.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Firstly agents observe the state of neighbours within a given range

@partial(
esquilax.transforms.spatial,
n_bins=5,
i_range=0.2,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)),
include_self=False,
Expand Down
26 changes: 14 additions & 12 deletions docs/source/examples/rl_boids.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,35 @@ example, but wrap them up in an environment class

@partial(
esquilax.transforms.spatial,
n_bins=10,
i_range=0.1,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), 0.0, 0.0),
include_self=False,
)
def observe(_k: chex.PRNGKey, _params: Params, _a: Boid, b: Boid):
return 1, b.pos, b.speed, b.heading
def observe(_k: chex.PRNGKey, _params: Params, a: Boid, b: Boid):
dh = esquilax.utils.shortest_vector(
a.heading, b.heading, length=2 * jnp.pi
)
dx = esquilax.utils.shortest_vector(a.pos, b.pos)
return 1, dx, b.speed, dh

@esquilax.transforms.amap
def flatten_observations(_k: chex.PRNGKey, params: Params, observations):
boid, n_nb, x_nb, s_nb, h_nb = observations

def obs_to_nbs():
_x_nb = x_nb / n_nb
_dx_nb = x_nb / n_nb
_s_nb = s_nb / n_nb
_h_nb = h_nb / n_nb

dx = esquilax.utils.shortest_vector(boid.pos, _x_nb)
d = jnp.sqrt(jnp.sum(dx * dx)) / 0.1
phi = jnp.arctan2(dx[1], dx[0]) + jnp.pi
d = jnp.sqrt(jnp.sum(_dx_nb * _dx_nb)) / 0.1
phi = jnp.arctan2(_dx_nb[1], _dx_nb[0]) + jnp.pi
d_phi = esquilax.utils.shortest_vector(
boid.heading, phi, 2 * jnp.pi
) / jnp.pi
dh = esquilax.utils.shortest_vector(
boid.heading, _h_nb, 2 * jnp.pi
) / jnp.pi
ds = (_s_nb - boid.speed) / (params.max_speed - params.min_speed)
dh = _h_nb / jnp.pi
ds = (_s_nb - boid.speed)
ds = ds / (params.max_speed - params.min_speed)

return jnp.array([d, d_phi, dh, ds])

Expand Down Expand Up @@ -114,7 +116,7 @@ example, but wrap them up in an environment class

@partial(
esquilax.transforms.spatial,
n_bins=5,
i_range=0.1,
reduction=jnp.add,
default=0.0,
include_self=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/boids/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Params:

@partial(
esquilax.transforms.spatial,
n_bins=10,
i_range=0.1,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), 0.0, 0.0),
include_self=False,
Expand Down Expand Up @@ -99,7 +99,7 @@ def move(_key: chex.PRNGKey, _params: Params, x):

@partial(
esquilax.transforms.spatial,
n_bins=5,
i_range=0.1,
reduction=jnp.add,
default=0.0,
include_self=False,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "esquilax"
version = "1.0.0"
version = "1.0.1"
description = "JAX multi-agent simulation and ML toolset"
authors = [
"Zombie-Einstein <zombie-einstein@proton.me>"
Expand Down
95 changes: 56 additions & 39 deletions src/esquilax/transforms/_space.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from math import floor
from typing import Any, Callable, Optional, Tuple

import chex
Expand Down Expand Up @@ -54,11 +55,11 @@ def _argument_checks(
def spatial(
f: Callable,
*,
n_bins: int,
reduction: Reduction,
default: Default,
include_self: bool = False,
topology: str = "moore",
n_bins: Optional[int] = None,
i_range: Optional[float] = None,
) -> Callable:
"""
Expand Down Expand Up @@ -108,7 +109,7 @@ def foo(_k, p, a, b):
result = esquilax.transforms.spatial(
foo,
n_bins=2,
i_range=0.5,
reduction=jnp.add,
default=0,
include_self=False,
Expand All @@ -135,7 +136,7 @@ def foo(_k, p, a, b):
@partial(
esquilax.transforms.spatial,
n_bins=2,
i_range=0.5,
reduction=(jnp.add, jnp.add),
default=(0, 0),
include_self=False,
Expand Down Expand Up @@ -165,7 +166,7 @@ def foo(_k, p, _, b):
@partial(
esquilax.transforms.spatial,
n_bins=2,
i_range=0.5,
reduction=jnp.add,
default=0,
topology="moore",
Expand Down Expand Up @@ -214,12 +215,6 @@ def f(
- ``b``: End agent in the interaction
- ``**static_kwargs``: Any arguments required at compile
time by JAX can be passed as keyword arguments.
n_bins
Number of bins each dimension is subdivided
into. Assumes that each dimension contains the
same number of cells. Each cell can only interact
with adjacent cells, so this value also consequently
also controls the number of interactions.
reduction
Binary monoidal reduction function, eg ``jax.numpy.add``.
default
Expand All @@ -234,11 +229,26 @@ def f(
cost of fidelity. Should be one of ``"same-cell"``,
``"von-neumann"`` or ``"moore"``.
i_range
Optional interaction range. By default, the width
of a cell is used as the interaction range, but this
can be increased/decreased using ``i_range`` dependent
on the use-case.
Range at which agents interact. Can be ommited, in which
case the width of a cell is used as the interaction range
(derived from ``n_bins``), but this can be increased/decreased
using ``i_range`` dependent on the use-case.
n_bins
Optional number of bins each dimension is subdivided
into. Assumes that each dimension contains the
same number of cells. Each cell can only interact
with adjacent cells, so this value also consequently
also controls the number of interactions. If not provided
the minimum number of bins if derived from ``i_range``.
"""
if n_bins is None:
assert (
i_range is not None
), "If n_bins is not provided, i_range should be provided"
n_bins = floor(1.0 / i_range)
else:
assert n_bins > 0, f"n_bins should be greater than 0, got {f}"

width = 1.0 / n_bins
i_range = width if i_range is None else i_range
i_range = i_range**2
Expand Down Expand Up @@ -355,9 +365,9 @@ def red(a, _, c):
def nearest_neighbour(
f: Callable,
*,
n_bins: int,
default: Default,
topology: str = "moore",
n_bins: Optional[int] = None,
i_range: Optional[float] = None,
) -> Callable:
"""
Expand Down Expand Up @@ -406,7 +416,7 @@ def foo(_k, p, a, b):
result = esquilax.transforms.nearest_neighbour(
foo,
n_bins=2,
i_range=0.5,
default=-1,
topology="moore"
)(
Expand All @@ -432,7 +442,7 @@ def foo(_k, p, a, b):
@partial(
esquilax.transforms.nearest_neighbour,
n_bins=2,
i_range=0.5,
default=(-1, -2),
topology="moore",
)
Expand Down Expand Up @@ -460,7 +470,7 @@ def foo(_k, p, _, b):
@partial(
esquilax.transforms.nearest_neighbour,
n_bins=2,
i_range=0.5,
default=-1,
topology="moore",
)
Expand All @@ -484,26 +494,6 @@ def foo(_, params, a, b):
Parameters
----------
n_bins
Number of bins each dimension is subdivided
into. Assumes that each dimension contains the
same number of cells. Each cell can only interact
with adjacent cells, so this value also consequently
also controls the number of interactions.
default
Default value(s) returned if no-neighbours are in
range of an agent.
topology
Topology of cells, default ``"moore"``. Since cells
interact with their neighbours, topologies with
fewer neighbours can increase performance at the
cost of fidelity. Should be one of ``"same-cell"``,
``"von-neumann"`` or ``"moore"``.
i_range
Optional interaction range. By default, the width
of a cell is used as the interaction range, but this
can be increased/decreased using ``i_range`` dependent
on the use-case.
f
Interaction to apply to in-proximity pairs, should
have the signature
Expand All @@ -528,8 +518,35 @@ def f(
- ``b``: End agent in the interaction
- ``**static_kwargs``: Any arguments required at compile
time by JAX can be passed as keyword arguments.
default
Default value(s) returned if no-neighbours are in
range of an agent.
topology
Topology of cells, default ``"moore"``. Since cells
interact with their neighbours, topologies with
fewer neighbours can increase performance at the
cost of fidelity. Should be one of ``"same-cell"``,
``"von-neumann"`` or ``"moore"``.
i_range
Range at which agents interact. Can be ommited, in which
case the width of a cell is used as the interaction range
(derived from ``n_bins``), but this can be increased/decreased
using ``i_range`` dependent on the use-case.
n_bins
Optional number of bins each dimension is subdivided
into. Assumes that each dimension contains the
same number of cells. Each cell can only interact
with adjacent cells, so this value also consequently
also controls the number of interactions. If not provided
the minimum number of bins if derived from ``i_range``.
"""
assert n_bins > 0, f"n_bins should be greater than 0, got {f}"
if n_bins is None:
assert (
i_range is not None
), "If n_bins is not provided, i_range should be provided"
n_bins = floor(1.0 / i_range)
else:
assert n_bins > 0, f"n_bins should be greater than 0, got {f}"

width = 1.0 / n_bins
i_range = width if i_range is None else i_range
Expand Down
3 changes: 1 addition & 2 deletions tests/test_transforms/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def foo(_k, _p, a, b):
vals_b = jnp.arange(2, n_agents + 2)
results = transforms.spatial(
foo,
n_bins=10,
reduction=jnp.add,
default=0,
include_self=True,
Expand Down Expand Up @@ -413,7 +412,7 @@ def foo(_k, _p, a, b):
vals_a = jnp.arange(1, n_agents_a + 1)
vals_b = jnp.arange(2, n_agents_b + 2)
results = transforms.spatial(
foo, n_bins=10, reduction=jnp.add, default=0, topology="moore", i_range=i_range
foo, reduction=jnp.add, default=0, topology="moore", i_range=i_range
)(k, None, vals_a, vals_b, pos=xa, pos_b=xb)

d = jax.vmap(
Expand Down

0 comments on commit 28aec8c

Please sign in to comment.