Skip to content

Commit 3b0ef23

Browse files
committed
Migrate to ruff and pyright
1 parent 24a62f5 commit 3b0ef23

File tree

11 files changed

+191
-91
lines changed

11 files changed

+191
-91
lines changed

.github/workflows/python-app.yml

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,15 @@ jobs:
2626
- name: Install dependencies
2727
run: |
2828
python -m pip install --upgrade pip
29-
pip install flake8 pytest mypy types-PyYAML types-tqdm
29+
pip install ruff pytest pyright types-PyYAML types-tqdm
3030
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
3131
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
32-
- name: Lint with flake8
32+
- name: Lint with ruff
3333
run: |
34-
# stop the build if there are Python syntax errors or undefined names
35-
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
36-
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
37-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38-
- name: Type check with mypy (training script)
34+
ruff
35+
- name: Type check with pyright
3936
run: |
40-
mypy train.py --disable-error-code=import-untyped
41-
- name: Type check with mypy (test script)
42-
run: |
43-
mypy test.py --disable-error-code=import-untyped
37+
pyright
4438
#- name: Test with pytest
4539
#run: |
4640
#pytest

.pre-commit-config.yaml

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@ repos:
88
- id: end-of-file-fixer
99
- id: check-yaml
1010
- id: check-added-large-files
11-
- repo: https://github.com/PyCQA/autoflake
12-
rev: v2.2.0
11+
- repo: https://github.com/astral-sh/ruff-pre-commit
12+
# Ruff version.
13+
rev: v0.4.7
1314
hooks:
14-
- id: autoflake
15-
args: [--in-place, --remove-all-unused-imports, -r]
16-
- repo: https://github.com/PyCQA/isort
17-
rev: 5.12.0
18-
hooks:
19-
- id: isort
20-
- repo: https://github.com/psf/black
21-
rev: 23.3.0
22-
hooks:
23-
- id: black
15+
# Run the linter.
16+
- id: ruff
17+
args: [ --select, I, --fix ] # Sort imports too
18+
# Run the formatter.
19+
- id: ruff-format

dataset/base/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,17 @@
1212
transforming it may be extended through class inheritance in a specific dataset file.
1313
"""
1414

15-
1615
import abc
1716
import os
1817
import os.path as osp
19-
from typing import Tuple, Union
18+
from typing import Any, Dict, List, Tuple, Union
2019

2120
import torch
2221
from hydra.utils import get_original_cwd
2322
from torch.utils.data import Dataset
2423

2524

26-
class BaseDataset(Dataset, abc.ABC):
25+
class BaseDataset(Dataset[Any], abc.ABC):
2726
def __init__(
2827
self,
2928
dataset_root: str,
@@ -36,6 +35,8 @@ def __init__(
3635
tiny: bool = False,
3736
) -> None:
3837
super().__init__()
38+
self._samples: Union[Dict[Any, Any], List[Any], torch.Tensor]
39+
self._labels: Union[Dict[Any, Any], List[Any], torch.Tensor]
3940
self._samples, self._labels = self._load(dataset_root, tiny, split, seed)
4041
self._augment = augment and split == "train"
4142
self._normalize = normalize
@@ -49,7 +50,10 @@ def __init__(
4950
@abc.abstractmethod
5051
def _load(
5152
self, dataset_root: str, tiny: bool, split: str, seed: int
52-
) -> Tuple[Union[dict, list, torch.Tensor], Union[dict, list, torch.Tensor]]:
53+
) -> Tuple[
54+
Union[Dict[str, Any], List[Any], torch.Tensor],
55+
Union[Dict[str, Any], List[Any], torch.Tensor],
56+
]:
5357
# Implement this
5458
raise NotImplementedError
5559

dataset/base/image.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
Base dataset for images.
1010
"""
1111

12-
1312
import abc
14-
from typing import List, Optional, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1514

1615
import torch
17-
from torchvision.io.image import read_image
18-
from torchvision.transforms import transforms
16+
from torchvision.io.image import read_image # type: ignore
17+
from torchvision.transforms import transforms # type: ignore
1918

2019
from dataset.base import BaseDataset
2120

