From 658e69f202a401f0f0a9073fdae0e00d91ff862b Mon Sep 17 00:00:00 2001 From: Howuhh Date: Thu, 18 Jan 2024 22:23:52 +0300 Subject: [PATCH] small fixes --- .gitignore | 5 +++++ scripts/benchmark_xland_all.py | 7 ++++--- scripts/generate_benchmarks.sh | 23 ++++++++++++++++++++++- scripts/ruleset_generator.py | 15 +++++++++++++-- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 04863c1..4cca508 100644 --- a/.gitignore +++ b/.gitignore @@ -165,6 +165,11 @@ cython_debug/ **/.ipynb_checkpoints *-preset.yml *_run.sh +training/test_train_meta_task.py +scripts/*.pdf +scripts/*.jpg +scripts/*.png +src/xminigrid/envs/xland_tmp.py # will remove later scripts/*testing* diff --git a/scripts/benchmark_xland_all.py b/scripts/benchmark_xland_all.py index 9a5f2d3..3474820 100644 --- a/scripts/benchmark_xland_all.py +++ b/scripts/benchmark_xland_all.py @@ -14,10 +14,10 @@ jax.config.update("jax_threefry_partitionable", True) -NUM_ENVS = (512, 1024, 2048, 4096, 8192) +NUM_ENVS = (128, 256, 512, 1024, 2048, 4096, 8192, 16384) parser = argparse.ArgumentParser() -parser.add_argument("--benchmark-id", type=str, default="Trivial") +parser.add_argument("--benchmark-id", type=str, default="trivial-1m") parser.add_argument("--timesteps", type=int, default=1000) parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats") parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)") @@ -70,10 +70,11 @@ def timeit_benchmark(args, benchmark_fn): args = parser.parse_args() print("Num devices:", num_devices) + environments = xminigrid.registered_environments() summary = {} for num_envs in tqdm(NUM_ENVS, desc="Benchmark", leave=False): results = {} - for env_id in tqdm(xminigrid.registered_environments(), desc="Envs.."): + for env_id in tqdm(environments, desc="Envs.."): assert num_envs % num_devices == 0 # building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps) benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id) diff --git a/scripts/generate_benchmarks.sh b/scripts/generate_benchmarks.sh index 330f366..d11dc7d 100644 --- a/scripts/generate_benchmarks.sh +++ b/scripts/generate_benchmarks.sh @@ -1,5 +1,4 @@ # This can take a lot of time. Generate only needed! -# TODO: provide same for 5M benchmarks # trivial python scripts/ruleset_generator.py \ @@ -54,3 +53,25 @@ python scripts/ruleset_generator.py \ --num_distractor_objects=2 \ --total_rulesets=1_000_000 \ --save_path="medium_dist_1m" + +# medium 3M +python scripts/ruleset_generator.py \ + --prune_chain \ + --prune_prob=0.3 \ + --chain_depth=2 \ + --sample_distractor_rules \ + --num_distractor_rules=3 \ + --num_distractor_objects=0 \ + --total_rulesets=3_000_000 \ + --save_path="medium_3m" + +# high 3M +python scripts/ruleset_generator.py \ + --prune_chain \ + --prune_prob=0.1 \ + --chain_depth=3 \ + --sample_distractor_rules \ + --num_distractor_rules=4 \ + --num_distractor_objects=1 \ + --total_rulesets=3_000_000 \ + --save_path="high_3m" diff --git a/scripts/ruleset_generator.py b/scripts/ruleset_generator.py index ef51d0d..2f7316e 100644 --- a/scripts/ruleset_generator.py +++ b/scripts/ruleset_generator.py @@ -1,7 +1,7 @@ # This is not the fastest implementation, but c'mon, # I only have to run it once in forever... # Meanwhile, make yourself a cup of tea and relax, tqdm go brrr... -# P.S. If you are willing to improve this, submit a PR! +# P.S. If you are willing to improve this, submit a PR! Beware that generation should remain deterministic! import argparse import random from itertools import product @@ -179,6 +179,7 @@ def sample_ruleset( # one empty rule as a placeholder, to fill up "rule" key, this will not introduce overhead under jit rules.append(EmptyRule()) + # for logging for level in range(num_levels): next_chain_tiles = [] @@ -214,7 +215,7 @@ def sample_ruleset( rules.append(rule) init_tiles.extend(rule_tiles) - # if for some reason there are no rules, add one empty + # if for some reason there are no rules, add one empty (we will ignore it later) if len(rules) == 0: rules.append(EmptyRule()) @@ -224,7 +225,11 @@ def sample_ruleset( "init_tiles": init_tiles, # additional info (for example for biasing sampling by number of rules) # you can add other field if needed, just copy-paste this file! + # saving counts, as later they will be padded to the same size "num_rules": len([r for r in rules if not isinstance(r, EmptyRule)]), + "num_init_tiles": len(init_tiles), + "max_chain_depth": num_levels, + "num_distractor_rules": num_distractor_rules, } @@ -276,6 +281,9 @@ def sample_ruleset( "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]), "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8), "num_rules": jnp.asarray(ruleset["num_rules"], dtype=jnp.uint8), + "num_init_tiles": jnp.asarray(ruleset["num_init_tiles"], dtype=jnp.uint8), + "max_chain_depth": jnp.asarray(ruleset["max_chain_depth"], dtype=jnp.uint8), + "num_distractor_rules": jnp.asarray(ruleset["num_distractor_rules"], dtype=jnp.uint8), } ) unique_rulesets_encodings.add(encode(ruleset)) @@ -298,6 +306,9 @@ def sample_ruleset( "rules": jnp.vstack([pad_along_axis(r["rules"], pad_to=max_rules)[None, ...] for r in rulesets]), "init_tiles": jnp.vstack([pad_along_axis(r["init_tiles"], pad_to=max_tiles)[None, ...] for r in rulesets]), "num_rules": jnp.vstack([r["num_rules"] for r in rulesets]), + "num_init_tiles": jnp.vstack([r["num_init_tiles"] for r in rulesets]), + "max_chain_depth": jnp.vstack([r["max_chain_depth"] for r in rulesets]), + "num_distractor_rules": jnp.vstack([r["num_distractor_rules"] for r in rulesets]), } print("Saving...") save_bz2_pickle(concat_rulesets, args.save_path, protocol=-1)