Skip to content

Commit 4fd7bf5

Browse files
committed
.
1 parent c918eec commit 4fd7bf5

File tree

1 file changed

+4
-48
lines changed

1 file changed

+4
-48
lines changed

model-with-mlp-1-head-aug-patch-extraction.ipynb

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -546,10 +546,8 @@
546546
" class_folder = os.path.join(self.image_dir, str(class_label+1)) \n",
547547
" img_path = os.path.join(class_folder, img_name)\n",
548548
"\n",
549-
" # Load the image\n",
550549
" image = Image.open(img_path).convert('RGB')\n",
551550
"\n",
552-
" # Apply transformations\n",
553551
" if self.transform:\n",
554552
" image = self.transform(image)\n",
555553
"\n",
@@ -572,7 +570,7 @@
572570
},
573571
"outputs": [],
574572
"source": [
575-
"image_directory = r\"/kaggle/input/raf-db-dataset/DATASET/train\" # Directory containing class subfolders\n",
573+
"image_directory = r\"/kaggle/input/raf-db-dataset/DATASET/train\" \n",
576574
"csv_file_path = r\"/kaggle/input/raf-db-dataset/train_labels.csv\""
577575
]
578576
},
@@ -668,24 +666,6 @@
668666
"torch.cuda.empty_cache()"
669667
]
670668
},
671-
{
672-
"cell_type": "code",
673-
"execution_count": 157,
674-
"id": "246285eb",
675-
"metadata": {
676-
"execution": {
677-
"iopub.execute_input": "2024-10-27T05:51:15.818480Z",
678-
"iopub.status.busy": "2024-10-27T05:51:15.818189Z",
679-
"iopub.status.idle": "2024-10-27T05:51:15.822552Z",
680-
"shell.execute_reply": "2024-10-27T05:51:15.821714Z",
681-
"shell.execute_reply.started": "2024-10-27T05:51:15.818441Z"
682-
}
683-
},
684-
"outputs": [],
685-
"source": [
686-
"# checkpoint = torch.load(r\"/kaggle/working/models/epoch_19_acc_79.73_PE_aug_1head_unfreeze_tot20epoch_4heads.pth\")"
687-
]
688-
},
689669
{
690670
"cell_type": "code",
691671
"execution_count": 158,
@@ -703,9 +683,7 @@
703683
"outputs": [],
704684
"source": [
705685
"criterion = nn.CrossEntropyLoss()\n",
706-
"# optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
707-
"optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)\n",
708-
"# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])"
686+
"optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)"
709687
]
710688
},
711689
{
@@ -834,10 +812,8 @@
834812
"outputs": [],
835813
"source": [
836814
"def train(model, train_loader, test_loader, criterion, optimizer, device, num_epochs=20):\n",
837-
" # train_loader = pl.MpDeviceLoader(train_loader, device)\n",
838815
" best_accuracy = 75\n",
839816
" model.train()\n",
840-
" # os.makedirs('models', exist_ok=True)\n",
841817
" for epoch in range(1, num_epochs+1):\n",
842818
" running_loss = 0.0\n",
843819
"\n",
@@ -851,7 +827,6 @@
851827
"\n",
852828
" loss.backward()\n",
853829
" optimizer.step()\n",
854-
" # xm.optimizer_step(optimizer)\n",
855830
"\n",
856831
" running_loss += loss.item()\n",
857832
" epoch_loss = running_loss / len(train_loader)\n",
@@ -868,7 +843,7 @@
868843
" logits = model(images)\n",
869844
" \n",
870845
" _, predicted = torch.max(logits, 1)\n",
871-
" total += labels.size(0) # Add the batch size to total\n",
846+
" total += labels.size(0) \n",
872847
" correct += (predicted == labels).sum().item()\n",
873848
" accuracy = 100*(correct/total)\n",
874849
" print(f\"test_Acc : {accuracy}\")\n",
@@ -2650,7 +2625,7 @@
26502625
" logits = model(images)\n",
26512626
" \n",
26522627
" _, predicted = torch.max(logits, 1)\n",
2653-
" total += labels.size(0) # Add the batch size to total\n",
2628+
" total += labels.size(0) \n",
26542629
" correct += (predicted == labels).sum().item()\n",
26552630
"accuracy = 100*(correct/total)\n",
26562631
"\n",
@@ -2660,25 +2635,6 @@
26602635
"accuracy\n"
26612636
]
26622637
},
2663-
{
2664-
"cell_type": "code",
2665-
"execution_count": 297,
2666-
"id": "307848e9",
2667-
"metadata": {
2668-
"execution": {
2669-
"iopub.execute_input": "2024-10-27T08:36:01.208701Z",
2670-
"iopub.status.busy": "2024-10-27T08:36:01.208262Z",
2671-
"iopub.status.idle": "2024-10-27T08:36:01.213300Z",
2672-
"shell.execute_reply": "2024-10-27T08:36:01.212375Z",
2673-
"shell.execute_reply.started": "2024-10-27T08:36:01.208658Z"
2674-
},
2675-
"scrolled": true
2676-
},
2677-
"outputs": [],
2678-
"source": [
2679-
"# checkpoint = torch.load(r\"/kaggle/working/models/tot20epoch_4heads_epoch_13_loss_0.07219641906097725.pth\")"
2680-
]
2681-
},
26822638
{
26832639
"cell_type": "code",
26842640
"execution_count": 298,

0 commit comments

Comments
 (0)