From 9ddd6116a8a5ece792ccd41a29d43b5b12df16b5 Mon Sep 17 00:00:00 2001 From: Alexander Nikulin Date: Fri, 12 Jul 2024 14:59:12 +0300 Subject: [PATCH] updated to jax new key naming (#27) --- README.md | 6 +++--- examples/train_meta_standalone.ipynb | 8 ++++---- examples/train_single_standalone.ipynb | 6 +++--- examples/walkthrough.ipynb | 20 ++++++++++---------- pyproject.toml | 4 ++-- scripts/benchmark_xland.py | 4 ++-- scripts/benchmark_xland_all.py | 4 ++-- src/xminigrid/manual_control.py | 2 +- training/eval.py | 2 +- training/train_meta_task.py | 4 ++-- training/train_single_task.py | 2 +- 11 files changed, 31 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index d2a8263..d355f0a 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ import xminigrid from xminigrid.wrappers import GymAutoResetWrapper from xminigrid.experimental.img_obs import RGBImgObservationWrapper -key = jax.random.PRNGKey(0) +key = jax.random.key(0) reset_key, ruleset_key = jax.random.split(key) # to list available benchmarks: xminigrid.registered_benchmarks() @@ -196,11 +196,11 @@ benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m") benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m") # users can sample or get specific rulesets -benchmark.sample_ruleset(jax.random.PRNGKey(0)) +benchmark.sample_ruleset(jax.random.key(0)) benchmark.get_ruleset(ruleset_id=benchmark.num_rulesets() - 1) # or split them for train & test -train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8) +train, test = benchmark.shuffle(key=jax.random.key(0)).split(prop=0.8) ``` We also provide the [script](scripts/ruleset_generator.py) used to generate these benchmarks. Users can use it for their own purposes: diff --git a/examples/train_meta_standalone.ipynb b/examples/train_meta_standalone.ipynb index 9ab1e6c..4ae9366 100644 --- a/examples/train_meta_standalone.ipynb +++ b/examples/train_meta_standalone.ipynb @@ -461,7 +461,7 @@ " benchmark = xminigrid.load_benchmark(config.benchmark_id)\n", "\n", " # set up training state\n", - " rng = jax.random.PRNGKey(config.train_seed)\n", + " rng = jax.random.key(config.train_seed)\n", " rng, _rng = jax.random.split(rng)\n", "\n", " network = ActorCriticRNN(\n", @@ -629,7 +629,7 @@ " rng, train_state = runner_state[:2]\n", "\n", " # EVALUATE AGENT\n", - " eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.PRNGKey(config.eval_seed))\n", + " eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.key(config.eval_seed))\n", " eval_ruleset_rng = jax.random.split(eval_ruleset_rng, num=config.eval_num_envs_per_device)\n", " eval_reset_rng = jax.random.split(eval_reset_rng, num=config.eval_num_envs_per_device)\n", "\n", @@ -756,7 +756,7 @@ "total_reward, num_episodes = 0, 0\n", "rendered_imgs = []\n", "\n", - "rng = jax.random.PRNGKey(1)\n", + "rng = jax.random.key(1)\n", "rng, _rng = jax.random.split(rng)\n", "\n", "# initial inputs\n", @@ -823,7 +823,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/examples/train_single_standalone.ipynb b/examples/train_single_standalone.ipynb index 19d16a2..c92f1fe 100644 --- a/examples/train_single_standalone.ipynb +++ b/examples/train_single_standalone.ipynb @@ -456,7 +456,7 @@ " env = RGBImgObservationWrapper(env)\n", "\n", " # setup training state\n", - " rng = jax.random.PRNGKey(config.seed)\n", + " rng = jax.random.key(config.seed)\n", " rng, _rng = jax.random.split(rng)\n", "\n", " network = ActorCriticRNN(\n", @@ -722,7 +722,7 @@ "total_reward = 0\n", "rendered_imgs = []\n", "\n", - "rng = jax.random.PRNGKey(1)\n", + "rng = jax.random.key(1)\n", "rng, _rng = jax.random.split(rng)\n", "\n", "# initial inputs\n", @@ -786,7 +786,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/examples/walkthrough.ipynb b/examples/walkthrough.ipynb index bc535cd..c27cbba 100644 --- a/examples/walkthrough.ipynb +++ b/examples/walkthrough.ipynb @@ -177,7 +177,7 @@ "source": [ "import xminigrid\n", "\n", - "key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "key, reset_key = jax.random.split(key)\n", "\n", "# to list available environments: xminigrid.registered_environments()\n", @@ -345,7 +345,7 @@ "rollout_fn = jax.jit(build_rollout(env, env_params, num_steps=1000))\n", "\n", "# first execution will compile\n", - "transitions = rollout_fn(jax.random.PRNGKey(0))\n", + "transitions = rollout_fn(jax.random.key(0))\n", "\n", "print(\"Transitions shapes: \\n\", jtu.tree_map(jnp.shape, transitions))" ] @@ -418,7 +418,7 @@ "outputs": [], "source": [ "vmap_rollout = jax.jit(jax.vmap(build_rollout(env, env_params, num_steps=1000)))\n", - "rngs = jax.random.split(jax.random.PRNGKey(0), num=1024)\n", + "rngs = jax.random.split(jax.random.key(0), num=1024)\n", "\n", "vmap_transitions = vmap_rollout(rngs)\n", "\n", @@ -527,7 +527,7 @@ " benchmark_fn_pmap = build_benchmark(\"MiniGrid-EmptyRandom-8x8\", num_envs // num_devices, 1024)\n", " benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)\n", "\n", - " key = jax.random.PRNGKey(0)\n", + " key = jax.random.key(0)\n", " pmap_keys = jax.random.split(key, num=num_devices)\n", "\n", " elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_vmap, key))\n", @@ -896,7 +896,7 @@ "env, env_params = xminigrid.make(\"XLand-MiniGrid-R4-9x9\")\n", "env_params = env_params.replace(ruleset=ruleset)\n", "\n", - "timestep = env.reset(env_params, jax.random.PRNGKey(0))\n", + "timestep = env.reset(env_params, jax.random.key(0))\n", "\n", "show_img(env.render(env_params, timestep), dpi=64)" ] @@ -921,7 +921,7 @@ "benchmark = xminigrid.load_benchmark(name=\"trivial-1m\")\n", "print(\"Total rulesets:\", benchmark.num_rulesets())\n", "print(\"Ruleset with id 128: \\n\", benchmark.get_ruleset(ruleset_id=128))\n", - "print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.PRNGKey(0)))" + "print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.key(0)))" ] }, { @@ -942,7 +942,7 @@ "outputs": [], "source": [ "env_params = env_params.replace(ruleset=benchmark.get_ruleset(ruleset_id=128))\n", - "timestep = env.reset(env_params, jax.random.PRNGKey(0))\n", + "timestep = env.reset(env_params, jax.random.key(0))\n", "\n", "show_img(env.render(env_params, timestep), dpi=64)" ] @@ -992,7 +992,7 @@ "metadata": {}, "outputs": [], "source": [ - "train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)\n", + "train, test = benchmark.shuffle(key=jax.random.key(0)).split(prop=0.8)\n", "\n", "# or, by some function:\n", "def cond_fn(goal, rules):\n", @@ -1042,7 +1042,7 @@ "outputs": [], "source": [ "env_params = env_params.replace(ruleset=rulesets)\n", - "timestep = jax.vmap(env.reset, in_axes=(0, None))(env_params, jax.random.PRNGKey(0))" + "timestep = jax.vmap(env.reset, in_axes=(0, None))(env_params, jax.random.key(0))" ] }, { @@ -1102,7 +1102,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 5d2396e..12623ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,8 @@ classifiers = [ ] dependencies = [ - "jax>=0.4.16", - "jaxlib>=0.4.16", + "jax>=0.4.26", + "jaxlib>=0.4.26", "flax>=0.8.0", "rich>=13.4.2", "chex>=0.1.85", diff --git a/scripts/benchmark_xland.py b/scripts/benchmark_xland.py index 5af3742..b6e374f 100644 --- a/scripts/benchmark_xland.py +++ b/scripts/benchmark_xland.py @@ -40,7 +40,7 @@ def build_benchmark( # choose XLand benchmark if needed if "XLand-MiniGrid" in env_id and benchmark_id is not None: - ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0)) + ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.key(0)) env_params = env_params.replace(ruleset=ruleset) def benchmark_fn(key): @@ -98,7 +98,7 @@ def timeit_benchmark(args, benchmark_fn): ) benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap) - key = jax.random.PRNGKey(0) + key = jax.random.key(0) pmap_keys = jax.random.split(key, num=num_devices) # benchmarking diff --git a/scripts/benchmark_xland_all.py b/scripts/benchmark_xland_all.py index cae74f9..96a0319 100644 --- a/scripts/benchmark_xland_all.py +++ b/scripts/benchmark_xland_all.py @@ -38,7 +38,7 @@ def build_benchmark( # choose XLand benchmark if needed if "XLand-MiniGrid" in env_id and benchmark_id is not None: - ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0)) + ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.key(0)) env_params = env_params.replace(ruleset=ruleset) def benchmark_fn(key): @@ -93,7 +93,7 @@ def timeit_benchmark(args, benchmark_fn): benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap) # benchmarking - pmap_keys = jax.random.split(jax.random.PRNGKey(0), num=num_devices) + pmap_keys = jax.random.split(jax.random.key(0), num=num_devices) elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_pmap, pmap_keys)) pmap_fps = (args.timesteps * num_envs) // elapsed_time diff --git a/src/xminigrid/manual_control.py b/src/xminigrid/manual_control.py index 4f482d8..2641936 100644 --- a/src/xminigrid/manual_control.py +++ b/src/xminigrid/manual_control.py @@ -41,7 +41,7 @@ def __init__( self._reset = jax.jit(self.env.reset) self._step = jax.jit(self.env.step) - self._key = jax.random.PRNGKey(0) + self._key = jax.random.key(0) self.timestep = None diff --git a/training/eval.py b/training/eval.py index a98ec27..450bdb5 100644 --- a/training/eval.py +++ b/training/eval.py @@ -44,7 +44,7 @@ def main(): total_reward, num_episodes = 0, 0 rendered_imgs = [] - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) rng, _rng = jax.random.split(rng) timestep = reset_fn(env_params, _rng) diff --git a/training/train_meta_task.py b/training/train_meta_task.py index e491fcc..1ee16fd 100644 --- a/training/train_meta_task.py +++ b/training/train_meta_task.py @@ -101,7 +101,7 @@ def linear_schedule(count): benchmark = xminigrid.load_benchmark(config.benchmark_id) # set up training state - rng = jax.random.PRNGKey(config.train_seed) + rng = jax.random.key(config.train_seed) rng, _rng = jax.random.split(rng) network = ActorCriticRNN( @@ -269,7 +269,7 @@ def _update_minbatch(train_state, batch_info): rng, train_state = runner_state[:2] # EVALUATE AGENT - eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.PRNGKey(config.eval_seed)) + eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.key(config.eval_seed)) eval_ruleset_rng = jax.random.split(eval_ruleset_rng, num=config.eval_num_envs_per_device) eval_reset_rng = jax.random.split(eval_reset_rng, num=config.eval_num_envs_per_device) diff --git a/training/train_single_task.py b/training/train_single_task.py index 70d1a36..65252fa 100644 --- a/training/train_single_task.py +++ b/training/train_single_task.py @@ -89,7 +89,7 @@ def linear_schedule(count): env = RGBImgObservationWrapper(env) # setup training state - rng = jax.random.PRNGKey(config.seed) + rng = jax.random.key(config.seed) rng, _rng = jax.random.split(rng) network = ActorCriticRNN(