@@ -32,7 +31,7 @@ def __init__(
3231
dataset_name: str,
3332
split: str,
3433
seed: int,
35-
img_size: Optional[tuple] = None,
34+
img_size: Optional[tuple[int, ...]] = None,
3635
augment: bool = False,
3736
normalize: bool = False,
3837
tiny: bool = False,
@@ -49,21 +48,21 @@ def __init__(
4948
tiny=tiny,
5049
)
5150
self._img_size = self.IMG_SIZE if img_size is None else img_size
52-
self._transforms = transforms.Compose(
51+
self._transforms: Callable[[torch.Tensor], torch.Tensor] = transforms.Compose(
5352
[
5453
transforms.Resize(self._img_size),
5554
]
5655
)
57-
self._normalization = transforms.Normalize(
58-
self.IMAGE_NET_MEAN, self.IMAGE_NET_STD
56+
self._normalization: Callable[[torch.Tensor], torch.Tensor] = (
57+
transforms.Normalize(self.IMAGE_NET_MEAN, self.IMAGE_NET_STD)
5958
)
6059
try:
6160
import albumentations as A # type: ignore
6261
except ImportError:
6362
raise ImportError(
6463
"Please install albumentations to use the augmentation pipeline."
6564
)
66-
self._augs = A.Compose(
65+
self._augs: Callable[..., Dict[str, Any]] = A.Compose(
6766
[
6867
A.RandomCropFromBorders(),
6968
A.RandomBrightnessContrast(),
@@ -74,22 +73,28 @@ def __init__(
7473
@abc.abstractmethod
7574
def _load(
7675
self, dataset_root: str, tiny: bool, split: str, seed: int
77-
) -> Tuple[Union[dict, list, torch.Tensor], Union[dict, list, torch.Tensor]]:
76+
) -> Tuple[
77+
Union[Dict[str, Any], List[Any], torch.Tensor],
78+
Union[Dict[str, Any], List[Any], torch.Tensor],
79+
]:
7880
# Implement this
7981
raise NotImplementedError
8082

81-
def __getitem__(self, index: int):
83+
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
8284
"""
8385
This should be common to all image datasets!
8486
Override if you need something else.
8587
"""
8688
# ==== Load image and apply transforms ===
87-
img = read_image(self._samples[index]) # Returns a Tensor
89+
img: torch.Tensor
90+
img = read_image(self._samples[index]) # type: ignore
91+
if not isinstance(img, torch.Tensor):
92+
raise ValueError("Image not loaded as a Tensor.")
8893
img = self._transforms(img)
8994
if self._normalize:
9095
img = self._normalization(img)
9196
if self._augment:
92-
img = self._augs(image=img)
97+
img = self._augs(image=img)["image"]
9398
# ==== Load label and apply transforms ===
94-
label = self._labels[index]
99+
label: Any = self._labels[index]
95100
return img, label

dataset/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
dataset_name,
3939
split,
4040
seed,
41-
(img_dim, img_dim),
41+
(img_dim, img_dim) if img_dim is not None else None,
4242
augment=augment,
4343
normalize=normalize,
4444
debug=debug,

launch_experiment.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import os
1010
from dataclasses import asdict
11+
from typing import Any, Optional
1112

1213
import hydra_zen
1314
import torch
@@ -17,8 +18,9 @@
1718
from hydra.utils import to_absolute_path
1819
from hydra_zen import just
1920
from hydra_zen.typing import Partial
21+
from torch.utils.data import DataLoader, Dataset
2022

21-
import conf.experiment # Must import the config to add all components to the store! # noqa
23+
import conf.experiment as exp_conf # type: ignore
2224
from conf import project as project_conf
2325
from model import TransparentDataParallel
2426
from src.base_tester import BaseTester
@@ -27,13 +29,13 @@
2729

2830

2931
def launch_experiment(
30-
run,
31-
data_loader: Partial[torch.utils.data.DataLoader],
32+
run: exp_conf.RunConfig,
33+
data_loader: Partial[torch.utils.data.DataLoader], # type: ignore
3234
optimizer: Partial[torch.optim.Optimizer],
33-
scheduler: Partial[torch.optim.lr_scheduler._LRScheduler],
35+
scheduler: Partial[torch.optim.lr_scheduler.LRScheduler],
3436
trainer: Partial[BaseTrainer],
3537
tester: Partial[BaseTester],
36-
dataset: Partial[torch.utils.data.Dataset],
38+
dataset: Partial[Dataset[Any]],
3739
model: Partial[torch.nn.Module],
3840
training_loss: Partial[torch.nn.Module],
3941
):
@@ -65,19 +67,19 @@ def launch_experiment(
6567

6668
"============ Partials instantiation ============"
6769
model_inst = model(
68-
encoder_input_dim=just(dataset).img_dim ** 2
70+
encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore
6971
) # Use just() to get the config out of the Zen-Partial
7072
print(model_inst)
7173
print(f"Number of parameters: {sum(p.numel() for p in model_inst.parameters())}")
7274
print(
7375
f"Number of trainable parameters: {sum(p.numel() for p in model_inst.parameters() if p.requires_grad)}"
7476
)
75-
train_dataset, val_dataset, test_dataset = None, None, None
77+
train_dataset: Optional[Dataset[Any]] = None
78+
val_dataset: Optional[Dataset[Any]] = None
79+
test_dataset: Optional[Dataset[Any]] = None
7680
if run.training_mode:
77-
train_dataset, val_dataset = (
78-
dataset(split="train", seed=run.seed),
79-
dataset(split="val", seed=run.seed),
80-
)
81+
train_dataset = dataset(split="train", seed=run.seed)
82+
val_dataset = dataset(split="val", seed=run.seed)
8183
else:
8284
test_dataset = dataset(split="test", augment=False, seed=run.seed)
8385

@@ -104,38 +106,45 @@ def launch_experiment(
104106
)
105107
model_inst = TransparentDataParallel(model_inst)
106108

107-
if not run.training_mode:
108-
training_loss_inst = None
109-
else:
109+
training_loss_inst: Optional[torch.nn.Module] = None
110+
if run.training_mode:
110111
training_loss_inst = training_loss()
111112

112113
"============ CUDA ============"
113114
model_inst: torch.nn.Module = to_cuda_(model_inst) # type: ignore
114-
training_loss_inst: torch.nn.Module = to_cuda_(training_loss_inst) # type: ignore
115+
training_loss_inst = to_cuda_(training_loss_inst) # type: ignore
115116

116117
"============ Weights & Biases ============"
117118
if project_conf.USE_WANDB:
118119
# exp_conf is a string, so we need to load it back to a dict:
119120
exp_conf = yaml.safe_load(exp_conf)
120-
wandb.init(
121+
wandb.init( # type: ignore
121122
project=project_conf.PROJECT_NAME,
122123
name=run_name,
123124
config=exp_conf,
124125
)
125-
wandb.watch(model_inst, log="all", log_graph=True)
126+
wandb.watch(model_inst, log="all", log_graph=True) # type: ignore
126127
" ============ Reproducibility of data loaders ============ "
127128
g = None
128129
if project_conf.REPRODUCIBLE:
129130
g = torch.Generator()
130131
g.manual_seed(run.seed)
131132

132-
train_loader_inst, val_loader_inst, test_loader_inst = None, None, None
133+
train_loader_inst: Optional[DataLoader[Any]] = None
134+
val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None
135+
test_loader_inst: Optional[DataLoader[Any]] = None
133136
if run.training_mode:
137+
if train_dataset is None or val_dataset is None:
138+
raise ValueError(
139+
"train_dataset and val_dataset must be defined in training mode!"
140+
)
134141
train_loader_inst = data_loader(train_dataset, generator=g)
135142
val_loader_inst = data_loader(
136143
val_dataset, generator=g, shuffle=False, drop_last=False
137144
)
138145
else:
146+
if test_dataset is None:
147+
raise ValueError("test_dataset must be defined in testing mode!")
139148
test_loader_inst = data_loader(
140149
test_dataset, generator=g, shuffle=False, drop_last=False
141150
)
@@ -167,6 +176,12 @@ def launch_experiment(
167176
)
168177

169178
if run.training_mode:
179+
if training_loss_inst is None:
180+
raise ValueError("training_loss must be defined in training mode!")
181+
if val_loader_inst is None or train_loader_inst is None:
182+
raise ValueError(
183+
"val_loader and train_loader must be defined in training mode!"
184+
)
170185
trainer(
171186
run_name=run_name,
172187
model=model_inst,
@@ -187,6 +202,8 @@ def launch_experiment(
187202
model_ckpt_path=model_ckpt_path,
188203
)
189204
else:
205+
if test_loader_inst is None:
206+
raise ValueError("test_loader must be defined in testing mode!")
190207
tester(
191208
run_name=run_name,
192209
model=model_inst,

0 commit comments

Comments
 (0)