Skip to content

Commit

Permalink
add new goals for extended benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Dec 14, 2023
1 parent c27d2f5 commit a8a35b2
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.1.0"
__version__ = "0.2.0"

# ---------- XLand-MiniGrid environments ----------
# TODO: reconsider grid sizes and time limits after the benchmarks are generated.
Expand Down
237 changes: 236 additions & 1 deletion src/xminigrid/core/goals.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ def check_goal(encoding, grid, agent, action, position):
lambda: TileNearGoal.decode(encoding)(grid, agent, action, position),
lambda: TileOnPositionGoal.decode(encoding)(grid, agent, action, position),
lambda: AgentOnPositionGoal.decode(encoding)(grid, agent, action, position),
# goals for the extended benchmarks
lambda: TileNearUpGoal.decode(encoding)(grid, agent, action, position),
lambda: TileNearRightGoal.decode(encoding)(grid, agent, action, position),
lambda: TileNearDownGoal.decode(encoding)(grid, agent, action, position),
lambda: TileNearLeftGoal.decode(encoding)(grid, agent, action, position),
lambda: AgentNearUpGoal.decode(encoding)(grid, agent, action, position),
lambda: AgentNearRightGoal.decode(encoding)(grid, agent, action, position),
lambda: AgentNearDownGoal.decode(encoding)(grid, agent, action, position),
lambda: AgentNearLeftGoal.decode(encoding)(grid, agent, action, position),
),
)
return check
Expand Down Expand Up @@ -128,7 +137,7 @@ def _check_fn():
equal(up, tile_b) | equal(right, tile_b) | equal(down, tile_b) | equal(left, tile_b),
equal(up, tile_a) | equal(right, tile_a) | equal(down, tile_a) | equal(left, tile_a),
),
jnp.array(False),
jnp.asarray(False),
)
return check

Expand Down Expand Up @@ -176,3 +185,229 @@ def decode(cls, encoding):
def encode(self):
encoding = jnp.hstack([jnp.asarray(6), self.position], dtype=jnp.uint8)
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class TileNearUpGoal(BaseGoal):
tile_a: jax.Array
tile_b: jax.Array

def __call__(self, grid, agent, action, position):
y, x = position
tile = grid[y, x]

def _check_fn():
up, _, down, _ = get_neighbouring_tiles(grid, y, x)
check = jnp.logical_or(
equal(tile, self.tile_b) & equal(down, self.tile_a), equal(tile, self.tile_a) & equal(up, self.tile_b)
)
return check

check = jax.lax.select(
jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile_a=encoding[1:3], tile_b=encoding[3:5])

def encode(self):
encoding = jnp.hstack([jnp.asarray(7), self.tile_a, self.tile_b])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class TileNearRightGoal(BaseGoal):
tile_a: jax.Array
tile_b: jax.Array

def __call__(self, grid, agent, action, position):
y, x = position
tile = grid[y, x]

def _check_fn():
_, right, _, left = get_neighbouring_tiles(grid, y, x)
check = jnp.logical_or(
equal(tile, self.tile_b) & equal(left, self.tile_a),
equal(tile, self.tile_a) & equal(right, self.tile_b),
)
return check

check = jax.lax.select(
jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile_a=encoding[1:3], tile_b=encoding[3:5])

def encode(self):
encoding = jnp.hstack([jnp.asarray(8), self.tile_a, self.tile_b])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class TileNearDownGoal(BaseGoal):
tile_a: jax.Array
tile_b: jax.Array

def __call__(self, grid, agent, action, position):
y, x = position
tile = grid[y, x]

def _check_fn():
up, _, down, _ = get_neighbouring_tiles(grid, y, x)
check = jnp.logical_or(
equal(tile, self.tile_b) & equal(up, self.tile_a), equal(tile, self.tile_a) & equal(down, self.tile_b)
)
return check

check = jax.lax.select(
jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile_a=encoding[1:3], tile_b=encoding[3:5])

def encode(self):
encoding = jnp.hstack([jnp.asarray(9), self.tile_a, self.tile_b])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class TileNearLeftGoal(BaseGoal):
tile_a: jax.Array
tile_b: jax.Array

def __call__(self, grid, agent, action, position):
y, x = position
tile = grid[y, x]

def _check_fn():
_, right, _, left = get_neighbouring_tiles(grid, y, x)
check = jnp.logical_or(
equal(tile, self.tile_b) & equal(right, self.tile_a),
equal(tile, self.tile_a) & equal(left, self.tile_b),
)
return check

check = jax.lax.select(
jnp.equal(action, 4) & (equal(tile, self.tile_a) | equal(tile, self.tile_b)),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile_a=encoding[1:3], tile_b=encoding[3:5])

def encode(self):
encoding = jnp.hstack([jnp.asarray(10), self.tile_a, self.tile_b])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class AgentNearUpGoal(BaseGoal):
tile: jax.Array

