Skip to content
This repository was archived by the owner on May 6, 2023. It is now read-only.

Commit 9db51cc

Browse files
committed
fix: add preprocessing batch size when loading dataset
Merge branch 'add-preprocessing-bsize'
2 parents 66e8068 + 413ea6b commit 9db51cc

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

pt_datasets/load_dataset.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def load_dataset(
4545
return_vectorizer: bool = False,
4646
image_size: int = 64,
4747
preprocessed_covidx: bool = False,
48+
preprocessing_bsize: int = 2048,
4849
) -> Tuple[object, object]:
4950
"""
5051
Returns a tuple of torchvision dataset objects.
@@ -85,6 +86,8 @@ def load_dataset(
8586
Whether to use the preprocessed COVID19 datasets or not.
8687
This requires the use of `modules/export_covid19_dataset`
8788
in the package repository.
89+
preprocessing_bsize: int
90+
The batch size to use for preprocessing the COVID19 dataset.
8891
8992
Returns
9093
-------
@@ -212,11 +215,17 @@ def load_dataset(
212215
train_dataset, test_dataset = load_wdbc()
213216
elif name == "binary_covid":
214217
train_dataset, test_dataset = load_binary_covid19(
215-
transform=transform, size=image_size, preprocessed=preprocessed_covidx
218+
transform=transform,
219+
size=image_size,
220+
preprocessed=preprocessed_covidx,
221+
preprocessing_bsize=preprocessing_bsize,
216222
)
217223
elif name == "multi_covid":
218224
train_dataset, test_dataset = load_multi_covid19(
219-
transform=transform, size=image_size, preprocessed=preprocessed_covidx
225+
transform=transform,
226+
size=image_size,
227+
preprocessed=preprocessed_covidx,
228+
preprocessing_bsize=preprocessing_bsize,
220229
)
221230
return (
222231
(train_dataset, test_dataset, vectorizer)
@@ -403,7 +412,10 @@ def load_wdbc(test_size: float = 3e-1, seed: int = 42):
403412

404413

405414
def load_binary_covid19(
406-
transform: torchvision.transforms, size: int = 64, preprocessed: bool = False
415+
transform: torchvision.transforms,
416+
size: int = 64,
417+
preprocessed: bool = False,
418+
preprocessing_bsize: int = 2048,
407419
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
408420
"""
409421
Returns a tuple of the tensor datasets for the
@@ -417,6 +429,8 @@ def load_binary_covid19(
417429
The size to use for image resizing.
418430
preprocessed: bool
419431
Whether to load preprocessed dataset or not.
432+
preprocessing_bsize: int
433+
The batch size to use for preprocessing the dataset.
420434
421435
Returns
422436
-------
@@ -432,14 +446,27 @@ def load_binary_covid19(
432446
download_binary_covid19_dataset()
433447
unzip_dataset(os.path.join(dataset_path, "BinaryCOVID19Dataset.tar.xz"))
434448
(train_data, test_data) = (
435-
BinaryCOVID19Dataset(train=True, preprocessed=preprocessed, size=size),
436-
BinaryCOVID19Dataset(train=False, preprocessed=preprocessed, size=size),
449+
BinaryCOVID19Dataset(
450+
train=True,
451+
preprocessed=preprocessed,
452+
size=size,
453+
preprocessing_bsize=preprocessing_bsize,
454+
),
455+
BinaryCOVID19Dataset(
456+
train=False,
457+
preprocessed=preprocessed,
458+
size=size,
459+
preprocessing_bsize=preprocessing_bsize,
460+
),
437461
)
438462
return train_data, test_data
439463

440464

441465
def load_multi_covid19(
442-
transform: torchvision.transforms, size: int = 64, preprocessed: bool = False
466+
transform: torchvision.transforms,
467+
size: int = 64,
468+
preprocessed: bool = False,
469+
preprocessing_bsize: int = 2048,
443470
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
444471
"""
445472
Returns a tuple of the tensor datasets for the
@@ -453,6 +480,8 @@ def load_multi_covid19(
453480
The size to use for image resizing.
454481
preprocessed: bool
455482
Whether to load preprocessed dataset or not.
483+
preprocessing_bsize: int
484+
The batch size to use for preprocessing the dataset.
456485
457486
Returns
458487
-------
@@ -468,7 +497,17 @@ def load_multi_covid19(
468497
download_covidx5_dataset()
469498
unzip_dataset(os.path.join(dataset_path, "MultiCOVID19Dataset.tar.xz"))
470499
(train_data, test_data) = (
471-
MultiCOVID19Dataset(train=True, preprocessed=preprocessed, size=size),
472-
MultiCOVID19Dataset(train=False, preprocessed=preprocessed, size=size),
500+
MultiCOVID19Dataset(
501+
train=True,
502+
preprocessed=preprocessed,
503+
size=size,
504+
preprocessing_bsize=preprocessing_bsize,
505+
),
506+
MultiCOVID19Dataset(
507+
train=False,
508+
preprocessed=preprocessed,
509+
size=size,
510+
preprocessing_bsize=preprocessing_bsize,
511+
),
473512
)
474513
return train_data, test_data

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _post_install():
2222

2323
setup(
2424
name="pt-datasets",
25-
version="0.11.2",
25+
version="0.11.3",
2626
packages=["pt_datasets"],
2727
url="https://github.com/AFAgarap/pt-datasets",
2828
license="AGPL-3.0 License",

0 commit comments

Comments
 (0)