diff --git a/Notebooks/Metrics-MixNets for Whole_Dataset.ipynb b/Notebooks/Metrics-MixNets for Whole_Dataset.ipynb deleted file mode 100644 index a7f206c..0000000 --- a/Notebooks/Metrics-MixNets for Whole_Dataset.ipynb +++ /dev/null @@ -1,964 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "EPEBQBTzcZiS" - }, - "outputs": [], - "source": [ - "import os\n", - "import re\n", - "import PIL\n", - "import sys\n", - "import sls\n", - "import cv2\n", - "import json\n", - "import time\n", - "import glob\n", - "import math\n", - "import timm\n", - "import copy\n", - "import torch\n", - "import pickle\n", - "import geffnet\n", - "import logging\n", - "import fnmatch\n", - "import argparse\n", - "import itertools\n", - "import torchvision\n", - "import numpy as np\n", - "%matplotlib inline\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import torch.nn as nn\n", - "from PIL import Image\n", - "from glob import glob\n", - "from pathlib import Path\n", - "from copy import deepcopy\n", - "from sklearn import metrics\n", - "import torch.optim as optim\n", - "from datetime import datetime\n", - "import matplotlib.pyplot as plt\n", - "import torch.nn.functional as F\n", - "import torch.utils.data as data\n", - "from torchvision import transforms\n", - "from torch.autograd import Variable\n", - "from tqdm import tqdm, tqdm_notebook\n", - "from torch.optim import lr_scheduler\n", - "import torch.utils.model_zoo as model_zoo\n", - "from timm.models.layers.activations import *\n", - "%config InlineBackend.figure_format = 'retina'\n", - "from efficientnet_pytorch import EfficientNet\n", - "from collections import OrderedDict, defaultdict\n", - "from torchvision import transforms, models, datasets\n", - "from torch.utils.data.sampler import SubsetRandomSampler\n", - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - "from sklearn.metrics import classification_report, confusion_matrix,accuracy_score" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 68 - }, - "colab_type": "code", - "id": "CA9wZQTds_Av", - "outputId": "e34e2986-0c5f-4723-fe0a-60e01320e951" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['COVID-19', 'normal', 'pneumonia']\n", - "{'data': 15000, 'test': 1489, 'COVID-19 Radiography Database_2020_05_18_Kaggle': 2905}\n", - "cuda:0\n", - "{0: 'COVID-19', 1: 'normal', 2: 'pneumonia'}\n" - ] - }, - { - "data": { - "text/plain": [ - "torch.Size([4, 3, 300, 300])" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_dir = '/home/linh/Downloads/Covid-19'\n", - "\n", - "# Define your transforms for the training and testing sets\n", - "data_transforms = { \n", - " 'data': transforms.Compose([\n", - " #transforms.RandomRotation(30),\n", - " #transforms.Resize(256),\n", - " transforms.Resize(340), #for EfficientNet_B3_PRUNED\n", - " #transforms.RandomResizedCrop(224), \n", - " #transforms.CenterCrop(224), \n", - " transforms.CenterCrop(300), #for EfficientNet_B3_PRUNED\n", - " #transforms.RandomHorizontalFlip(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize([0.485, 0.456, 0.406], \n", - " [0.229, 0.224, 0.225])\n", - " ]),\n", - " 'test': transforms.Compose([\n", - " #transforms.RandomRotation(30),\n", - " #transforms.Resize(256),\n", - " transforms.Resize(340), #for EfficientNet_B3_PRUNED\n", - " #transforms.RandomResizedCrop(224), \n", - " #transforms.CenterCrop(224), \n", - " transforms.CenterCrop(300), #for EfficientNet_B3_PRUNED\n", - " #transforms.RandomHorizontalFlip(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize([0.485, 0.456, 0.406], \n", - " [0.229, 0.224, 0.225])\n", - " ]),\n", - " 'COVID-19 Radiography Database_2020_05_18_Kaggle': transforms.Compose([\n", - " #transforms.RandomRotation(30),\n", - " #transforms.Resize(256),\n", - " transforms.Resize(340), #for EfficientNet_B3_PRUNED\n", - " #transforms.RandomResizedCrop(224), \n", - " #transforms.CenterCrop(224), \n", - " transforms.CenterCrop(300), #for EfficientNet_B3_PRUNED\n", - " #transforms.RandomHorizontalFlip(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize([0.485, 0.456, 0.406], \n", - " [0.229, 0.224, 0.225])\n", - " ])\n", - "\n", - " }\n", - "\n", - "# Load the datasets with ImageFolder\n", - "image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n", - " data_transforms[x])\n", - " for x in ['data', 'test', 'COVID-19 Radiography Database_2020_05_18_Kaggle']}\n", - "\n", - "batch_size = 4\n", - "data_loader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,\n", - " shuffle=True, num_workers=4)\n", - " for x in ['data', 'test', 'COVID-19 Radiography Database_2020_05_18_Kaggle']}\n", - "dataset_sizes = {x: len(image_datasets[x]) for x in ['data', 'test', 'COVID-19 Radiography Database_2020_05_18_Kaggle']}\n", - "\n", - "class_names = image_datasets['test'].classes\n", - "print(class_names)\n", - "print(dataset_sizes)\n", - "print(device)\n", - "\n", - "\n", - "### we get the class_to_index in the data_Set but what we really need is the cat_to_names so we will create\n", - "_ = image_datasets['test'].class_to_idx\n", - "cat_to_name = {_[i]: i for i in list(_.keys())}\n", - "print(cat_to_name)\n", - " \n", - "# Run this to test the data loader\n", - "images, labels = next(iter(data_loader['test']))\n", - "images.size()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# MixNet_XL with image size = 224" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "checkpoint loaded\n", - "Training complete in 2703m 45s\n", - " precision recall f1-score support\n", - "\n", - " 0 1.000000 0.759259 0.863158 108\n", - " 1 0.967309 0.992882 0.979929 8851\n", - " 2 0.986628 0.952657 0.969345 6041\n", - "\n", - " accuracy 0.975000 15000\n", - " macro avg 0.984646 0.901599 0.937477 15000\n", - "weighted avg 0.975325 0.975000 0.974825 15000\n", - "\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAs4AAAHwCAYAAAC2dOlsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de/xtZV0v+s93g0CgG7xVAnZQNmi7tFJJxe0FTNJ0lyke9XU00qPmBUvFzu54SS1s19bwhkf3thLTzkEPph68m0ComBfUsG2BgmSiJV5YyF34PeePMX41nczfWmP9fnOt37PWer99zddgjjmeMZ/5EuS7Pn7H81RrLQAAwNb9u82eAAAA7AoUzgAAMIHCGQAAJlA4AwDABApnAACYQOEMAAATKJwBAGAChTMAAEygcAYAgAkUzgAAMIHCGQAAJlA4AwDABApnAACYYO/NnsBG7LPvoW2z5wAbtdL8bQzQkxtvuKw2ew5J8oNvX7L0f0Hc4nZ37uK37aokzgAAMMEunTgDAOy2Vm7a7BkwR+IMAAATSJwBAHrUVjZ7BsxROAMA9GhF4dwbrRoAADCBxBkAoENNq0Z3JM4AADCBxBkAoEd6nLujcAYA6JFWje5o1QAAgAkkzgAAPbJzYHckzgAAMIHEGQCgR3qcuyNxBgCACSTOAAA9shxddxTOAAAdsnNgf7RqAADABBJnAIAeadXojsQZAAAmkDgDAPRIj3N3FM4AAD2yc2B3tGoAAMAEEmcAgB5p1eiOxBkAACaQOAMA9MhydN1ROAMA9EirRne0agAAwAQSZwCAHmnV6I7EGQAAJpA4AwB0qDUboPRG4QwA0CMPB3ZHqwYAAEwgcQYA6JGHA7sjcQYAgAkkzgAAPdLj3B2JMwAATCBxBgDo0Yrl6HqjcAYA6JFWje5o1QAAgAkkzgAAPbIcXXckzgAAMIHEGQCgR3qcu6NwBgDokVaN7mjVAACACSTOAAA9kjh3R+IMAAATSJwBADrUmp0De6NwBgDokVaN7mjVAACACSTOAAA9so5zdyTOAAAwgcQZAKBHepy7I3EGAIAJJM4AAD3S49wdhTMAQI+0anRHqwYAAEygcAYA6FFbWf5rg6rq16uqbeN108z1h23j2tO38l0nVNWnq+qqqtpSVedU1SO2cv1eVfXcqrqgqq6tqu9W1fur6ugN//CRVg0AAKb6QpKXrfHZ/ZMcm+QDCz772yTvXnD+7xbdqKpemeSkJF9P8qYk+yR5XJIzq+rZrbVT566vJKcnOT7JhUlOTXKbJI9Ncm5VPbq19p6t/7RtUzgDAPSowx7n1toXMhTPN1NVnxz/8n8s+PgLrbWXTvmOMSE+KcnFSY5qrX1vPP+KJOcneWVVvbe1dunMsMdlKJrPS/Lg1tp145g3Jvl4kjdV1Vmtte9PmcNatGoAAPRoZWX5rx2kqu6W5D5JLkvyvg3e7unj8eWrRXOSjIXy65Psm+RJc2OeMR5ftFo0j2M+k+TtSW6fobDeEIUzAMAeoqrOX+u1wVs/bTz+aWvtpgWfH1xVv1FVLxiPd9/KvY4djx9c8NkH5q5JVe2X5Ogk1yT52JQx66VVAwCgR7vIOs5V9SNJnpDkpiR/ssZlDxlfs+POSXJCa+1rM+cOSHJIkqtaa99ccJ8vj8cjZ84dnmSvJJe01m6cOGZdFM4AAHuI1to9d8Bt/9ckByV5X2vtn+Y+uybJ72d4MPCS8dzdk7w0yTFJPlpVP9tau3r87MDxuGWN71o9f9DMufWMWReFMwBAjzp8OHANq20a/33+g9bat5L87tzpc6vquAwP7d07yVOSvGaHznBJ9DgDAPSow3Wc51XVT2XoL/56kvdP/mlDS8VqW8cDZj5aTYcPzGKr56/Y4Jh1UTgDALBe23oocGsuH48HrJ4YWzYuS3LLqrrDgjFHjMeLZs5dnKG/+s5VtaibYtGYdVlqq8bYHH6fDM3Xq30kV2SY6N+01q5d5vcBAOy2Om/VGFezeGKGovVP13GL+4zHS+bOnzXe96FJ3jz32cNmrkmStNauq6rzMmzAcv8kZ29rzHotpXCuqlsneXmGH7n/GpddU1V/nmF9ve+tcQ0AALuGxyS5dZL3LngoMElSVffIsPnJytz5Byd57vj2bXPD3pihpnxhVb17ZgOUw5I8K8n1uXlB/YYMRfPJVTW7AcpRGXYPvDzJO9fxG3/IhgvnqjooySeS3DXJ1Uk+kmHZj9l+kyOS3C/D4tTHVNV9W2trPfkIAED/y9Gttmks2ilw1SlJjhgT4a+P5+6ef1tT+cWttfNmB7TWzquqU5I8L8kFVXVGhi23H5thG+1nz+0amAzbbT8qwyYnn6+qM5PcdhyzV5Knttau3P6f+MOWkTi/JEPR/KokL2mtXbXooqq6ZZLfS/KcDE9XnjTl5ltbkPsW+xyy3ZMFAGBjquonk/ynbPuhwLcm+dUkR2VombhFkn9J8o4kp7bWFm1YktbaSVX1xQwJ89OSrCT5XJJXtNbeu+D6VlWPz7Dl9pOTPDvJdUnOTXLyfHG+XtVa29gNqr6a5OLW2i9MvP6sJHdqrd1p4vVbK5zvMW2W0K+VDf4zCMBy3XjDZbXZc0iSa884een/gviR41/UxW/bVS0jcb5Dkv9nO67/mwzLlkyytYW699n3UBUHALB76vzhwD3RMpaj+06Su2zH9T85jgEAgF3GMgrnDyV5ZFU9c1sXVtWJSX45yQeX8L0AALuv1pb/YkOW0arx4iQPT/K6qjopyYczrNs8u6rGkUmOS3JYkkVbLwIAQNc2XDi31i6rqvtmWD/vIUl+I8n8H2lWG9E/nOSZrbXLNvq9AAC7NT3O3VnKBiittUuS/GJV3TnJMRl6nlf3Bd+S5MIkZ4/XAQCwLQrn7ix1y+2xMFYcAwCw21lq4QwAwJL0v3PgHmcZq2oAAMBuT+IMANAjPc7dUTgDAPTIusvd0aoBAAATSJwBAHqkVaM7EmcAAJhA4gwA0COJc3ckzgAAMIHEGQCgRzZA6Y7CGQCgQ23FcnS90aoBAAATSJwBAHrk4cDuSJwBAGACiTMAQI88HNgdhTMAQI88HNgdrRoAADCBxBkAoEceDuyOxBkAACaQOAMA9Eji3B2FMwBAj5qHA3ujVQMAACaQOAMA9EirRnckzgAAMIHEGQCgRzZA6Y7EGQAAJpA4AwD0qOlx7o3CGQCgR1o1uqNVAwAAJpA4AwB0qFmOrjsSZwAAmEDiDADQIz3O3VE4AwD0yKoa3dGqAQAAE0icAQB6pFWjOxJnAACYQOIMANAjy9F1R+EMANAjrRrd0aoBAAATSJwBAHpkObruSJwBAGACiTMAQI/0OHdH4gwAABNInAEAOtQsR9cdhTMAQI+0anRHqwYAAEwgcQYA6JHEuTsSZwAAmEDiDADQIxugdEfhDADQI60a3dGqAQAAE0icAQA61CTO3ZE4AwCw3arqwVX1rqr656q6vqq+UVUfqqpfWnDt0VX1/qr6blVdW1UXVNVzqmqvrdz/EVV1TlVtqaqrqupTVXXCNuZ0QlV9erx+yzj+Ecv4vYnCGQCgTytt+a8lqar/luSvktwryf+X5I+TvC/J7ZM8aO7aX0lybpIHJHlXklOT7JPkVUlOX+P+JyY5M8lPJ3lbkjclOTjJaVX1yjXGvDLJaUnuMF7/tiR3S3LmeL8Nq9Z23f8bYJ99D911Jw+jlV34n0GA3dGNN1xWmz2HJPn+ib+09H9B3OrU92/4t1XVU5P8jyRvSfK01toNc5/forX2g/Gv/32SryQ5MMn9WmufHc/vl+SsJPdN8vjW2ukz4w9L8g9Jrk5yz9bapeP5Wyf5TJLDkxzdWvvkzJijk3wiycVJjmqtfW/mXucnOSDJXVfvtV4SZwAAJqmqfZO8PMnXsqBoTpLVonl0fIYU+vTVonm85rokLxrfPmPuFk9Osm+SU2cL3bEY/oPx7dPnxqy+f/lq0TyOuTTJ68f7PWnbv3DrFM4AAD3qs1XjIRkK4b9MslJVD6+q/1JVv1VV911w/bHj8YMLPjs3yTVJjh4L8iljPjB3zUbGbDeragAA7CGq6vy1Pmut3XPCLY4aj9cl+XyGHuTZ+5+b5PjW2uXjqbuMx4sWfN+NVfXVJD+V5M5J/n7CmG9W1dVJDq2q/Vtr11TVAUkOSXJVa+2bC+b85fF45ITft1USZwCAHvWZOP/oePztJC3J/ZPcKsndk3w4wwOA/+/M9QeOxy1r3G/1/EHrGHPg3HF7vmNdJM4AAHuIiany1qyGrjcm+eWZHuQvVtWvJrkwyQOr6r6zD+/tLiTOAAAdaq0t/bUEV4zHz8+vUNFauybJh8a3Pz8e59Pheavnr5g5N3XMlrnj9nzHuiicAQB61GerxoXjca0idHVFix+Zu/5m/cVVtXeSO2VIry9Z8B2Lxtwhw9JyXx8L9bTWrk5yWZJbjp/PO2I83qxnenspnAEAmOqjGXqb/2NVLaojVx8W/Op4PGs8PnTBtQ9Isn+S81pr18+c39qYh81ds5Ex203hDADQow4T59baP2bY0e8nkvzW7GdVdVySX8yQRq8uC3dGkm8neVxV3Wvm2v2SnDy+fcPc17w5yfVJThw3MFkdc+skLxjfvnFuzOr7F47XrY45LMmzxvu9edKP3Ipd+uFAO66xO7j2Gx/b7CnAUux/8P03ewrAzvGsJD+X5JSqeniGZenulOSRSW5K8pTW2pYkaa1dOe40eEaSc6rq9CTfTfLLGZadOyPJ22dv3lr7alX9dpLXJvlsVb09yQ0ZNlM5NMkfzz942Fo7r6pOSfK8JBdU1RkZtvV+bJLbJHn2RncNTHbxwhkAYHfVltOTvHStta9X1T2T/G6GAvgBSa7MkET/19bap+euf3dVPTDJC5M8Osl+Gbbhfl6S17YFTy221l5XVZcmeX6SX8vQJfGlJC9qrb1ljXmdVFVfzFDYPy3JSpLPJXlFa+29G/7hSWpJT1huir33OWTXnTyMJM7sLiTO7C5+cMNltdlzSJItJzx46XXOgW/5aBe/bVelxxkAACbQqgEA0KOVzZ4A8yTOAAAwgcQZAKBDvT4cuCdTOAMA9Ejh3B2tGgAAMIHEGQCgRx4O7I7EGQAAJpA4AwB0yMOB/ZE4AwDABBJnAIAe6XHujsIZAKBDWjX6o1UDAAAmkDgDAPRIq0Z3JM4AADCBxBkAoENN4twdhTMAQI8Uzt3RqgEAABNInAEAOqRVoz8SZwAAmEDiDADQI4lzdxTOAAAd0qrRH60aAAAwgcQZAKBDEuf+SJwBAGACiTMAQIckzv1ROAMA9KjVZs+AOVo1AABgAokzAECHtGr0R+IMAAATSJwBADrUVvQ490biDAAAE0icAQA6pMe5PwpnAIAONcvRdUerBgAATCBxBgDokFaN/kicAQBgAokzAECHLEfXH4UzAECHWtvsGTBPqwYAAEwgcQYA6JBWjf5InAEAYAKJMwBAhyTO/VE4AwB0yMOB/dGqAQAAE0icAQA6pFWjPxJnAACYQOIMANCh1iTOvZE4AwDABBJnAIAOtZXNngHzFM4AAB1a0arRHa0aAAAwgcQZAKBDHg7sj8QZAAAmkDgDAHTIBij9UTgDAHSotc2eAfO0agAAwAQSZwCADmnV6I/EGQAAJlA4AwB0aKXV0l87QlU9oara+HrK3GcPmvls0esP17jnXlX13Kq6oKqurarvVtX7q+rorczjR6rqZVV1YVVdV1Xfqqp3VNVPLuu3atUAAOjQrrCOc1XdMcmpSa5KcsutXPrXSc5ZcP7jC+5ZSU5PcnySC8f73ybJY5OcW1WPbq29Z27Mvkk+kuR+ST6b5DVJ7pjkMUkeXlXHttY+tV0/bgGFMwAA220scN+c5DtJ/jLJ87dy+TmttZdOvPXjMhTN5yV5cGvtuvH73pih0H5TVZ3VWvv+zJjnZSiaz0jy2Nbayjjm7UneneTPqupuq+fXS6sGAECHWlv+a8l+M8mxSZ6U5Ool3vcZ4/FFq0VzkrTWPpPk7Ulun6GwTvKvBfzTx7f/x2xxPCbTH0vyH5M8cKMTUzgDALBdxr7hP0zymtbauROG/IeqOrGqXlBVT66qI9a4735Jjk5yTYaCd94HxuOxM+cOT/ITSS5qrX114ph10aoBANChHfEwX1Wdv9ZnrbV7TrzH3knemuRrSV4w8av/t/E1e593Jnlqa+17M6cPT7JXkktaazcuuM+Xx+ORM+fuMh4vWuO7F41ZF4kzAADb43eT/FySX2+tXbuNay9P8jtJ7pbkVhnaLB6W5PNJHp3kzKqarUcPHI9b1rjf6vmDNjhmXSTOAAAd2hGrakxNlddSVffOkDL/cWvtkxO+738m+Z8zp65K8sGqOi/JFzI80Pefk7xnwfDuSJwBADrU28OBY4vGn2doiXjxxn5buzLJ/z2+fcDMR6vp8IFZbPX8FRscsy6bUjhX1Suq6uLN+G4AANbllhn6hH8yyXWzG5kkecl4zZvGc6+ecL/Lx+MBM+cuTnJTkjuPhfq81YcKZ/uZLxyPa/UwLxqzLpvVqnG7JIdNuXBrTex73eLgZc0HAKArO2qnvw24PsmfrvHZPTL0PX88QyG7zTaOJPcZj5esnmitXTe2cdx/fJ09N+Zh4/GsmXMXZ3hQ8ciqutOClTUWjVkXPc4AAGzT+CDgUxZ9VlUvzVA4v6W19icz5+/VWvvsguufkGEnwBuSvGPu4zdkKJpPrqrZDVCOGsdcnuSdM/Nq4+Yof5Dkv1XV7AYovzLe60sZdi/ckKUUzlX159s5ZM19xudtrYl9730OWf5S3gAAHdgVttye4IyqujHDNthfT7JfkqOS/HySG5P8Rmvt0rkxpyd5VIZNTj5fVWcmuW2GonmvDEvYXTk35pQkjxjHfKqqPpphbefHZFgT+skb3TUwWV7i/IQkLcn2/Des6AUAWEOHrRrr8YYkv5Bh9YzbZagVL0tyWpJXt9b+dn7AmCA/PsOW209O8uwk1yU5N8nJrbXzFoy5vqoekmHpu8cneW6SKzNst/2S1tqXlvFjqi1h/8Wq2pLhTxHPnDjkd5Ic11rbayPfK3Fmd3DtNxZtjAS7nv0Pvv9mTwGW4gc3XNZFxfqpgx+19Drn3t/4yy5+265qWYnz3yb5mdbapN6Rqvr1JX0vAMBuSTrYn2UtR/eFJLesqsOXdD8AAOjKshLnv87wxOKhGZYE2ZZ3J7l0Sd8NALDb2U16nHcrSymcW2vvzMyyIBOuf092ka0VAQA2w26yqsZuxZbbAAAwgQ1QAAA6tOFFh1k6iTMAAEwgcQYA6FDbrn3l2BkkzgAAMIHEGQCgQyt2QOmOwhkAoEMrWjW6o1UDAAAmkDgDAHTIw4H9kTgDAMAEEmcAgA7ZAKU/CmcAgA5p1eiPVg0AAJhA4gwA0CGtGv2ROAMAwAQSZwCADkmc+6NwBgDokIcD+6NVAwAAJpA4AwB0aEXg3B2JMwAATCBxBgDo0Ioe5+5InAEAYAKJMwBAh9pmT4CbUTgDAHTIOs790aoBAAATSJwBADq0Uh4O7I3EGQAAJpA4AwB0yMOB/VE4AwB0yMOB/dGqAQAAE0icAQA6tOLZwO5InAEAYAKJMwBAh1Yicu6NwhkAoENW1eiPVg0AAJhA4gwA0CEPB/ZH4gwAABNInAEAOmQDlP5InAEAYAKJMwBAh6yq0R+FMwBAhzwc2B+tGgAAMIHEGQCgQx4O7I/EGQAAJpA4AwB0SOLcH4UzAECHmocDu6NVAwAAJpA4AwB0SKtGfyTOAAAwgcQZAKBDEuf+KJwBADpky+3+aNUAAIAJJM4AAB1asRxddyTOAAAwgcQZAKBDHg7sj8QZAIDJquqPquqjVfVPVXVtVX23qj5fVS+pqtuuMeboqnr/eO21VXVBVT2nqvbayvc8oqrOqaotVXVVVX2qqk7YxtxOqKpPj9dvGcc/YqO/eZXCGQCgQys74LUkz01yQJKPJHlNkr9IcmOSlya5oKruOHtxVf1KknOTPCDJu5KcmmSfJK9KcvqiL6iqE5OcmeSnk7wtyZuSHJzktKp65RpjXpnktCR3GK9/W5K7JTlzvN+GVWu77mIne+9zyK47eRhd+42PbfYUYCn2P/j+mz0FWIof3HBZF4/lvfInnrD0Ouf5X3vbhn9bVe3XWrtuwfmXJ3lBkje01p45nvv3Sb6S5MAk92utfXb1HknOSnLfJI9vrZ0+c5/DkvxDkquT3LO1dul4/tZJPpPk8CRHt9Y+OTPm6CSfSHJxkqNaa9+budf5GQr9u67ea70kzgAATLaoaB69YzweMXPu+CS3T3L6atE8c48XjW+fMXefJyfZN8mps4XuWAz/wfj26XNjVt+/fLVoHsdcmuT14/2etOaPmkjhDADQoZVa/msH+8/j8YKZc8eOxw8uuP7cJNckObqq9p045gNz12xkzHazqgYAwB6iqs5f67PW2j23817PT3LLDG0Y90rynzIUzX84c9ldxuNFC77vxqr6apKfSnLnJH8/Ycw3q+rqJIdW1f6ttWuq6oAkhyS5qrX2zQVT/fJ4PHJ7ft8iCmcAgA7tAsvRPT/Jj828/2CSX2+tXT5z7sDxuGWNe6yeP2g7xxwwXnfNOr9jXRTOAAAd2hErIGxvqryNe/14klTVjyU5OkPS/PmqekRr7XPL+p6e6HEGAGDdWmv/0lp7V5Ljktw2yZ/PfLya9h54s4E/fP6KdYzZMnfcnu9YF4kzbLLb/i+/sNlTgKW4/DF32fZFwGQrOyRz3nFaa/9YVV9K8rNVdbvW2reTXJih//nIDMvC/auq2jvJnTKsAX3JzEcXJrndOOaTc2PukKFN4+uttWvG7726qi5LckhV3WFBn/PqKh8365neXhJnAACW5eDxeNN4PGs8PnTBtQ9Isn+S81pr18+c39qYh81ds5Ex203hDADQoR53DqyqI6vqZi0RVfXvxg1QfjRDIby6lvIZSb6d5HFVda+Z6/dLcvL49g1zt3tzkuuTnDhuYLI65tYZNlhJkjfOjVl9/8LxutUxhyV51ni/N0/6kVuhVQMAoEOdNmr8UpL/WlUfT/LVJN/JsLLGAzMsKffPSZ66enFr7cqqemqGAvqcqjo9yXeT/HKGZefOSPL22S9orX21qn47yWuTfLaq3p7khgybqRya5I9ndw0cx5xXVackeV6Gbb/PyLCt92OT3CbJsze6a2CicAYAYLq/SvIfMqzZ/HMZlni7OkP/8FuTvLa19t3ZAa21d1fVA5O8MMmjk+yXYRvu543X3+zPCK2111XVpRmWvPu1DF0SX0ryotbaWxZNrLV2UlV9MUPC/LQMIfvnkryitfbeDf7uJApnAIAu9biOc2vt75KcuI5xn8iQVm/PmDOTnLmdY05Lctr2jNkeepwBAGACiTMAQIdWarNnwDyFMwBAh3a1dZz3BFo1AABgAokzAECH5M39kTgDAMAEEmcAgA71uBzdnk7iDAAAE0icAQA6ZFWN/iicAQA6pGzuj1YNAACYQOIMANAhDwf2R+IMAAATSJwBADrk4cD+KJwBADqkbO6PVg0AAJhA4gwA0CEPB/ZH4gwAABNInAEAOtR0OXdH4QwA0CGtGv3RqgEAABNInAEAOmQd5/5InAEAYAKJMwBAh+TN/ZE4AwDABBJnAIAO6XHuj8IZAKBDlqPrj1YNAACYQOIMANAhOwf2R+IMAAATSJwBADqkx7k/CmcAgA5p1eiPVg0AAJhA4gwA0CGtGv2ROAMAwAQSZwCADq00Pc69UTgDAHRI2dwfrRoAADCBxBkAoEMrMufuSJwBAGACiTMAQIdsgNIfiTMAAEwgcQYA6JANUPqjcAYA6JCHA/ujVQMAACaQOAMAdMjDgf2ROAMAwAQSZwCADnk4sD8KZwCADrWmVaM3WjUAAGACiTMAQIcsR9cfiTMAAEwgcQYA6JCHA/ujcAYA6JB1nPujVQMAACaQOAMAdMjDgf2ROAMAwAQSZwCADtkApT8SZwAAmEDiDADQIcvR9UfiDADQobYD/rNRVXV8Vb2uqj5WVVdWVauqt61x7WHj52u9Tt/K95xQVZ+uqquqaktVnVNVj9jK9XtV1XOr6oKquraqvltV76+qozf8o2dInAEAmOpFSX4myVVJvp7krhPG/G2Sdy84/3eLLq6qVyY5abz/m5Lsk+RxSc6sqme31k6du76SnJ7k+CQXJjk1yW2SPDbJuVX16NbaeybMc5sUzgAAHep0ObrnZihov5LkgUnOnjDmC621l065+ZgQn5Tk4iRHtda+N55/RZLzk7yyqt7bWrt0ZtjjMhTN5yV5cGvtunHMG5N8PMmbquqs1tr3p8xha7RqAAAwSWvt7Nbal9uOW/Lj6ePx5atF8/i9lyZ5fZJ9kzxpbswzxuOLVovmccxnkrw9ye0zFNYbtrTEuap+NcmDktyY5IOttY+scd0JSU5orR27rO8GANjd7IjatKrO38r33XPpXzg4uKp+I8ltk3wnySdbaxesce1qffjBBZ99IMmLx2tekiRVtV+So5Nck+Rja4x54jjmzev9Aas2XDiPfSVvT/LoJDWefk5VvS/Jr7XWrpgbcliGaB8AgDV02qqxHg8ZX/+qqs7JEKR+bebcAUkOSXJVa+2bC+7z5fF45My5w5PsleSS1tqNE8es2zIS5ydliL//Kckbk/wgyQlJHpHk41V1bGvtW+u9+db+ZLTXLQ5e720BAPY4OzBVXuSaJL+f4cHAS8Zzd0/y0iTHJPloVf1sa+3q8bMDx+OWNe63ev6gmXPrGbNuyyqcr8jQwP2tJKmqVyX5oyTPS/JXY/H87SV8FwDAHmEZy8dtprEu/N250+dW1XEZHtq7d5KnJHnNzp7bei2jcL5bkjNmU+XW2k1Jnl9VX0vy6gzF8zGzTd5Tbe1PRnvvc8iu/XcUAMAeprV2Y1X9SYbC+QH5t8J5NR0+cOHAfzs/2wa8njHrtozCeZ8k/7Log9baa6tqJclrk3ykqn5hCd8HALDbW9lhC1d04fLxeMDqidba1VV1WZJDquoOC/qcjxiPF82cuzjJTUnuXFV7L+hzXjRm3ZaxHN1lSX5irQ/HRaqfl+QeST6Utf9EAADAqO2AV8wR0UUAAAiZSURBVEfuMx4vmTt/1nh86IIxD5u7JuPyc+cl2T/J/aeM2YhlFM5fzNDgvabW2quT/J9Jjkpy4hK+EwCAjlXVParqZrVmVT04w0YqSTK/Xfcbx+MLq+rWM2MOS/KsJNfn5svKvWE8njwuT7c65qgMuwdenuSd6/sVP2wZrRrvT/LIqnp4a+19a13UWvujqtonycvS3R96AAD60uNydFX1yCSPHN/++Hi8b1WdNv71t1trzx//+pQkR1TVeRl2G0yGVTVW12p+cWvtvNn7t9bOq6pTMnQrXFBVZ2RoC35shm20nz23a2AybLf9qAyrvH2+qs7MsGb0YzMsVffU1tqV6//V/2YZhfNfZpjU1du6sLX2++MDg4ct4XsBANi5fjbDssOz7jy+kuQfk6wWzm9N8qsZOg4eluQWGZ6Le0eSU1trizYsSWvtpKr6YoaE+WlJVpJ8LskrWmvvXXB9q6rHZ2jZeHKSZye5Lsm5SU6eL843onbcjok7nlU12B3sf4t9N3sKsBRfe+Rhmz0FWIqD/uKs2vZVO959Dzlm6XXOJy87u4vftqtaRo8zAADs9pbRqgEAwJLtyl0BuyuFMwBAh3p8OHBPp1UDAAAmkDgDAHSoSZy7I3EGAIAJJM4AAB3ycGB/FM4AAB3ycGB/tGoAAMAEEmcAgA5p1eiPxBkAACaQOAMAdEiPc38UzgAAHbKOc3+0agAAwAQSZwCADq14OLA7EmcAAJhA4gwA0CE9zv2ROAMAwAQSZwCADulx7o/CGQCgQ1o1+qNVAwAAJpA4AwB0SKtGfyTOAAAwgcQZAKBDepz7o3AGAOiQVo3+aNUAAIAJJM4AAB3SqtEfiTMAAEwgcQYA6FBrK5s9BeYonAEAOrSiVaM7WjUAAGACiTMAQIea5ei6I3EGAIAJJM4AAB3S49wfiTMAAEwgcQYA6JAe5/4onAEAOrSicO6OVg0AAJhA4gwA0KHm4cDuSJwBAGACiTMAQIc8HNgfhTMAQIes49wfrRoAADCBxBkAoENaNfojcQYAgAkkzgAAHbIBSn8UzgAAHdKq0R+tGgAAMIHEGQCgQ5aj64/EGQAAJpA4AwB0SI9zfyTOAAAwgcQZAKBDlqPrj8IZAKBDzcOB3dGqAQAAE0icAQA6pFWjPxJnAACYQOIMANAhy9H1R+EMANAhDwf2R6sGAADbpaoOrao/q6pvVNX1VXVpVb26qm692XPbkSTOAAAd6rVVo6oOT3Jekh9N8p4k/5Dk55P8VpKHVtX9Wmvf2cQp7jASZwAAtsf/laFo/s3W2iNba7/TWjs2yauS3CXJyzd1djuQwhkAoEOttaW/NmpMm49LcmmS1899/JIkVyd5YlUdsOEv65DCGQCgQ20HvJbgmPH44dbayg/Nt7XvJ/lEkv2T3Gc5X9cXPc4AAHuIqjp/rc9aa/eccIu7jMeL1vj8yxkS6SOTfHT7Zte/XbpwvvGGy2qz57C7W/0HbOI/TNAlfx+zu/D38p5lR9Q5WyucJzpwPG5Z4/PV8wdt8Hu6tEsXzgAATOcPXRujxxkAgKlWE+UD1/h89fwVO2EuO53CGQCAqS4cj0eu8fkR43GtHuhdmsIZAICpzh6Px1XVD9WRVXWrJPdLck2Sv9nZE9sZFM4AAEzSWrs4yYeTHJbkWXMfvyzJAUne2lq7eidPbafwcCAAANvjmRm23H5tVT04yd8nuXeGNZ4vSvLCTZzbDlW97oMOAECfquqOSX4vyUOT3DbJN5O8K8nLWmvf28y57UgKZwAAmECPMwAATKBwBgCACRTOAAAwgcIZAAAmUDgDAMAECmcAAJhA4cxCVXVoVf1ZVX2jqq6vqkur6tVVdevNnhtMUVXHV9XrqupjVXVlVbWqettmzwu2R1XdtqqeUlXvqqqvVNW1VbWlqj5eVf/7/JbHwI5lHWdupqoOz7Aj0I8meU+Sf0jy8xl2BLowyf1aa9/ZvBnCtlXVF5L8TJKrknw9yV2T/EVr7QmbOjHYDlX19CRvyLC5xNlJvpbkx5I8KsmBSd6Z5DHNv8xhp1A4czNV9aEkxyX5zdba62bOn5LkuUn+e2vt6Zs1P5iiqo7JUDB/JckDMxQdCmd2KVV1bJIDkryvtbYyc/7Hk3w6yR2THN9ae+cmTRH2KP4vHn7ImDYfl+TSJK+f+/glSa5O8sSqOmAnTw22S2vt7NbalyVx7Mpaa2e11s6cLZrH8/+c5I3j2wft9InBHkrhzLxjxuOHF/wP9feTfCLJ/knus7MnBsAP+cF4vHFTZwF7EIUz8+4yHi9a4/Mvj8cjd8JcAFigqvZO8mvj2w9u5lxgT6JwZt6B43HLGp+vnj9oJ8wFgMX+MMlPJ3l/a+1Dmz0Z2FMonAFgF1JVv5nkpAwrHj1xk6cDexSFM/NWE+UD1/h89fwVO2EuAMyoqhOTvCbJl5Ic01r77iZPCfYoCmfmXTge1+phPmI8rtUDDcAOUFXPSfK6JH+XoWj+502eEuxxFM7MO3s8Hje/I1VV3SrJ/ZJck+RvdvbEAPZUVfVfkrwqyRcyFM3f2uQpwR5J4cwPaa1dnOTDSQ5L8qy5j1+WYSH+t7bWrt7JUwPYI1XVizM8DHh+kge31r69yVOCPZadA7mZBVtu/32Se2dY4/miJEfbcpveVdUjkzxyfPvjSX4xySVJPjae+3Zr7fmbMTeYqqpOSHJakpsytGksWvHo0tbaaTtxWrDHUjizUFXdMcnvJXloktsm+WaSdyV5WWvte5s5N5iiql6aYbfLtfxja+2wnTMbWJ8Jfx8nyV+31h6042cDKJwBAGACPc4AADCBwhkAACZQOAMAwAQKZwAAmEDhDAAAEyicAQBgAoUzAABMoHAGAIAJFM4AADCBwhkAACZQOAMAwAQKZwAAmEDhDAAAEyicAQBgAoUzAABMoHAGAIAJFM4AADDB/w99EyvPLPY4NAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "image/png": { - "height": 248, - "width": 359 - }, - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#model = models.resnet50(pretrained=True)\n", - "#model = timm.create_model('resnet50', pretrained=True)\n", - "model = timm.create_model('mixnet_xl', pretrained=True)\n", - "#model.fc #show fully connected layer for ResNet family\n", - "model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - "# Create classifier\n", - "for param in model.parameters():\n", - " param.requires_grad = True\n", - "# define `classifier` for ResNet\n", - "# Otherwise, define `fc` for EfficientNet family \n", - "#because the definition of the full connection/classifier of 2 CNN families is differnt\n", - "fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - "# connect base model (EfficientNet_B0) with modified classifier layer\n", - "model.fc = fc\n", - "criterion = nn.CrossEntropyLoss()\n", - "#optimizer = Nadam(model.parameters(), lr=0.001)\n", - "#optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)\n", - "optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - "#show our model architechture and send to GPU\n", - "model.to(device)\n", - "CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Extra_Large_Covid-19.pth'\n", - "try:\n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " print(\"checkpoint loaded\")\n", - "except:\n", - " checkpoint = None\n", - " print(\"checkpoint not found\")\n", - "\n", - "def load_model(path): \n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - "load_model(CHECK_POINT_PATH) \n", - "\n", - "since = round(time.monotonic() * 1000)\n", - "\n", - "#since = time.time()\n", - "model.eval()\n", - "y_test = []\n", - "y_pred = []\n", - "for images, labels in data_loader['data']:\n", - " images = Variable(images.cuda())\n", - " labels = Variable(labels.cuda())\n", - " outputs = model(images)\n", - " _, predictions = outputs.max(1)\n", - " \n", - " y_test.append(labels.data.cpu().numpy())\n", - " y_pred.append(predictions.data.cpu().numpy())\n", - " \n", - "y_test = np.concatenate(y_test)\n", - "y_pred = np.concatenate(y_pred)\n", - "pd.DataFrame({'true_label':y_test,'predicted_label':y_pred}).to_csv('/home/linh/Downloads/Covid-19/results/Modified_MixNet_Extra_Large_Covid-19_Whole_Dataset.csv',index=False)\n", - "\n", - "time_elapsed = round(time.monotonic() * 1000) - since \n", - "\n", - "#time_elapsed = time.time() - since\n", - "\n", - "print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", - "\n", - "sns.heatmap(confusion_matrix(y_test, y_pred))\n", - "accuracy_score(y_test, y_pred)\n", - "\n", - "report = classification_report(y_test, y_pred, digits=6)\n", - "print(report)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "checkpoint loaded\n", - "Training complete in 542m 52s\n", - " precision recall f1-score support\n", - "\n", - " 0 0.8121 0.6119 0.6979 219\n", - " 1 0.5910 0.9955 0.7417 1341\n", - " 2 0.8690 0.3108 0.4578 1345\n", - "\n", - " accuracy 0.6496 2905\n", - " macro avg 0.7574 0.6394 0.6325 2905\n", - "weighted avg 0.7364 0.6496 0.6070 2905\n", - "\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAs4AAAHwCAYAAAC2dOlsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3dfbSud1kf+O8lkBfimCAw5U0bsAk6itImVggVSCgUKsWoYaBTIKJWUMACwdGRd6xtLTEICQusUqKkM8GGARbvWJNAMAglTkAWlhBCRNJoDCSBJICQfc0fz711Z2fvc+6993PO/p1zPh/Ws+7z3K+/Z3EWXPnm+v3u6u4AAAB79i27PQAAADgQKJwBAGAGhTMAAMygcAYAgBkUzgAAMIPCGQAAZlA4AwDADApnAACYQeEMAAAzKJwBAGAGhTMAAMygcAYAgBkUzgAAMMOdd3sAO3H8PU/s3R4D7NTnv3Ldbg8BluLIOx+220OApbjh5itrt8eQJN+4/qql1zl3uccDhvhtByqJMwAAzHBAJ84AAAetldt2ewSsI3EGAIAZJM4AACPqld0eAesonAEARrSicB6NVg0AAJhB4gwAMKDWqjEciTMAAMwgcQYAGJEe5+EonAEARqRVYzhaNQAAYAaJMwDAiLw5cDgSZwAAmEHiDAAwIj3Ow5E4AwDADBJnAIARWY5uOApnAIABeXPgeLRqAADADBJnAIARadUYjsQZAABmkDgDAIxIj/NwJM4AACNauW35nx2qqtOq6uyquqSqvlxVXVXnbXLucVX1S1V1YVX9RVX9TVX9VVW9vapO3uSan5zuudnnmZtcd2RVvbyqPl1VX6uq66rq96vqe3b8o9eQOAMAMNeLkvxAkpuTfCHJd+/h3F9N8qQkn0ry7iRfSvLAJE9I8oSq+jfd/ZpNrn17kss32P+x9Tuq6vAkf5DkYdPxVyf5jiRPTPIjVXVKd39k7z9t7xTOAAAjGrNV43lZFMxXJnlEkov2cO57k/x6d/9/a3dW1SOyKHRfWVX/tbuv3eDat3X3uTPH9PwsiuYLkjypp3X8qurNSd6W5D9X1YN6Cev7adUAAGCW7r6ouz/T3T3j3HPXF83T/g8kuTjJYUlO2sl4qqqSrLZv/J9ri+PufnuSS5L8b1kU+TsmcQYAGNE+WI6uqi7b7Fh3n7D0B27uG9P2m5scf3BVPTfJEUmuSXJRd39hg/O+K8l3Jrmiuz+3wfH3JPnhJKdkz+n4LApnAIARjdmqsWNV9feTPCrJrUk+uMlp/2bd99uq6neSPLe7v7Zm/wOn7RWb3Ocz0/b47Yx1PYUzAMAhYj+nyncwTeT7L0kOz6K14oZ1p3wuyXOSvD+LXuqjk/yTJP8+yTOSfFuS/2PN+UdP25s2eeTq/mN2PPgonAEAxnSQvTmwqu6U5E1ZTOR7c5Iz158z9T9/YM2uW5P816r64yQfT/Ivq+rXu/vj+2HId2ByIAAA+9RUNJ+XxRJxv5/kKXMmGK7q7r/IYkm7JHn4mkOrifLR2djq/hvnj3ZzEmcAgAF17/yFJSOoqrtk0Z7xxCT/d5Kn9fZ+3F9P26PW7Pv0tN2sh/m4abtZD/SWKJwBAEZ0EEwOrKrDskiYfzTJ7yV5+g7WU/6haXvVmn2fTfL5JMdX1f03WFnjcdP2wm0+83a0agAAsHTTRMC3ZlE0vyEziuaqOnGDfd9SVf9XkocmuT6LF6skSaZ2j9dPX/9jVX3Lmut+NIul6D6V2/dNb5vEGQBgRANODqyqU5OcOn2917R9aFWdO/35+u5+wfTn1yf551kUu9ckecnifSW3c3F3X7zm+3+vqk9mMRHwmix6lB+W5PuymCj4r7r7y+vucVaSxyc5LclHquoPs1jb+YnTNT+1jLcGJgpnAADme3CS09fte8D0SZI/T7JaON9/2t4jyUv2cM+L1/z5zCT/OIsXlnx7kpUsWjFem+Ss7r5q/cXd/fWqenSSX07yL7N4LfiXs3jd9ku7+1NzftgctYUJjcM5/p4nHriDh8nnv3Ldbg8BluLIOx+220OApbjh5ivvEIvuhq9d9ral1zlHnHDqEL/tQKXHGQAAZtCqAQAwopWDYzm6g4nCGQBgRAfBcnQHG60aAAAwg8QZAGBEAy5Hd6iTOAMAwAwSZwCAEelxHo7CGQBgRFo1hqNVAwAAZpA4AwCMSOI8HIkzAADMIHEGABhQtzcHjkbhDAAwIq0aw9GqAQAAM0icAQBGZB3n4UicAQBgBokzAMCI9DgPR+IMAAAzSJwBAEakx3k4CmcAgBFp1RiOVg0AAJhB4gwAMCKtGsOROAMAwAwSZwCAEelxHo7CGQBgRArn4WjVAACAGSTOAAAjMjlwOBJnAACYQeIMADAiPc7DUTgDAIxIq8ZwtGoAAMAMS02cq+rIJA9JcnySY6bdNya5Iskfd/dXl/k8AICDllaN4SylcK6quyX5tSRPTXLXTU67tap+L8mLuvuGZTwXAAD2lx0XzlV1TJI/SvLdSW5J8gdJPpPkpumUo5Mcl+RhSX4uyclV9dDuvmmD2wEAkOhxHtAyEueXZlE0vyrJS7v75o1OqqpvTfKKJM9N8pIkZ8y5eVVdttmx4+5xwpYHCwAA27GMyYGnJrmwu8/YrGhOku6+ubufn+TiJD++hOcCABy8VlaW/2FHlpE43zvJ/7OF8/84yUlzT+7uTWPl4+95Ym/huQAABw6F7nCWkTh/MckDt3D+90zXAADAAWMZhfP7kpxaVT+/txOr6tlJnpDkvUt4LgDAwat7+R92ZBmtGi9O8iNJzq6qM5K8P4t1m9euqnF8ksckOTbJdVlMDgQAgAPGjgvn7r6mqh6a5HVJHp3kGUnW/yNNTdv3J/n57r5mp88FADio6XEezlJegNLdVyX5Z1X1gCQnZ9HzfPR0+KYkn05y0XQeAAB7o3AezlJfuT0VxopjAAAOOkstnAEAWBJvDhzOMlbVAACAg57EGQBgRHqch6NwBgAYkXWXh6NVAwAAZpA4AwCMSKvGcCTOAAAwg8QZAGBEEufhSJwBAGAGiTMAwIi8AGU4CmcAgAH1iuXoRqNVAwAAZpA4AwCMyOTA4UicAQBgBokzAMCITA4cjsIZAGBEJgcOR6sGAADMIHEGABiRyYHDkTgDAMAMEmcAgBFJnIejcAYAGFGbHDgarRoAADCDxBkAYERaNYYjcQYAYJaqOq2qzq6qS6rqy1XVVXXeXq45qareXVVfqqqvVtUnquq5VXWnPVzz+Kq6uKpuqqqbq+ojVXX6Xp5zelV9dDr/pun6x2/3t25E4QwAMKKVXv5n516U5NlJHpzkmr2dXFU/muSDSR6e5K1JzklyWJJXJTl/k2ueneQdSb4vyXlJfjvJfZKcW1VnbnLNmUnOTXLv6fzzkjwoyTum+y2FwhkAgLmel+T4JN+W5Of2dGJVfVsWRextSR7Z3T/d3b+YRdH94SSnVdWT111zbJIzk3wpyYnd/azufl6S70/y2SRnVNVD111zUpIzpuPf393P6+5nJTlhus+Z0313TOEMADCiXln+Z6dD6r6ouz/TPWvJj9OS3DPJ+d39sTX3+FoWyXVyx+L7p5IcnuSc7r56zTU3JPl309dnrrtm9fuvTeetXnN1ktdO93v6jPHulcIZAGBEY7ZqbMUp0/a9Gxz7YJJbk5xUVYfPvOY9687ZyTXbYlUNAIBDRFVdttmx7j5hyY974LS9YoNnfbOqPpfke5M8IMmfzbjm2qq6Jcn9ququ3X1rVR2V5L5Jbu7uazcYw2em7fE7+B1/S+EMADCgPvCXozt62t60yfHV/cds8ZqjpvNu3eYztk3hDABwiNgHqfIhReEMADCi/d+TvGyrae/Rmxxf3X/jumvuMR374h6uuWnddivP2DaTAwEARjTgqhpb9Olpe4f+4qq6c5L7J/lmkqtmXnPvLNo0vtDdtyZJd9+SxXrS3zodX++4aXuHnuntUDgDALAvXDhtH7vBsYcnuWuSS7v76zOvedy6c3ZyzbYonAEARnTgL0d3QZLrkzy5qk5c3VlVRyT5t9PX16275o1Jvp7k2WtfWlJVd0vyK9PX16+7ZvX7C6fzVq85Nsmzpvu9cfs/4+/ocQYAYJaqOjXJqdPXe03bh1bVudOfr+/uFyRJd3+5qv51FgX0xVV1fhZv8ntCFsvOXZDkzWvv392fq6pfTPKaJB+rqjcn+ZssXqZyvyS/0d0fXnfNpVV1VpLnJ/lEVV2QxWu9n5Tk25M8Z+3LVHZC4QwAMKIxl6N7cJLT1+17wPRJkj9P8oLVA939tqp6RJIXJvmJJEckuTKLIvc1G72BsLvPrqqrp/s8LYsOiU8leVF3/+5Gg+ruM6rqT7NImH82yUqSP0nyyu5+5/Z+6h3VvDcmjun4e5544A4eJp//ynW7PQRYiiPvfNhuDwGW4oabr6zdHkOS3PKSJy+9zjnqFecP8dsOVHqcAQBgBq0aAAAj2v/Lx7EXEmcAAJhB4gwAMKID/82BBx2JMwAAzCBxBgAYUI+5HN0hTeEMADAirRrD0aoBAAAzSJwBAEYkcR6OxBkAAGaQOAMAjMgLUIajcAYAGJFWjeFo1QAAgBkkzgAAA2qJ83AkzgAAMIPEGQBgRBLn4SicAQBG5JXbw9GqAQAAM0icAQBGpFVjOBJnAACYQeIMADAiifNwJM4AADCDxBkAYEDdEufRKJwBAEakVWM4WjUAAGAGiTMAwIgkzsM5oAvnq266dreHADv21f95yW4PAZbiyPv88G4PAWCfOqALZwCAg1VLnIejcAYAGJHCeTgmBwIAwAwSZwCAEa3s9gBYT+IMAAAzSJwBAAZkcuB4FM4AACNSOA9HqwYAAMwgcQYAGJHJgcOROAMAwAwSZwCAAZkcOB6JMwAAzCBxBgAYkR7n4SicAQAGpFVjPFo1AABgBokzAMCItGoMR+IMAAAzSJwBAAbUEufhKJwBAEakcB6OVg0AAJhB4gwAMCCtGuOROAMAwAwSZwCAEUmch6NwBgAYkFaN8WjVAACAGSTOAAADkjiPR+IMAAAzSJwBAAYkcR6PwhkAYERduz0C1tGqAQAAM0icAQAGpFVjPBJnAACYQeIMADCgXtHjPBqJMwAAzCBxBgAYkB7n8SicAQAG1JajG45WDQAAmEHiDAAwIK0a45E4AwDADBJnAIABWY5uPBJnAIABdS//s1NV9ZNV1Xv53Lbm/GP3cu75e3jW6VX10aq6uapuqqqLq+rxO/8V2ydxBgBgrsuTvHyTYz+c5JQk79ng2MeTvG2D/Z/c6EZVdWaSM5J8IclvJzksyZOTvKOqntPd52xx3EuhcAYAGNCIrRrdfXkWxfMdVNWHpz/+pw0OX97dL5vzjKo6KYui+bNJfrC7b5j2vzLJZUnOrKp3dvfVWxv9zmnVAABgR6rqQUkekuSaJO/a4e2eOW1/bbVoTpKpUH5tksOTPH2Hz9gWiTMAwID2ReJcVZdt+rzuE3Zw65+dtm/o7ts2OH6fqnpGkrsn+WKSD3f3Jza51ynT9r0bHHtPkhdP57x0B+PdFoUzAMCAljGZb3+oqiOTPCXJbUl+Z5PTHj191l53cZLTu/vza/YdleS+SW7u7ms3uM9npu3xOxz2tiicAQAOETtMlTfzvyc5Jsm7uvsv1h27NcmvZjEx8Kpp3/cneVmSk5P8YVU9uLtvmY4dPW1v2uRZq/uPWcK4t0zhDAAwoBEnB25itU3jt9Yf6O7rkrxk3e4PVtVjknwoyQ8l+Zkkr96nI1wSkwMBANiWqvreJCdlsWzcu+de193fzN+1dTx8zaHVRPnobGx1/41bGObSSJwBAAbUfUAkznubFLgnfz1tj1rd0d23VNU1Se5bVffeoM/5uGl7xdaHunMSZwAAtqyqjkjy1CwmBb5hG7d4yLS9at3+C6ftYze45nHrztmvFM4AAAPqleV/luyJSe6W5D0bTApMklTVP6qqO9SbVfWoJM+bvp637vDrp+0Lq+pua645Nsmzknw9yRt3NPJt0qoBADCglfFbNVbbNDZ6U+Cqs5IcV1WXZtEHnSxW1Vhdq/nF3X3p2gu6+9KqOivJ85N8oqouyOKV209K8u1JnrMbbw1MFM4AAGxRVX1Pkn+SvU8KfFOSH0vyg1m0WdwlyV8l+f0k53T3JRtd1N1nVNWfZpEw/2ySlSR/kuSV3f3OZf2OrVI4AwAMaOTJgd39Z0n2OsDufkO21/+c7j43ybnbuXZf0eMMAAAzSJwBAAZ0AL0A5ZChcAYAGFD3bo+A9bRqAADADBJnAIABadUYj8QZAABmkDgDAAzoAHgByiFH4QwAMKCR13E+VGnVAACAGSTOAAADshzdeCTOAAAwg8QZAGBAJgeOR+IMAAAzSJwBAAZkVY3xKJwBAAZkcuB4dqVVo6peWVWf3Y1nAwDAduxW4nyPJMfOObGqLtvs2J3ucp9ljQcAYCgmB47H5EAAAJhhKYlzVf3eFi85ae6J3X3CZsfufNh9df8AAAclkwPHs6xWjack6SRb+W9Y0QsAsAmtGuNZVuH8lSRfSPLzM8//5SSPWdKzAQBgn1tW4fzxJD/Q3R+Yc3JV/eSSngsAcFDyr+bHs6zJgZcn+daq+q4l3Q8AAIayrMT5A0l+OMn9ksxZn/ltSa5e0rMBAA46epzHs5TCubvfkuQtWzj/7UnevoxnAwAcjKyqMR7rOAMAwAy79eZAAAD2YGW3B8AdSJwBAGAGiTMAwIB6S++VY3+QOAMAwAwSZwCAAa14A8pwFM4AAANa0aoxHK0aAAAwg8QZAGBAJgeOR+IMAAAzSJwBAAbkBSjjUTgDAAxIq8Z4tGoAAMAMEmcAgAFp1RiPxBkAAGaQOAMADEjiPB6FMwDAgEwOHI9WDQAAmEHiDAAwoBWB83AkzgAAMIPEGQBgQCt6nIcjcQYAgBkkzgAAA+rdHgB3oHAGABiQdZzHo1UDAABmkDgDAAxopUwOHI3EGQAAZpA4AwAMyOTA8SicAQAGZHLgeLRqAADADBJnAIABrZgbOByJMwAAzCBxBgAY0EpEzqNROAMADMiqGuPRqgEAADNInAEABmRy4HgkzgAAMIPEGQBgQF6AMh6JMwAAzCBxBgAYkFU1xqNwBgAYkMmB49GqAQDAbFV1dVX1Jp+/3OSak6rq3VX1par6alV9oqqeW1V32sNzHl9VF1fVTVV1c1V9pKpO33e/bO8kzgAAAxp8cuBNSX5zg/03r99RVT+a5C1JvpbkzUm+lORfJHlVkocleeIG1zw7ydlJvpjkvCR/k+S0JOdW1YO6+wXL+Rlbo3AGAGCrbuzul+3tpKr6tiS/neS2JI/s7o9N+1+c5MIkp1XVk7v7/DXXHJvkzCwK7BO7++pp/yuS/PckZ1TVW7r7w8v8QXNo1QAAGNDKPvjsgtOS3DPJ+atFc5J099eSvGj6+nPrrvmpJIcnOWe1aJ6uuSHJv5u+PnNfDXhPJM4AAAPqsScHHl5VT0nynUluSfKJJB/s7tvWnXfKtH3vBvf4YJJbk5xUVYd399dnXPOedefsVwpnAIBDRFVdttmx7j5hC7e6V5I3rdv3uap6end/YM2+B07bKzZ43jer6nNJvjfJA5L82Yxrrq2qW5Lcr6ru2t23bmHMO6ZVAwBgQAO3arwxyaOyKJ6PSvKgJL+V5Ngk76mqH1hz7tHT9qZN7rW6/5htXHP0Jsf3GYkzAMAhYoup8mb3ePm6XZ9M8syqujnJGUleluTHdvqcEUmcAQAGNHDivJnXT9uHr9m3t3R4df+N27hms0R6n1E4AwAMqPfBZx/762l71Jp9n562x68/uarunOT+Sb6Z5KqZ19x7uv8X9nd/c6JwBgBgOR4ybdcWwRdO28ducP7Dk9w1yaVrVtTY2zWPW3fOfqVwBgAY0Eot/7NTVfU9VXXUBvuPTXLO9PW8NYcuSHJ9kidX1Ylrzj8iyb+dvr5u3e3emOTrSZ493Xf1mrsl+ZXp6+uzC0wOBABgridl8ea+Dyb58yRfSfJdSX4kyRFJ3p3FW/+SJN395ar611kU0BdX1flZvBHwCVksO3dBFq/hzpprPldVv5jkNUk+VlVvzt+9cvt+SX5jN94amCicAQCGtEtv+tubi7IoeP9hkodl0W98Y5IPZbGu85u6+3bt1N39tqp6RJIXJvmJLArsK5M8P8lr1p8/XXN2VV2d5AVJnpZFl8Snkryou3933/y0vVM4AwAwy/Rykw/s9cQ7XvdHSf75Fq95R5J3bPVZ+5LCGQBgQIMmzoc0hTMAwID2w/JxbJFVNQAAYAaJMwDAgJaxfBzLJXEGAIAZJM4AAAMyOXA8CmcAgAGZHDgerRoAADDDAZ043+VOB/TwIUny1V96xm4PAZbil+/ziN0eAhxUVmTOw5E4AwDADCJbAIABmRw4HoUzAMCANGqMR6sGAADMIHEGABiQVo3xSJwBAGAGiTMAwIBWardHwHoKZwCAAVnHeTxaNQAAYAaJMwDAgOTN45E4AwDADBJnAIABWY5uPBJnAACYQeIMADAgq2qMR+EMADAgZfN4tGoAAMAMEmcAgAGZHDgeiTMAAMwgcQYAGJDJgeNROAMADEjZPB6tGgAAMIPEGQBgQCYHjkfiDAAAM0icAQAG1Lqch6NwBgAYkFaN8WjVAACAGSTOAAADso7zeCTOAAAwg8QZAGBA8ubxSJwBAGAGiTMAwID0OI9H4QwAMCDL0Y1HqwYAAMwgcQYAGJA3B45H4gwAADNInAEABqTHeTwKZwCAAWnVGI9WDQAAmEHiDAAwIK0a45E4AwDADBJnAIABrbQe59EonAEABqRsHo9WDQAAmEHiDAAwoBWZ83AkzgAAMIPEGQBgQF6AMh6JMwAAzCBxBgAYkBegjEfhDAAwIJMDx6NVAwAAZpA4AwAMyOTA8UicAQBgBokzAMCATA4cj8IZAGBA3Vo1RqNVAwAAZpA4AwAMyHJ045E4AwDADBJnAIABmRw4HokzAMCAeh/8Z6eq6u5V9TNV9daqurKqvlpVN1XVh6rqp6vqW9adf2xV9R4+5+/hWadX1Uer6ubpGRdX1eN3/CN2QOIMAMBcT0zyuiTXJrkoyeeT/L0kP57kd5I8rqqe2HdcEuTjSd62wf0+udFDqurMJGck+UKS305yWJInJ3lHVT2nu89Zwm/ZMoUzAMCABp0ceEWSJyR5V3f/bTdJVf1Kko8m+Yksiui3rLvu8u5+2ZwHVNVJWRTNn03yg919w7T/lUkuS3JmVb2zu6/e2U/ZOq0aAADM0t0Xdvc71hbN0/6/TPL66esjd/iYZ07bX1stmqdnXJ3ktUkOT/L0HT5jWyTOAAAD2hcvQKmqy/bwvBN2ePtvTNtvbnDsPlX1jCR3T/LFJB/u7k9scp9Tpu17Nzj2niQvns556Q7Gui0KZwAAdqSq7pzkadPXjQreR0+ftddcnOT07v78mn1HJblvkpu7+9oN7vOZaXv8Tse8HQpnAIAB7Yvl6JaQKm/mPyT5viTv7u73rdl/a5JfzWJi4FXTvu9P8rIkJyf5w6p6cHffMh07etretMlzVvcfs6Rxb4keZwCAAY24HN1GquoXspjM9z+SPPV2v6H7uu5+SXf/SXffOH0+mOQxST6S5B8k+Zl9MrB9QOEMAMC2VNWzk7w6yaeSnNzdX5pzXXd/M4vl65Lk4WsOrSbKR2djq/tv3OJQl0KrBgDAgAZdju5vVdVzk7wqi7WYH9Xd123xFn89bY9a3dHdt1TVNUnuW1X33qDP+bhpe8V2xrxTEmcAALakqn4pi6L58iyS5q0WzUnykGl71br9F07bx25wzePWnbNfLa1wrqofq6pXV9VvVNWj93De6VW1Kz8WAOBA0d1L/yxDVb04i8mAl2WRNF+/h3P/0frXcE/7H5XkedPX89YdXl0P+oVVdbc11xyb5FlJvp7kjdsd/07suFWjqirJm7N4U0xNu59bVe9K8rTuXt+DcmySR+z0uQAAB7MRWzWq6vQkr0hyW5JLkvzCohS8nau7+9zpz2clOa6qLs3i9dnJYlWN1bWaX9zdl669uLsvraqzkjw/ySeq6oIsXrn9pCTfnuQ5u/HWwGQ5Pc5PT3Jakr/I4p8QvpHk9CSPT/Khqjplm/F9kj0v1H3EEd+53dsCALB195+2d0ry3E3O+UCSc6c/vynJjyX5wSzaLO6S5K+S/H6Sc7r7ko1u0N1nVNWfZpEw/2wWq/P9SZJXdvc7d/4ztmdZhfONWbxL/LokqapXJfn1LP5J4b9NxfOmMT4AALe3r5aP24nuflkWazDPPf8NSd6wzWedm78rwIewjML5QUkuWJsqd/dtSV5QVZ9P8ptZFM8nr33f+Fx7Wqj7yCP//nh/owAAOCgto3A+LIvI/Q66+zVVtZLkNUn+oKr+6RKeBwBw0FtZ0mQ+lmcZhfM1STZtNu7uc6b3l5+V5H1J/mgJzwQAOKgpm8ezjML5T7N41/imuvs3q+rwJP8+yT9cwjMBAGC/WsY6zu9Ocp+q+pE9ndTdv57kpfG2QgCAvVpJL/3DziyjiP1/s1iS5Ja9ndjdvzpNGDx2Cc8FAID9ZseFc3d/KclvbeH8393pMwEADnYS4vEs7ZXbAABwMNNvDAAwoLYc3XAUzgAAA9KqMR6tGgAAMIPEGQBgQC1xHo7EGQAAZpA4AwAMyOTA8SicAQAGZHLgeLRqAADADBJnAIABadUYj8QZAABmkDgDAAxIj/N4FM4AAAOyjvN4tGoAAMAMEmcAgAGtmBw4HIkzAADMIHEGABiQHufxSJwBAGAGiTMAwID0OI9H4QwAMCCtGuPRqgEAADNInAEABqRVYzwSZwAAmEHiDAAwID3O41E4AwAMSKvGeLRqAADADBJnAIABadUYj8QZAABmkDgDAAyoe2W3h8A6CmcAgAGtaNUYjlYNAACYQeIMADCgthzdcCTOAAAwg8QZAGBAepzHI3EGAIAZJM4AAAPS4zwehTMAwIBWFM7D0aoBAAAzSJwBAAbUJgcOR+IMAAAzSJwBAAZkcuB4FM4AAAOyjvN4tGoAAMAMEmcAgAFp1RiPxBkAAGaQOAMADMgLUMajcAYAGJBWjfFo1QAAgK1TJ7wAAAQxSURBVBkkzgAAA7Ic3XgkzgAAMIPEGQBgQHqcxyNxBgCAGSTOAAADshzdeBTOAAADapMDh6NVAwAAZpA4AwAMSKvGeCTOAAAwg8QZAGBAlqMbj8IZAGBAJgeOR6sGAADMIHEGABiQVo3xSJwBAGAGiTMAwIAkzuNROAMADEjZPB6tGgAAMEP51wDsSVVdliTdfcJujwW2y99jDhb+LsPukjgDAMAMCmcAAJhB4QwAADMonAEAYAaFMwAAzKBwBgCAGSxHBwAAM0icAQBgBoUzAADMoHAGAIAZFM4AADCDwhkAAGZQOAMAwAwKZwAAmEHhzIaq6n5V9Z+r6n9W1der6uqq+s2quttujw3mqKrTqursqrqkqr5cVV1V5+32uGArquruVfUzVfXWqrqyqr5aVTdV1Yeq6qeryv+Pw37kBSjcQVV9V5JLk/yvSd6e5H8k+cdJTk7y6SQP6+4v7t4IYe+q6vIkP5Dk5iRfSPLdSf5Ldz9lVwcGW1BVz0zyuiTXJrkoyeeT/L0kP57k6CRvSfLE9n/msF8onLmDqnpfksck+YXuPnvN/rOSPC/Jb3X3M3drfDBHVZ2cRcF8ZZJHZFF0KJw5oFTVKUmOSvKu7l5Zs/9eST6a5DuSnNbdb9mlIcIhxb/i4XamtPkxSa5O8tp1h1+a5JYkT62qo/bz0GBLuvui7v6MJI4DWXdf2N3vWFs0T/v/Msnrp6+P3O8Dg0OUwpn1Tp6279/gf6i/kuSPktw1yUP298AAuJ1vTNtv7uoo4BCicGa9B07bKzY5/plpe/x+GAsAG6iqOyd52vT1vbs5FjiUKJxZ7+hpe9Mmx1f3H7MfxgLAxv5Dku9L8u7uft9uDwYOFQpnADiAVNUvJDkjixWPnrrLw4FDisKZ9VYT5aM3Ob66/8b9MBYA1qiqZyd5dZJPJTm5u7+0y0OCQ4rCmfU+PW0362E+btpu1gMNwD5QVc9NcnaST2ZRNP/lLg8JDjkKZ9a7aNo+Zv0bqarqf0nysCS3Jvnj/T0wgENVVf1SklcluTyLovm6XR4SHJIUztxOd382yfuTHJvkWesOvzyLhfjf1N237OehARySqurFWUwGvCzJo7r7+l0eEhyyvDmQO9jgldt/luSHsljj+YokJ3nlNqOrqlOTnDp9vVeSf5bkqiSXTPuu7+4X7MbYYK6qOj3JuUluy6JNY6MVj67u7nP347DgkKVwZkNV9R1JXpHksUnunuTaJG9N8vLuvmE3xwZzVNXLsnjb5Wb+vLuP3T+jge2Z8fc4ST7Q3Y/c96MBFM4AADCDHmcAAJhB4QwAADMonAEAYAaFMwAAzKBwBgCAGRTOAAAwg8IZAABmUDgDAMAMCmcAAJhB4QwAADMonAEAYAaFMwAAzKBwBgCAGRTOAAAwg8IZAABmUDgDAMAMCmcAAJjh/weRomYVyxlVEwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "image/png": { - "height": 248, - "width": 359 - }, - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#model = models.resnet50(pretrained=True)\n", - "#model = timm.create_model('resnet50', pretrained=True)\n", - "model = timm.create_model('mixnet_xl', pretrained=True)\n", - "#model.fc #show fully connected layer for ResNet family\n", - "model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - "# Create classifier\n", - "for param in model.parameters():\n", - " param.requires_grad = True\n", - "# define `classifier` for ResNet\n", - "# Otherwise, define `fc` for EfficientNet family \n", - "#because the definition of the full connection/classifier of 2 CNN families is differnt\n", - "fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - "# connect base model (EfficientNet_B0) with modified classifier layer\n", - "model.fc = fc\n", - "criterion = nn.CrossEntropyLoss()\n", - "#optimizer = Nadam(model.parameters(), lr=0.001)\n", - "#optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)\n", - "optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - "#show our model architechture and send to GPU\n", - "model.to(device)\n", - "CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Extra_Large_Covid-19.pth'\n", - "try:\n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " print(\"checkpoint loaded\")\n", - "except:\n", - " checkpoint = None\n", - " print(\"checkpoint not found\")\n", - "\n", - "def load_model(path): \n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - "load_model(CHECK_POINT_PATH) \n", - "\n", - "since = round(time.monotonic() * 1000)\n", - "\n", - "#since = time.time()\n", - "model.eval()\n", - "y_test = []\n", - "y_pred = []\n", - "for images, labels in data_loader['COVID-19 Radiography Database_2020_05_18_Kaggle']:\n", - " images = Variable(images.cuda())\n", - " labels = Variable(labels.cuda())\n", - " outputs = model(images)\n", - " _, predictions = outputs.max(1)\n", - " \n", - " y_test.append(labels.data.cpu().numpy())\n", - " y_pred.append(predictions.data.cpu().numpy())\n", - " \n", - "y_test = np.concatenate(y_test)\n", - "y_pred = np.concatenate(y_pred)\n", - "pd.DataFrame({'true_label':y_test,'predicted_label':y_pred}).to_csv('/home/linh/Downloads/Covid-19/results/Modified_MixNet_Extra_Large_Covid-19_COVID-19 Radiography Database_2020_05_18_Kaggle.csv',index=False)\n", - "\n", - "time_elapsed = round(time.monotonic() * 1000) - since \n", - "\n", - "#time_elapsed = time.time() - since\n", - "\n", - "print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", - "\n", - "sns.heatmap(confusion_matrix(y_test, y_pred))\n", - "accuracy_score(y_test, y_pred)\n", - "\n", - "report = classification_report(y_test, y_pred, digits=4)\n", - "print(report)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# MixNet_XXL with image size = 224" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Pretrained model URL is invalid, using random initialization.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "checkpoint loaded\n", - "Training complete in 3m 23s\n", - " precision recall f1-score support\n", - "\n", - " 0 0.000000 0.000000 0.000000 108\n", - " 1 0.931230 0.963846 0.947257 8851\n", - " 2 0.933893 0.902665 0.918013 6041\n", - "\n", - " accuracy 0.932267 15000\n", - " macro avg 0.621708 0.622170 0.621757 15000\n", - "weighted avg 0.925598 0.932267 0.928660 15000\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/linh/.conda/envs/CV/lib/python3.6/site-packages/sklearn/metrics/classification.py:1437: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n", - " 'precision', 'predicted', average, warn_for)\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAs4AAAHwCAYAAAC2dOlsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de9yu13wn/s93EhHC7MShLcErZRJ6oL824hDjFCNlaKvoD68fE/xQ5yI67ctZSw9DKWLoaCvKzISJ00QJKtKEtEXQ1GiTSKQqtA5hR44kz5o/ruvRO3fuZ+/reZ5772ftvd9vr/u1cl/XWte1bq/Yvvnmu9aq1loAAIAd+zdbPQEAANgTCJwBAGACgTMAAEwgcAYAgAkEzgAAMIHAGQAAJhA4AwDABAJnAACYQOAMAAATCJwBAGACgTMAAEwgcAYAgAkEzgAAMMH+Wz2Bzdj/gEPbVs8BNqu2egIAXMcPvn9xF380/+BbFy49zrnBLW7fxW/bU8k4AwDABHt0xhkAYK+1cu1Wz4A5Ms4AADCBjDMAQI/aylbPgDkCZwCAHq0InHujVAMAACaQcQYA6FBTqtEdGWcAAJhAxhkAoEdqnLsjcAYA6JFSje4o1QAAgAlknAEAeuTkwO7IOAMAwAQyzgAAPVLj3B0ZZwAAmEDGGQCgR7aj647AGQCgQ04O7I9SDQAAmEDGGQCgR0o1uiPjDAAAE8g4AwD0SI1zdwTOAAA9cnJgd5RqAADABDLOAAA9UqrRHRlnAACYQMYZAKBHtqPrjsAZAKBHSjW6o1QDAAAmkHEGAOiRUo3uyDgDAMAEMs4AAB1qzQEovRE4AwD0yOLA7ijVAACACWScAQB6ZHFgd2ScAQBgAhlnAIAeqXHujowzAABMIOMMANCjFdvR9UbgDADQI6Ua3VGqAQAAE8g4AwD0yHZ03ZFxBgCACWScAQB6pMa5OwJnAIAeKdXojlINAACYQOAMANCjlZXlfzapqh5fVW0nn2tn+h+2k74n7eBdx1XVp6rqsqraXlWnV9VDd9B/v6p6blWdU1VXVtUlVfXBqjp60z98pFQDAICpPp/k5Wvcu3eSY5J8aMG9v03yvgXXv7DoQVX16iTHJ/lqkrckOSDJo5OcUlXPaq2dMNe/kpyU5JFJzk1yQpKbJXlUkjOq6hGttffv+KftnMAZAKBDrfV3cmBr7fMZgufrqaq/Gv/yvy24/fnW2sumvGPMEB+f5IIkR7XWvjNef1WSs5O8uqo+0Fq7aGbYozMEzWcleUBr7apxzJuTfCLJW6rqtNba96bMYS1KNQAAetRhqcZaqurOSe6R5OIkf77Jxz11bF+5GjQnyRgovzHJDZM8YW7M08b2RatB8zjm00nemeSWGQLrTRE4AwCwWU8Z2z9pi1Plt66qX62qF4ztXXbwrGPG9tQF9z401ydVdWCSo5NckeTMKWM2SqkGAECPdsE+zlV19pqva+3IDT7zRkkem+TaJH+8RrcHjp/ZcacnOa619pWZawclOTTJZa21ry94zvlje8TMtTsk2S/Jha21ayaO2RAZZwAANuP/TXJwklNba/80d++KJL+d5Mgkh4yf+yb5eJL7JfnYGCyv2ja229d41+r1gzc5ZkNknAEAerQLapI3mlXeidUyjT9a8L5vJHnJ3OUzqurYDIv27p7kSUletwvmtXQyzgAAbEhV/VSG+uKvJvng1HFjScVqWcd9Zm6tZoe3ZbHV69/d5JgNkXEGAOjRLqhx3gV2tihwR745tj8s1WitXV5VFyc5tKputaDO+fCxPW/m2gUZ6qtvX1X7L6hzXjRmQ2ScAQB61Pl2dONuFo/LELT+yQYecY+xvXDu+mlj+6AFYx481yfj9nNnJblxhkNYdjpmowTOAABsxK9kWOz3oQWLApMkVfVzVXW9eLOqHpDkuePXd8zdfvPYvrCqDpkZc1iSZyS5Oslb58a8aWxfMQb0q2OOynB64DeTvHvnP2nHlGoAAPSo/1KN1TKNRScFrnpNksOr6qwMddBJcpf8657KL26tnTU7oLV2VlW9JsnzkpxTVSdnOHL7URmO0X7W3KmByXDc9sMzHHLyuao6JcnNxzH7JXlya+3S9f/E6xI4AwCwLlX1E0n+fXa+KPDtSX45yVEZSiZukORfkrwryQmttUUHlqS1dnxV/V2GDPNTkqwk+WySV7XWPrCgf6uqx2Qo2XhikmcluSrJGUleMR+cb1S11pbxnC2x/wGH7rmTh1Ft9QQAuI4ffP/iLv5ovvJDr196nHOjBz+7i9+2p5JxBgDo0S7Yx5nNsTgQAAAmkHEGAOhR/4sD9zkyzgAAMIGMMwBAj9Q4d0fgDADQI6Ua3VGqAQAAEyw141xVN8pw7vgRSQ4eL383yXlJ/rq1duUy3wcAsNdSqtGdpQTO4znir0zyuCQ3XqPbFVX1Z0le1Fr7zjLeCwAAu8umA+eqOjjJJ5PcKcnlST6a5Pwk28cu25IcnuReSZ6W5P5Vdc/W2vYFjwMAIFHj3KFlZJxfmiFofm2Sl7bWLlvUqapukuS3kjwnyUuSHD/l4VV19lr39rvBrdc9WQAA2IhlLA58WJLTWmvHrxU0J0lr7bLW2vOSnJ7k4Ut4LwDA3mtlZfkfNmUZGedbJfmf6+j/10mOntq5tXbkWvf2P+DQto73AgDsOQS63VlGxvnbSe64jv4/MY4BAIA9xjIC5w8neVhVPX1nHavqmUl+McmpS3gvAMDeq7Xlf9iUZZRqvDjJQ5K8oaqOT/KRDPs2z+6qcUSSY5McluQbGRYHAgDAHmPTgXNr7eKqumeSNyV5YJJfTTL/jzQ1th9J8vTW2sWbfS8AwF5NjXN3lnIASmvtwiQ/X1W3T3L/DDXP28bb25Ocm+TjYz8AAHZG4NydpR65PQbGgmMAAPY6Sw2cAQBYEicHdmcZu2oAAMBeT8YZAKBHapy7I3AGAOiRfZe7o1QDAAAmkHEGAOiRUo3uyDgDAMAEMs4AAD2Sce6OjDMAAEwg4wwA0CMHoHRH4AwA0KG2Yju63ijVAACACWScAQB6ZHFgd2ScAQBgAhlnAIAeWRzYHYEzAECPLA7sjlINAACYQMYZAKBHFgd2R8YZAAAmkHEGAOiRjHN3BM4AAD1qFgf2RqkGAABMIOMMANAjpRrdkXEGAIAJZJwBAHrkAJTuyDgDAMAEMs4AAD1qapx7I3AGAOiRUo3uKNUAAIAJZJwBADrUbEfXHRlnAACYQMYZAKBHapy7I3AGAOiRXTW6o1QDAAAmkHEGAOiRUo3uyDgDAMAEMs4AAD2yHV13BM4AAD1SqtEdpRoAADCBjDMAQI9sR9cdGWcAAJhAxhkAoEdqnLsj4wwAABMInAEAOtRWVpb+WaaqekBVvbeq/rmqrq6qr1XVh6vqPy7oe3RVfbCqLqmqK6vqnKp6TlXtt4PnP7SqTq+q7VV1WVX9TVUdt5M5HVdVnxr7bx/HP3QZvzcROAMA9GmlLf+zJFX1X5L8RZK7JvnfSf4gyZ8nuWWS+831/aUkZyS5T5L3JjkhyQFJXpvkpDWe/8wkpyT56STvSPKWJLdOcmJVvXqNMa9OcmKSW43935HkzklOGZ+3adXanls/s/8Bh+65k4dRbfUEALiOH3z/4i7+aL7sNx6+9DjnJr//nk3/tqp6cpL/luRtSZ7SWvv+3P0btNZ+MP71v03ypSTbktyrtfaZ8fqBSU5Lcs8kj2mtnTQz/rAk/5Dk8iRHttYuGq8fkuTTSe6Q5OjW2l/NjDk6ySeTXJDkqNbad2aedXaSg5LcafVZGyXjDADQow4zzlV1wySvTPKVLAiak2Q1aB49MkMW+qTVoHnsc1WSF41fnzb3iCcmuWGSE2YD3TEY/p3x61Pnxqx+f+Vq0DyOuSjJG8fnPWHnv3DHBM4AAEz1wAyB8HuSrFTVQ6rqN6rq16rqngv6HzO2py64d0aSK5IcPQbkU8Z8aK7PZsasm+3oAAB6tAsOQKmqs9d8XWtHTnjEUWN7VZLPZahBnn3+GUke2Vr75njpjmN73oL3XVNVX07yU0lun+TvJ4z5elVdnuQ2VXXj1toVVXVQkkOTXNZa+/qCOZ8/tkdM+H07JOMMANCjDks1kvzI2P56kpbk3klumuQuST6SYQHg/5rpv21st6/xvNXrB29gzLa5dj3v2BAZZwCAfcTErPKOrCZdr0nyizM1yH9XVb+c5Nwk962qe84u3ttbyDgDAHSorbSlf5bgu2P7ufkdKlprVyT58Pj1bmM7nx2et3r9uzPXpo7ZPteu5x0bInAGAGCqc8d2rSB0dUeLG831v159cVXtn+THM2SvL1zwjkVjbpVha7mvjoF6WmuXJ7k4yU3G+/MOH9vr1Uyvl8AZAKBHfdY4fyxDbfNPVtWiOHJ1seCXx/a0sX3Qgr73SXLjJGe11q6eub6jMQ+e67OZMesmcAYA6NHKyvI/m9Ra+8cMJ/rdLsmvzd6rqmOT/HyGbPTqtnAnJ/lWkkdX1V1n+h6Y5BXj1zfNveatSa5O8szxAJPVMYckecH49c1zY1a/v3DstzrmsCTPGJ/31kk/cgcsDgQAYD2ekeRnk7ymqh6SYVu6H0/ysCTXJnlSa217krTWLh1PGjw5yelVdVKSS5L8YoZt505O8s7Zh7fWvlxVv57k9Uk+U1XvTPL9DIep3CbJH8wvPGytnVVVr0nyvCTnVNXJGY71flSSmyV51mZPDUwcuQ1brotzXQH4oV6O3P7e0x+89Djnpv/1Q0v5bVV1yyQvyRAA3yrJpUnOTPK7rbVPLeh/ryQvzHDE9oEZjuH+0ySvb61du8Y7fiHJ85P8XIYqiS9mOE3wbTuY1+MzBPY/mWQlyWeTvKq19oEN/dD55wucYWt18aczAD8kcGYtSjUAAHq0nMV8LJHFgQAAMIGMMwBAh/bkctq9lcAZAKBHSjW6o1QDAAAmkHEGAOiRjHN3BM6wxa742plbPQVYilsc9sCtngLALiVwBgDoUJNx7o7AGQCgRwLn7lgcCAAAE8g4AwD0aGWrJ8A8GWcAAJhAxhkAoEMWB/ZH4AwA0COBc3eUagAAwAQyzgAAPbI4sDsyzgAAMIGMMwBAhywO7I+MMwAATCDjDADQIzXO3RE4AwB0SKlGf5RqAADABDLOAAA9UqrRHRlnAACYQMYZAKBDTca5OwJnAIAeCZy7o1QDAAAmkHEGAOiQUo3+yDgDAMAEMs4AAD2Sce6OwBkAoENKNfqjVAMAACaQcQYA6JCMc39knAEAYAIZZwCADsk490fgDADQo1ZbPQPmKNUAAIAJZJwBADqkVKM/Ms4AADCBjDMAQIfaihrn3sg4AwDABDLOAAAdUuPcH4EzAECHmu3ouqNUAwAAJpBxBgDokFKN/sg4AwDABDLOAAAdsh1dfwTOAAAdam2rZ8A8pRoAADCBjDMAQIeUavRHxhkAACaQcQYA6JCMc38EzgAAHbI4sD9KNQAAYAIZZwCADinV6I+MMwAATCDjDADQodZknHsj4wwAABPIOAMAdKitbPUMmCdwBgDo0IpSje4o1QAAgAkEzgAAHWqtlv7ZFarqsVXVxs+T5u7db+beos/vrfHM/arquVV1TlVdWVWXVNUHq+roHczjRlX18qo6t6quqqpvVNW7quonlvVblWoAALAhVXXbJCckuSzJTXbQ9S+TnL7g+icWPLOSnJTkkUnOHZ9/sySPSnJGVT2itfb+uTE3TPLRJPdK8pkkr0ty2yS/kuQhVXVMa+1v1vXjFhA4AwB0qPcDUMYA961Jvp3kPUmev4Pup7fWXjbx0Y/OEDSfleQBrbWrxve9OUOg/ZaqOq219r2ZMc/LEDSfnORRrQ1LK6vqnUnel+RPq+rOq9c3SqkGAECHWlv+Z8meneSYJE9IcvkSn/u0sX3RatCcJK21Tyd5Z5JbZgisk/wwgH/q+PU/zwbHY2b6zCQ/meS+m52YwBkAgHUZ64Z/L8nrWmtnTBjy76rqmVX1gqp6YlUdvsZzD0xydJIrMgS88z40tsfMXLtDktslOa+19uWJYzZEqQYAQId2RalGVZ295vtaO3LiM/ZP8vYkX0nygomv/v/Gz+xz3p3kya2178xcvkOS/ZJc2Fq7ZsFzzh/bI2au3XFsz1vj3YvGbIiMMwAA6/GSJD+b5PGttSt30vebSX4zyZ2T3DRDmcWDk3wuySOSnFJVs/HotrHdvsbzVq8fvMkxGyLjDADQoV1xAMrUrPJaquruGbLMf9Ba+6sJ7/s/Sf7PzKXLkpxaVWcl+XyGBX2/kOT9C4Z3R8YZAKBDve3jPJZo/FmGkogXb+63tUuT/I/x631mbq1mh7dlsdXr393kmA0ROAMAMMVNMtQJ/0SSq2YPMkny0rHPW8Zrfzjhed8c24Nmrl2Q5Noktx8D9Xmriwpn65nPHdu1apgXjdkQpRoAAB3aBdvHbdbVSf5kjXs/l6Hu+RMZAtmdlnEkucfYXrh6obV21VjGce/x8/G5MQ8e29Nmrl2QYaHiEVX14wt21lg0ZkMEzgAA7NS4EPBJi+5V1csyBM5va6398cz1u7bWPrOg/2MznAT4/STvmrv9pgxB8yuqavYAlKPGMd9M8u6ZebXxcJTfSfJfqmr2AJRfGp/1xQynF26KwBkAoEO7YnHgFji5qq7JcAz2V5McmOSoJHdLck2SX22tXTQ35qQkD89wyMnnquqUJDfPEDTvl2ELu0vnxrwmyUPHMX9TVR/LsLfzr2TYE/qJmz01MBE4AwCw67wpyX/IsHvGLZJUkouTnJjkD1trfzs/YMwgPybDkdtPTPKsJFclOSPJK1prZy0Yc3VVPTDD1nePSfLcJJdmOG77pa21Ly7jx1TrsIBmqv0POHTPnTyMrvzaooORYM9zi8MeuNVTgKXYftkFXaR6P3e7X1p6nPOzX3l/F79tTyXjDADQoT04t7nX2pLt6KrqVVV1wVa8GwAANmKrMs63SHLYlI47OlN9vxvcelnzAQDoyl6yOHCv4gAUAACYYCkZ56r6s3UOOXpqxx2dqW5xIACwt9rsEdks37JKNR6bpGXYYmQqQS8AwBqUavRnWYHz9zJsav30if1/M8mxS3o3AADscssKnP82yc+01iYdZVhVj1/SewEA9kr+1Xx/lrU48PNJblJVd1jS8wAAoCvLyjj/ZZJ7J7lNkin7M78vyUVLejcAwF5HjXN/lhI4t9beneTd6+j//iTvX8a7AQD2RnbV6I99nAEAYIKtOjkQAIAdWNnqCXA9Ms4AADCBjDMAQIfaus6VY3eQcQYAgAlknAEAOrTiBJTuCJwBADq0olSjO0o1AABgAhlnAIAOWRzYHxlnAACYQMYZAKBDDkDpj8AZAKBDSjX6o1QDAAAmkHEGAOiQUo3+yDgDAMAEMs4AAB2Sce6PwBkAoEMWB/ZHqQYAAEwg4wwA0KEVCefuyDgDAMAEMs4AAB1aUePcHRlnAACYQMYZAKBDbasnwPUInAEAOmQf5/4o1QAAgAlknAEAOrRSFgf2RsYZAAAmkHEGAOiQxYH9ETgDAHTI4sD+KNUAAIAJZJwBADq0Ym1gd2ScAQBgAhlnAIAOrUTKuTcCZwCADtlVoz9KNQAAYAIZZwCADlkc2B8ZZwAAmEDGGQCgQw5A6Y+MMwAATCDjDADQIbtq9EfgDADQIYsD+6NUAwAAJpBxBgDokMWB/ZFxBgCACWScAQA6JOPcH4EzAECHmsWB3VGqAQAAE8g4AwB0SKlGf2ScAQBgAhlnAIAOyTj3R+AMANAhR273R6kGAACTVdXvV9XHquqfqurKqrqkqj5XVS+tqpuvMeboqvrg2PfKqjqnqp5TVfvt4D0PrarTq2p7VV1WVX9TVcftZG7HVdWnxv7bx/EP3exvXiVwBgDo0Eot/7Mkz01yUJKPJnldkv+e5JokL0tyTlXddrZzVf1SkjOS3CfJe5OckOSAJK9NctKiF1TVM5OckuSnk7wjyVuS3DrJiVX16jXGvDrJiUluNfZ/R5I7JzllfN6mVWt77r8I2P+AQ/fcycPoyq+dudVTgKW4xWEP3OopwFJsv+yCLnZQft3tHrv0OOfXvvKOTf+2qjqwtXbVguuvTPKCJG9qrT19vPZvk3wpybYk92qtfWb1GUlOS3LPJI9prZ0085zDkvxDksuTHNlau2i8fkiSTye5Q5KjW2t/NTPm6CSfTHJBkqNaa9+ZedbZGQL9O60+a6NknAEAOrSyCz7LsChoHr1rbA+fufbIJLdMctJq0DzzjBeNX58295wnJrlhkhNmA90xGP6d8etT58asfn/latA8jrkoyRvH5z1hzR81kcAZAIBl+IWxPWfm2jFje+qC/mckuSLJ0VV1w4ljPjTXZzNj1s2uGgAAHdoV29FV1dlr3WutHbnOZz0/yU0ylGHcNcm/zxA0/95MtzuO7XkL3ndNVX05yU8luX2Sv58w5utVdXmS21TVjVtrV1TVQUkOTXJZa+3rC6Z6/tgesZ7ft4jAGQCgQ3vAQq7nJ/nRme+nJnl8a+2bM9e2je32NZ6xev3gdY45aOx3xQbfsSECZwCAfcR6s8o7edaPJUlV/WiSozNkmj9XVQ9trX12We/piRpnAIAOdbwd3XW01v6ltfbeJMcmuXmSP5u5vZrt3Xa9gde9/t0NjNk+167nHRsicAYAYNNaa/+Y5ItJfqqqbjFePndsr1dfXFX7J/nxDHtAXzhza0djbpWhTOOrrbUrxvdenuTiJDcZ789b3eXjejXT6yVwBgDoUK/b0e3Ercf22rE9bWwftKDvfZLcOMlZrbWrZ67vaMyD5/psZsy6CZwBADrUdsFns6rqiKq6XklEVf2b8QCUH8kQCK/upXxykm8leXRV3XWm/4FJXjF+fdPc496a5OokzxwPMFkdc0iGA1aS5M1zY1a/v3DstzrmsCTPGJ/31kk/cgcsDgQAYKr/mOR3q+oTSb6c5NsZdta4b4Yt5f45yZNXO7fWLq2qJ2cIoE+vqpOSXJLkFzNsO3dyknfOvqC19uWq+vUkr0/ymap6Z5LvZzhM5TZJ/mD21MBxzFlV9Zokz8tw7PfJGY71flSSmyV51mZPDUwEzrDlDjv8F3beCfYAF/38bbd6CrBXWelzQ7q/SPLvMuzZ/LMZtni7PEP98NuTvL61dsnsgNba+6rqvklemOQRSQ7McAz388b+1/uhrbU3VNVFGba8+08ZqiS+mORFrbW3LZpYa+34qvq7DBnmp2SoTvlskle11j6wyd+dROAMAMBErbUvJHnmBsZ9MkO2ej1jTklyyjrHnJjkxPWMWQ+BMwBAh3bTYj7WQeAMANChLgs19nF21QAAgAlknAEAOqRUoz8yzgAAMIGMMwBAh1Zqq2fAPIEzAECHOt3HeZ+mVAMAACaQcQYA6JB8c39knAEAYAIZZwCADtmOrj8yzgAAMIGMMwBAh+yq0R+BMwBAh4TN/VGqAQAAE8g4AwB0yOLA/sg4AwDABDLOAAAdsjiwPwJnAIAOCZv7o1QDAAAmkHEGAOiQxYH9kXEGAIAJZJwBADrUVDl3R+AMANAhpRr9UaoBAAATyDgDAHTIPs79kXEGAIAJZJwBADok39wfGWcAAJhAxhkAoENqnPsjcAYA6JDt6PqjVAMAACaQcQYA6JCTA/sj4wwAABPIOAMAdEiNc38EzgAAHVKq0R+lGgAAMIGMMwBAh5Rq9EfGGQAAJpBxBgDo0EpT49wbgTMAQIeEzf1RqgEAABPIOAMAdGhFzrk7Ms4AADCBjDMAQIccgNIfGWcAAJhAxhkAoEMOQOmPwBkAoEMWB/ZHqQYAAEwg4wwA0CGLA/sj4wwAABPIOAMAdMjiwP4InAEAOtSaUo3eKNUAAIAJZJwBADpkO7r+yDgDAMAEMs4AAB2yOLA/AmcAgA7Zx7k/SjUAAGACGWcAgA5ZHNgfGWcAAJhA4AwA0KHW2tI/m1VVj6yqN1TVmVV1aVW1qnrHGn0PG++v9TlpB+85rqo+VVWXVdX2qjq9qh66g/77VdVzq+qcqrqyqi6pqg9W1dGb/tEzlGoAADDVi5L8TJLLknw1yZ0mjPnbJO9bcP0LizpX1auTHD8+/y1JDkjy6CSnVNWzWmsnzPWvJCcleWSSc5OckORmSR6V5IyqekRr7f0T5rlTAmcAgA51uh3dczMEtF9Kct8kH58w5vOttZdNefiYIT4+yQVJjmqtfWe8/qokZyd5dVV9oLV20cywR2cIms9K8oDW2lXjmDcn+USSt1TVaa21702Zw44o1QAA6FDbBf/Z9Jxa+3hr7fy2jLqPxZ46tq9cDZrH916U5I1JbpjkCXNjnja2L1oNmscxn07yziS3zBBYb5rAGQCAXenWVfWrVfWCsb3LDvoeM7anLrj3obk+qaoDkxyd5IokZ04ZsxlKNQAAOrQrtqOrqrPXutdaO3LpLxw8cPzMzuP0JMe11r4yc+2gJIcmuay19vUFzzl/bI+YuXaHJPslubC1ds3EMRsm4wwAwK5wRZLfTnJkkkPGz2pd9P2SfGwMlldtG9vtazxv9frBmxyzYUvLOFfVL2f4L+GaJKe21j66Rr/jMvwTxlJS5gAAe6NdUUa8C7PKi971jSQvmbt8RlUdm2HR3t2TPCnJ63bXnDZr0xnnGrwryclJnpVhteWpVfW/q2pRdH9Yhn/aAABgDStpS//0YCyp+OPx631mbq1mh7dlsdXr393kmA1bRsb5CRlWKv5Tkjcn+UGS45I8NMknquqY8Z84NmRHtTj73eDWG30sAABb55tj+8NSjdba5VV1cZJDq+pWC+qcDx/b82auXZDk2iS3r6r9F9Q5LxqzYcuocX5Chij+qNba77bWXp3k/0nymiQ/meQvquoWS3gPAMA+o8ft6JboHmN74dz108b2QQvGPHiuT8bt585KcuMk954yZjOWETjfOcl7ZrPKrbVrW2vPT/KcJD+dIXg+ZCMPb60dudZnCXMHAGAXqKqfq6rrxZpV9YAMpb1JMn9c95vH9oWzsWNVHZbkGUmuTvLWuTFvGttXjNvTrY45KsPpgd9M8u6N/YrrWkapxgFJ/mXRjdba66tqJcnrk3y0qv7DEt4HALDXW9llZ4xsXFU9LBpdYjMAAAhoSURBVMnDxq8/Nrb3rKoTx7/+1pg8TYbqg8Or6qwMpw0myV3yr3sqv7i1dtbs81trZ1XVa5I8L8k5VXVyhljzURmO0X7W3KmByXDc9sMzlA5/rqpOSXLzccx+SZ7cWrt047/6Xy0jcL44ye3WutlaO6Gq9s/wX96Hk3xyCe8EANir9Rc2JxnKcY+bu3b78ZMk/5hkNXB+e5JfTnJUhpKJG2RItr4ryQmttUUHlqS1dnxV/V2GDPNTMpw+/tkkr2qtfWBB/1ZVj8lQsvHEDJtVXJXkjCSvmA/ON6M2u9VJVb0nyd1aa7fZSb/fSPK7Gbar26+1tt+mXpxk/wMO7fTvKZjuRw9aytaSsOW+cMwtt3oKsBSHvPv02uo5JMm9D33A0uOcMy/+WBe/bU+1jBrnD2Y4SvEhO+rUWvv9JC+N0woBAHZqb92Obk+2jCD2PRnqRy7fWcfW2m9X1Vcy7OUMAAB7jE0Hzq21S5L80Tr6v22z7wQA2NvJEPdnGaUaAACw11NvDADQoc1u4MDyCZwBADqkVKM/SjUAAGACGWcAgA41GefuyDgDAMAEMs4AAB2yOLA/AmcAgA5ZHNgfpRoAADCBjDMAQIeUavRHxhkAACaQcQYA6JAa5/4InAEAOmQf5/4o1QAAgAlknAEAOrRicWB3ZJwBAGACGWcAgA6pce6PjDMAAEwg4wwA0CE1zv0ROAMAdEipRn+UagAAwAQyzgAAHVKq0R8ZZwAAmEDGGQCgQ2qc+yNwBgDokFKN/ijVAACACWScAQA6pFSjPzLOAAAwgYwzAECHWlvZ6ikwR+AMANChFaUa3VGqAQAAE8g4AwB0qNmOrjsyzgAAMIGMMwBAh9Q490fGGQAAJpBxBgDokBrn/gicAQA6tCJw7o5SDQAAmEDGGQCgQ83iwO7IOAMAwAQyzgAAHbI4sD8CZwCADtnHuT9KNQAAYAIZZwCADinV6I+MMwAATCDjDADQIQeg9EfgDADQIaUa/VGqAQAAE8g4AwB0yHZ0/ZFxBgCACWScAQA6pMa5PzLOAAAwgYwzAECHbEfXH4EzAECHmsWB3VGqAQAAE8g4AwB0SKlGf2ScAQBgAhlnAIAO2Y6uPwJnAIAOWRzYH6UaAACsS1Xdpqr+tKq+VlVXV9VFVfWHVXXIVs9tV5JxBgDoUK+lGlV1hyRnJfmRJO9P8g9J7pbk15I8qKru1Vr79hZOcZeRcQYAYD3+a4ag+dmttYe11n6ztXZMktcmuWOSV27p7HYhgTMAQIdaa0v/bNaYbT42yUVJ3jh3+6VJLk/yuKo6aNMv65DAGQCgQ20XfJbg/mP7kdbaynXm29r3knwyyY2T3GM5r+uLGmcAgH1EVZ291r3W2pETHnHHsT1vjfvnZ8hIH5HkY+ubXf/26MD5mu9fXFs9h73d6v/AJv6PCbrk72P2Fv5e3rfsijhnR4HzRNvGdvsa91evH7zJ93Rpjw6cAQCYzj90bY4aZwAAplrNKG9b4/7q9e/uhrnsdgJnAACmOndsj1jj/uFju1YN9B5N4AwAwFQfH9tjq+o6cWRV3TTJvZJckeSvd/fEdgeBMwAAk7TWLkjykSSHJXnG3O2XJzkoydtba5fv5qntFhYHAgCwHk/PcOT266vqAUn+PsndM+zxfF6SF27h3Hap6vUcdAAA+lRVt03yW0kelOTmSb6e5L1JXt5a+85Wzm1XEjgDAMAEapwBAGACgTMAAEwgcAYAgAkEzgAAMIHAGQAAJhA4AwDABAJnFqqq21TVn1bV16rq6qq6qKr+sKoO2eq5wRRV9ciqekNVnVlVl1ZVq6p3bPW8YD2q6uZV9aSqem9Vfamqrqyq7VX1iar6/+ePPAZ2Lfs4cz1VdYcMJwL9SJL3J/mHJHfLcCLQuUnu1Vr79tbNEHauqj6f5GeSXJbkq0nulOS/t9Yeu6UTg3WoqqcmeVOGwyU+nuQrSX40ycOTbEvy7iS/0vyfOewWAmeup6o+nOTYJM9urb1h5vprkjw3yR+11p66VfODKarq/hkC5i8luW+GoEPgzB6lqo5JclCSP2+trcxc/7Ekn0py2ySPbK29e4umCPsU/4qH6xizzccmuSjJG+duvzTJ5UkeV1UH7eapwbq01j7eWjtfJo49WWvttNbaKbNB83j9n5O8efx6v90+MdhHCZyZd/+x/ciCP6i/l+STSW6c5B67e2IAXMcPxvaaLZ0F7EMEzsy749iet8b988f2iN0wFwAWqKr9k/yn8eupWzkX2JcInJm3bWy3r3F/9frBu2EuACz2e0l+OskHW2sf3urJwL5C4AwAe5CqenaS4zPsePS4LZ4O7FMEzsxbzShvW+P+6vXv7oa5ADCjqp6Z5HVJvpjk/q21S7Z4SrBPETgz79yxXauG+fCxXasGGoBdoKqek+QNSb6QIWj+5y2eEuxzBM7M+/jYHjt/IlVV3TTJvZJckeSvd/fEAPZVVfUbSV6b5PMZguZvbPGUYJ8kcOY6WmsXJPlIksOSPGPu9sszbMT/9tba5bt5agD7pKp6cYbFgGcneUBr7VtbPCXYZzk5kOtZcOT23ye5e4Y9ns9LcrQjt+ldVT0sycPGrz+W5OeTXJjkzPHat1prz9+KucFUVXVckhOTXJuhTGPRjkcXtdZO3I3Tgn2WwJmFquq2SX4ryYOS3DzJ15O8N8nLW2vf2cq5wRRV9bIMp12u5R9ba4ftntnAxkz4+zhJ/rK1dr9dPxtA4AwAABOocQYAgAkEzgAAMIHAGQAAJhA4AwDABAJnAACYQOAMAAATCJwBAGACgTMAAEwgcAYAgAkEzgAAMIHAGQAAJhA4AwDABAJnAACYQOAMAAATCJwBAGACgTMAAEwgcAYAgAn+L3aVN+xK9iD+AAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "image/png": { - "height": 248, - "width": 359 - }, - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#model = models.resnet50(pretrained=True)\n", - "#model = timm.create_model('resnet50', pretrained=True)\n", - "model = timm.create_model('mixnet_xxl', pretrained=True)\n", - "#model.fc #show fully connected layer for ResNet family\n", - "model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - "# Create classifier\n", - "for param in model.parameters():\n", - " param.requires_grad = True\n", - "# define `classifier` for ResNet\n", - "# Otherwise, define `fc` for EfficientNet family \n", - "#because the definition of the full connection/classifier of 2 CNN families is differnt\n", - "fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - "# connect base model (EfficientNet_B0) with modified classifier layer\n", - "model.fc = fc\n", - "criterion = nn.CrossEntropyLoss()\n", - "#optimizer = Nadam(model.parameters(), lr=0.001)\n", - "#optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)\n", - "optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - "#show our model architechture and send to GPU\n", - "model.to(device)\n", - "CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Extra_Extra_Large_Covid-19.pth'\n", - "try:\n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " print(\"checkpoint loaded\")\n", - "except:\n", - " checkpoint = None\n", - " print(\"checkpoint not found\")\n", - "\n", - "def load_model(path): \n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - "load_model(CHECK_POINT_PATH) \n", - "since = round(time.monotonic() * 1000)\n", - "\n", - "#since = time.time()\n", - "model.eval()\n", - "y_test = []\n", - "y_pred = []\n", - "for images, labels in data_loader['data']:\n", - " images = Variable(images.cuda())\n", - " labels = Variable(labels.cuda())\n", - " outputs = model(images)\n", - " _, predictions = outputs.max(1)\n", - " \n", - " y_test.append(labels.data.cpu().numpy())\n", - " y_pred.append(predictions.data.cpu().numpy())\n", - " \n", - "y_test = np.concatenate(y_test)\n", - "y_pred = np.concatenate(y_pred)\n", - "pd.DataFrame({'true_label':y_test,'predicted_label':y_pred}).to_csv('/home/linh/Downloads/Covid-19/results/Modified_MixNet_Extra_Extra_Large_Covid-19_Whole_Dataset.csv',index=False)\n", - "time_elapsed = round(time.monotonic() * 1000) - since \n", - "\n", - "#time_elapsed = time.time() - since\n", - "\n", - "print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", - "\n", - "sns.heatmap(confusion_matrix(y_test, y_pred))\n", - "accuracy_score(y_test, y_pred)\n", - "\n", - "report = classification_report(y_test, y_pred, digits=6)\n", - "print(report)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Ensemble voting of 6 models with image size = 224" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def get_Eff_B0():\n", - " model = timm.create_model('efficientnet_b0', pretrained=True)\n", - "\n", - " for param in model.parameters():\n", - " param.requires_grad = True\n", - " fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - " model.fc = fc\n", - " criterion = nn.CrossEntropyLoss()\n", - " optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - "\n", - " CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/EfficientNet_B0_Covid-19.pth'\n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - " model.to(device)\n", - " model.eval()\n", - " return model\n", - " \n", - "def get_MixNet_S():\n", - " model = timm.create_model('mixnet_s', pretrained=True)\n", - " model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - " for param in model.parameters():\n", - " param.requires_grad = True\n", - " fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - " model.fc = fc\n", - " criterion = nn.CrossEntropyLoss()\n", - " optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - " CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Small_Covid-19.pth' \n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - " model.to(device)\n", - " model.eval()\n", - " \n", - " return model\n", - "\n", - "def get_MixNet_M():\n", - " model = timm.create_model('mixnet_m', pretrained=True)\n", - " model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - " for param in model.parameters():\n", - " param.requires_grad = True\n", - " fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - " model.fc = fc\n", - " criterion = nn.CrossEntropyLoss()\n", - " optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - " CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Medium_Covid-19.pth' \n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - " model.to(device)\n", - " model.eval()\n", - " \n", - " return model\n", - "\n", - "\n", - "def get_MixNet_L():\n", - " model = timm.create_model('mixnet_l', pretrained=True)\n", - " model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - " for param in model.parameters():\n", - " param.requires_grad = True\n", - " fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - " model.fc = fc\n", - " criterion = nn.CrossEntropyLoss()\n", - " optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - " CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Large_Covid-19.pth' \n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - " model.to(device)\n", - " model.eval()\n", - " \n", - " return model\n", - "\n", - "\n", - "def get_MixNet_XL():\n", - " model = timm.create_model('mixnet_xl', pretrained=True)\n", - " model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - " for param in model.parameters():\n", - " param.requires_grad = True\n", - " fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - " model.fc = fc\n", - " criterion = nn.CrossEntropyLoss()\n", - " optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - " CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Extra_Large_Covid-19.pth' \n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - " model.to(device)\n", - " model.eval()\n", - " \n", - " return model\n", - "\n", - "\n", - "def get_MixNet_XXL():\n", - " model = timm.create_model('mixnet_xxl', pretrained=True)\n", - " model.classifier #show the classifier layer (fully connected layer) for EfficientNets\n", - "\n", - " for param in model.parameters():\n", - " param.requires_grad = True\n", - " fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(1536, 1000, bias=True)),\n", - "\t\t\t\t\t\t\t ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('dropout1', nn.Dropout(0.7)),\n", - " ('fc2', nn.Linear(1000, 512)),\n", - "\t\t\t\t\t\t\t\t ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t\t ('swish1', Swish()),\n", - "\t\t\t\t\t\t\t\t ('dropout2', nn.Dropout(0.5)),\n", - "\t\t\t\t\t\t\t\t ('fc3', nn.Linear(512, 128)),\n", - "\t\t\t\t\t\t\t\t ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),\n", - "\t\t\t\t\t\t\t ('swish2', Swish()),\n", - "\t\t\t\t\t\t\t\t ('fc4', nn.Linear(128, 3)),\n", - "\t\t\t\t\t\t\t\t ('output', nn.Softmax(dim=1))\n", - "\t\t\t\t\t\t\t ]))\n", - " model.fc = fc\n", - " criterion = nn.CrossEntropyLoss()\n", - " optimizer = optim.SGD(model.parameters(), \n", - " lr=0.01,momentum=0.9,\n", - " nesterov=True,\n", - " weight_decay=0.0001)\n", - " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)\n", - " CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19/weights/MixNet_Extra_Extra_Large_Covid-19.pth' \n", - " checkpoint = torch.load(CHECK_POINT_PATH)\n", - " model.load_state_dict(checkpoint['model_state_dict'])\n", - " best_model_wts = copy.deepcopy(model.state_dict())\n", - " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", - " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", - " best_loss = checkpoint['best_val_loss']\n", - " best_acc = checkpoint['best_val_accuracy']\n", - " model.to(device)\n", - " model.eval()\n", - " \n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "class Ensemble_Model(nn.Module):\n", - " def __init__(self):\n", - " super(Ensemble_Model, self).__init__()\n", - " self.Eff_B0 = get_Eff_B0()\n", - " self.MixNet_S = get_MixNet_S()\n", - " self.MixNet_M = get_MixNet_M()\n", - " self.MixNet_L = get_MixNet_L()\n", - " self.MixNet_XL = get_MixNet_XL()\n", - " self.MixNet_XXL = get_MixNet_XXL()\n", - " def forward(self, x):\n", - " x1 = self.Eff_B0(x)\n", - " x2 = self.MixNet_S(x)\n", - " x3 = self.MixNet_M(x)\n", - " x4 = self.MixNet_L(x)\n", - " x5 = self.MixNet_XL(x)\n", - " x6 = self.MixNet_XXL(x)\n", - "\n", - " x = sum([x1, x2, x3, x4, x5, x6]) / 6\n", - " return x\n", - " \n", - "model = Ensemble_Model()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "prediction time for complete in 40701 milisecond\n", - " precision recall f1-score support\n", - "\n", - " 0 1.000000 0.300000 0.461538 10\n", - " 1 0.960221 0.981921 0.970950 885\n", - " 2 0.963855 0.942761 0.953191 594\n", - "\n", - " accuracy 0.961719 1489\n", - " macro avg 0.974692 0.741561 0.795227 1489\n", - "weighted avg 0.961938 0.961719 0.960444 1489\n", - "\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsEAAAHwCAYAAABHf7LhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de7Ru53wv8O+viVDBjlJFoiOlibZaLXGNNpI4Ulo0SE91FKnW/VZEx+lxqUspPeIeh1ZVXHoaThQnLkElEYQiFZdSQYQK6paEXEnWc/5452JZWWvvtfZ6517PXvPzMd4x887r8w4bv/31m89TrbUAAMCU/NRmDwAAAHY1RTAAAJOjCAYAYHIUwQAATI4iGACAyVEEAwAwOYpgAAAmRxEMAMDkKIIBAJgcRTAAAJOjCAYAYHIUwQAATI4iGACAydlzswewEXvutW/b7DEAMFObPQCYkx/+4Lwu/jj/8NvnzL3Oudr1b9rFb+uBJBgAgMnZrZNgAIAta+HKzR7BliYJBgBgciTBAAA9agubPYItTREMANCjBUXwmLRDAAAwOZJgAIAONe0Qo5IEAwAwOZJgAIAe6QkelSIYAKBH2iFGpR0CAIDJkQQDAPTIinGjkgQDADA5kmAAgB7pCR6VJBgAgMmRBAMA9MgUaaNSBAMAdMiKcePSDgEAwORIggEAeqQdYlSSYAAAJkcSDADQIz3Bo1IEAwD0yIpxo9IOAQDA5EiCAQB6pB1iVJJgAAAmRxIMANAjU6SNShEMANAj7RCj0g4BAMDkSIIBAHqkHWJUkmAAACZHEgwA0KHWLJYxJkUwAECPvBg3Ku0QAABMjiQYAKBHXowblSQYAIDJkQQDAPRIT/CoJMEAAEyOJBgAoEcLpkgbkyIYAKBH2iFGpR0CAIDJkQQDAPTIFGmjkgQDADA5kmAAgB7pCR6VIhgAoEfaIUalHQIAgMmRBAMA9EgSPCpJMAAAkyMJBgDoUGtWjBuTIhgAoEfaIUalHQIAgMmRBAMA9Mg8waOSBAMAMDmKYACAHi0szP+zQVX1x1XVdvC5csn5++/g3BO286yjq+ojVXVRVV1YVadV1T02/CMG2iEAAFirs5I8Y5Vjv5Xk8CTvXOHYJ5K8ZYX9n17pRlV1bJJjknw1ySuT7JXkfklOqqrHtNaOW+e4r0IRDADQow57gltrZ2VWCF9FVX1o+Me/W+HwWa21p6/lGVV1cGYF8BeT3La1dv6w/3lJzkxybFW9rbV27vpG/5O0QwAA9KjDdojVVNWvJblDkvOSvH2Dt3v4sH32YgGcJEPR+7IkV0/yoA0+QxEMAMCGPXTYvqqtvMrHjavqYVX1pGF7y+3c6/Bhe/IKx9657Jydph0CAKBHI7RDVNWZqz6utYN28p4/neT+Sa5M8vernHbX4bP0utOSHN1a+8qSfXsn2TfJRa21r69wn88P2wN3ZqxLSYIBANiI/55knyQnt9b+c9mxS5L8VZKDklx3+Nw5yalJDk3y3qHwXbRt2F64yrMW9++z0UFLggEAejRCD+/Opr07sNgK8bcrPO+bSf5y2e7Tq+qIJB9IcvskD07y4hHGtV2SYACAHu0GL8ZV1S2SHJzZVGbvWOt1rbUr8uPWiUOWHFpMerdlZYv7L1jHMFekCAYAYGft6IW47fnWsP1RO0Rr7eLMZpi4VlXdaIVrDhi2Z6/zWVehCAYA6FFbmP9njqrqGkkekNkLca/aiVvcYdies2z/KcP2bitcc/dl5+w0RTAAADvj9zN70e2dK7wQlySpqltX1VXqzaq6S5LHD19fv+zwK4btk6vqukuu2T/Jo5JcnuTVGxp5vBgHANCnERe3mJPFVoiVVohb9IIkB1TVGZn1DSfJLfPjeX6f2lo7Y+kFrbUzquoFSZ6Q5JNVdWJmyyb/QZKfSfKYja4WlyiCAQD61OGyyYuq6peT/GZ2/ELc65LcO8ltM2tluFqS/0ryxiTHtdbev9JFrbVjqupTmSW/D02ykOTfkjyvtfa2efwGRTAAAOvSWvtsklrDea/KzvULp7V2fJLjd+batZhrETysGHKHzFbxWJzE+ILM3uD7cGvt0nk+DwBgy+q/HWK3NpcieGhafnZmbwhec5XTLqmq1yZ5Smvt/Hk8FwAAdsaGi+Cq2ifJB5P8UpKLk7wns3Wdl052fECSOyV5RJLDquqOrbXVlsMDAKDjnuCtYB5J8NMyK4BfmORprbWLVjqpqq6V5JlJHpfZ8nnHrOXmVXXmasf2uNqN1z1YAACYxzzBRyY5pbV2zGoFcJK01i5qrT0hyWlJ7jOH5wIAbF27wbLJu7N5JME3SvJP6zj/w5mtMb0mrbWDVju25177tnU8FwBg96FoHdU8kuDvJLn5Os7/5eEaAADYFPMogt+V5MiqeuSOTqyqRye5V5KT5/BcAICtq7X5f/iRebRDPDXJ7yZ5aVUdk+Tdmc0LvHR2iAOTHJFk/yTfzOzFOAAA2BQbLoJba+dV1R2TvDzJXZM8LMnyv2osrijy7iSPbK2dt9HnAgBsaXqCRzWXxTJaa+ck+e2qummSwzLrEd42HL4wyeeSnDqcBwDAjiiCRzXXZZOHIlehCwBA1+ZaBAMAMCdWjBvVPGaHAACA3YokGACgR3qCR6UIBgDokXl9R6UdAgCAyZEEAwD0SDvEqCTBAABMjiQYAKBHkuBRSYIBAJgcSTAAQI8sljEqRTAAQIfaginSxqQdAgCAyZEEAwD0yItxo5IEAwAwOZJgAIAeeTFuVIpgAIAeeTFuVNohAACYHEkwAECPvBg3KkkwAACTIwkGAOiRJHhUimAAgB41L8aNSTsEAACTIwkGAOiRdohRSYIBAJgcSTAAQI8sljEqSTAAAJMjCQYA6FHTEzwmRTAAQI+0Q4xKOwQAAJMjCQYA6FAzRdqoJMEAAEyOJBgAoEd6gkelCAYA6JHZIUalHQIAgMmRBAMA9Eg7xKgkwQAATI4kGACgR6ZIG5UiGACgR9ohRqUdAgCAyZEEAwD0yBRpo5IEAwAwOZJgAIAe6QkelSQYAIDJkQQDAHSomSJtVIpgAIAeaYcYlXYIAAAmRxIMANAjSfCoJMEAAEyOJBgAoEcWyxiVJBgAoEcLbf6fOaqqu1TVm6vqG1V1eVV9rareVVW/s8K5B1fVO6rqu1V1aVV9sqoeV1V7bOf+96iq06rqwqq6qKr+taqOntf4FcEAAKxLVf2vJP+S5DZJ/l+S5yd5e5KfTXLosnN/L8npSQ5J8uYkxyXZK8kLk5ywyv0fneSkJL+a5PVJXpnkxkmOr6pj5/IbWtt9m6733Gvf3XfwAFtMbfYAYE5++IPzuvjj/P3H3XPudc61X3TShn9bVT0kyd8leU2Sh7bWfrDs+NVaaz8c/vk6Sb6QZFuSO7XWPjbsv0aSU5LcMckfttZOWHL9/kn+I8nFSQ5qrZ077L9uko8muVmSg1trH9rI75AEAwCwJlV19STPTvKVrFAAJ8liATw4KrN0+ITFAng457IkTxm+PmLZLf4kydWTHLdYAA/XnJ/kr4evD9/YL/FiHABAn0aYIq2qzlztWGvtoDXc4q6ZFbUvSrJQVb+bWcvCZUk+skI6e/iwPXmFe52e5JIkB1fV1Vtrl6/hmncuO2enKYIBAHrU57LJtx22lyX5eGYF8I9U1elJjmqtfWvYdfNhe/byG7XWrqiqLyW5RZKbJvnsGq75elVdnGS/qrpma+2Snf0himAAgIlYY9q7PTcYtn+e5DNJfivJWUl+IcmxSY5I8n/z45fjtg3bC1e53+L+fZbsW8s1ew/nKYIBALaUPleMW3yf7Iok91rSs/upqrp3ks8luXNV3XGjL66NzYtxAACs1QXD9uNLX1pLkqE14V3D19sN28U0d1tWtrj/giX71nrNaknxmiiCAQB61OdiGZ8bthescvz8YfvTy84/cPmJVbVnZm0UVyQ5Z4VnrHTNjTJrhfjqRvqBE0UwAABr994kLcmvVNVKdeTii3JfGranDNu7rXDuIUmumeSMJTND7Oiauy87Z6cpggEAOtRam/tnDmP6cmYruf18kj9beqyqjkjy25mlxIvTm52Y5NtJ7ldVt1ly7jWSPGv4+vJlj3l1ksuTPHpYOGPxmusmedLw9RUb/S1ejAMA6FGfL8YlyaOS3CrJC4Z5gj+eWVvDkUmuTPLg1tqFSdJa+96wwtyJSU6rqhOSfDfJvTKbCu3EJG9YevPW2peq6s+TvCTJx6rqDUl+kNnCG/slef48XrpTBAMAsGatta9W1UFJ/jKzYvaQJN/LLCF+TmvtI8vOf0tV3TnJk5PcN8k1MltK+QlJXtJWiKhbay+tqnOTPDHJAzPrXvhMkqe01l4zj99R84jGN8uee+27+w4eYIupzR4AzMkPf3BeF3+cv/end517nXOdV72ni9/WA0kwbLJLv/b+zR4CzMW19zt0s4cAsGaKYACADrV+e4K3BEUwAECPFMGjMkUaAACTIwkGAOjRwmYPYGuTBAMAMDmSYACADnkxblyKYACAHimCR6UdAgCAyZEEAwD0yItxo5IEAwAwOZJgAIAOeTFuXJJgAAAmRxIMANAjPcGjUgQDAHRIO8S4tEMAADA5kmAAgB5phxiVJBgAgMmRBAMAdKhJgkelCAYA6JEieFTaIQAAmBxJMABAh7RDjEsSDADA5EiCAQB6JAkelSIYAKBD2iHGpR0CAIDJkQQDAHRIEjwuSTAAAJMjCQYA6JAkeFyKYACAHrXa7BFsadohAACYHEkwAECHtEOMSxIMAMDkSIIBADrUFvQEj0kSDADA5EiCAQA6pCd4XIpgAIAONVOkjUo7BAAAkyMJBgDokHaIcUmCAQCYHEkwAECHTJE2LkUwAECHWtvsEWxt2iEAAJgcSTAAQIe0Q4xLEgwAwORIggEAOiQJHpciGACgQ16MG5d2CAAAJkcSDADQIe0Q45IEAwAwOZJgAIAOtSYJHpMkGACAyZEEAwB0qC1s9gi2NkUwAECHFrRDjEo7BAAAkyMJBgDokBfjxiUJBgBgciTBAAAdsljGuBTBAAAdam2zR7C1aYcAAGByFMEAAB1qCzX3zxiq6v5V1YbPg5cdO3TJsZU+z13lnntU1eOr6pNVdWlVfbeq3lFVB89r3NohAADYKVV1kyTHJbkoybW2c+r7kpy2wv4PrHDPSnJCkqOSfG64/88k+YMkp1fVfVtrb93YyBXBAABd6n2xjKFYfXWS7yT55yRP3M7pp7XWnr7GW98vswL4jCR3aa1dNjzvFZkVza+sqlNaa9/f2bEn2iEAALrUWs39M2ePTXJ4kgcluXiO933EsH3KYgGcJK21jyZ5Q5KfzaxI3hBJMADARFTVmasda60dtI77/HKS5yZ5cWvt9Ko6fAeX/GJVPTrJdZJ8I8n7W2ufX+G+10hycJJLkrx/hfu8M8kDMiu+X73W8a5EEQwA0KFep0irqj2TvC7JV5I8aY2X/dHwWXqfNyV5SGvt/CW7b5ZkjyTntNauWOE+i4Xzgesa9AoUwQAAE7GetHc7/jLJrZL8Zmvt0h2c+60kf5Hk7UnOTXKNJLdJ8tdJ7pvkhlV1SGttYTh/27C9cJX7Le7fZ+eG/mOKYACADvX4YlxV3T6z9Pf5rbUP7ej81tq/J/n3JbsuSnJyVZ2R5Kwkd0pyzyQbnu1hvbwYBwDADg1tEK9NcnaSp27kXq217yX5P8PXQ5YcWkx6t2Vli/sv2MjzE0kwAECXRpjNYaOulR/34l42myHtKl5ZVa/M7IW5x+3gft8atnsv2ffFJFcmuWlV7blCX/ABw/bstQ97ZYpgAIAOdfhi3OVJXrXKsVtn1if8gcwWuNhhq0SSOwzbcxZ3tNYuG1olfmv4nLrsmrsP21PWOOZVbUoRXFXPS3Kf1trNNuP5AACsz/AS3INXOlZVT8+sCH5Na+3vl+y/TWvtYyucf//MVoD7QZI3Ljv88swK4GdV1dLFMm47XPOtJG/a6O/ZrCT4+kn2X8uJ25vPbo+r3Xhe4wEA6EqPL8bthBOr6ookH0vy1cxmh7htktsluSLJw1pr5y675oQk98lsQYyPV9VJSa6XWQG8R2bTqn1vowPTDgEAwFhenuS/ZTYLxPWTVJLzkhyf5EWttU8sv6C11qrqDzNbNvlPkjwmyWVJTk/yrNbaGfMYWLU5NJxU1WvXecnBSX6htbbHRp6751779tctA+t06ddWWhAHdj/X3u/QzR4CzMVll32liwj2o/vee+51zm3Pe3MXv60H80qC75+kZVbdr5UCFgBgFVukHaJb8yqCv59Zn8cj13j+XyQ5Yk7PBgCAdZlXEfyJJL/eWnvfWk6uqj+e03MBALYk/5f5uOa1YtxZSa5VVaY8AwCge/NKgt+X2Xxu+2W20seOvCXJuXN6NgDAlqMneFxzKYJba2/KOiYtbq29Nclb5/FsAICtqMNlk7eUebVDAADAbsNiGQAAHVrY7AFscZJgAAAmRxIMANChtq41yFgvSTAAAJMjCQYA6NCC1TJGpQgGAOjQgnaIUWmHAABgciTBAAAd8mLcuCTBAABMjiQYAKBDFssYlyIYAKBD2iHGpR0CAIDJkQQDAHRIO8S4JMEAAEyOJBgAoEOS4HEpggEAOuTFuHFphwAAYHIkwQAAHVoQBI9KEgwAwORIggEAOrSgJ3hUkmAAACZHEgwA0KG22QPY4hTBAAAdMk/wuLRDAAAwOZJgAIAOLZQX48YkCQYAYHIkwQAAHfJi3LgUwQAAHfJi3Li0QwAAMDmSYACADi14L25UkmAAACZHEgwA0KGFiILHpAgGAOiQ2SHGpR0CAIDJkQQDAHTIi3HjkgQDADA5kmAAgA5ZLGNckmAAACZHEgwA0CGzQ4xLEQwA0CEvxo1LOwQAAJMjCQYA6JAX48YlCQYAYHIkwQAAHZIEj0sRDADQoebFuFFphwAAYHIkwQAAHdIOMS5JMAAAkyMJBgDokCR4XIpgAIAOWTZ5XNohAACYHEkwAECHFkyRNipJMAAAk6MIBgDo0MIIn3moqr+pqvdW1X9W1aVV9d2q+nhVPa2qrrfKNQdX1TuGcy+tqk9W1eOqao/tPOceVXVaVV1YVRdV1b9W1dFz+hmKYAAA1uXxSfZO8p4kL07yj0muSPL0JJ+sqpssPbmqfi/J6UkOSfLmJMcl2SvJC5OcsNIDqurRSU5K8qtJXp/klUlunOT4qjp2Hj9CTzAAQIc6niLtOq21y5bvrKpnJ3lSkv+Z5JHDvutkVsBemeTQ1trHhv1PTXJKkqOq6n6ttROW3Gf/JMcm+W6S27TWzh32PzPJR5McU1Vvaq19aCM/QhIMANChNsJnLuNaoQAevHHYHrBk31FJfjbJCYsF8JJ7PGX4+ohl9/mTJFdPctxiATxcc36Svx6+PnynBr+EJBgAYCKq6szVjrXWDtrg7e85bD+5ZN/hw/bkFc4/PcklSQ6uqqu31i5fwzXvXHbOTlMEAwB0qPcp0qrqiUmulWRbktsk+c3MCuDnLjnt5sP27OXXt9auqKovJblFkpsm+ewarvl6VV2cZL+qumZr7ZKdHb8iGABgIuaQ9i71xCQ/t+T7yUn+uLX2rSX7tg3bC1e5x+L+fdZ5zd7DeTtdBOsJBgDoUK9TpC1qrd2wtVZJbpjkPpmluR+vqlvP+VGjUAQDAHSo1xfjrjLO1v6rtfbmJEckuV6S1y45vJjmbrvKhT+5/4KduGa1pHhNFMEAAGxYa+3LST6T5BZVdf1h9+eG7YHLz6+qPZP8QmZzDJ+z5ND2rrlRZq0QX91IP3CiJxg23Q32P2KzhwBz8fV73nSzhwBbysJo2e2objxsrxy2pyT5oyR3S/JPy849JMk1k5y+ZGaIxWvuNFyzfC7guy85Z0MkwQAArElVHVhVV2lTqKqfGhbLuEGSM4Y5fZPkxCTfTnK/qrrNkvOvkeRZw9eXL7vdq5NcnuTRw8IZi9dcN7PFOJLkFRv9LZJgAIAOdbpi3O8keU5VfSDJl5J8J7MZIu6c2Ytx30jykMWTW2vfq6qHZFYMn1ZVJ2S2Ety9MpsK7cQkb1j6gNbal6rqz5O8JMnHquoNSX6Q2cIb+yV5/kZXi0sUwQAAXeq0GeJfkvxiZnMC3yqzqc0uzmxO39cleUlr7btLL2itvaWq7pzkyUnum+QaSb6Q5AnD+Vf5qa21l1bVuZlNw/bAzLoXPpPkKa2118zjhyiCAQBYk9bap5M8eieu+2BmKfJ6rjkpyUnrfdZaKYIBADrUaTvEluHFOAAAJkcSDADQoYXa7BFsbYpgAIAO7abzBO82tEMAADA5kmAAgA7JgcclCQYAYHIkwQAAHTJF2rgkwQAATI4kGACgQ2aHGJciGACgQ0rgcWmHAABgciTBAAAd8mLcuCTBAABMjiQYAKBDXowblyIYAKBDSuBxaYcAAGByJMEAAB3yYty4JMEAAEyOJBgAoENNV/CoFMEAAB3SDjEu7RAAAEyOJBgAoEPmCR6XJBgAgMmRBAMAdEgOPC5JMAAAkyMJBgDokJ7gcSmCAQA6ZIq0cWmHAABgciTBAAAdsmLcuCTBAABMjiQYAKBDeoLHpQgGAOiQdohxaYcAAGByJMEAAB3SDjEuSTAAAJMjCQYA6NBC0xM8JkUwAECHlMDj0g4BAMDkSIIBADq0IAselSQYAIDJkQQDAHTIYhnjkgQDADA5kmAAgA5ZLGNcimAAgA55MW5c2iEAAJgcSTAAQIe8GDcuSTAAAJMjCQYA6JAX48alCAYA6FBr2iHGpB0CAIDJkQQDAHTIFGnjkgQDADA5kmAAgA55MW5cimAAgA6ZJ3hc2iEAAJgcSTAAQIe8GDcuSTAAAJMjCQYA6JDFMsYlCQYAYHIUwQAAHVoY4bNRVXVUVb20qt5fVd+rqlZVr1/l3P2H46t9TtjOc46uqo9U1UVVdWFVnVZV95jDT/gR7RAAAB3qdIq0pyT59SQXJflqkl9awzWfSPKWFfZ/eqWTq+rYJMcM939lkr2S3C/JSVX1mNbacTsx7qtQBAMAsFaPz6w4/UKSOyc5dQ3XnNVae/pabl5VB2dWAH8xyW1ba+cP+5+X5Mwkx1bV21pr565/6D9JOwQAQIcW0ub+2ajW2qmttc+38d7ae/iwffZiATw899wkL0ty9SQPmseDFMEAAIzpxlX1sKp60rC95XbOPXzYnrzCsXcuO2dD5tYOUVX3TnJokiuSnNxae88q5x2d5OjW2lx+AADAVjRG2FpVZ27neQfN/YEzdx0+S8dxWmb14FeW7Ns7yb5JLmqtfX2F+3x+2B44j0FtuAiuqkryhiT3TVLD7sdV1duTPLC1dsGyS/bPrIcEAIBVbIEV4y5J8leZvRR3zrDvlkmenuSwJO+tqt9orV08HNs2bC9c5X6L+/eZx+DmkQQ/KMlRSf4zySuS/DDJ0UnukeQDVXV4a+2bO3vz7f2NZY+r3XhnbwsAMDkjpr0rPeubSf5y2e7Tq+qIJB9IcvskD07y4l01pqXm0RP8oCQXZPYG33Naa8cm+Y0kL0jyK0n+paquP4fnAABMRhvhXz1orV2R5O+Hr4csObSY9G7Lyhb3L+8y2CnzSIJ/LcmJS9Pe1tqVSZ5YVV9J8qLMCuHDlr7lt1bb+xvLnnvt28e/mwAArMe3hu3eiztaaxdX1XlJ9q2qG63QF3zAsD17HgOYRxK8V5L/WulAa+0lSR6bWf/He6pqLj0cAABb3UJrc/905A7D9pxl+08Ztndb4Zq7LztnQ+ZRBJ+X5OdXOzis6vGEJLdO8q6sHnEDADBoI3x2paq6dVVdpdasqrtktuhGkixfcvkVw/bJVXXdJdfsn+RRSS5P8up5jG8e7RCfyuwNv1W11l5UVVdP8pwkt5rDMwEA2MWq6sgkRw5fbzhs71hVxw///O3W2hOHf35BkgOq6ozMVplLZt0Bi9PkPrW1dsbS+7fWzqiqF2QWoH6yqk7MrOvgD5L8TJLHzGO1uGQ+RfA7khxZVb/bWnv7aie11v6mqvZK8ozs+r+MAADsVjqdIu03MpsFbKmbDp8k+XKSxSL4dUnuneS2mbUyXC2zFto3Jjmutfb+lR7QWjumqj6VWfL70CQLSf4tyfNaa2+b1w+ZRxH8z0n2SHLxjk5srf3V8LLc/nN4LgAAu1Br7emZzfO7lnNfleRVO/mc45McvzPXrtWGi+DW2neT/O06zn/NRp8JALDVdZoEbxnzeDEOAAB2K/NohwAAYM5aX1OabTmKYACADmmHGJd2CAAAJkcSDADQoSYJHpUkGACAyZEEAwB0yItx41IEAwB0yItx49IOAQDA5EiCAQA6pB1iXJJgAAAmRxIMANAhPcHjUgQDAHTIPMHj0g4BAMDkSIIBADq04MW4UUmCAQCYHEkwAECH9ASPSxIMAMDkSIIBADqkJ3hcimAAgA5phxiXdggAACZHEgwA0CHtEOOSBAMAMDmSYACADukJHpciGACgQ9ohxqUdAgCAyZEEAwB0SDvEuCTBAABMjiQYAKBDrS1s9hC2NEUwAECHFrRDjEo7BAAAkyMJBgDoUDNF2qgkwQAATI4kGACgQ3qCxyUJBgBgciTBAAAd0hM8LkUwAECHFhTBo9IOAQDA5EiCAQA61LwYNypJMAAAkyMJBgDokBfjxqUIBgDokHmCx6UdAgCAyZEEAwB0SDvEuCTBAABMjiQYAKBDFssYlyIYAKBD2iHGpR0CAIDJkQQDAHTIFGnjkgQDADA5kmAAgA7pCR6XJBgAgMmRBAMAdMgUaeNSBAMAdKh5MW5U2iEAAJgcSTAAQIe0Q4xLEgwAwORIggEAOmSKtHEpggEAOuTFuHFphwAAYF2qar+q+oeq+lpVXV5V51bVi6rqups9trWSBAMAdKjXdoiqulmSM5LcIMlbk/xHktsl+bMkd6uqO7XWvrOJQ1wTSTAAAOvxvzMrgB/bWjuytfYXrbXDk7wwyc2TPHtTR7dGimAAgA611ub+2aghBT4iyblJXrbs8NOSXJzkAVW194YfNjJFMABAh9oInzk4bNi+u7W28BPjbe37ST6Y5JpJ7jCfx41HTzAAwERU1ZmrHWutHbSGW9x82J69yvHPZ5YUH0yiBV0AAARdSURBVJjkvesb3a61WxfBV/zgvNrsMWx1i/9hWeN/MKBL/hyzVfizPC1j1DnbK4LXaNuwvXCV44v799ngc0a3WxfBAACsnb9A/ZieYAAA1mox6d22yvHF/RfsgrFsiCIYAIC1+tywPXCV4wcM29V6hruhCAYAYK1OHbZHVNVP1JFVde0kd0pySZIP7+qBrZciGACANWmtfTHJu5Psn+RRyw4/I8neSV7XWrt4Fw9t3bwYBwDAejwys2WTX1JVd0ny2SS3z2wO4bOTPHkTx7Zm1eu61AAA9KmqbpLkmUnuluR6Sb6e5M1JntFaO38zx7ZWimAAACZHTzAAAJOjCAYAYHIUwQAATI4iGACAyVEEAwAwOYpgAAAmRxHMiqpqv6r6h6r6WlVdXlXnVtWLquq6mz02WIuqOqqqXlpV76+q71VVq6rXb/a4YD2q6npV9eCqenNVfaGqLq2qC6vqA1X1p8uXrQXWzjzBXEVV3SyzlWBukOStSf4jye0yWwnmc0nu1Fr7zuaNEHasqs5K8utJLkry1SS/lOQfW2v339SBwTpU1cOTvDyzhQhOTfKVJD+X5D5JtiV5U5Lfb/7HHNZNEcxVVNW7khyR5LGttZcu2f+CJI9P8rettYdv1vhgLarqsMyK3y8kuXNmBYQimN1KVR2eZO8kb2+tLSzZf8MkH0lykyRHtdbetElDhN2W/xuFnzCkwEckOTfJy5YdflqSi5M8oKr23sVDg3VprZ3aWvu8hIzdWWvtlNbaSUsL4GH/N5K8Yvh66C4fGGwBimCWO2zYvnuF/9L9fpIPJrlmkjvs6oEB8BN+OGyv2NRRwG5KEcxyNx+2Z69y/PPD9sBdMBYAVlBVeyZ54PD15M0cC+yuFMEst23YXrjK8cX9++yCsQCwsucm+dUk72itvWuzBwO7I0UwAOxGquqxSY7JbOaeB2zycGC3pQhmucWkd9sqxxf3X7ALxgLAElX16CQvTvKZJIe11r67yUOC3ZYimOU+N2xX6/k9YNiu1jMMwAiq6nFJXprk05kVwN/Y5CHBbk0RzHKnDtsjlq9EVFXXTnKnJJck+fCuHhjAVFXV/0jywiRnZVYAf3OThwS7PUUwP6G19sUk706yf5JHLTv8jMwmbX9da+3iXTw0gEmqqqdm9iLcmUnu0lr79iYPCbYEK8ZxFSssm/zZJLfPbA7hs5McbNlkeldVRyY5cvh6wyS/neScJO8f9n27tfbEzRgbrFVVHZ3k+CRXZtYKsdLMPee21o7fhcOCLUERzIqq6iZJnpnkbkmul9m69W9O8ozW2vmbOTZYi6p6emarHK7my621/XfNaGDnrOHPcZK8r7V26Pijga1FEQwAwOToCQYAYHIUwQAATI4iGACAyVEEAwAwOYpgAAAmRxEMAMDkKIIBAJgcRTAAAJOjCAYAYHIUwQAATI4iGACAyVEEAwAwOYpgAAAmRxEMAMDkKIIBAJgcRTAAAJOjCAYAYHL+PwGdvk2HvgHeAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "image/png": { - "height": 248, - "width": 352 - }, - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "since = time.time()\n", - "model.eval()\n", - "y_test = []\n", - "y_pred = []\n", - "for images, labels in data_loader['test']:\n", - " images = Variable(images.cuda())\n", - " labels = Variable(labels.cuda())\n", - " outputs = model(images)\n", - " _, predictions = outputs.max(1)\n", - " \n", - " y_test.append(labels.data.cpu().numpy())\n", - " y_pred.append(predictions.data.cpu().numpy())\n", - "\n", - "\n", - "time_elapsed = time.time() - since\n", - "\n", - "print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", - "#concat predicted results to be dataframe\n", - "y_test = np.concatenate(y_test)\n", - "y_pred = np.concatenate(y_pred)\n", - "\n", - "pd.DataFrame({'true_label':y_test,'predicted_label':y_pred}).to_csv('/home/linh/Downloads/Covid-19/results/Ensemble_Eff_B0_MixNet_S_MixNet_M_MixNet_L_XL_XXL_Testset.csv',index=False)\n", - "\n", - "sns.heatmap(confusion_matrix(y_test, y_pred))\n", - "# set accuracy score to control processes\n", - "accuracy_score(y_test, y_pred)\n", - "\n", - "# Generate a classification report\n", - "from sklearn.metrics import classification_report\n", - "\n", - "report = classification_report(y_test, y_pred, digits=4)\n", - "print(report)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "Metrics_for_EfficientNets.ipynb", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}