Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated to jax new key naming #27

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

key = jax.random.PRNGKey(0)
key = jax.random.key(0)
reset_key, ruleset_key = jax.random.split(key)

# to list available benchmarks: xminigrid.registered_benchmarks()
Expand Down Expand Up @@ -196,11 +196,11 @@ benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")

# users can sample or get specific rulesets
benchmark.sample_ruleset(jax.random.PRNGKey(0))
benchmark.sample_ruleset(jax.random.key(0))
benchmark.get_ruleset(ruleset_id=benchmark.num_rulesets() - 1)

# or split them for train & test
train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)
train, test = benchmark.shuffle(key=jax.random.key(0)).split(prop=0.8)
```

We also provide the [script](scripts/ruleset_generator.py) used to generate these benchmarks. Users can use it for their own purposes:
Expand Down
8 changes: 4 additions & 4 deletions examples/train_meta_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
"\n",
" # set up training state\n",
" rng = jax.random.PRNGKey(config.train_seed)\n",
" rng = jax.random.key(config.train_seed)\n",
" rng, _rng = jax.random.split(rng)\n",
"\n",
" network = ActorCriticRNN(\n",
Expand Down Expand Up @@ -629,7 +629,7 @@
" rng, train_state = runner_state[:2]\n",
"\n",
" # EVALUATE AGENT\n",
" eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.PRNGKey(config.eval_seed))\n",
" eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.key(config.eval_seed))\n",
" eval_ruleset_rng = jax.random.split(eval_ruleset_rng, num=config.eval_num_envs_per_device)\n",
" eval_reset_rng = jax.random.split(eval_reset_rng, num=config.eval_num_envs_per_device)\n",
"\n",
Expand Down Expand Up @@ -756,7 +756,7 @@
"total_reward, num_episodes = 0, 0\n",
"rendered_imgs = []\n",
"\n",
"rng = jax.random.PRNGKey(1)\n",
"rng = jax.random.key(1)\n",
"rng, _rng = jax.random.split(rng)\n",
"\n",
"# initial inputs\n",
Expand Down Expand Up @@ -823,7 +823,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions examples/train_single_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@
" env = RGBImgObservationWrapper(env)\n",
"\n",
" # setup training state\n",
" rng = jax.random.PRNGKey(config.seed)\n",
" rng = jax.random.key(config.seed)\n",
" rng, _rng = jax.random.split(rng)\n",
"\n",
" network = ActorCriticRNN(\n",
Expand Down Expand Up @@ -722,7 +722,7 @@
"total_reward = 0\n",
"rendered_imgs = []\n",
"\n",
"rng = jax.random.PRNGKey(1)\n",
"rng = jax.random.key(1)\n",
"rng, _rng = jax.random.split(rng)\n",
"\n",
"# initial inputs\n",
Expand Down Expand Up @@ -786,7 +786,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
20 changes: 10 additions & 10 deletions examples/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@
"source": [
"import xminigrid\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"key, reset_key = jax.random.split(key)\n",
"\n",
"# to list available environments: xminigrid.registered_environments()\n",
Expand Down Expand Up @@ -345,7 +345,7 @@
"rollout_fn = jax.jit(build_rollout(env, env_params, num_steps=1000))\n",
"\n",
"# first execution will compile\n",
"transitions = rollout_fn(jax.random.PRNGKey(0))\n",
"transitions = rollout_fn(jax.random.key(0))\n",
"\n",
"print(\"Transitions shapes: \\n\", jtu.tree_map(jnp.shape, transitions))"
]
Expand Down Expand Up @@ -418,7 +418,7 @@
"outputs": [],
"source": [
"vmap_rollout = jax.jit(jax.vmap(build_rollout(env, env_params, num_steps=1000)))\n",
"rngs = jax.random.split(jax.random.PRNGKey(0), num=1024)\n",
"rngs = jax.random.split(jax.random.key(0), num=1024)\n",
"\n",
"vmap_transitions = vmap_rollout(rngs)\n",
"\n",
Expand Down Expand Up @@ -527,7 +527,7 @@
" benchmark_fn_pmap = build_benchmark(\"MiniGrid-EmptyRandom-8x8\", num_envs // num_devices, 1024)\n",
" benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)\n",
"\n",
" key = jax.random.PRNGKey(0)\n",
" key = jax.random.key(0)\n",
" pmap_keys = jax.random.split(key, num=num_devices)\n",
"\n",
" elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_vmap, key))\n",
Expand Down Expand Up @@ -896,7 +896,7 @@
"env, env_params = xminigrid.make(\"XLand-MiniGrid-R4-9x9\")\n",
"env_params = env_params.replace(ruleset=ruleset)\n",
"\n",
"timestep = env.reset(env_params, jax.random.PRNGKey(0))\n",
"timestep = env.reset(env_params, jax.random.key(0))\n",
"\n",
"show_img(env.render(env_params, timestep), dpi=64)"
]
Expand All @@ -921,7 +921,7 @@
"benchmark = xminigrid.load_benchmark(name=\"trivial-1m\")\n",
"print(\"Total rulesets:\", benchmark.num_rulesets())\n",
"print(\"Ruleset with id 128: \\n\", benchmark.get_ruleset(ruleset_id=128))\n",
"print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.PRNGKey(0)))"
"print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.key(0)))"
]
},
{
Expand All @@ -942,7 +942,7 @@
"outputs": [],
"source": [
"env_params = env_params.replace(ruleset=benchmark.get_ruleset(ruleset_id=128))\n",
"timestep = env.reset(env_params, jax.random.PRNGKey(0))\n",
"timestep = env.reset(env_params, jax.random.key(0))\n",
"\n",
"show_img(env.render(env_params, timestep), dpi=64)"
]
Expand Down Expand Up @@ -992,7 +992,7 @@
"metadata": {},
"outputs": [],
"source": [
"train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)\n",
"train, test = benchmark.shuffle(key=jax.random.key(0)).split(prop=0.8)\n",
"\n",
"# or, by some function:\n",
"def cond_fn(goal, rules):\n",
Expand Down Expand Up @@ -1042,7 +1042,7 @@
"outputs": [],
"source": [
"env_params = env_params.replace(ruleset=rulesets)\n",
"timestep = jax.vmap(env.reset, in_axes=(0, None))(env_params, jax.random.PRNGKey(0))"
"timestep = jax.vmap(env.reset, in_axes=(0, None))(env_params, jax.random.key(0))"
]
},
{
Expand Down Expand Up @@ -1102,7 +1102,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ classifiers = [
]

dependencies = [
"jax>=0.4.16",
"jaxlib>=0.4.16",
"jax>=0.4.26",
"jaxlib>=0.4.26",
"flax>=0.8.0",
"rich>=13.4.2",
"chex>=0.1.85",
Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmark_xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def build_benchmark(

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.key(0))
env_params = env_params.replace(ruleset=ruleset)

def benchmark_fn(key):
Expand Down Expand Up @@ -98,7 +98,7 @@ def timeit_benchmark(args, benchmark_fn):
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

key = jax.random.PRNGKey(0)
key = jax.random.key(0)
pmap_keys = jax.random.split(key, num=num_devices)

# benchmarking
Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmark_xland_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def build_benchmark(

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.key(0))
env_params = env_params.replace(ruleset=ruleset)

def benchmark_fn(key):
Expand Down Expand Up @@ -93,7 +93,7 @@ def timeit_benchmark(args, benchmark_fn):
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

# benchmarking
pmap_keys = jax.random.split(jax.random.PRNGKey(0), num=num_devices)
pmap_keys = jax.random.split(jax.random.key(0), num=num_devices)

elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_pmap, pmap_keys))
pmap_fps = (args.timesteps * num_envs) // elapsed_time
Expand Down
2 changes: 1 addition & 1 deletion src/xminigrid/manual_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(

self._reset = jax.jit(self.env.reset)
self._step = jax.jit(self.env.step)
self._key = jax.random.PRNGKey(0)
self._key = jax.random.key(0)

self.timestep = None

Expand Down
2 changes: 1 addition & 1 deletion training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main():
total_reward, num_episodes = 0, 0
rendered_imgs = []

rng = jax.random.PRNGKey(0)
rng = jax.random.key(0)
rng, _rng = jax.random.split(rng)

timestep = reset_fn(env_params, _rng)
Expand Down
4 changes: 2 additions & 2 deletions training/train_meta_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def linear_schedule(count):
benchmark = xminigrid.load_benchmark(config.benchmark_id)

# set up training state
rng = jax.random.PRNGKey(config.train_seed)
rng = jax.random.key(config.train_seed)
rng, _rng = jax.random.split(rng)

network = ActorCriticRNN(
Expand Down Expand Up @@ -269,7 +269,7 @@ def _update_minbatch(train_state, batch_info):
rng, train_state = runner_state[:2]

# EVALUATE AGENT
eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.PRNGKey(config.eval_seed))
eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.key(config.eval_seed))
eval_ruleset_rng = jax.random.split(eval_ruleset_rng, num=config.eval_num_envs_per_device)
eval_reset_rng = jax.random.split(eval_reset_rng, num=config.eval_num_envs_per_device)

Expand Down
2 changes: 1 addition & 1 deletion training/train_single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def linear_schedule(count):
env = RGBImgObservationWrapper(env)

# setup training state
rng = jax.random.PRNGKey(config.seed)
rng = jax.random.key(config.seed)
rng, _rng = jax.random.split(rng)

network = ActorCriticRNN(
Expand Down
Loading