Skip to content

Commit

Permalink
changed max steps for xland
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Dec 27, 2023
1 parent 131cc86 commit e3209ee
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@ cython_debug/
*_run.sh

# will remove later
scripts/*testing*
scripts/*testing*
configs
2 changes: 1 addition & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.4.1"
__version__ = "0.5.0"

# ---------- XLand-MiniGrid environments ----------
# TODO: reconsider grid sizes and time limits after the benchmarks are generated.
Expand Down
5 changes: 2 additions & 3 deletions src/xminigrid/envs/xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,8 @@ def default_params(self, **kwargs) -> XLandMiniGridEnvOptions:
def time_limit(self, params: XLandMiniGridEnvOptions) -> int:
# this is just a heuristic to prevent brute force in one episode,
# agent need to remember what he tried in previous episodes.
# If this is too small, just increase number of trials.
coef = len(params.ruleset.init_tiles) // 3
return coef * (params.height * params.width)
# If this is too small, change it or increase number of trials (these are not equivalent).
return 3 * (params.height * params.width)

def _generate_problem(self, params: XLandMiniGridEnvOptions, key: jax.Array) -> State:
# WARN: we can make this compatible with jit (to vmap on different layouts during training),
Expand Down
2 changes: 1 addition & 1 deletion training/train_meta_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _update_minbatch(train_state, batch_info):
runner_state = (rng, train_state, timestep, prev_action, prev_reward, hstate)
return runner_state, loss_info

# on each meta-update we reset hidden to init_hstate
# on each meta-update we reset rnn hidden to init_hstate
runner_state = (rng, train_state, timestep, prev_action, prev_reward, init_hstate)
runner_state, loss_info = jax.lax.scan(_update_step, runner_state, None, config.num_inner_updates)
# WARN: do not forget to get updated params
Expand Down

0 comments on commit e3209ee

Please sign in to comment.