Skip to content

Commit 5033881

Browse files
committed
Remove trailing whitespaces and add final new line for each module
1 parent aff6401 commit 5033881

31 files changed

+293
-291
lines changed

school_project/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Main package of A-level Computer Science NEA Programming Project."""
22

3-
__all__ = ['models', 'frames', 'test']
3+
__all__ = ['models', 'frames', 'test']

school_project/__main__.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,30 @@
99

1010
import pympler.tracker as tracker
1111

12-
from school_project.frames import (HyperParameterFrame, TrainingFrame,
13-
LoadModelFrame, TestMNISTFrame,
12+
from school_project.frames import (HyperParameterFrame, TrainingFrame,
13+
LoadModelFrame, TestMNISTFrame,
1414
TestCatRecognitionFrame, TestXORFrame)
1515

1616
class SchoolProjectFrame(tk.Frame):
1717
"""Main frame of school project."""
1818
def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
1919
"""Initialise school project pages.
20-
20+
2121
Args:
2222
root (tk.Tk): the widget object that contains this widget.
2323
width (int): the pixel width of the frame.
2424
height (int): the pixel height of the frame.
2525
bg (str): the hex value or name of the frame's background colour.
2626
Raises:
2727
TypeError: if root, width or height are not of the correct type.
28-
28+
2929
"""
3030
super().__init__(master=root, width=width, height=height, bg=bg)
3131
self.root = root.title("School Project")
3232
self.WIDTH = width
3333
self.HEIGHT = height
3434
self.BG = bg
35-
35+
3636
# Setup school project frame variables
3737
self.hyper_parameter_frame: HyperParameterFrame
3838
self.training_frame: TrainingFrame
@@ -118,14 +118,14 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
118118
command=self.enter_home_frame)
119119

120120
# Setup home frame
121-
self.home_frame = tk.Frame(master=self,
122-
width=self.WIDTH,
121+
self.home_frame = tk.Frame(master=self,
122+
width=self.WIDTH,
123123
height=self.HEIGHT,
124124
bg=self.BG)
125125
self.title_label = tk.Label(
126126
master=self.home_frame,
127127
bg=self.BG,
128-
font=('Arial', 20),
128+
font=('Arial', 20),
129129
text="A-level Computer Science NEA Programming Project"
130130
)
131131
self.about_label = tk.Label(
@@ -162,7 +162,7 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
162162
font=tkf.Font(size=12),
163163
text="Load Model",
164164
command=self.enter_load_model_frame)
165-
165+
166166
# Grid home frame widgets
167167
self.title_label.grid(row=0, column=0, columnspan=4, pady=(10,0))
168168
self.about_label.grid(row=1, column=0, columnspan=4, pady=(10,50))
@@ -172,19 +172,19 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
172172
self.load_model_button.grid(row=4, column=2)
173173

174174
self.home_frame.pack()
175-
175+
176176
# Setup frame attributes
177177
self.grid_propagate(flag=False)
178178
self.pack_propagate(flag=False)
179179

180180
@staticmethod
181181
def setup_database() -> tuple[sqlite3.Connection, sqlite3.Cursor]:
182-
"""Create a connection to the pretrained_models database file and
182+
"""Create a connection to the pretrained_models database file and
183183
setup base table if needed.
184-
184+
185185
Returns:
186186
a tuple of the database connection and the cursor for it.
187-
187+
188188
"""
189189
connection = sqlite3.connect(
190190
database='school_project/saved_models.db'
@@ -232,14 +232,14 @@ def enter_load_model_frame(self) -> None:
232232
)
233233
self.load_model_frame.pack()
234234

