Skip to content

Commit

Permalink
Merge pull request #7 from DubiousCactus/ui_rehauling
Browse files Browse the repository at this point in the history
Framework rehauling
  • Loading branch information
DubiousCactus authored Jul 27, 2024
2 parents ac2f113 + 7eadbe0 commit 3a34745
Show file tree
Hide file tree
Showing 22 changed files with 1,719 additions and 612 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff pytest pyright types-PyYAML types-tqdm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install --upgrade ruff pytest pyright types-PyYAML types-tqdm
pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with ruff
run: |
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ tags.*
**/*.pickle
models/
FIGURES/
.mypy_cache/
.ruff_cache/
utils/data/MNIST/raw
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.7
rev: v0.5.5
hooks:
# Run the linter.
- id: ruff
Expand Down
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![pre-commit](https://img.shields.io/badge/Pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
[![Vim](https://img.shields.io/badge/VIM%20ready!-forestgreen?style=for-the-badge&logo=vim)](https://github.com/DubiousCactus/bells-and-whistles/blob/main/.vimspector.json)

A batteries-included PyTorch template with a terminal display that stays out of your way!

Click on [<kbd>Use this
Expand Down Expand Up @@ -91,6 +91,9 @@ This template writes the necessary boilerplate for you, while staying out of you

```
my-pytorch-project/
bootstrap/
factories.py <-- Factory functions for instantiating models, optimizers, etc.
launch_experiment.py <-- Bootstraps the experiment and launches the training/testing loop
conf/
experiment.py <-- experiment-level configurations
project.py <-- project-level constants
Expand Down Expand Up @@ -118,10 +121,9 @@ my-pytorch-project/
helpers.py <-- high-level utilities
training.py <-- training-related utilities
vendor/
. <-- third-party code goes here
launch_experiment.py <-- Builds the trainer and tester, instantiates all partials, etc.
train.py <-- training entry point (calls launch_experiment)
test.py <-- testing entry point (calls launch_experiment)
. <-- third-party code goes here (github submodules, etc.)
train.py <-- training entry point (calls bootstrap/launch_experiment)
test.py <-- testing entry point (calls bootstrap/launch_experiment)
```

## Setting up
Expand Down
Empty file added bootstrap/__init__.py
Empty file.
138 changes: 138 additions & 0 deletions bootstrap/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
#
# Copyright © 2024 Théo Morales <theo.morales.fr@gmail.com>
#
# Distributed under terms of the MIT license.

"""
All factories.
"""

from typing import Any, Dict, Optional, Tuple

import torch
from hydra_zen import just
from hydra_zen.typing import Partial
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress, TaskID
from torch.utils.data import DataLoader, Dataset

from conf import project as project_conf
from model import TransparentDataParallel

console = Console()


def make_datasets(
training_mode: bool, seed: int, dataset_partial: Partial[Dataset[Any]]
) -> Tuple[Optional[Dataset[Any]], Optional[Dataset[Any]], Optional[Dataset[Any]]]:
datasets: Dict[str, Optional[Dataset[Any]]] = {
"train": None,
"val": None,
"test": None,
}
status = console.status("Loading dataset...", spinner="runner")
progress = Progress(transient=False)
with Live(Panel(Group(status, progress), title="Loading datasets")):
splits = ("train", "val") if training_mode else ("test",)
for split in splits:
status.update(f"Loading {split} dataset...")
job_id: TaskID = progress.add_task(f"Processing {split} split...")
aug = {"augment": False} if split == "test" else {}
datasets[split] = dataset_partial(
split=split, seed=seed, progress=progress, job_id=job_id, **aug
)
return datasets["train"], datasets["val"], datasets["test"]


def make_dataloaders(
data_loader_partial: Partial[DataLoader[Dataset[Any]]],
train_dataset: Optional[Dataset[Any]],
val_dataset: Optional[Dataset[Any]],
test_dataset: Optional[Dataset[Any]],
training_mode: bool,
seed: int,
) -> Tuple[
Optional[DataLoader[Dataset[Any]]],
Optional[DataLoader[Dataset[Any]]],
Optional[DataLoader[Dataset[Any]]],
]:
generator = None
if project_conf.REPRODUCIBLE:
generator = torch.Generator()
generator.manual_seed(seed)

train_loader_inst: Optional[DataLoader[Any]] = None
val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None
test_loader_inst: Optional[DataLoader[Any]] = None
if training_mode:
if train_dataset is None or val_dataset is None:
raise ValueError(
"train_dataset and val_dataset must be defined in training mode!"
)
train_loader_inst = data_loader_partial(train_dataset, generator=generator)
val_loader_inst = data_loader_partial(
val_dataset, generator=generator, shuffle=False, drop_last=False
)
else:
if test_dataset is None:
raise ValueError("test_dataset must be defined in testing mode!")
test_loader_inst = data_loader_partial(
test_dataset, generator=generator, shuffle=False, drop_last=False
)
return train_loader_inst, val_loader_inst, test_loader_inst


def make_model(
model_partial: Partial[torch.nn.Module], dataset: Partial[Dataset[Any]]
) -> torch.nn.Module:
with console.status("Loading model...", spinner="runner"):
model_inst = model_partial(
encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore
) # Use just() to get the config out of the Zen-Partial
return model_inst


def parallelize_model(model: torch.nn.Module) -> torch.nn.Module:
console.print(
f"[*] Number of GPUs: {torch.cuda.device_count()}",
style="bold cyan",
)
if torch.cuda.device_count() > 1:
console.print(
f"-> Using {torch.cuda.device_count()} GPUs!",
style="bold cyan",
)
model = TransparentDataParallel(model)
return model


def make_optimizer(
optimizer_partial: Partial[torch.optim.optimizer.Optimizer], model: torch.nn.Module
) -> torch.optim.optimizer.Optimizer:
return optimizer_partial(model.parameters())


def make_scheduler(
scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler],
optimizer: torch.optim.optimizer.Optimizer,
epochs: int,
) -> torch.optim.lr_scheduler.LRScheduler:
scheduler = scheduler_partial(
optimizer
) # TODO: less hacky way to set T_max for CosineAnnealingLR?
if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR):
scheduler.T_max = epochs
return scheduler


def make_training_loss(
training_mode: bool, training_loss_partial: Partial[torch.nn.Module]
):
training_loss: Optional[torch.nn.Module] = None
if training_mode:
training_loss = training_loss_partial()
return training_loss
Loading

0 comments on commit 3a34745

Please sign in to comment.