Skip to content

Commit

Permalink
feat: add reverse transformation possibility in inference notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
ttrenty committed May 5, 2024
1 parent 2ef6e6b commit eed41ba
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 18 deletions.
2 changes: 1 addition & 1 deletion GANs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
"metadata": {},
"outputs": [],
"source": [
"start_epoch = 1 # epoch to start training from\n",
"start_epoch = 0 # epoch to start training from\n",
"n_epochs = 200 # number of epochs of training\n",
"decay_epoch = 100 # epoch from which to start lr decay\n",
"n_cpu = 8 # number of cpu threads to use during batch generation\n",
Expand Down
101 changes: 84 additions & 17 deletions Inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,31 @@
"\n",
"# -- Adapted to both noise vectors and images input -- #\n",
"\n",
"mother_dataset = \"apple2orange64\"\n",
"dataset = \"apple2orange64\"\n",
"# dataset = \"orange2apple64\""
"\n",
"# -- Opposite transformation\n",
"\n",
"# mother_dataset = \"apple2orange64\"\n",
"# dataset = \"custom\"\n",
"# dataset_path = \"inferences/cycle-gan/\"\n",
"# source_dataset_name = \"apple2orange64\"\n",
"# dataset_name = \"orange2apple64\"\n",
"# my_class_A = 1\n",
"# my_class_B = \"A\"\n",
"\n",
"# mother_dataset = \"apple2orange64\"\n",
"# dataset = \"orange2apple64\"\n",
"\n",
"# -- Opposite transformation\n",
"\n",
"# mother_dataset = \"apple2orange64\"\n",
"# dataset = \"custom\"\n",
"# dataset_path = \"inferences/cycle-gan/\"\n",
"# source_dataset_name = \"apple2orange64\"\n",
"# dataset_name = \"apple2orange64\"\n",
"# my_class_A = 0\n",
"# my_class_B = \"B\"\n"
]
},
{
Expand All @@ -84,8 +107,8 @@
"import os\n",
"\n",
"# generator_name = \"generator_199\" # no cycle-gan\n",
"generator_name = \"G_AB_136\" # cycle-gan\n",
"# generator_name = \"G_BA_136\" # cycle-gan\n",
"generator_name = \"G_AB_185\" # cycle-gan\n",
"# generator_name = \"G_BA_185\" # cycle-gan\n",
"\n",
"# is_trained_with_cycleGAN = False\n",
"is_trained_with_cycleGAN = True\n",
Expand All @@ -94,12 +117,16 @@
"if is_trained_with_cycleGAN:\n",
" path_model = \"cycle-gan/\" + model\n",
"\n",
"model_path = \"saved_models/\" + path_model + \"/\" + dataset + \"/\" + generator_name + \".pth\"\n",
"model_path = \"saved_models/\" + path_model + \"/\" + mother_dataset + \"/\" + generator_name + \".pth\"\n",
"if dataset == \"custom\":\n",
" model_path = \"saved_models/\" + path_model + \"/\" + dataset_name + \"/\" + generator_name + \".pth\"\n",
" if is_trained_with_cycleGAN:\n",
" model_path = \"saved_models/\" + path_model + \"/\" + source_dataset_name + \"/\" + generator_name + \".pth\"\n",
"\n",
"# Make sure that model exists\n",
"\n",
"if not os.path.isfile(model_path):\n",
" print(\"Model does not exist\")\n",
" print(f\"Model {model_path} does not exist.\")\n",
" exit()\n",
"else:\n",
" print(\"Loaded model\", model_path)"
Expand All @@ -123,6 +150,7 @@
"import time\n",
"import glob\n",
"import random\n",
"import classifier_data\n",
"\n",
"import numpy as np\n",
"\n",
Expand Down Expand Up @@ -186,6 +214,11 @@
" latent_dim = 300\n",
" img_size = 64\n",
"\n",
"elif dataset == \"custom\":\n",
" channels = 3\n",
" latent_dim = 300\n",
" img_size = 64\n",
"\n",
"else:\n",
" raise Exception(\"Unknown dataset\")"
]
Expand All @@ -210,7 +243,7 @@
" from dcgan import Generator\n",
"\n",
"elif model == \"fdcgan\" or model == \"resnet\":\n",
" if dataset not in [\"apple2orange64\", \"orange2apple64\"]:\n",
" if dataset not in [\"apple2orange64\", \"orange2apple64\", \"custom\"]:\n",
" raise Exception(f\"Dataset {dataset} has no input image for the generator\")\n",
" if model == \"fdcgan\":\n",
" from fdcgan import Generator\n",
Expand Down Expand Up @@ -322,8 +355,6 @@
"\n",
"transforms_ = [\n",
" transforms.Resize(int(img_size * 1.12), Image.BICUBIC),\n",
" transforms.RandomCrop((img_size, img_size)),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
"]\n",
Expand All @@ -347,11 +378,11 @@
" )\n",
"\n",
"if dataset == \"apple2orange64\":\n",
" my_class_A = \"A\"\n",
" my_class_A = 0\n",
" my_class_B = \"B\"\n",
"\n",
"if dataset == \"orange2apple64\":\n",
" my_class_A = \"B\"\n",
" my_class_A = 1\n",
" my_class_B = \"A\"\n",
"\n",
"if dataset == \"apple2orange64\" or dataset == \"orange2apple64\":\n",
Expand All @@ -361,12 +392,34 @@
" subprocess.run(command, shell=True)\n",
"\n",
" # Test data loader\n",
" input_dataset = classifier_data.BinaryClassificationImageDataset(\n",
" \"./datasets/apple2orange64/\",\n",
" transformations=transforms_,\n",
" mode=\"validation\"\n",
" )\n",
"\n",
" dataloader = DataLoader(\n",
" ImageDataset(\"./datasets/apple2orange64\", transforms_=transforms_, unaligned=True, mode=\"validation\"),\n",
" input_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" shuffle=False,\n",
" num_workers=n_cpu,\n",
" )"
" )\n",
"\n",
"if dataset == \"custom\":\n",
"\n",
" input_dataset = classifier_data.BinaryClassificationImageDataset(\n",
" \"./inferences/\" + model,\n",
" transformations=transforms_,\n",
" mode=source_dataset_name\n",
" )\n",
" dataloader = DataLoader(\n",
" input_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=n_cpu,\n",
" )\n",
"\n",
"print(\"Number of mini-batches: \", len(dataloader))"
]
},
{
Expand All @@ -390,6 +443,9 @@
"outputs": [],
"source": [
"image_inference_folder = \"inferences/%s/%s/%s\" % (path_model, dataset, my_class_B)\n",
"if dataset == \"custom\":\n",
" image_inference_folder = \"inferences/%s/%s/%s\" % (path_model, dataset_name, my_class_B)\n",
"\n",
"os.makedirs(image_inference_folder, exist_ok=True)\n",
"\n",
"print(f\"Using '{model}' with '{dataset}', saving results to '{image_inference_folder}'.\")"
Expand Down Expand Up @@ -581,12 +637,24 @@
"metadata": {},
"outputs": [],
"source": [
"if (dataset == \"orange2apple64\" or dataset == \"apple2orange64\") and (model == \"fdcgan\" or model == \"resnet\"):\n",
"if (dataset == \"orange2apple64\" or dataset == \"apple2orange64\" or dataset == \"custom\") and (model == \"fdcgan\" or model == \"resnet\"):\n",
" prev_time = time.time()\n",
" print(dataloader)\n",
" for i, batch in enumerate(dataloader):\n",
"\n",
" # Set model input\n",
" real_A = Variable(batch[my_class_A].type(Tensor))\n",
" if dataset == \"custom\":\n",
" x, y = batch\n",
" x = x.to(device)\n",
" # batch_tensor = torch.stack(x)\n",
" real_A = Variable(x)\n",
" else:\n",
" x, y = batch\n",
" if (y != my_class_A):\n",
" continue\n",
" x = x.to(device)\n",
" # batch_tensor = torch.stack(x)\n",
" real_A = Variable(x)\n",
"\n",
" # ------------------\n",
" # Use Generators\n",
Expand Down Expand Up @@ -621,8 +689,7 @@
" # --------------\n",
"\n",
" image_name = image_inference_folder + \"/%d.png\" % i\n",
" save_image(fake_B, image_name, normalize=True)\n",
" \n"
" save_image(fake_B, image_name, normalize=True)"
]
}
],
Expand Down
2 changes: 2 additions & 0 deletions apple_orange_classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,11 @@
" input_dataset_folder,\n",
" transformations=transformations,\n",
" mode=\"apple2orange64\"\n",
" # mode=\"orange2apple64\"\n",
")\n",
"\n",
"input_dataset_loader = DataLoader(input_dataset, batch_size=1, shuffle=False)\n",
"\n",
"print(f\"Number of images: {len(input_dataset)}\")\n",
"print(type(input_dataset_loader))\n",
"# x_test_img, y_test = next(iter(input_dataset_loader))\n",
Expand Down

0 comments on commit eed41ba

Please sign in to comment.