235-
# Don't give option to test loaded model if no models have been saved
235+
# Don't give option to test loaded model if no models have been saved
236236
# for the dataset.
237237
if len(self.load_model_frame.model_options) > 0:
238238
self.test_loaded_model_button.pack()
239239
self.delete_loaded_model_button.pack(pady=(5,0))
240-
240+
241241
self.exit_load_model_frame_button.pack(pady=(5,0))
242-
242+
243243
def exit_hyper_parameter_frame(self) -> None:
244244
"""Unpack hyper-parameter frame and pack home frame."""
245245
self.hyper_parameter_frame.pack_forget()
@@ -269,7 +269,7 @@ def enter_training_frame(self) -> None:
269269
self.exit_hyper_parameter_frame_button.pack_forget()
270270
self.training_frame = TrainingFrame(
271271
root=self,
272-
width=self.WIDTH,
272+
width=self.WIDTH,
273273
height=self.HEIGHT,
274274
bg=self.BG,
275275
model=self.model,
@@ -282,7 +282,7 @@ def enter_training_frame(self) -> None:
282282
def manage_training(self, train_thread: threading.Thread) -> None:
283283
"""Wait for model training thread to finish,
284284
then plot training losses on training frame.
285-
285+
286286
Args:
287287
train_thread (threading.Thread):
288288
the thread running the model's train() method.
@@ -308,7 +308,7 @@ def test_created_model(self) -> None:
308308
self.training_frame.pack_forget()
309309
self.test_created_model_button.pack_forget()
310310
if self.hyper_parameter_frame.dataset == "MNIST":
311-
self.test_frame = TestMNISTFrame(
311+
self.test_frame = TestMNISTFrame(
312312
root=self,
313313
width=self.WIDTH,
314314
height=self.HEIGHT,
@@ -319,7 +319,7 @@ def test_created_model(self) -> None:
319319
elif self.hyper_parameter_frame.dataset == "Cat Recognition":
320320
self.test_frame = TestCatRecognitionFrame(
321321
root=self,
322-
width=self.WIDTH,
322+
width=self.WIDTH,
323323
height=self.HEIGHT,
324324
bg=self.BG,
325325
use_gpu=self.hyper_parameter_frame.use_gpu,
@@ -335,7 +335,7 @@ def test_created_model(self) -> None:
335335
self.manage_testing(test_thread=self.test_frame.test_thread)
336336

337337
def test_loaded_model(self) -> None:
338-
"""Load saved model from load model frame, unpack load model frame,
338+
"""Load saved model from load model frame, unpack load model frame,
339339
pack test frame for the dataset and begin managing the test thread."""
340340
self.saving_model = False
341341
try:
@@ -347,7 +347,7 @@ def test_loaded_model(self) -> None:
347347
self.delete_loaded_model_button.pack_forget()
348348
self.exit_load_model_frame_button.pack_forget()
349349
if self.load_model_frame.dataset == "MNIST":
350-
self.test_frame = TestMNISTFrame(
350+
self.test_frame = TestMNISTFrame(
351351
root=self,
352352
width=self.WIDTH,
353353
height=self.HEIGHT,
@@ -376,13 +376,13 @@ def test_loaded_model(self) -> None:
376376
def manage_testing(self, test_thread: threading.Thread) -> None:
377377
"""Wait for model test thread to finish,
378378
then plot results on test frame.
379-
379+
380380
Args:
381381
test_thread (threading.Thread):
382382
the thread running the model's predict() method.
383383
Raises:
384384
TypeError: if test_thread is not of type threading.Thread.
385-
385+
386386
"""
387387
if not test_thread.is_alive():
388388
self.test_frame.plot_results(model=self.model)
@@ -395,7 +395,7 @@ def manage_testing(self, test_thread: threading.Thread) -> None:
395395
self.after(1_000, self.manage_testing, test_thread)
396396

397397
def save_model(self) -> None:
398-
"""Save the model, save the model information to the database, then
398+
"""Save the model, save the model information to the database, then
399399
enter the home frame."""
400400
model_name = self.save_model_name_entry.get()
401401

@@ -480,19 +480,19 @@ def enter_home_frame(self) -> None:
480480
self.home_frame.pack()
481481
summary_tracker.create_summary() # BUG: Object summary seems to reduce
482482
# memory leak greatly
483-
483+
484484
def main() -> None:
485485
"""Entrypoint of project."""
486486
root = tk.Tk()
487487
school_project_frame = SchoolProjectFrame(root=root, width=1280,
488488
height=835, bg='white')
489489
school_project_frame.pack(side='top', fill='both', expand=True)
490490
root.mainloop()
491-
491+
492492
# Stop model training when GUI closes
493-
if school_project_frame.model != None:
493+
if school_project_frame.model is not None:
494494
school_project_frame.model.set_running(value=False)
495495

496496
if __name__ == "__main__":
497497
summary_tracker = tracker.SummaryTracker() # Setup object tracker
498-
main()
498+
main()

school_project/frames/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .load_model import LoadModelFrame
55
from .test_model import TestMNISTFrame, TestCatRecognitionFrame, TestXORFrame
66

7-
__all__ = ['create_model', 'load_model', 'test_model']
7+
__all__ = ['create_model', 'load_model', 'test_model']

school_project/frames/create_model.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
class HyperParameterFrame(tk.Frame):
1313
"""Frame for hyper-parameter page."""
14-
def __init__(self, root: tk.Tk, width: int,
14+
def __init__(self, root: tk.Tk, width: int,
1515
height: int, bg: str, dataset: str) -> None:
1616
"""Initialise hyper-parameter frame widgets.
17-
17+
1818
Args:
1919
root (tk.Tk): the widget object that contains this widget.
2020
width (int): the pixel width of the frame.
@@ -24,21 +24,21 @@ def __init__(self, root: tk.Tk, width: int,
2424
('MNIST', 'Cat Recognition' or 'XOR')
2525
Raises:
2626
TypeError: if root, width or height are not of the correct type.
27-
27+
2828
"""
2929
super().__init__(master=root, width=width, height=height, bg=bg)
3030
self.root = root
3131
self.WIDTH = width
3232
self.HEIGHT = height
3333
self.BG = bg
34-
34+
3535
# Setup hyper-parameter frame variables
3636
self.dataset = dataset
3737
self.use_gpu: bool
3838
self.default_hyper_parameters = self.load_default_hyper_parameters(
3939
dataset=dataset
4040
)
41-
41+
4242
# Setup widgets
4343
self.title_label = tk.Label(master=self,
4444
bg=self.BG,
@@ -115,7 +115,7 @@ def __init__(self, root: tk.Tk, width: int,
115115
self.model_status_label = tk.Label(master=self,
116116
bg=self.BG,
117117
font=('Arial', 15))
118-
118+
119119
# Pack widgets
120120
self.title_label.grid(row=0, column=0, columnspan=3)
121121
self.about_label.grid(row=1, column=0, columnspan=3)
@@ -129,13 +129,13 @@ def __init__(self, root: tk.Tk, width: int,
129129
self.use_gpu_check_button.grid(row=3, column=2, pady=(30, 0))
130130
self.model_status_label.grid(row=5, column=0,
131131
columnspan=3, pady=50)
132-
132+
133133
def load_default_hyper_parameters(self, dataset: str) -> dict[
134-
str,
134+
str,
135135
str | int | list[int] | float
136136
]:
137137
"""Load the dataset's default hyper-parameters from the json file.
138-
138+
139139
Args:
140140
dataset (str): the name of the dataset to load hyper-parameters
141141
for. ('MNIST', 'Cat Recognition' or 'XOR')
@@ -144,7 +144,7 @@ def load_default_hyper_parameters(self, dataset: str) -> dict[
144144
"""
145145
with open('school_project/frames/hyper-parameter-defaults.json') as f:
146146
return json.load(f)[dataset]
147-
147+
148148
def create_model(self) -> object:
149149
"""Create and return a Model using the hyper-parameters set.
150150
@@ -171,10 +171,12 @@ def create_model(self) -> object:
171171
from school_project.models.cpu.cat_recognition import CatRecognitionModel as Model
172172
elif self.dataset == "XOR":
173173
from school_project.models.cpu.xor import XORModel as Model
174-
model = Model(hidden_layers_shape = [int(neuron_count) for neuron_count in hidden_layers_shape_input],
175-
train_dataset_size = self.train_dataset_size_scale.get(),
176-
learning_rate = self.learning_rate_scale.get(),
177-
use_relu = self.use_relu_check_button_var.get())
174+
model = Model(
175+
hidden_layers_shape = [int(neuron_count) for neuron_count in hidden_layers_shape_input],
176+
train_dataset_size = self.train_dataset_size_scale.get(),
177+
learning_rate = self.learning_rate_scale.get(),
178+
use_relu = self.use_relu_check_button_var.get()
179+
)
178180
model.create_model_values()
179181

180182
else:
@@ -197,14 +199,14 @@ def create_model(self) -> object:
197199
)
198200
raise ImportError
199201
return model
200-
202+
201203
class TrainingFrame(tk.Frame):
202204
"""Frame for training page."""
203205
def __init__(self, root: tk.Tk, width: int,
204206
height: int, bg: str,
205207
model: object, epoch_count: int) -> None:
206208
"""Initialise training frame widgets.
207-
209+
208210
Args:
209211
root (tk.Tk): the widget object that contains this widget.
210212
width (int): the pixel width of the frame.
@@ -214,14 +216,14 @@ def __init__(self, root: tk.Tk, width: int,
214216
epoch_count (int): the number of training epochs.
215217
Raises:
216218
TypeError: if root, width or height are not of the correct type.
217-
219+
218220
"""
219221
super().__init__(master=root, width=width, height=height, bg=bg)
220222
self.root = root
221223
self.WIDTH = width
222224
self.HEIGHT = height
223225
self.BG = bg
224-
226+
225227
# Setup widgets
226228
self.model_status_label = tk.Label(master=self,
227229
bg=self.BG,
@@ -234,11 +236,11 @@ def __init__(self, root: tk.Tk, width: int,
234236
figure=self.loss_figure,
235237
master=self
236238
)
237-
239+
238240
# Pack widgets
239241
self.model_status_label.pack(pady=(30,0))
240242
self.training_progress_label.pack(pady=30)
241-
243+
242244
# Start training thread
243245
self.model_status_label.configure(
244246
text="Training weights and biases...",
@@ -252,10 +254,10 @@ def __init__(self, root: tk.Tk, width: int,
252254

253255
def plot_losses(self, model: object) -> None:
254256
"""Plot losses of Model training.
255-
257+
256258
Args:
257259
model (object): the Model object thats been trained.
258-
260+
259261
"""
260262
self.model_status_label.configure(
261263
text=f"Weights and biases trained in {model.training_time}s",
@@ -267,4 +269,4 @@ def plot_losses(self, model: object) -> None:
267269
graph.set_xlabel("Epochs")
268270
graph.set_ylabel("Loss Value")
269271
graph.plot(np.squeeze(model.train_losses))
270-
self.loss_canvas.get_tk_widget().pack()
272+
self.loss_canvas.get_tk_widget().pack()

0 commit comments

Comments
 (0)