Skip to content

Commit

Permalink
Format code and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Jul 27, 2024
1 parent 20875de commit 8c01ae7
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 67 deletions.
1 change: 0 additions & 1 deletion conf/experiment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Configurations for the experiments and config groups, using hydra-zen.
"""
Expand Down
6 changes: 3 additions & 3 deletions dataset/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def _load(
progress: Progress,
job_id: TaskID,
) -> Tuple[Union[dict, list, Tensor], Union[dict, list, Tensor]]:
len = 3 if self._tiny else 20
progress.update(job_id, total=len)
for _ in range(len):
length = 3 if self._tiny else 20
progress.update(job_id, total=length)
for _ in range(length):
progress.advance(job_id)
sleep(0.001 if self._tiny else 0.1)
return torch.rand(10000, self._img_dim, self._img_dim), torch.rand(10000, 8)
Expand Down
1 change: 1 addition & 0 deletions dataset/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
)
self._split = split
self._lazy = scd_lazy # TODO: Implement eager caching (rn the default is lazy)
# TODO: Refactor and reduce cyclomatic complexity
argnames = inspect.getfullargspec(self.__class__.__init__).args
found = False
frame: FrameType | None = inspect.currentframe()
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ exclude = [
line-length = 88
indent-width = 4

# Assume Python 3.8
target-version = "py38"
# Assume Python 3.11
target-version = "py311"

[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
Expand Down Expand Up @@ -69,7 +69,7 @@ line-ending = "auto"
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = false
docstring-code-format = true

# Set the line length limit used when formatting code snippets in
# docstrings.
Expand Down
15 changes: 9 additions & 6 deletions src/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


console = Console()
print = console.print
print = console.print # skipcq: PYL-W0603


class BaseTester(BaseTrainer):
Expand All @@ -55,7 +55,7 @@ def __init__(
_loss = training_loss
self._gui = gui
global print # skipcq: PYL-W0603
print = self._gui.print
print = self._gui.print # skipcq: PYL-W0603
self._run_name = run_name
self._model = model
if model_ckpt_path is None:
Expand Down Expand Up @@ -83,7 +83,8 @@ def _test_iteration(
batch: Union[Tuple, List, Tensor],
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Evaluation procedure for one batch. We want to keep the code DRY and avoid
making mistakes, so this code calls the BaseTrainer._train_val_iteration() method.
making mistakes, so this code calls the BaseTrainer._train_val_iteration()
method.
Args:
batch: The batch to process.
Returns:
Expand All @@ -99,15 +100,16 @@ async def test(
) -> None:
"""Computes the average loss on the test set.
Args:
visualize_every (int, optional): Visualize the model predictions every n batches.
visualize_every (int, optional): Visualize the model predictions every n
batches.
Defaults to 0 (no visualization).
"""
_ = kwargs
test_loss: MeanMetric = MeanMetric()
test_metrics: Dict[str, MeanMetric] = defaultdict(MeanMetric)
self._model.eval()
print(Text(f"[*] Testing {self._run_name}", style="bold green"))
""" ==================== Training loop for one epoch ==================== """
# ==================== Training loop for one epoch ====================
pbar, update_loss_hook = self._gui.track_testing(
self._data_loader, total=len(self._data_loader)
)
Expand All @@ -124,7 +126,8 @@ async def test(
if visualize_every > 0 and (i + 1) % visualize_every == 0:
self._visualize(batch, i)

# TODO: Report metrics in a special panel? Then hang the GUI until the user is done.
# TODO: Report metrics in a special panel? Then hang the GUI until the user is
# done.
print("=" * 81)
print("==" + " " * 31 + " Test results " + " " * 31 + "==")
print("=" * 81)
Expand Down
40 changes: 22 additions & 18 deletions src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from utils.training import visualize_model_predictions

console = Console()
print = console.print
print = console.print # skipcq: PYL-W0603


class BaseTrainer:
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
self._n_ctrl_c = 0
self._gui = gui
global print # skipcq: PYL-W0603
print = self._gui.print
print = self._gui.print # skipcq: PYL-W0603
if model_ckpt_path is not None:
self._load_checkpoint(model_ckpt_path)
signal.signal(signal.SIGINT, self._terminator)
Expand All @@ -104,8 +104,9 @@ def _train_val_iteration(
epoch: int,
validation: bool = False,
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Training or validation procedure for one batch. We want to keep the code DRY and avoid
making mistakes, so write this code only once at the cost of many function calls!
"""Training or validation procedure for one batch. We want to keep the code DRY
and avoid making mistakes, so write this code only once at the cost of many
function calls!
Args:
batch: The batch to process.
Returns:
Expand Down Expand Up @@ -136,7 +137,7 @@ def _train_epoch(
epoch_loss: MeanMetric = MeanMetric()
epoch_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric)
has_visualized = 0
""" ==================== Training loop for one epoch ==================== """
# ==================== Training loop for one epoch ====================
pbar, update_loss_hook = self._gui.track_training(
self._train_loader,
total=len(self._train_loader),
Expand Down Expand Up @@ -188,7 +189,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float:
float: Average validation loss for the epoch.
"""
has_visualized = 0
""" ==================== Validation loop for one epoch ==================== """
# ==================== Validation loop for one epoch ====================
with torch.no_grad():
val_loss: MeanMetric = MeanMetric()
val_loss_components: Dict[str, MeanMetric] = defaultdict(MeanMetric)
Expand All @@ -212,7 +213,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float:
for k, v in loss_components.items():
val_loss_components[k].update(v.item())
update_loss_hook(val_loss.compute())
""" ==================== Visualization ==================== """
# ==================== Visualization ====================
if (
visualize
and has_visualized < self._viz_n_samples
Expand All @@ -235,8 +236,8 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float:
},
step=epoch,
)
# Set minimize_metric to a key in val_loss_components if you wish to minimize
# a specific metric instead of the validation loss:
# Set minimize_metric to a key in val_loss_components if you wish to
# minimize a specific metric instead of the validation loss:
self._model_saver(
epoch,
mean_val_loss,
Expand Down Expand Up @@ -268,7 +269,7 @@ async def train(
)
self._viz_n_samples = visualize_n_samples
self._gui.set_start_epoch(self._epoch)
""" ==================== Training loop ==================== """
# ==================== Training loop ====================
last_val_loss = float("inf")
for epoch in range(self._epoch, epochs):
print(f"Epoch: {epoch}")
Expand All @@ -294,7 +295,7 @@ async def train(
last_val_loss = val_loss
if self._scheduler is not None:
await asyncio.to_thread(self._scheduler.step)
""" ==================== Plotting ==================== """
# ==================== Plotting ====================
self._gui.plot(epoch, train_loss, last_val_loss) # , self._model_saver)
await asyncio.to_thread(
self._save_checkpoint,
Expand Down Expand Up @@ -333,18 +334,19 @@ def _save_checkpoint(self, val_loss: float, ckpt_path: str, **kwargs) -> None:
)

def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None:
"""Loads the model and optimizer state from a checkpoint file. This method should remain in
this class because it should be extendable in classes inheriting from this class, instead
of being overwritten/modified. That would be a source of bugs and a bad practice.
"""Loads the model and optimizer state from a checkpoint file. This method
should remain in this class because it should be extendable in classes
inheriting from this class, instead of being overwritten/modified. That would be
a source of bugs and a bad practice.
Args:
ckpt_path (str): The path to the checkpoint file.
model_only (bool): If True, only the model is loaded (useful for BaseTester).
model_only (bool): If True, only the model is loaded (useful for
BaseTester).
Returns:
None
"""
print(f"[*] Restoring from checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path)
# If the model was optimized with torch.optimize() we need to remove the "_orig_mod"
# prefix:
if "_orig_mod" in list(ckpt["model_ckpt"].keys())[0]:
ckpt["model_ckpt"] = {
Expand All @@ -355,7 +357,8 @@ def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None:
except Exception:
if project_conf.PARTIALLY_LOAD_MODEL_IF_NO_FULL_MATCH:
print(
"[!] Partially loading model weights (no full match between model and checkpoint)"
"[!] Partially loading model weights "
+ "(no full match between model and checkpoint)"
)
self._model.load_state_dict(ckpt["model_ckpt"], strict=False)
if not model_only:
Expand All @@ -378,7 +381,8 @@ def _terminator(self, sig, frame):
and self._n_ctrl_c == 0
):
print(
f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}. Press Ctrl+C again to abort."
f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}."
+ " Press Ctrl+C again to abort."
)
self._n_ctrl_c += 1
elif (
Expand Down
2 changes: 1 addition & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# import sys
import traceback
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional

# import IPython
import numpy as np
Expand Down
65 changes: 30 additions & 35 deletions utils/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def __init__(
classes: The CSS classes of the plotter widget.
disabled: Whether the plotter widget is disabled or not.
"""
super().__init__(name=name, id=id, classes=classes, disabled=disabled)
super().__init__(
name=name,
id=id,
classes=classes,
disabled=disabled,
)
self._title = title
self._log_scale = use_log_scale
self._train_losses: list[float] = []
Expand Down Expand Up @@ -171,8 +176,9 @@ def __init__(self):

def update_loss_hook(self, loss: float) -> None:
"""Update the loss value in the progress bar."""
# TODO: min_val_loss during validation, val_loss during training. Ideally the
# second parameter would be super flexible (use a dict then).
# TODO: min_val_loss during validation, val_loss during training.
# Ideally the second parameter would be super flexible (use a dict
# then).
self._loss = loss

class SeqWrapper(abc.Iterator, LossHook):
Expand Down Expand Up @@ -235,24 +241,11 @@ def reset_hook():
plabel.update(self.DESCRIPTIONS[Task.IDLE])

wrapper = None
update_p, reset_p = (
partial(update_hook),
partial(reset_hook),
)
update_p, reset_p = partial(update_hook), partial(reset_hook)
if isinstance(iterable, abc.Sequence):
wrapper = SeqWrapper(
iterable,
total,
update_p,
reset_p,
)
wrapper = SeqWrapper(iterable, total, update_p, reset_p)
elif isinstance(iterable, (abc.Iterator, DataLoader)):
wrapper = IteratorWrapper(
iterable,
total,
update_p,
reset_p,
)
wrapper = IteratorWrapper(iterable, total, update_p, reset_p)
else:
raise ValueError(
f"iterable must be a Sequence or an Iterator, got {type(iterable)}"
Expand All @@ -264,7 +257,9 @@ def reset_hook():


class GUI(App):
"""A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework."""
"""
A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.
"""

TITLE = "Matchbox TUI"
CSS_PATH = "style.css"
Expand Down Expand Up @@ -327,11 +322,7 @@ def print(self, message: Any):
if isinstance(message, (RenderableType, str)):
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""),
message,
),
)
Expand Down Expand Up @@ -379,28 +370,31 @@ def print(self, message: Any):
)

def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]:
"""Return an iterable that tracks the progress of the training process, and a progress bar
hook to update the loss value at each iteration."""
"""Return an iterable that tracks the progress of the training process, and a
progress bar hook to update the loss value at each iteration."""
return self.query_one(DatasetProgressBar).track_iterable(
iterable, Task.TRAINING, total
)

def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]:
"""Return an iterable that tracks the progress of the validation process, and a progress bar
hook to update the loss value at each iteration."""
"""Return an iterable that tracks the progress of the validation process, and a
progress bar hook to update the loss value at each iteration."""
return self.query_one(DatasetProgressBar).track_iterable(
iterable, Task.VALIDATION, total
)

def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]:
"""Return an iterable that tracks the progress of the testing process, and a progress bar
hook to update the loss value at each iteration."""
"""Return an iterable that tracks the progress of the testing process, and a
progress bar hook to update the loss value at each iteration."""
return self.query_one(DatasetProgressBar).track_iterable(
iterable, Task.TESTING, total
)

def plot(
self, epoch: int, train_loss: float, val_loss: Optional[float] = None
self,
epoch: int,
train_loss: float,
val_loss: Optional[float] = None,
) -> None:
"""Plot the training and validation losses for the current epoch."""
self.query_one(PlotterWidget).loading = False
Expand Down Expand Up @@ -432,8 +426,8 @@ async def run_my_app():
await asyncio.sleep(0.1)
await asyncio.sleep(2)
mnist = MNIST(root="data", train=False, download=True, transform=to_tensor)
# Somehow, the dataloader will crash if it's not forked when using multiprocessing along with
# Textual.
# Somehow, the dataloader will crash if it's not forked when using multiprocessing
# along with Textual.
mp.set_start_method("fork")
dataloader = DataLoader(mnist, 32, shuffle=True, num_workers=2)
pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader))
Expand All @@ -444,7 +438,8 @@ async def run_my_app():
update_progress_loss(random())
gui.plot(epoch=i, train_loss=random(), val_loss=random())
gui.print(
f"[{i+1}/{len(dataloader)}]: We can also iterate over PyTorch dataloaders!"
f"[{i+1}/{len(dataloader)}]: "
+ "We can also iterate over PyTorch dataloaders!"
)
if i == 0:
gui.print(batch)
Expand Down

0 comments on commit 8c01ae7

Please sign in to comment.