Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring #8

Merged
merged 6 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions bootstrap/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,15 @@ def parallelize_model(model: torch.nn.Module) -> torch.nn.Module:


def make_optimizer(
optimizer_partial: Partial[torch.optim.optimizer.Optimizer], model: torch.nn.Module
) -> torch.optim.optimizer.Optimizer:
optimizer_partial: Partial[torch.optim.Optimizer], # pyright: ignore
model: torch.nn.Module,
) -> torch.optim.Optimizer: # pyright: ignore
return optimizer_partial(model.parameters())


def make_scheduler(
scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler],
optimizer: torch.optim.optimizer.Optimizer,
optimizer: torch.optim.Optimizer, # pyright: ignore
epochs: int,
) -> torch.optim.lr_scheduler.LRScheduler:
scheduler = scheduler_partial(
Expand Down
20 changes: 10 additions & 10 deletions bootstrap/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
make_training_loss,
parallelize_model,
)
from bootstrap.tui.training_ui import TrainingUI
from conf import project as project_conf
from src.base_tester import BaseTester
from src.base_trainer import BaseTrainer
from utils import load_model_ckpt, to_cuda_
from utils.gui import GUI

console = Console()

Expand Down Expand Up @@ -108,7 +108,7 @@ def init_wandb(
def launch_experiment(
run, # type: ignore
data_loader: Partial[DataLoader[Any]],
optimizer: Partial[torch.optim.optimizer.Optimizer],
optimizer: Partial[torch.optim.Optimizer], # pyright: ignore
scheduler: Partial[torch.optim.lr_scheduler.LRScheduler],
trainer: Partial[BaseTrainer],
tester: Partial[BaseTester],
Expand Down Expand Up @@ -161,20 +161,20 @@ def launch_experiment(
sleep(1)

async def launch_with_async_gui():
gui = GUI(run_name, project_conf.LOG_SCALE_PLOT)
task = asyncio.create_task(gui.run_async())
while not gui.is_running:
tui = TrainingUI(run_name, project_conf.LOG_SCALE_PLOT)
task = asyncio.create_task(tui.run_async())
while not tui.is_running:
await asyncio.sleep(0.01) # Wait for the app to start up
model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode)
common_args = dict(
run_name=run_name,
model=model_inst,
model_ckpt_path=model_ckpt_path,
training_loss=training_loss_inst,
gui=gui,
tui=tui,
)
if run.training_mode:
gui.print("Training started!")
tui.print("Training started!")
if training_loss_inst is None:
raise ValueError("training_loss must be defined in training mode!")
if val_loader_inst is None or train_loader_inst is None:
Expand All @@ -195,9 +195,9 @@ async def launch_with_async_gui():
visualize_train_every=run.viz_train_every,
visualize_n_samples=run.viz_num_samples,
)
gui.print("Training finished!")
tui.print("Training finished!")
else:
gui.print("Testing started!")
tui.print("Testing started!")
if test_loader_inst is None:
raise ValueError("test_loader must be defined in testing mode!")
await tester(
Expand All @@ -207,7 +207,7 @@ async def launch_with_async_gui():
visualize_every=run.viz_every,
**asdict(run),
)
gui.print("Testing finished!")
tui.print("Testing finished!")
_ = await task

asyncio.run(launch_with_async_gui())
19 changes: 19 additions & 0 deletions bootstrap/tui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict


class Task(Enum):
IDLE = -1
TRAINING = 0
VALIDATION = 1
TESTING = 2


@dataclass
class Plot_BestModel:
"""Dataclass for representing a best model mark in the plotter widget."""

epoch: int
loss: float
metrics: Dict[str, float]
File renamed without changes.
229 changes: 229 additions & 0 deletions bootstrap/tui/training_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import asyncio
from datetime import datetime
from itertools import cycle
from random import random
from typing import (
Any,
Callable,
Iterable,
Optional,
Tuple,
)

import numpy as np
import torch
import torch.multiprocessing as mp
from rich.console import Group, RenderableType
from rich.pretty import Pretty
from rich.text import Text
from textual.app import App, ComposeResult
from textual.reactive import var
from textual.widgets import (
Footer,
Header,
Placeholder,
RichLog,
)
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor

from bootstrap.tui import Plot_BestModel, Task
from bootstrap.tui.widgets.plotting import PlotterWidget
from bootstrap.tui.widgets.progress import DatasetProgressBar


class TrainingUI(App):
"""
A Textual app to serve as *useful* GUI/TUI for my pytorch-based micro framework.
"""

TITLE = "Matchbox TUI"
CSS_PATH = "styles/training_ui.css"

BINDINGS = [
("q", "quit", "Quit"),
("d", "toggle_dark", "Toggle dark mode"),
("p", "marker", "Change plotter style"),
("ctrl+z", "suspend_progress"),
]

MARKERS = {
"dot": "Dot",
"hd": "High Definition",
"fhd": "Higher Definition",
"braille": "Braille",
"sd": "Standard Definition",
}

marker: var[str] = var("hd")

def __init__(self, run_name: str, log_scale: bool) -> None:
"""Initialise the application."""
super().__init__()
self._markers = cycle(self.MARKERS.keys())
self._log_scale = log_scale
self.run_name = run_name

def compose(self) -> ComposeResult:
yield Header()
yield PlotterWidget(
title=f"Trainign curves for {self.run_name}",
use_log_scale=self._log_scale,
classes="box",
)
yield RichLog(
highlight=True, markup=True, wrap=True, id="logger", classes="box"
)
yield DatasetProgressBar()
yield Placeholder(classes="box")
yield Footer()

def on_mount(self):
self.query_one(PlotterWidget).loading = True

def action_toggle_dark(self) -> None:
self.dark = not self.dark # skipcq: PYL-W0201

def watch_marker(self) -> None:
"""React to the marker type being changed."""
self.sub_title = self.MARKERS[self.marker] # skipcq: PYL-W0201
self.query_one(PlotterWidget).marker = self.marker

def action_marker(self) -> None:
"""Cycle to the next marker type."""
self.marker = next(self._markers) # skipcq: PTC-W0063

def print(self, message: Any):
logger: RichLog = self.query_one(RichLog)
if isinstance(message, (RenderableType, str)):
logger.write(
Group(
Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""),
message,
),
)
else:
ppable, pp_msg = True, None
try:
pp_msg = Pretty(message)
except Exception:
ppable = False
if ppable and pp_msg is not None:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
Text(str(type(message)) + " ", style="italic blue", end=""),
pp_msg,
)
)
else:
try:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
message,
),
)
except Exception as e:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
Text("Logging error: ", style="bold red"),
Text(str(e), style="bold red"),
)
)

