Skip to content

Commit 718d11a

Browse files
hjmjohnsonBenjamin Gormancrtrentz
authored
Fix pytype static checks (Project-MONAI#449)
* Fix pytype warnings File "monai/monai/inferers/utils.py", line 67, in sliding_window_inference: Cannot unpack 2 values into 3 variables [bad-unpacking] File "monai/monai/inferers/utils.py", line 97, in sliding_window_inference: Cannot unpack 2 values into 3 variables [bad-unpacking] * Avoid pytypes warning File "monai/monai/data/dataset.py", line 29, in <module>: No attribute 'Dataset' on module 'torch.utils.data' [module-attr] File "monai/monai/data/dataset.py", line 293, in <module>: No attribute 'Dataset' on module 'torch.utils.data' [module-attr] * Multiple inheritance demands identifying superclass pytype monai/engines/workflow.py File "monai/monai/engines/workflow.py", line 57, in __init__: Function Workflow.__init__ expects 1 arg(s), got 2 [wrong-arg-count] Expected: (self) Actually passed: (self, _) Disable failing pytype test. This needs further investigation by someone who better understands the intended behavior. The expected behavior is unclear. Which superclass is intended to be called for this multiple inheritance case? For the superclass ABC and Engine, why is Workflow the function that gives the warning? FIXME: pytype . File "monai/engines/workflow.py", line 57: Invalid directive syntax. [invalid-directive] File "monai/engines/workflow.py", line 57, in __init__: Function Workflow.__init__ expects 1 arg(s), got 2 [wrong-arg-count] Expected: (self) Actually passed: (self, _) For more details, see https://google.github.io/pytype/errors.html. * Import as local name for pytypes warning suppression File "monai/monai/transforms/adaptors.py", line 102, in <module>: No attribute 'export' on module 'monai.utils' [module-attr] File "monai/monai/transforms/adaptors.py", line 187, in <module>: No attribute 'export' on module 'monai.utils' [module-attr] File "monai/monai/transforms/adaptors.py", line 208, in <module>: No attribute 'export' on module 'monai.utils' [module-attr] * Resolve conflicting type definitions File "monai/monai/visualize/img2tensorboard.py", line 23, in <module>: Invalid type annotation '<instance of Callable>' for image [invalid-annotation] Not a type File "monai/monai/visualize/img2tensorboard.py", line 55, in <module>: Invalid type annotation '<instance of Callable>' for image [invalid-annotation] Not a type File "monai/monai/visualize/img2tensorboard.py", line 90, in make_animated_gif_summary: Invalid type annotation '<instance of Callable>' for one_channel_img [invalid-annotation] Not a type File "monai/monai/visualize/img2tensorboard.py", line 101, in <module>: Invalid type annotation '<instance of Callable>' for image_tensor [invalid-annotation] Not a type File "monai/monai/visualize/img2tensorboard.py", line 128, in <module>: Invalid type annotation '<instance of Callable>' for image_tensor [invalid-annotation] Not a type * Fix callable object type matching File "monai/monai/networks/nets/densenet.py", line 169, in __init__: Built-in function isinstance was called with the wrong arguments [wrong-arg-types] Expected: (object, class_or_type_or_tuple: Union[Tuple[Union[Tuple[type, ...], type], ...], type]) Actually passed: (object, class_or_type_or_tuple: Callable) File "monai/monai/networks/nets/densenet.py", line 171, in __init__: Built-in function isinstance was called with the wrong arguments [wrong-arg-types] Expected: (object, class_or_type_or_tuple: Union[Tuple[Union[Tuple[type, ...], type], ...], type]) Actually passed: (object, class_or_type_or_tuple: Callable) T484 Returning Any from function declared to return Callable[..., Any] * ENH: Support pytype static analysis. * Remove unused import monai/monai/transforms/adaptors.py:99:1: F401 'monai' imported but unused * ENH: Adding .pytype and .mypy_cache to ignored black file list. * add type hints Co-authored-by: Cameron Trentz <cameron-trentz@uiowa.edu> * Annotate dtype as np.dtype * add type hints [second pass] Co-authored-by: Cameron Trentz <cameron-trentz@uiowa.edu> * Suppress some pytype warnings Some pytype warnings will require more extensive corrections to fully resolve them. These few cases are suppressed in the interest of allowing more static checks without imposing a huge burden at the moment. File "monai/monai/transforms/io/array.py", line 93, in __call__: No attribute 'dtype' on bool [attribute-error] In Union[Any, bool] For more details, see https://google.github.io/pytype/errors.html#attribute-error. File "monai/monai/transforms/intensity/array.py", line 142, in __call__: Function ScaleIntensity.__init__ was called with the wrong arguments [wrong-arg-types] Expected: (self, minv, maxv: Union[float, int] = ..., ...) Actually passed: (self, minv, maxv: None, ...) For more details, see https://google.github.io/pytype/errors.html#wrong-arg-types. File "monai/monai/transforms/intensity/dictionary.py", line 190, in __call__: Function ScaleIntensity.__init__ was called with the wrong arguments [wrong-arg-types] Expected: (self, minv: Union[float, int] = ..., ...) Actually passed: (self, minv: None, ...) For more details, see https://google.github.io/pytype/errors.html#wrong-arg-types. File "monai/monai/data/dataset.py", line 124, in _pre_first_random_transform: No attribute 'transforms' on Callable [attribute-error] In Union[Any, Callable] File "monai/monai/data/dataset.py", line 140, in _first_random_and_beyond_transform: No attribute 'transforms' on Callable [attribute-error] In Union[Any, Callable] File "monai/monai/data/dataset.py", line 285, in __getitem__: No attribute 'transforms' on Callable [attribute-error] In Union[Any, Callable] For more details, see https://google.github.io/pytype/errors.html#attribute-error. * Fix type mismatch warning for range object as list monai/monai/transforms/spatial/array.py:913:22: T484 Invalid index type "range" for "Union[Any, Tensor]"; expected type "Union[None, int, slice, Tensor, List[Any], Tuple[Any, ...]]" Co-authored-by: Benjamin Gorman <benjamin-gorman@uiowa.edu> Co-authored-by: Cameron Trentz <cameron-trentz@uiowa.edu>
1 parent 0e5c3ae commit 718d11a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+835
-466
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ venv.bak/
100100
# mkdocs documentation
101101
/site
102102

103+
# pytype cache
104+
.pytype/
105+
103106
# mypy
104107
.mypy_cache/
105108
examples/scd_lvsegs.npz

examples/notebooks/3d_image_transforms.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@
156156
"metadata": {},
157157
"outputs": [],
158158
"source": [
159-
"loader = LoadNifti(dtype=np.float32)"
159+
"loader = LoadNifti(dtype: np.dtype = np.float32)"
160160
]
161161
},
162162
{

monai/data/csv_saver.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from collections import OrderedDict
13+
from typing import Union
14+
1215
import os
1316
import csv
1417
import numpy as np
1518
import torch
16-
from collections import OrderedDict
1719

1820

1921
class CSVSaver:
@@ -24,7 +26,7 @@ class CSVSaver:
2426
the cached data into CSV file. If no meta data provided, use index from 0 to save data.
2527
"""
2628

27-
def __init__(self, output_dir="./", filename="predictions.csv", overwrite=True):
29+
def __init__(self, output_dir: str = "./", filename: str = "predictions.csv", overwrite: bool = True):
2830
"""
2931
Args:
3032
output_dir (str): output CSV file directory.
@@ -34,7 +36,7 @@ def __init__(self, output_dir="./", filename="predictions.csv", overwrite=True):
3436
3537
"""
3638
self.output_dir = output_dir
37-
self._cache_dict = OrderedDict()
39+
self._cache_dict: OrderedDict = OrderedDict()
3840
assert isinstance(filename, str) and filename[-4:] == ".csv", "filename must be a string with CSV format."
3941
self._filepath = os.path.join(output_dir, filename)
4042
self.overwrite = overwrite
@@ -60,7 +62,7 @@ def finalize(self):
6062
f.write("," + str(result))
6163
f.write("\n")
6264

63-
def save(self, data, meta_data=None):
65+
def save(self, data: np.ndarray, meta_data=None):
6466
"""Save data into the cache dictionary. The metadata should have the following key:
6567
- ``'filename_or_obj'`` -- save the data corresponding to file name or object.
6668
If meta_data is None, use the default index from 0 to save data instead.
@@ -76,7 +78,7 @@ def save(self, data, meta_data=None):
7678
data = data.detach().cpu().numpy()
7779
self._cache_dict[save_key] = data.astype(np.float32)
7880

79-
def save_batch(self, batch_data, meta_data=None):
81+
def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data=None):
8082
"""Save a batch of data into the cache dictionary.
8183
8284
args:

monai/data/dataset.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from typing import Optional, Callable
1213
import sys
13-
1414
import hashlib
1515
import json
1616
from pathlib import Path
@@ -25,8 +25,10 @@
2525
from monai.transforms.utils import apply_transform
2626
from monai.utils import process_bar, get_seed
2727

28+
from torch.utils.data import Dataset as _TorchDataset
29+
2830

29-
class Dataset(torch.utils.data.Dataset):
31+
class Dataset(_TorchDataset):
3032
"""
3133
A generic dataset with a length property and an optional callable data transform
3234
when fetching a data sample.
@@ -39,7 +41,7 @@ class Dataset(torch.utils.data.Dataset):
3941
}, }, }]
4042
"""
4143

42-
def __init__(self, data, transform=None):
44+
def __init__(self, data, transform: Optional[Callable] = None):
4345
"""
4446
Args:
4547
data (Iterable): input data to load and transform to generate dataset for model.
@@ -51,7 +53,7 @@ def __init__(self, data, transform=None):
5153
def __len__(self):
5254
return len(self.data)
5355

