Skip to content

Commit

Permalink
[mypy] fix all train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed May 31, 2024
1 parent 2775a3b commit f5bad4b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest mypy
mypy --install-types
pip install flake8 pytest mypy types-PyYAML types-tqdm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
Expand All @@ -38,10 +37,10 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Type check with mypy (training script)
run: |
mypy train.py
mypy train.py --disable-error-code=import-error
- name: Type check with mypy (test script)
run: |
mypy test.py
mypy test.py --disable-error-code=import-error
- name: Test with pytest
run: |
pytest
6 changes: 5 additions & 1 deletion conf/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
make_custom_builds_fn,
store,
)
from hydra_zen.typing import SupportedPrimitive
from hydra_zen.typing._builds_overloads import PBuilds
from torch.utils.data import DataLoader
from unique_names_generator import get_random_name
from unique_names_generator.data import ADJECTIVES, NAMES
Expand All @@ -47,7 +49,9 @@
group="hydra",
)
hydra_store.add_to_hydra_store()
pbuilds = make_custom_builds_fn(zen_partial=True, populate_full_signature=False)
pbuilds: PBuilds[SupportedPrimitive] = make_custom_builds_fn(
zen_partial=True, populate_full_signature=False
)

" ================== Dataset ================== "

Expand Down
4 changes: 2 additions & 2 deletions dataset/base/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class ImageDataset(BaseDataset, abc.ABC):
IAGE_NET_MEAN: List[float] = []
IMAGE_NET_MEAN: List[float] = []
IMAGE_NET_STD: List[float] = []
COCO_MEAN, COCO_STD = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
IMG_SIZE = (32, 32)
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
self.IMAGE_NET_MEAN, self.IMAGE_NET_STD
)
try:
import albumentations as A
import albumentations as A # type: ignore
except ImportError:
raise ImportError(
"Please install albumentations to use the augmentation pipeline."
Expand Down
24 changes: 13 additions & 11 deletions src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,17 @@ def _train_epoch(
self._visualize(batch, epoch)
has_visualized += 1
self._pbar.update()
epoch_loss: float = epoch_loss.compute().item()
mean_epoch_loss: float = epoch_loss.compute().item()
if project_conf.USE_WANDB:
wandb.log({"train_loss": epoch_loss}, step=epoch)
wandb.log({"train_loss": mean_epoch_loss}, step=epoch)
wandb.log(
{
f"Detailed loss - Training/{k}": v.compute().item()
for k, v in epoch_loss_components.items()
},
step=epoch,
)
return epoch_loss
return mean_epoch_loss

def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float:
"""Validation loop for one epoch.
Expand Down Expand Up @@ -214,27 +214,28 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float:
):
self._visualize(batch, epoch)
has_visualized += 1
val_loss: float = val_loss.compute().item()
mean_val_loss: float = val_loss.compute().item()
mean_val_loss_components: Dict[str, float] = {}
for k, v in val_loss_components.items():
val_loss_components[k] = v.compute().item()
mean_val_loss_components[k] = v.compute().item()
if project_conf.USE_WANDB:
wandb.log({"val_loss": val_loss}, step=epoch)
wandb.log({"val_loss": mean_val_loss}, step=epoch)
wandb.log(
{
f"Detailed loss - Validation/{k}": v
for k, v in val_loss_components.items()
for k, v in mean_val_loss_components.items()
},
step=epoch,
)
# Set minimize_metric to a key in val_loss_components if you wish to minimize
# a specific metric instead of the validation loss:
self._model_saver(
epoch,
val_loss,
val_loss_components,
mean_val_loss,
mean_val_loss_components,
minimize_metric=self._minimize_metric,
)
return val_loss
return mean_val_loss

def train(
self,
Expand All @@ -259,7 +260,8 @@ def train(
self._setup_plot()
print(f"[*] Training for {epochs} epochs")
self._viz_n_samples = visualize_n_samples
train_losses, val_losses = [], []
train_losses: List[float] = []
val_losses: List[float] = []
" ==================== Training loop ==================== "
for epoch in range(self._epoch, epochs):
self._epoch = epoch # Update for the model saver
Expand Down

0 comments on commit f5bad4b

Please sign in to comment.