def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]:
"""Return an iterable that tracks the progress of the training process, and a
progress bar hook to update the loss value at each iteration."""
return self.query_one(DatasetProgressBar).track_iterable(
iterable, Task.TRAINING, total
)

def track_validation(self, iterable, total: int) -> Tuple[Iterable, Callable]:
"""Return an iterable that tracks the progress of the validation process, and a
progress bar hook to update the loss value at each iteration."""
return self.query_one(DatasetProgressBar).track_iterable(
iterable, Task.VALIDATION, total
)

def track_testing(self, iterable, total: int) -> Tuple[Iterable, Callable]:
"""Return an iterable that tracks the progress of the testing process, and a
progress bar hook to update the loss value at each iteration."""
return self.query_one(DatasetProgressBar).track_iterable(
iterable, Task.TESTING, total
)

def plot(
self,
epoch: int,
train_loss: float,
val_loss: Optional[float] = None,
best_model: Optional[Plot_BestModel] = None,
) -> None:
"""Plot the training and validation losses for the current epoch."""
self.query_one(PlotterWidget).loading = False
self.query_one(PlotterWidget).update(epoch, train_loss, val_loss, best_model)

def set_start_epoch(self, start_epoch: int) -> None:
"""Set the starting epoch for the plotter widget."""
self.query_one(PlotterWidget).set_start_epoch(start_epoch)


async def run_my_app():
gui = TrainingUI("test-run", log_scale=False)
task = asyncio.create_task(gui.run_async())
while not gui.is_running:
await asyncio.sleep(0.01) # Wait for the app to start up
gui.print("Hello, World!")
await asyncio.sleep(2)
gui.print(Text("Let's log some tensors :)", style="bold magenta"))
await asyncio.sleep(0.5)
gui.print(torch.rand(2, 4))
await asyncio.sleep(2)
gui.print(Text("How about some numpy arrays?!", style="italic green"))
await asyncio.sleep(1)
gui.print(np.random.rand(3, 3))
pbar, update_progress_loss = gui.track_training(range(10), 10)
for i, e in enumerate(pbar):
gui.print(f"[{i+1}/10]: We can iterate over iterables")
gui.print(e)
await asyncio.sleep(0.1)
await asyncio.sleep(2)
mnist = MNIST(root="data", train=False, download=True, transform=to_tensor)
# Somehow, the dataloader will crash if it's not forked when using multiprocessing
# along with Textual.
mp.set_start_method("fork")
dataloader = DataLoader(mnist, 32, shuffle=True, num_workers=2)
pbar, update_progress_loss = gui.track_validation(dataloader, len(dataloader))
for i, batch in enumerate(pbar):
await asyncio.sleep(0.01)
if i % 10 == 0:
gui.print(batch)
update_progress_loss(random())
gui.plot(epoch=i, train_loss=random(), val_loss=random())
gui.print(
f"[{i+1}/{len(dataloader)}]: "
+ "We can also iterate over PyTorch dataloaders!"
)
if i == 0:
gui.print(batch)
gui.print("Goodbye, world!")
_ = await task


if __name__ == "__main__":
asyncio.run(run_my_app())
Loading
Loading