Skip to content

Commit

Permalink
style: black formatting for precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dimakis committed Nov 14, 2023
1 parent 569d1c2 commit 55455b1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
19 changes: 12 additions & 7 deletions demo-notebooks/guided-demos/download_mnist_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,32 @@
from torchvision.datasets import MNIST
from torchvision import transforms


def download_mnist_dataset(destination_dir):
# Ensure the destination directory exists
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)

# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Download the training data
train_set = MNIST(root=destination_dir, train=True, download=True, transform=transform)
train_set = MNIST(
root=destination_dir, train=True, download=True, transform=transform
)

# Download the test data
test_set = MNIST(root=destination_dir, train=False, download=True, transform=transform)
test_set = MNIST(
root=destination_dir, train=False, download=True, transform=transform
)

print(f"MNIST dataset downloaded in {destination_dir}")


# Specify the directory where you
script_dir = os.path.dirname(os.path.abspath(__file__))
destination_dir = script_dir + "/mnist_datasets"

download_mnist_dataset(destination_dir)
download_mnist_dataset(destination_dir)
4 changes: 3 additions & 1 deletion demo-notebooks/guided-demos/mnist_disconnected.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def prepare_data(self):
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform, download=False)
mnist_full = MNIST(
self.data_dir, train=True, transform=self.transform, download=False
)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

# Assign test dataset for use in dataloader(s)
Expand Down

0 comments on commit 55455b1

Please sign in to comment.