Skip to content

Commit e8d4fca

Browse files
committed
Fix linter issues
1 parent 768d24f commit e8d4fca

File tree

9 files changed

+19
-15
lines changed

9 files changed

+19
-15
lines changed

.github/workflows/python-app.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ jobs:
3737
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
3838
- name: Type check with mypy (training script)
3939
run: |
40-
mypy train.py +experiment=exp_a run.epochs=2 run.viz_every=0
40+
mypy train.py
4141
- name: Type check with mypy (test script)
4242
run: |
43-
mypy test.py +experiment=exp_a run.viz_every=0
43+
mypy test.py
4444
- name: Test with pytest
4545
run: |
4646
pytest

launch_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from hydra_zen import just
1919
from hydra_zen.typing import Partial
2020

21-
import conf.experiment # Must import the config to add all components to the store!
21+
import conf.experiment # Must import the config to add all components to the store! # noqa
2222
from conf import project as project_conf
2323
from model import TransparentDataParallel
2424
from src.base_trainer import BaseTrainer

src/base_tester.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test(self, visualize_every: int = 0, **kwargs):
7979
visualize_every (int, optional): Visualize the model predictions every n batches.
8080
Defaults to 0 (no visualization).
8181
"""
82-
test_loss, test_loss_components = MeanMetric(), defaultdict(MeanMetric)
82+
test_loss, test_metrics = MeanMetric(), defaultdict(MeanMetric)
8383
self._model.eval()
8484
self._pbar.reset()
8585
self._pbar.set_description("Testing")
@@ -92,7 +92,7 @@ def test(self, visualize_every: int = 0, **kwargs):
9292
loss, metrics = self._test_iteration(batch)
9393
test_loss.update(loss.item())
9494
for k, v in metrics.items():
95-
metrics[k].update(v.item())
95+
test_metrics[k].update(v.item())
9696
update_pbar_str(
9797
self._pbar,
9898
f"Testing [loss={test_loss.compute():.4f}]",
@@ -107,7 +107,7 @@ def test(self, visualize_every: int = 0, **kwargs):
107107
print("=" * 81)
108108
print("==" + " " * 31 + " Test results " + " " * 31 + "==")
109109
print("=" * 81)
110-
for k, v in metrics.items():
110+
for k, v in test_metrics.items():
111111
print(f"\t -> {k}: {v.compute().item():.2f}")
112112
print(f"\t -> Average loss: {test_loss:.4f}")
113113
print("_" * 81)

src/base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]):
323323
plt.theme("dark")
324324
plt.xlabel("Epoch")
325325
if project_conf.LOG_SCALE_PLOT:
326-
if any(l <= 0 for l in train_losses + val_losses):
326+
if any(loss_val <= 0 for loss_val in train_losses + val_losses):
327327
raise ValueError(
328328
"Cannot plot on a log scale if there are non-positive losses."
329329
)

test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from hydra_zen import store, zen
1010

11-
import conf.experiment # Must import the config to add all components to the store!
11+
import conf.experiment # Must import the config to add all components to the store! # noqa
1212
from conf import project as project_conf
1313
from launch_experiment import launch_experiment
1414
from utils import seed_everything

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from hydra_zen import store, zen
1010

11-
import conf.experiment # Must import the config to add all components to the store!
11+
import conf.experiment # Must import the config to add all components to the store! # noqa
1212
from conf import project as project_conf
1313
from launch_experiment import launch_experiment
1414
from utils import seed_everything

utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def blink_pbar(i: int, pbar: tqdm.tqdm, n: int) -> None:
102102

103103
@contextmanager
104104
def colorize_prints(ansii_code: Union[int, str]):
105-
if type(ansii_code) == str:
105+
if type(ansii_code) is str:
106106
ansii_code = project_conf.ANSI_COLORS[ansii_code]
107107
print(f"\033[{ansii_code}m", end="")
108108
try:
@@ -180,7 +180,7 @@ def wrapper(*args, **kwargs):
180180
)
181181

182182
if reload.lower() not in ("l", "", "r"):
183-
print(f"[!] Aborting")
183+
print("[!] Aborting")
184184
# TODO: Why can't I just raise the exception? It's weird but it gets caught by
185185
# the wrapper a few times until it finally gets raised.
186186
sys.exit(1)
@@ -196,7 +196,7 @@ def wrapper(*args, **kwargs):
196196
cfg=IPython.terminal.embed.load_default_config(),
197197
banner1=colorize(
198198
f"[*] Dropping into an IPython shell to inspect {callable} "
199-
+ f"with the locals as they were at the time of the exception "
199+
+ "with the locals as they were at the time of the exception "
200200
+ f"thrown at line {frame.f_lineno} of {frame.f_code.co_filename}."
201201
+ "\n============================== TIPS =============================="
202202
+ "\n -> Use '%whos' to list variables in the current scope."

utils/anim.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ def __init__(
7676
super().__init__()
7777
try:
7878
import scenepic as sp
79-
except:
79+
except ImportError:
8080
raise Exception(
81-
"scenepic not installed. Some visualization functions will not work. (I know it's not available on Apple Silicon :("
81+
"scenepic not installed. "
82+
+ "Some visualization functions will not work. "
83+
+ "(I know it's not available on Apple Silicon :("
8284
)
8385
pv.start_xvfb()
8486
self.scene = sp.Scene()

utils/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def __call__(
4545
): # Either val_loss or min_val_metric
4646
ckpt_path = osp.join(
4747
HydraConfig.get().runtime.output_dir,
48-
f"epoch_{epoch:03d}_{minimize_metric if minimize_metric in metrics.keys() else 'val-loss'}_{metrics.get(minimize_metric, val_loss):06f}.ckpt",
48+
f"epoch_{epoch:03d}_"
49+
+ f"{minimize_metric if minimize_metric in metrics.keys() else 'val-loss'}"
50+
+ f"_{metrics.get(minimize_metric, val_loss):06f}.ckpt",
4951
)
5052
self._save_if_best_model(
5153
metrics.get(minimize_metric, val_loss), ckpt_path

0 commit comments

Comments
 (0)