Skip to content
This repository has been archived by the owner on May 6, 2023. It is now read-only.

Commit

Permalink
chore: add assert statement for dataset name
Browse files Browse the repository at this point in the history
  • Loading branch information
AFAgarap committed Nov 2, 2020
1 parent 2ea0582 commit 7bcf0fe
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pt_datasets/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,16 @@ def load_dataset(
Tuple[object, object]
A tuple consisting of the training dataset and the test dataset.
"""
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
supported_datasets = ["mnist", "fashion_mnist", "emnist", "cifar10", "svhn"]

name = name.lower()

assert (
name in supported_datasets
), f"[ERROR] Dataset {name} is not supported. Supported datasets: mnist, fashion_mnist, emnist, cifar10, svhn."

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

if name == "mnist":
train_dataset = torchvision.datasets.MNIST(
root=data_folder, train=True, download=True, transform=transform
Expand Down

0 comments on commit 7bcf0fe

Please sign in to comment.