def __call__(self, grid, agent, action, position):
def _check_fn():
up, _, _, _ = get_neighbouring_tiles(grid, agent.position[0], agent.position[1])
check = equal(up, self.tile)
return check

check = jax.lax.select(
jnp.equal(action, 0) | jnp.equal(action, 4),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile=encoding[1:3])

def encode(self):
encoding = jnp.hstack([jnp.asarray(11), self.tile])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class AgentNearRightGoal(BaseGoal):
tile: jax.Array

def __call__(self, grid, agent, action, position):
def _check_fn():
_, right, _, _ = get_neighbouring_tiles(grid, agent.position[0], agent.position[1])
check = equal(right, self.tile)
return check

check = jax.lax.select(
jnp.equal(action, 0) | jnp.equal(action, 4),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile=encoding[1:3])

def encode(self):
encoding = jnp.hstack([jnp.asarray(12), self.tile])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class AgentNearDownGoal(BaseGoal):
tile: jax.Array

def __call__(self, grid, agent, action, position):
def _check_fn():
_, _, down, _ = get_neighbouring_tiles(grid, agent.position[0], agent.position[1])
check = equal(down, self.tile)
return check

check = jax.lax.select(
jnp.equal(action, 0) | jnp.equal(action, 4),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile=encoding[1:3])

def encode(self):
encoding = jnp.hstack([jnp.asarray(13), self.tile])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)


class AgentNearLeftGoal(BaseGoal):
tile: jax.Array

def __call__(self, grid, agent, action, position):
def _check_fn():
_, _, _, left = get_neighbouring_tiles(grid, agent.position[0], agent.position[1])
check = equal(left, self.tile)
return check

check = jax.lax.select(
jnp.equal(action, 0) | jnp.equal(action, 4),
_check_fn(),
jnp.asarray(False),
)
return check

@classmethod
def decode(cls, encoding):
return cls(tile=encoding[1:3])

def encode(self):
encoding = jnp.hstack([jnp.asarray(14), self.tile])
return pad_along_axis(encoding, MAX_GOAL_ENCODING_LEN)
30 changes: 24 additions & 6 deletions src/xminigrid/core/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
MAX_RULE_ENCODING_LEN = 6 + 1 # +1 for idx


# this is very costly, will evaluate all under vmap. Submit a PR if you know how to do it better!
# this is very costly, will evaluate all rules under vmap. Submit a PR if you know how to do it better!
# In general, we need a way to select specific function/class based on ID number.
# We can not just decode without evaluation, as then return type will be different between branches
def check_rule(encodings, grid, agent, action, position):
def _check(carry, encoding):
grid, agent = carry
# What if use lax.cond here instead? Will it be faster?
grid, agent = jax.lax.switch(
encoding[0],
(
Expand All @@ -23,6 +24,7 @@ def _check(carry, encoding):
lambda: AgentHoldRule.decode(encoding)(grid, agent, action, position),
lambda: AgentNearRule.decode(encoding)(grid, agent, action, position),
lambda: TileNearRule.decode(encoding)(grid, agent, action, position),
# rules for the extended benchmarks
lambda: TileNearUpRule.decode(encoding)(grid, agent, action, position),
lambda: TileNearRightRule.decode(encoding)(grid, agent, action, position),
lambda: TileNearDownRule.decode(encoding)(grid, agent, action, position),
Expand Down Expand Up @@ -363,7 +365,11 @@ def _rule_fn(grid):
)
return grid

grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid)
grid = jax.lax.cond(
jnp.equal(action, 0) | jnp.equal(action, 4),
lambda: _rule_fn(grid),
lambda: grid,
)
return grid, agent

@classmethod
Expand All @@ -390,7 +396,11 @@ def _rule_fn(grid):
)
return grid

grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid)
grid = jax.lax.cond(
jnp.equal(action, 0) | jnp.equal(action, 4),
lambda: _rule_fn(grid),
lambda: grid,
)
return grid, agent

@classmethod
Expand All @@ -417,7 +427,11 @@ def _rule_fn(grid):
)
return grid

grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid)
grid = jax.lax.cond(
jnp.equal(action, 0) | jnp.equal(action, 4),
lambda: _rule_fn(grid),
lambda: grid,
)
return grid, agent

@classmethod
Expand All @@ -444,7 +458,11 @@ def _rule_fn(grid):
)
return grid

grid = jax.lax.cond(jnp.equal(action, 0) | jnp.equal(action, 4), lambda: _rule_fn(grid), lambda: grid)
grid = jax.lax.cond(
jnp.equal(action, 0) | jnp.equal(action, 4),
lambda: _rule_fn(grid),
lambda: grid,
)
return grid, agent

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/xminigrid/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def observation_shape(self, params: EnvParams) -> tuple[int, int, int]:

# TODO: NOT sure that this should be hardcoded like that...
def time_limit(self, params: EnvParams) -> int:
return 256
return 3 * params.height * params.width

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
return NotImplemented
Expand Down

0 comments on commit a8a35b2

Please sign in to comment.