|
546 | 546 | " class_folder = os.path.join(self.image_dir, str(class_label+1)) \n",
|
547 | 547 | " img_path = os.path.join(class_folder, img_name)\n",
|
548 | 548 | "\n",
|
549 |
| - " # Load the image\n", |
550 | 549 | " image = Image.open(img_path).convert('RGB')\n",
|
551 | 550 | "\n",
|
552 |
| - " # Apply transformations\n", |
553 | 551 | " if self.transform:\n",
|
554 | 552 | " image = self.transform(image)\n",
|
555 | 553 | "\n",
|
|
572 | 570 | },
|
573 | 571 | "outputs": [],
|
574 | 572 | "source": [
|
575 |
| - "image_directory = r\"/kaggle/input/raf-db-dataset/DATASET/train\" # Directory containing class subfolders\n", |
| 573 | + "image_directory = r\"/kaggle/input/raf-db-dataset/DATASET/train\" \n", |
576 | 574 | "csv_file_path = r\"/kaggle/input/raf-db-dataset/train_labels.csv\""
|
577 | 575 | ]
|
578 | 576 | },
|
|
668 | 666 | "torch.cuda.empty_cache()"
|
669 | 667 | ]
|
670 | 668 | },
|
671 |
| - { |
672 |
| - "cell_type": "code", |
673 |
| - "execution_count": 157, |
674 |
| - "id": "246285eb", |
675 |
| - "metadata": { |
676 |
| - "execution": { |
677 |
| - "iopub.execute_input": "2024-10-27T05:51:15.818480Z", |
678 |
| - "iopub.status.busy": "2024-10-27T05:51:15.818189Z", |
679 |
| - "iopub.status.idle": "2024-10-27T05:51:15.822552Z", |
680 |
| - "shell.execute_reply": "2024-10-27T05:51:15.821714Z", |
681 |
| - "shell.execute_reply.started": "2024-10-27T05:51:15.818441Z" |
682 |
| - } |
683 |
| - }, |
684 |
| - "outputs": [], |
685 |
| - "source": [ |
686 |
| - "# checkpoint = torch.load(r\"/kaggle/working/models/epoch_19_acc_79.73_PE_aug_1head_unfreeze_tot20epoch_4heads.pth\")" |
687 |
| - ] |
688 |
| - }, |
689 | 669 | {
|
690 | 670 | "cell_type": "code",
|
691 | 671 | "execution_count": 158,
|
|
703 | 683 | "outputs": [],
|
704 | 684 | "source": [
|
705 | 685 | "criterion = nn.CrossEntropyLoss()\n",
|
706 |
| - "# optimizer = optim.Adam(model.parameters(), lr=0.001)\n", |
707 |
| - "optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)\n", |
708 |
| - "# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])" |
| 686 | + "optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)" |
709 | 687 | ]
|
710 | 688 | },
|
711 | 689 | {
|
|
834 | 812 | "outputs": [],
|
835 | 813 | "source": [
|
836 | 814 | "def train(model, train_loader, test_loader, criterion, optimizer, device, num_epochs=20):\n",
|
837 |
| - " # train_loader = pl.MpDeviceLoader(train_loader, device)\n", |
838 | 815 | " best_accuracy = 75\n",
|
839 | 816 | " model.train()\n",
|
840 |
| - " # os.makedirs('models', exist_ok=True)\n", |
841 | 817 | " for epoch in range(1, num_epochs+1):\n",
|
842 | 818 | " running_loss = 0.0\n",
|
843 | 819 | "\n",
|
|
851 | 827 | "\n",
|
852 | 828 | " loss.backward()\n",
|
853 | 829 | " optimizer.step()\n",
|
854 |
| - " # xm.optimizer_step(optimizer)\n", |
855 | 830 | "\n",
|
856 | 831 | " running_loss += loss.item()\n",
|
857 | 832 | " epoch_loss = running_loss / len(train_loader)\n",
|
|
868 | 843 | " logits = model(images)\n",
|
869 | 844 | " \n",
|
870 | 845 | " _, predicted = torch.max(logits, 1)\n",
|
871 |
| - " total += labels.size(0) # Add the batch size to total\n", |
| 846 | + " total += labels.size(0) \n", |
872 | 847 | " correct += (predicted == labels).sum().item()\n",
|
873 | 848 | " accuracy = 100*(correct/total)\n",
|
874 | 849 | " print(f\"test_Acc : {accuracy}\")\n",
|
|
2650 | 2625 | " logits = model(images)\n",
|
2651 | 2626 | " \n",
|
2652 | 2627 | " _, predicted = torch.max(logits, 1)\n",
|
2653 |
| - " total += labels.size(0) # Add the batch size to total\n", |
| 2628 | + " total += labels.size(0) \n", |
2654 | 2629 | " correct += (predicted == labels).sum().item()\n",
|
2655 | 2630 | "accuracy = 100*(correct/total)\n",
|
2656 | 2631 | "\n",
|
|
2660 | 2635 | "accuracy\n"
|
2661 | 2636 | ]
|
2662 | 2637 | },
|
2663 |
| - { |
2664 |
| - "cell_type": "code", |
2665 |
| - "execution_count": 297, |
2666 |
| - "id": "307848e9", |
2667 |
| - "metadata": { |
2668 |
| - "execution": { |
2669 |
| - "iopub.execute_input": "2024-10-27T08:36:01.208701Z", |
2670 |
| - "iopub.status.busy": "2024-10-27T08:36:01.208262Z", |
2671 |
| - "iopub.status.idle": "2024-10-27T08:36:01.213300Z", |
2672 |
| - "shell.execute_reply": "2024-10-27T08:36:01.212375Z", |
2673 |
| - "shell.execute_reply.started": "2024-10-27T08:36:01.208658Z" |
2674 |
| - }, |
2675 |
| - "scrolled": true |
2676 |
| - }, |
2677 |
| - "outputs": [], |
2678 |
| - "source": [ |
2679 |
| - "# checkpoint = torch.load(r\"/kaggle/working/models/tot20epoch_4heads_epoch_13_loss_0.07219641906097725.pth\")" |
2680 |
| - ] |
2681 |
| - }, |
2682 | 2638 | {
|
2683 | 2639 | "cell_type": "code",
|
2684 | 2640 | "execution_count": 298,
|
|
0 commit comments