Skip to content

Commit 8ef5e00

Browse files
committed
fix mypy errors
1 parent 6e622c5 commit 8ef5e00

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/imitation/algorithms/adversarial/common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import abc
33
import dataclasses
44
import logging
5-
from typing import Iterable, Iterator, Mapping, Optional, Type, overload
5+
from typing import Iterable, Iterator, Mapping, Optional, Type, List, overload
66

77
import numpy as np
88
import torch as th
@@ -414,7 +414,10 @@ def train_gen(
414414
if learn_kwargs is None:
415415
learn_kwargs = {}
416416

417-
callbacks = [self.gen_callback]
417+
callbacks: List[BaseCallback] = []
418+
419+
if self.gen_callback:
420+
callbacks.append(self.gen_callback)
418421

419422
if isinstance(callback, list):
420423
callbacks.extend(callback)

src/imitation/scripts/train_adversarial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
interval: int,
3434
):
3535
"""Creates new Checkpoint callback."""
36-
super().__init__(self)
36+
super().__init__()
3737
self.trainer = trainer
3838
self.log_dir = log_dir
3939
self.interval = interval

0 commit comments

Comments
 (0)