54-
def __getitem__(self, index):
56+
def __getitem__(self, index: int):
5557
data = self.data[index]
5658
if self.transform is not None:
5759
data = self.transform(data)
@@ -93,7 +95,7 @@ class PersistentDataset(Dataset):
9395
followed by applying the random dependant parts of transform processing.
9496
"""
9597

96-
def __init__(self, data, transform=None, cache_dir=None):
98+
def __init__(self, data, transform: Optional[Callable] = None, cache_dir=None):
9799
"""
98100
Args:
99101
data (Iterable): input data to load and transform to generate dataset for model.
@@ -119,7 +121,7 @@ def _pre_first_random_transform(self, item_transformed):
119121
the transformed element up to the first identified
120122
random transform object
121123
"""
122-
for _transform in self.transform.transforms:
124+
for _transform in self.transform.transforms: # pytype: disable=attribute-error
123125
# execute all the deterministic transforms before the first random transform
124126
if isinstance(_transform, Randomizable):
125127
break
@@ -135,7 +137,7 @@ def _first_random_and_beyond_transform(self, item_transformed):
135137
the transformed element through the random transforms
136138
"""
137139
start_post_randomize_run = False
138-
for _transform in self.transform.transforms:
140+
for _transform in self.transform.transforms: # pytype: disable=attribute-error
139141
if start_post_randomize_run or isinstance(_transform, Randomizable):
140142
start_post_randomize_run = True
141143
item_transformed = apply_transform(_transform, item_transformed)
@@ -226,7 +228,9 @@ class CacheDataset(Dataset):
226228
and the outcome not cached.
227229
"""
228230

229-
def __init__(self, data, transform, cache_num=sys.maxsize, cache_rate=1.0, num_workers=0):
231+
def __init__(
232+
self, data, transform: Callable, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0
233+
):
230234
"""
231235
Args:
232236
data (Iterable): input data to load and transform to generate dataset for model.
@@ -278,7 +282,7 @@ def __getitem__(self, index):
278282
# load data from cache and execute from the first random transform
279283
start_run = False
280284
data = self._cache[index]
281-
for _transform in self.transform.transforms:
285+
for _transform in self.transform.transforms: # pytype: disable=attribute-error
282286
if not start_run and not isinstance(_transform, Randomizable):
283287
continue
284288
else:
@@ -290,7 +294,7 @@ def __getitem__(self, index):
290294
return data
291295

