Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Jan 18, 2024
1 parent e3209ee commit 658e69f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ cython_debug/
**/.ipynb_checkpoints
*-preset.yml
*_run.sh
training/test_train_meta_task.py
scripts/*.pdf
scripts/*.jpg
scripts/*.png
src/xminigrid/envs/xland_tmp.py

# will remove later
scripts/*testing*
Expand Down
7 changes: 4 additions & 3 deletions scripts/benchmark_xland_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

jax.config.update("jax_threefry_partitionable", True)

NUM_ENVS = (512, 1024, 2048, 4096, 8192)
NUM_ENVS = (128, 256, 512, 1024, 2048, 4096, 8192, 16384)

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark-id", type=str, default="Trivial")
parser.add_argument("--benchmark-id", type=str, default="trivial-1m")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")
Expand Down Expand Up @@ -70,10 +70,11 @@ def timeit_benchmark(args, benchmark_fn):
args = parser.parse_args()
print("Num devices:", num_devices)

environments = xminigrid.registered_environments()
summary = {}
for num_envs in tqdm(NUM_ENVS, desc="Benchmark", leave=False):
results = {}
for env_id in tqdm(xminigrid.registered_environments(), desc="Envs.."):
for env_id in tqdm(environments, desc="Envs.."):
assert num_envs % num_devices == 0
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id)
Expand Down
23 changes: 22 additions & 1 deletion scripts/generate_benchmarks.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# This can take a lot of time. Generate only needed!
# TODO: provide same for 5M benchmarks

# trivial
python scripts/ruleset_generator.py \
Expand Down Expand Up @@ -54,3 +53,25 @@ python scripts/ruleset_generator.py \
--num_distractor_objects=2 \
--total_rulesets=1_000_000 \
--save_path="medium_dist_1m"

# medium 3M
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.3 \
--chain_depth=2 \
--sample_distractor_rules \
--num_distractor_rules=3 \
--num_distractor_objects=0 \
--total_rulesets=3_000_000 \
--save_path="medium_3m"

# high 3M
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.1 \
--chain_depth=3 \
--sample_distractor_rules \
--num_distractor_rules=4 \
--num_distractor_objects=1 \
--total_rulesets=3_000_000 \
--save_path="high_3m"
15 changes: 13 additions & 2 deletions scripts/ruleset_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This is not the fastest implementation, but c'mon,
# I only have to run it once in forever...
# Meanwhile, make yourself a cup of tea and relax, tqdm go brrr...
# P.S. If you are willing to improve this, submit a PR!
# P.S. If you are willing to improve this, submit a PR! Beware that generation should remain deterministic!
import argparse
import random
from itertools import product
Expand Down Expand Up @@ -179,6 +179,7 @@ def sample_ruleset(
# one empty rule as a placeholder, to fill up "rule" key, this will not introduce overhead under jit
rules.append(EmptyRule())

# for logging
for level in range(num_levels):
next_chain_tiles = []

Expand Down Expand Up @@ -214,7 +215,7 @@ def sample_ruleset(
rules.append(rule)
init_tiles.extend(rule_tiles)

# if for some reason there are no rules, add one empty
# if for some reason there are no rules, add one empty (we will ignore it later)
if len(rules) == 0:
rules.append(EmptyRule())

Expand All @@ -224,7 +225,11 @@ def sample_ruleset(
"init_tiles": init_tiles,
# additional info (for example for biasing sampling by number of rules)
# you can add other field if needed, just copy-paste this file!
# saving counts, as later they will be padded to the same size
"num_rules": len([r for r in rules if not isinstance(r, EmptyRule)]),
"num_init_tiles": len(init_tiles),
"max_chain_depth": num_levels,
"num_distractor_rules": num_distractor_rules,
}


Expand Down Expand Up @@ -276,6 +281,9 @@ def sample_ruleset(
"rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
"init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
"num_rules": jnp.asarray(ruleset["num_rules"], dtype=jnp.uint8),
"num_init_tiles": jnp.asarray(ruleset["num_init_tiles"], dtype=jnp.uint8),
"max_chain_depth": jnp.asarray(ruleset["max_chain_depth"], dtype=jnp.uint8),
"num_distractor_rules": jnp.asarray(ruleset["num_distractor_rules"], dtype=jnp.uint8),
}
)
unique_rulesets_encodings.add(encode(ruleset))
Expand All @@ -298,6 +306,9 @@ def sample_ruleset(
"rules": jnp.vstack([pad_along_axis(r["rules"], pad_to=max_rules)[None, ...] for r in rulesets]),
"init_tiles": jnp.vstack([pad_along_axis(r["init_tiles"], pad_to=max_tiles)[None, ...] for r in rulesets]),
"num_rules": jnp.vstack([r["num_rules"] for r in rulesets]),
"num_init_tiles": jnp.vstack([r["num_init_tiles"] for r in rulesets]),
"max_chain_depth": jnp.vstack([r["max_chain_depth"] for r in rulesets]),
"num_distractor_rules": jnp.vstack([r["num_distractor_rules"] for r in rulesets]),
}
print("Saving...")
save_bz2_pickle(concat_rulesets, args.save_path, protocol=-1)
Expand Down

0 comments on commit 658e69f

Please sign in to comment.