Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff #43

Merged
merged 8 commits into from
Dec 21, 2023
Merged

Ruff #43

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: poetry
- uses: chartboost/ruff-action@v1
- name: Install dependencies
run: |
poetry install
poetry run pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Lint
run: poetry run flake8 -v .
- name: Test
run: poetry run pytest -v -s --cov=. --cov-report=xml tests
- name: Upload coverage reports to Codecov
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ install:
poetry install

lint:
poetry run flake8 -v
poetry run ruff check .

test:
poetry run pytest -v -s --cov=. tests
Expand Down
1 change: 1 addition & 0 deletions configs/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ job:
trainer:
name: MNISTTrainer
num_epochs: 20
num_classes: 10

dataset:
name: MNISTDataLoader
Expand Down
1,175 changes: 34 additions & 1,141 deletions poetry.lock

Large diffs are not rendered by default.

31 changes: 27 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ torchmetrics = "^1.2.0"
tqdm = "^4.66.1"
loguru = "^0.7.2"
mlconfig = "^0.2.0"
mlflow = "^2.9.2"
mlflow-skinny = "^2.9.2"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
flake8 = "^6.0.0"
isort = "^5.12.0"
pytest = "^7.3.1"
pytest-cov = "^4.1.0"
ruff = "^0.1.8"
toml = "^0.10.2"

[build-system]
Expand All @@ -28,3 +26,28 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
template = "template.cli:main"

[tool.ruff]
exclude = ["build"]
line-length = 120

[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"C", # flake8-comprehensions
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
# "UP", # pyupgrade
"W", # pycodestyle warnings

]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401", "F403"]

[tool.ruff.isort]
force-single-line = true

[tool.pytest.ini_options]
filterwarnings = ["ignore::DeprecationWarning"]
17 changes: 0 additions & 17 deletions setup.cfg

This file was deleted.

67 changes: 39 additions & 28 deletions template/trainers/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import torch
import torch.nn.functional as F
from mlconfig import register
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchmetrics import MeanMetric
from tqdm import tqdm
Expand All @@ -13,48 +17,55 @@
@register
class MNISTTrainer(Trainer):
def __init__(
self, device, model, optimizer, scheduler, train_loader, test_loader, num_epochs
):
self,
device: torch.device,
model: Module,
optimizer: Optimizer,
scheduler: LRScheduler,
train_loader: DataLoader,
test_loader: DataLoader,
num_epochs: int,
num_classes: int,
) -> None:
self.device = device
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.train_loader = train_loader
self.test_loader = test_loader
self.num_epochs = num_epochs
self.num_classes = num_classes

Check warning on line 37 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L37

Added line #L37 was not covered by tests

self.epoch = 1
self.best_acc = 0
self.state = {"epoch": 1}

Check warning on line 40 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L40

Added line #L40 was not covered by tests

def fit(self):
for self.epoch in trange(self.epoch, self.num_epochs + 1):
def fit(self) -> None:
for epoch in trange(self.state["epoch"], self.num_epochs + 1):

Check warning on line 43 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L43

Added line #L43 was not covered by tests
train_loss, train_acc = self.train()
test_loss, test_acc = self.evaluate()
self.scheduler.step()

metrics = dict(
train_loss=train_loss,
train_acc=train_acc,
test_loss=test_loss,
test_acc=test_acc,
)
mlflow.log_metrics(metrics, step=self.epoch)

format_string = "Epoch: {}/{}, ".format(self.epoch, self.num_epochs)
format_string += "train loss: {:.4f}, train acc: {:.4f}, ".format(
train_loss, train_acc
)
format_string += "test loss: {:.4f}, test acc: {:.4f}, ".format(
test_loss, test_acc
)
metrics = {

Check warning on line 48 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L48

Added line #L48 was not covered by tests
"train_loss": train_loss,
"train_acc": train_acc,
"test_loss": test_loss,
"test_acc": test_acc,
}
mlflow.log_metrics(metrics, step=epoch)

Check warning on line 54 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L54

Added line #L54 was not covered by tests

format_string = "Epoch: {}/{}, ".format(epoch, self.num_epochs)
format_string += "train loss: {:.4f}, train acc: {:.4f}, ".format(train_loss, train_acc)
format_string += "test loss: {:.4f}, test acc: {:.4f}, ".format(test_loss, test_acc)

Check warning on line 58 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L56-L58

Added lines #L56 - L58 were not covered by tests
format_string += "best test acc: {:.4f}.".format(self.best_acc)
tqdm.write(format_string)

def train(self):
self.state["epoch"] = epoch

Check warning on line 62 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L62

Added line #L62 was not covered by tests

def train(self) -> None:
self.model.train()

loss_metric = MeanMetric()
acc_metric = Accuracy()
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes)

Check warning on line 68 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L68

Added line #L68 was not covered by tests

for x, y in tqdm(self.train_loader):
x = x.to(self.device)
Expand All @@ -73,11 +84,11 @@
return loss_metric.compute().item(), acc_metric.compute().item()

@torch.no_grad()
def evaluate(self):
def evaluate(self) -> None:
self.model.eval()

loss_metric = MeanMetric()
acc_metric = Accuracy()
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes)

Check warning on line 91 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L91

Added line #L91 was not covered by tests

for x, y in tqdm(self.test_loader):
x = x.to(self.device)
Expand All @@ -96,25 +107,25 @@

return loss_metric.compute().item(), test_acc

def save_checkpoint(self, f):
def save_checkpoint(self, f) -> None:
self.model.eval()

checkpoint = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"epoch": self.epoch,
"state": self.state,
"best_acc": self.best_acc,
}

torch.save(checkpoint, f)
mlflow.log_artifact(f)

def resume(self, f):
def resume(self, f) -> None:
checkpoint = torch.load(f, map_location=self.device)

self.model.load_state_dict(checkpoint["model"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.scheduler.load_state_dict(checkpoint["scheduler"])
self.epoch = checkpoint["epoch"] + 1
self.state = checkpoint["state"]

Check warning on line 130 in template/trainers/mnist.py

View check run for this annotation

Codecov / codecov/patch

template/trainers/mnist.py#L130

Added line #L130 was not covered by tests
self.best_acc = checkpoint["best_acc"]
2 changes: 1 addition & 1 deletion template/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class Trainer:
def train(self):
def train(self) -> None:
raise NotImplementedError
3 changes: 1 addition & 2 deletions template/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from pathlib import Path

import numpy as np
import torch
import yaml

from pathlib import Path


def manual_seed(seed=0):
"""https://pytorch.org/docs/stable/notes/randomness.html"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from template.models import LeNet


Expand Down