9
9
10
10
import pympler .tracker as tracker
11
11
12
- from school_project .frames import (HyperParameterFrame , TrainingFrame ,
13
- LoadModelFrame , TestMNISTFrame ,
12
+ from school_project .frames import (HyperParameterFrame , TrainingFrame ,
13
+ LoadModelFrame , TestMNISTFrame ,
14
14
TestCatRecognitionFrame , TestXORFrame )
15
15
16
16
class SchoolProjectFrame (tk .Frame ):
17
17
"""Main frame of school project."""
18
18
def __init__ (self , root : tk .Tk , width : int , height : int , bg : str ) -> None :
19
19
"""Initialise school project pages.
20
-
20
+
21
21
Args:
22
22
root (tk.Tk): the widget object that contains this widget.
23
23
width (int): the pixel width of the frame.
24
24
height (int): the pixel height of the frame.
25
25
bg (str): the hex value or name of the frame's background colour.
26
26
Raises:
27
27
TypeError: if root, width or height are not of the correct type.
28
-
28
+
29
29
"""
30
30
super ().__init__ (master = root , width = width , height = height , bg = bg )
31
31
self .root = root .title ("School Project" )
32
32
self .WIDTH = width
33
33
self .HEIGHT = height
34
34
self .BG = bg
35
-
35
+
36
36
# Setup school project frame variables
37
37
self .hyper_parameter_frame : HyperParameterFrame
38
38
self .training_frame : TrainingFrame
@@ -118,14 +118,14 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
118
118
command = self .enter_home_frame )
119
119
120
120
# Setup home frame
121
- self .home_frame = tk .Frame (master = self ,
122
- width = self .WIDTH ,
121
+ self .home_frame = tk .Frame (master = self ,
122
+ width = self .WIDTH ,
123
123
height = self .HEIGHT ,
124
124
bg = self .BG )
125
125
self .title_label = tk .Label (
126
126
master = self .home_frame ,
127
127
bg = self .BG ,
128
- font = ('Arial' , 20 ),
128
+ font = ('Arial' , 20 ),
129
129
text = "A-level Computer Science NEA Programming Project"
130
130
)
131
131
self .about_label = tk .Label (
@@ -162,7 +162,7 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
162
162
font = tkf .Font (size = 12 ),
163
163
text = "Load Model" ,
164
164
command = self .enter_load_model_frame )
165
-
165
+
166
166
# Grid home frame widgets
167
167
self .title_label .grid (row = 0 , column = 0 , columnspan = 4 , pady = (10 ,0 ))
168
168
self .about_label .grid (row = 1 , column = 0 , columnspan = 4 , pady = (10 ,50 ))
@@ -172,19 +172,19 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
172
172
self .load_model_button .grid (row = 4 , column = 2 )
173
173
174
174
self .home_frame .pack ()
175
-
175
+
176
176
# Setup frame attributes
177
177
self .grid_propagate (flag = False )
178
178
self .pack_propagate (flag = False )
179
179
180
180
@staticmethod
181
181
def setup_database () -> tuple [sqlite3 .Connection , sqlite3 .Cursor ]:
182
- """Create a connection to the pretrained_models database file and
182
+ """Create a connection to the pretrained_models database file and
183
183
setup base table if needed.
184
-
184
+
185
185
Returns:
186
186
a tuple of the database connection and the cursor for it.
187
-
187
+
188
188
"""
189
189
connection = sqlite3 .connect (
190
190
database = 'school_project/saved_models.db'
@@ -232,14 +232,14 @@ def enter_load_model_frame(self) -> None:
232
232
)
233
233
self .load_model_frame .pack ()
234
234
235
- # Don't give option to test loaded model if no models have been saved
235
+ # Don't give option to test loaded model if no models have been saved
236
236
# for the dataset.
237
237
if len (self .load_model_frame .model_options ) > 0 :
238
238
self .test_loaded_model_button .pack ()
239
239
self .delete_loaded_model_button .pack (pady = (5 ,0 ))
240
-
240
+
241
241
self .exit_load_model_frame_button .pack (pady = (5 ,0 ))
242
-
242
+
243
243
def exit_hyper_parameter_frame (self ) -> None :
244
244
"""Unpack hyper-parameter frame and pack home frame."""
245
245
self .hyper_parameter_frame .pack_forget ()
@@ -269,7 +269,7 @@ def enter_training_frame(self) -> None:
269
269
self .exit_hyper_parameter_frame_button .pack_forget ()
270
270
self .training_frame = TrainingFrame (
271
271
root = self ,
272
- width = self .WIDTH ,
272
+ width = self .WIDTH ,
273
273
height = self .HEIGHT ,
274
274
bg = self .BG ,
275
275
model = self .model ,
@@ -282,7 +282,7 @@ def enter_training_frame(self) -> None:
282
282
def manage_training (self , train_thread : threading .Thread ) -> None :
283
283
"""Wait for model training thread to finish,
284
284
then plot training losses on training frame.
285
-
285
+
286
286
Args:
287
287
train_thread (threading.Thread):
288
288
the thread running the model's train() method.
@@ -308,7 +308,7 @@ def test_created_model(self) -> None:
308
308
self .training_frame .pack_forget ()
309
309
self .test_created_model_button .pack_forget ()
310
310
if self .hyper_parameter_frame .dataset == "MNIST" :
311
- self .test_frame = TestMNISTFrame (
311
+ self .test_frame = TestMNISTFrame (
312
312
root = self ,
313
313
width = self .WIDTH ,
314
314
height = self .HEIGHT ,
@@ -319,7 +319,7 @@ def test_created_model(self) -> None:
319
319
elif self .hyper_parameter_frame .dataset == "Cat Recognition" :
320
320
self .test_frame = TestCatRecognitionFrame (
321
321
root = self ,
322
- width = self .WIDTH ,
322
+ width = self .WIDTH ,
323
323
height = self .HEIGHT ,
324
324
bg = self .BG ,
325
325
use_gpu = self .hyper_parameter_frame .use_gpu ,
@@ -335,7 +335,7 @@ def test_created_model(self) -> None:
335
335
self .manage_testing (test_thread = self .test_frame .test_thread )
336
336
337
337
def test_loaded_model (self ) -> None :
338
- """Load saved model from load model frame, unpack load model frame,
338
+ """Load saved model from load model frame, unpack load model frame,
339
339
pack test frame for the dataset and begin managing the test thread."""
340
340
self .saving_model = False
341
341
try :
@@ -347,7 +347,7 @@ def test_loaded_model(self) -> None:
347
347
self .delete_loaded_model_button .pack_forget ()
348
348
self .exit_load_model_frame_button .pack_forget ()
349
349
if self .load_model_frame .dataset == "MNIST" :
350
- self .test_frame = TestMNISTFrame (
350
+ self .test_frame = TestMNISTFrame (
351
351
root = self ,
352
352
width = self .WIDTH ,
353
353
height = self .HEIGHT ,
@@ -376,13 +376,13 @@ def test_loaded_model(self) -> None:
376
376
def manage_testing (self , test_thread : threading .Thread ) -> None :
377
377
"""Wait for model test thread to finish,
378
378
then plot results on test frame.
379
-
379
+
380
380
Args:
381
381
test_thread (threading.Thread):
382
382
the thread running the model's predict() method.
383
383
Raises:
384
384
TypeError: if test_thread is not of type threading.Thread.
385
-
385
+
386
386
"""
387
387
if not test_thread .is_alive ():
388
388
self .test_frame .plot_results (model = self .model )
@@ -395,7 +395,7 @@ def manage_testing(self, test_thread: threading.Thread) -> None:
395
395
self .after (1_000 , self .manage_testing , test_thread )
396
396
397
397
def save_model (self ) -> None :
398
- """Save the model, save the model information to the database, then
398
+ """Save the model, save the model information to the database, then
399
399
enter the home frame."""
400
400
model_name = self .save_model_name_entry .get ()
401
401
@@ -480,19 +480,19 @@ def enter_home_frame(self) -> None:
480
480
self .home_frame .pack ()
481
481
summary_tracker .create_summary () # BUG: Object summary seems to reduce
482
482
# memory leak greatly
483
-
483
+
484
484
def main () -> None :
485
485
"""Entrypoint of project."""
486
486
root = tk .Tk ()
487
487
school_project_frame = SchoolProjectFrame (root = root , width = 1280 ,
488
488
height = 835 , bg = 'white' )
489
489
school_project_frame .pack (side = 'top' , fill = 'both' , expand = True )
490
490
root .mainloop ()
491
-
491
+
492
492
# Stop model training when GUI closes
493
- if school_project_frame .model != None :
493
+ if school_project_frame .model is not None :
494
494
school_project_frame .model .set_running (value = False )
495
495
496
496
if __name__ == "__main__" :
497
497
summary_tracker = tracker .SummaryTracker () # Setup object tracker
498
- main ()
498
+ main ()
0 commit comments