292296

293-
class ZipDataset(torch.utils.data.Dataset):
297+
class ZipDataset(_TorchDataset):
294298
"""
295299
Zip several PyTorch datasets and output data(with the same index) together in a tuple.
296300
If the output of single dataset is already a tuple, flatten it and extend to the result.
@@ -323,7 +327,7 @@ def __init__(self, datasets, transform=None):
323327
def __len__(self):
324328
return self.len
325329

326-
def __getitem__(self, index):
330+
def __getitem__(self, index: int):
327331
def to_list(x):
328332
return list(x) if isinstance(x, (tuple, list)) else [x]
329333

@@ -375,7 +379,13 @@ def __call__(self, input_):
375379
"""
376380

377381
def __init__(
378-
self, img_files, img_transform=None, seg_files=None, seg_transform=None, labels=None, label_transform=None
382+
self,
383+
img_files,
384+
img_transform: Optional[Callable] = None,
385+
seg_files=None,
386+
seg_transform: Optional[Callable] = None,
387+
labels=None,
388+
label_transform: Optional[Callable] = None,
379389
):
380390
"""
381391
Initializes the dataset with the filename lists. The transform `img_transform` is applied
@@ -396,7 +406,7 @@ def __init__(
396406
def randomize(self):
397407
self.seed = self.R.randint(np.iinfo(np.int32).max)
398408

399-
def __getitem__(self, index):
409+
def __getitem__(self, index: int):
400410
self.randomize()
401411
for dataset in self.datasets:
402412
if isinstance(dataset.transform, Randomizable):

monai/data/grid_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class GridPatchDataset(IterableDataset):
2222
Yields patches from arrays read from an input dataset. The patches are chosen in a contiguous grid sampling scheme.
2323
"""
2424

