Skip to content

Commit

Permalink
update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
BartekCupial committed Jan 26, 2025
1 parent 4b7d189 commit 79d9c58
Showing 1 changed file with 73 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -12,20 +12,22 @@
"from sample_factory.algo.utils.make_env import make_env_func_batched\n",
"from sample_factory.utils.attr_dict import AttrDict\n",
"from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components\n",
"from sf_examples.nethack.models.simba import SimBaEncoder"
"from sf_examples.nethack.models.simba import SimBaEncoder\n",
"from sf_examples.nethack.models.vit import ViTEncoder\n",
"from sf_examples.nethack.models.chaotic_dwarf import ChaoticDwarvenGPT5"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[36m[2025-01-26 09:32:39,142][825930] register_encoder_factory: <function make_nethack_encoder at 0x71519d7113f0>\u001b[0m\n",
"\u001b[36m[2025-01-26 09:32:39,144][825930] register_actor_critic_factory: <function make_nethack_actor_critic at 0x71519d7112d0>\u001b[0m\n"
"\u001b[36m[2025-01-26 09:36:31,020][1027662] register_encoder_factory: <function make_nethack_encoder at 0x7baf623fcaf0>\u001b[0m\n",
"\u001b[36m[2025-01-26 09:36:31,022][1027662] register_actor_critic_factory: <function make_nethack_actor_critic at 0x7baf623fc9d0>\u001b[0m\n"
]
}
],
Expand All @@ -37,9 +39,16 @@
"env_info = extract_env_info(env, cfg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SimBa"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -78,9 +87,16 @@
"print(pivot_table)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ChaoticDwarven"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -99,6 +115,56 @@
"print(f\"Model Size: {total_params / 10**6:.2f}M\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ViT"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hidden Dim 16 32 64 128 256 512\n",
"Depth \n",
"1 0.41M 0.82M 1.65M 3.37M 7.07M 15.43M\n",
"2 0.45M 0.89M 1.80M 3.70M 7.85M 17.53M\n",
"3 0.48M 0.96M 1.95M 4.03M 8.64M 19.63M\n",
"4 0.51M 1.03M 2.09M 4.36M 9.43M 21.74M\n"
]
}
],
"source": [
"results = []\n",
"for hidden_dim in [16, 32, 64, 128, 256, 512]:\n",
" for depth in [1, 2, 3, 4]:\n",
" model = ViTEncoder(\n",
" obs_space=env_info.obs_space,\n",
" hidden_dim=hidden_dim,\n",
" depth=depth,\n",
" heads=8,\n",
" mlp_dim=hidden_dim * 2,\n",
" use_prev_action=cfg.use_prev_action,\n",
" )\n",
" total_params = sum(p.numel() for p in model.parameters())\n",
"\n",
" results.append({\n",
" \"Hidden Dim\": hidden_dim,\n",
" \"Depth\": depth,\n",
" \"Model Size\": f\"{total_params / 10**6:.2f}M\"\n",
" })\n",
"\n",
"df = pd.DataFrame(results)\n",
"pivot_table = df.pivot(index=\"Depth\", columns=\"Hidden Dim\", values=\"Model Size\")\n",
"print(pivot_table)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 79d9c58

Please sign in to comment.