@@ -45,6 +45,7 @@ def load_dataset(
45
45
return_vectorizer : bool = False ,
46
46
image_size : int = 64 ,
47
47
preprocessed_covidx : bool = False ,
48
+ preprocessing_bsize : int = 2048 ,
48
49
) -> Tuple [object , object ]:
49
50
"""
50
51
Returns a tuple of torchvision dataset objects.
@@ -85,6 +86,8 @@ def load_dataset(
85
86
Whether to use the preprocessed COVID19 datasets or not.
86
87
This requires the use of `modules/export_covid19_dataset`
87
88
in the package repository.
89
+ preprocessing_bsize: int
90
+ The batch size to use for preprocessing the COVID19 dataset.
88
91
89
92
Returns
90
93
-------
@@ -212,11 +215,17 @@ def load_dataset(
212
215
train_dataset , test_dataset = load_wdbc ()
213
216
elif name == "binary_covid" :
214
217
train_dataset , test_dataset = load_binary_covid19 (
215
- transform = transform , size = image_size , preprocessed = preprocessed_covidx
218
+ transform = transform ,
219
+ size = image_size ,
220
+ preprocessed = preprocessed_covidx ,
221
+ preprocessing_bsize = preprocessing_bsize ,
216
222
)
217
223
elif name == "multi_covid" :
218
224
train_dataset , test_dataset = load_multi_covid19 (
219
- transform = transform , size = image_size , preprocessed = preprocessed_covidx
225
+ transform = transform ,
226
+ size = image_size ,
227
+ preprocessed = preprocessed_covidx ,
228
+ preprocessing_bsize = preprocessing_bsize ,
220
229
)
221
230
return (
222
231
(train_dataset , test_dataset , vectorizer )
@@ -403,7 +412,10 @@ def load_wdbc(test_size: float = 3e-1, seed: int = 42):
403
412
404
413
405
414
def load_binary_covid19 (
406
- transform : torchvision .transforms , size : int = 64 , preprocessed : bool = False
415
+ transform : torchvision .transforms ,
416
+ size : int = 64 ,
417
+ preprocessed : bool = False ,
418
+ preprocessing_bsize : int = 2048 ,
407
419
) -> Tuple [torch .utils .data .Dataset , torch .utils .data .Dataset ]:
408
420
"""
409
421
Returns a tuple of the tensor datasets for the
@@ -417,6 +429,8 @@ def load_binary_covid19(
417
429
The size to use for image resizing.
418
430
preprocessed: bool
419
431
Whether to load preprocessed dataset or not.
432
+ preprocessing_bsize: int
433
+ The batch size to use for preprocessing the dataset.
420
434
421
435
Returns
422
436
-------
@@ -432,14 +446,27 @@ def load_binary_covid19(
432
446
download_binary_covid19_dataset ()
433
447
unzip_dataset (os .path .join (dataset_path , "BinaryCOVID19Dataset.tar.xz" ))
434
448
(train_data , test_data ) = (
435
- BinaryCOVID19Dataset (train = True , preprocessed = preprocessed , size = size ),
436
- BinaryCOVID19Dataset (train = False , preprocessed = preprocessed , size = size ),
449
+ BinaryCOVID19Dataset (
450
+ train = True ,
451
+ preprocessed = preprocessed ,
452
+ size = size ,
453
+ preprocessing_bsize = preprocessing_bsize ,
454
+ ),
455
+ BinaryCOVID19Dataset (
456
+ train = False ,
457
+ preprocessed = preprocessed ,
458
+ size = size ,
459
+ preprocessing_bsize = preprocessing_bsize ,
460
+ ),
437
461
)
438
462
return train_data , test_data
439
463
440
464
441
465
def load_multi_covid19 (
442
- transform : torchvision .transforms , size : int = 64 , preprocessed : bool = False
466
+ transform : torchvision .transforms ,
467
+ size : int = 64 ,
468
+ preprocessed : bool = False ,
469
+ preprocessing_bsize : int = 2048 ,
443
470
) -> Tuple [torch .utils .data .Dataset , torch .utils .data .Dataset ]:
444
471
"""
445
472
Returns a tuple of the tensor datasets for the
@@ -453,6 +480,8 @@ def load_multi_covid19(
453
480
The size to use for image resizing.
454
481
preprocessed: bool
455
482
Whether to load preprocessed dataset or not.
483
+ preprocessing_bsize: int
484
+ The batch size to use for preprocessing the dataset.
456
485
457
486
Returns
458
487
-------
@@ -468,7 +497,17 @@ def load_multi_covid19(
468
497
download_covidx5_dataset ()
469
498
unzip_dataset (os .path .join (dataset_path , "MultiCOVID19Dataset.tar.xz" ))
470
499
(train_data , test_data ) = (
471
- MultiCOVID19Dataset (train = True , preprocessed = preprocessed , size = size ),
472
- MultiCOVID19Dataset (train = False , preprocessed = preprocessed , size = size ),
500
+ MultiCOVID19Dataset (
501
+ train = True ,
502
+ preprocessed = preprocessed ,
503
+ size = size ,
504
+ preprocessing_bsize = preprocessing_bsize ,
505
+ ),
506
+ MultiCOVID19Dataset (
507
+ train = False ,
508
+ preprocessed = preprocessed ,
509
+ size = size ,
510
+ preprocessing_bsize = preprocessing_bsize ,
511
+ ),
473
512
)
474
513
return train_data , test_data
0 commit comments