25-
def __init__(self, dataset, patch_size, start_pos=(), pad_mode="wrap", **pad_opts):
25+
def __init__(self, dataset, patch_size, start_pos=(), pad_mode: str = "wrap", **pad_opts):
2626
"""
2727
Initializes this dataset in terms of the input dataset and patch size. The `patch_size` is the size of the
2828
patch to sample from the input arrays. Tt is assumed the arrays first dimension is the channel dimension which

monai/data/nifti_reader.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from typing import Optional, Callable
13+
1214
import numpy as np
1315
from torch.utils.data import Dataset
1416
from monai.transforms import LoadNifti
@@ -27,11 +29,11 @@ def __init__(
2729
image_files,
2830
seg_files=None,
2931
labels=None,
30-
as_closest_canonical=False,
31-
transform=None,
32-
seg_transform=None,
33-
image_only=True,
34-
dtype=np.float32,
32+
as_closest_canonical: bool = False,
33+
transform: Optional[Callable] = None,
34+
seg_transform: Optional[Callable] = None,
35+
image_only: bool = True,
36+
dtype: Optional[np.dtype] = np.float32,
3537
):
3638
"""
3739
Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied
@@ -67,7 +69,7 @@ def __len__(self):
6769
def randomize(self):
6870
self.seed = self.R.randint(np.iinfo(np.int32).max)
6971

70-
def __getitem__(self, index):
72+
def __getitem__(self, index: int):
7173
self.randomize()
7274
meta_data = None
7375
img_loader = LoadNifti(

monai/data/nifti_saver.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from typing import Union, Optional
13+
1214
import numpy as np
1315
import torch
1416
from monai.data.nifti_writer import write_nifti
@@ -25,14 +27,14 @@ class NiftiSaver:
2527

2628
def __init__(
2729
self,
28-
output_dir="./",
29-
output_postfix="seg",
30-
output_ext=".nii.gz",
31-
resample=True,
32-
interp_order=3,
33-
mode="constant",
34-
cval=0,
35-
dtype=None,
30+
output_dir: str = "./",
31+
output_postfix: str = "seg",
32+
output_ext: str = ".nii.gz",
33+
resample: bool = True,
34+
interp_order: int = 3,
35+
mode: str = "constant",
36+
cval: Union[int, float] = 0,
37+
dtype: Optional[np.dtype] = None,
3638
):
3739
"""
3840
Args:
@@ -63,7 +65,7 @@ def __init__(
6365
self.dtype = dtype
6466
self._data_index = 0
6567

66-
def save(self, data, meta_data=None):
68+
def save(self, data: Union[torch.Tensor, np.ndarray], meta_data=None):
6769
"""
6870
Save data into a Nifti file.
6971
The metadata could optionally have the following keys:
@@ -108,7 +110,7 @@ def save(self, data, meta_data=None):
108110
dtype=self.dtype or data.dtype,
109111
)
110112

111-
def save_batch(self, batch_data, meta_data=None):
113+
def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data=None):
112114
"""Save a batch of data into Nifti format files.
113115
114116
args:

monai/data/png_saver.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from typing import Union
13+
1214
import torch
1315
import numpy as np
1416
from monai.data.png_writer import write_png
@@ -25,14 +27,14 @@ class PNGSaver:
2527

2628
def __init__(
2729
self,
28-
output_dir="./",
29-
output_postfix="seg",
30-
output_ext=".png",
31-
resample=True,
32-
interp_order=3,
33-
mode="constant",
34-
cval=0,
35-
scale=False,
30+
output_dir: str = "./",
31+
output_postfix: str = "seg",
32+
output_ext: str = ".png",
33+
resample: bool = True,
34+
interp_order: int = 3,
35+
mode: str = "constant",
36+
cval: Union[int, float] = 0,
37+
scale: bool = False,
3638
):
3739
"""
3840
Args:
@@ -61,7 +63,7 @@ def __init__(
6163
self.scale = scale
6264
self._data_index = 0
6365

64-
def save(self, data, meta_data=None):
66+
def save(self, data: Union[torch.Tensor, np.ndarray], meta_data=None):
6567
"""
6668
Save data into a png file.
6769
The metadata could optionally have the following keys:
@@ -108,7 +110,7 @@ def save(self, data, meta_data=None):
108110
scale=self.scale,
109111
)
110112

111-
def save_batch(self, batch_data, meta_data=None):
113+
def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data=None):
112114
"""Save a batch of data into png format files.
113115
114116
args:

monai/data/png_writer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515

1616
def write_png(
17-
data, file_name, output_shape=None, interp_order=3, mode="constant", cval=0, scale=False, plugin=None, **plugin_args
17+
data,
18+
file_name,
19+
output_shape=None,
20+
interp_order=3,
21+
mode="constant",
22+
cval=0,
23+
scale=False,
24+
plugin=None,
25+
**plugin_args,
1826
):
1927
"""
2028
Write numpy data into png files to disk.

0 commit comments

Comments
 (0)