Skip to content

Commit 08302f4

Browse files
committed
Rename utilities to help.
1 parent af5b2c4 commit 08302f4

File tree

13 files changed

+55
-132
lines changed

13 files changed

+55
-132
lines changed

breakout/example.ipynb

Lines changed: 22 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,30 @@
33
{
44
"cell_type": "markdown",
55
"source": [
6-
"# Visualise the pre-trained agent in action\n",
7-
"\n",
8-
"Modify the path to the weights and run the notebook."
6+
"# Example of visualising the agent's training history performance"
97
],
108
"metadata": {
119
"collapsed": false
1210
},
1311
"id": "b3d8465ecb86eca7"
1412
},
13+
{
14+
"cell_type": "markdown",
15+
"source": [
16+
"MODEL : Path to the pre-trained model\n",
17+
"METRICS : Path to the training history, or None"
18+
],
19+
"metadata": {
20+
"collapsed": false
21+
},
22+
"id": "fff872a8189754af"
23+
},
1524
{
1625
"cell_type": "code",
1726
"outputs": [],
1827
"source": [
19-
"WEIGHTS = './_output/weights-15000.pth'\n",
20-
"METRICS = './_output/metrics.csv'"
28+
"MODEL = './results/model.pth'\n",
29+
"METRICS = None"
2130
],
2231
"metadata": {
2332
"collapsed": false
@@ -34,11 +43,9 @@
3443
"import gymnasium as gym\n",
3544
"import matplotlib.pyplot as plt\n",
3645
"\n",
37-
"from DQN import VisionDeepQ\n",
38-
"\n",
3946
"sys.path.append(\"../\")\n",
40-
"from utilities.visualisation.plot import graph # noqa\n",
41-
"from utilities.visualisation.movie import movie # noqa"
47+
"from help.visualisation.plot import graph # noqa\n",
48+
"from help.visualisation.movie import movie # noqa"
4249
],
4350
"metadata": {
4451
"collapsed": false
@@ -49,48 +56,7 @@
4956
{
5057
"cell_type": "markdown",
5158
"source": [
52-
"## Parameters"
53-
],
54-
"metadata": {
55-
"collapsed": false
56-
},
57-
"id": "4dddd56883444fab"
58-
},
59-
{
60-
"cell_type": "code",
61-
"outputs": [],
62-
"source": [
63-
"network = {\n",
64-
" \"input_channels\": 4, \"outputs\": 4,\n",
65-
" \"channels\": [32, 64, 64],\n",
66-
" \"kernels\": [8, 4, 3],\n",
67-
" \"padding\": [\"valid\", \"valid\", \"valid\"],\n",
68-
" \"strides\": [4, 2, 1],\n",
69-
" \"nodes\": [],\n",
70-
"}\n",
71-
"optimizer = {\n",
72-
" \"optimizer\": torch.optim.Adam,\n",
73-
" \"lr\": 1e-5,\n",
74-
" \"hyperparameters\": {}\n",
75-
"}\n",
76-
"shape = {\n",
77-
" \"original\": (1, 1, 210, 160),\n",
78-
" \"width\": slice(7, -7),\n",
79-
" \"height\": slice(31, -17),\n",
80-
" \"max_pooling\": 2,\n",
81-
"}\n",
82-
"skip = 4"
83-
],
84-
"metadata": {
85-
"collapsed": false
86-
},
87-
"id": "16867687f37ddbca",
88-
"execution_count": null
89-
},
90-
{
91-
"cell_type": "markdown",
92-
"source": [
93-
"## Setup"
59+
"## Loading the agent and environment"
9460
],
9561
"metadata": {
9662
"collapsed": false
@@ -101,13 +67,7 @@
10167
"cell_type": "code",
10268
"outputs": [],
10369
"source": [
104-
"value_agent = VisionDeepQ(\n",
105-
" network=network, optimizer=optimizer, shape=shape,\n",
106-
" exploration_rate=0.002,\n",
107-
")\n",
108-
"\n",
109-
"weights = torch.load(WEIGHTS, map_location=torch.device('cpu'))\n",
110-
"value_agent.load_state_dict(weights)\n",
70+
"agent = torch.load(MODEL, map_location=torch.device('cpu'))\n",
11171
"\n",
11272
"environment = gym.make('ALE/Breakout-v5', render_mode=\"rgb_array\",\n",
11373
" obs_type=\"grayscale\", frameskip=1, repeat_action_probability=0.0)\n",
@@ -132,7 +92,7 @@
13292
{
13393
"cell_type": "markdown",
13494
"source": [
135-
"### Plotting the metrics from the csv-file created during training."
95+
"### Training history (if specified)"
13696
],
13797
"metadata": {
13898
"collapsed": false
@@ -143,7 +103,7 @@
143103
"cell_type": "code",
144104
"outputs": [],
145105
"source": [
146-
"graph(METRICS, title=\"Training history\", window=20) if METRICS else None\n",
106+
"graph(METRICS, title=\"Breakout training history\", window=20) if METRICS else None\n",
147107
"plt.show() if METRICS else None"
148108
],
149109
"metadata": {
@@ -155,7 +115,7 @@
155115
{
156116
"cell_type": "markdown",
157117
"source": [
158-
"### Creating and saving a gif of the agent in action. The gif will be saved to the given path."
118+
"### In action"
159119
],
160120
"metadata": {
161121
"collapsed": false
@@ -166,7 +126,7 @@
166126
"cell_type": "code",
167127
"outputs": [],
168128
"source": [
169-
"movie(environment, value_agent, './_output/breakout.avi', fps=60)"
129+
"movie(environment, agent, './results/breakout.mp4', fps=20)"
170130
],
171131
"metadata": {
172132
"collapsed": false

cart-pole/DQN.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
"from DQN import DeepQ\n",
4545
"\n",
4646
"sys.path.append(\"../\")\n",
47-
"from utilities.visualisation.plot import plot # noqa\n",
48-
"from utilities.visualisation.gif import gif2 # noqa"
47+
"from help.visualisation.plot import plot # noqa\n",
48+
"from help.visualisation.gif import gif2 # noqa"
4949
]
5050
},
5151
{

cart-pole/REINFORCE.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
"from REINFORCE import PolicyGradient\n",
4444
"\n",
4545
"sys.path.append(\"../\")\n",
46-
"from utilities.visualisation.plot import plot # noqa\n",
47-
"from utilities.visualisation.gif import gif2 # noqa"
46+
"from help.visualisation.plot import plot # noqa\n",
47+
"from help.visualisation.gif import gif2 # noqa"
4848
]
4949
},
5050
{

enduro/example.ipynb

Lines changed: 23 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,30 @@
33
{
44
"cell_type": "markdown",
55
"source": [
6-
"# Visualise the pre-trained agent in action\n",
7-
"\n",
8-
"Modify the path to the weights and run the notebook."
6+
"# Example of visualising the agent's training history performance"
97
],
108
"metadata": {
119
"collapsed": false
1210
},
1311
"id": "b3d8465ecb86eca7"
1412
},
13+
{
14+
"cell_type": "markdown",
15+
"source": [
16+
"MODEL : Path to the pre-trained model\n",
17+
"METRICS : Path to the training history, or None"
18+
],
19+
"metadata": {
20+
"collapsed": false
21+
},
22+
"id": "fbd0af6b11428abe"
23+
},
1524
{
1625
"cell_type": "code",
1726
"outputs": [],
1827
"source": [
19-
"WEIGHTS = './_output/weights-0.pth'\n",
20-
"METRICS = None #'./_output/metrics.csv'"
28+
"MODEL = './results/model.pth'\n",
29+
"METRICS = None"
2130
],
2231
"metadata": {
2332
"collapsed": false
@@ -34,11 +43,11 @@
3443
"import gymnasium as gym\n",
3544
"import matplotlib.pyplot as plt\n",
3645
"\n",
37-
"from DQN import VisionDeepQ\n",
46+
"from train import SKIP\n",
3847
"\n",
3948
"sys.path.append(\"../\")\n",
40-
"from utilities.visualisation.plot import visualise_csv # noqa\n",
41-
"from utilities.visualisation.gif import gif # noqa"
49+
"from help.visualisation.plot import graph2 # noqa\n",
50+
"from help.visualisation.gif import gif # noqa"
4251
],
4352
"metadata": {
4453
"collapsed": false
@@ -49,47 +58,7 @@
4958
{
5059
"cell_type": "markdown",
5160
"source": [
52-
"## Parameters"
53-
],
54-
"metadata": {
55-
"collapsed": false
56-
},
57-
"id": "4dddd56883444fab"
58-
},
59-
{
60-
"cell_type": "code",
61-
"outputs": [],
62-
"source": [
63-
"network = {\n",
64-
" \"input_channels\": 2, \"outputs\": 9,\n",
65-
" \"channels\": [32, 64, 64],\n",
66-
" \"kernels\": [8, 4, 3],\n",
67-
" \"padding\": [\"valid\", \"valid\", \"valid\"],\n",
68-
" \"strides\": [4, 2, 1],\n",
69-
" \"nodes\": [512],\n",
70-
"}\n",
71-
"optimizer = {\n",
72-
" \"optimizer\": torch.optim.RMSprop,\n",
73-
" \"lr\": 0.0001,\n",
74-
" \"hyperparameters\": {}\n",
75-
"}\n",
76-
"shape = {\n",
77-
" \"original\": (1, 1, 210, 160),\n",
78-
" \"height\": slice(51, 155),\n",
79-
" \"width\": slice(8, 160)\n",
80-
"}\n",
81-
"skip = 4"
82-
],
83-
"metadata": {
84-
"collapsed": false
85-
},
86-
"id": "16867687f37ddbca",
87-
"execution_count": null
88-
},
89-
{
90-
"cell_type": "markdown",
91-
"source": [
92-
"## Setup"
61+
"## Loading the agent and environment"
9362
],
9463
"metadata": {
9564
"collapsed": false
@@ -100,13 +69,7 @@
10069
"cell_type": "code",
10170
"outputs": [],
10271
"source": [
103-
"value_agent = VisionDeepQ(\n",
104-
" network=network, optimizer=optimizer, shape=shape,\n",
105-
" exploration_rate=0.01,\n",
106-
")\n",
107-
"\n",
108-
"weights = torch.load(WEIGHTS, map_location=torch.device('cpu'))\n",
109-
"value_agent.load_state_dict(weights)\n",
72+
"agent = torch.load(MODEL, map_location=torch.device('cpu'))\n",
11073
"\n",
11174
"environment = gym.make('ALE/Enduro-v5', render_mode=\"rgb_array\",\n",
11275
" obs_type=\"grayscale\", frameskip=1, repeat_action_probability=0.0)\n",
@@ -131,7 +94,7 @@
13194
{
13295
"cell_type": "markdown",
13396
"source": [
134-
"### Plotting the metrics from the csv-file created during training."
97+
"### Training history (if specified)"
13598
],
13699
"metadata": {
137100
"collapsed": false
@@ -142,7 +105,7 @@
142105
"cell_type": "code",
143106
"outputs": [],
144107
"source": [
145-
"visualise_csv(METRICS, title=\"Training history\", window=20) if METRICS else None\n",
108+
"graph2(METRICS, title=\"Enduro training history\", window=20) if METRICS else None\n",
146109
"plt.show() if METRICS else None"
147110
],
148111
"metadata": {
@@ -154,7 +117,7 @@
154117
{
155118
"cell_type": "markdown",
156119
"source": [
157-
"### Creating and saving a gif of the agent in action. The gif will be saved to the given path."
120+
"### In action"
158121
],
159122
"metadata": {
160123
"collapsed": false
@@ -165,7 +128,7 @@
165128
"cell_type": "code",
166129
"outputs": [],
167130
"source": [
168-
"gif(environment, value_agent, './_output/enduro-0.gif', skip, 25)"
131+
"gif(environment, agent, './results/enduro.gif', SKIP, 25)"
169132
],
170133
"metadata": {
171134
"collapsed": false
File renamed without changes.
File renamed without changes.

utilities/visualisation/gif.py renamed to help/visualisation/gif.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def gif2(environment, agent, path="./live-preview.gif", duration=50):
3131
_ = imageio.mimsave(path, images, duration=duration)
3232

3333

34-
def gif(environment, agent, path="./live-preview.gif", skip=4, duration=50):
34+
def gif(environment, agent, path="./live-preview.gif", skip=1, duration=50):
3535
"""
3636
Create a GIF of the agent playing the environment.
3737

utilities/visualisation/movie.py renamed to help/visualisation/movie.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66

7-
def movie(environment, agent, path="./live-preview.mp4", skip=4, fps=50):
7+
def movie(environment, agent, path="./live-preview.mp4", skip=1, fps=50):
88
"""Created by Mistral Large."""
99
states = agent.preprocess(environment.reset()[0])
1010
if hasattr(agent, "shape") and "reshape" in agent.shape:
File renamed without changes.

tetris/example.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
"from DQN import DeepQ\n",
4242
"\n",
4343
"sys.path.append(\"../\")\n",
44-
"from utilities.visualisation.plot import graph # noqa\n",
45-
"from utilities.visualisation.gif import gif # noqa"
44+
"from help.visualisation.plot import graph # noqa\n",
45+
"from help.visualisation.gif import gif # noqa"
4646
],
4747
"metadata": {
4848
"collapsed": false,

tetris/transfer-learning/DQN-ResNet18.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
"from DQN import TransferDeepQ\n",
3535
"\n",
3636
"sys.path.append(\"../\")\n",
37-
"from utilities.visualisation.plot import plot # noqa\n",
38-
"from utilities.visualisation.gif import gif # noqa"
37+
"from help.visualisation.plot import plot # noqa\n",
38+
"from help.visualisation.gif import gif # noqa"
3939
]
4040
},
4141
{

0 commit comments

Comments
 (0)