diff --git a/.ipynb_checkpoints/deep_compression_exercise-checkpoint.ipynb b/.ipynb_checkpoints/deep_compression_exercise-checkpoint.ipynb new file mode 100644 index 0000000..7b8e422 --- /dev/null +++ b/.ipynb_checkpoints/deep_compression_exercise-checkpoint.ipynb @@ -0,0 +1,599 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exercise Week 9: Pruning and Quantization\n", + "This week, we will explore some of the ideas discussed in Han, Mao, and Dally's Deep Compression. In particular, we will implement weight pruning with fine tuning, as well as k-means weight quantization. **Note that we will unfortunately not be doing this in a way that will actually lead to substantial efficiency gains: that would involve the use of sparse matrices which are not currently well-supported in pytorch.** \n", + "\n", + "## Training an MNIST classifier\n", + "For this example, we'll work with a basic multilayer perceptron with a single hidden layer. We will train it on the MNIST dataset so that it can classify handwritten digits. As usual we load the data:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision.transforms as transforms\n", + "import torchvision.datasets as datasets\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", + "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n", + "\n", + "batch_size = 300\n", + "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then define a model:" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "class MultilayerPerceptron(torch.nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim,mask=None):\n", + " super(MultilayerPerceptron, self).__init__()\n", + " if not mask:\n", + " self.mask = torch.nn.Parameter(torch.ones(input_dim,hidden_dim),requires_grad=False)\n", + " else:\n", + " self.mask = torch.nn.Parameter(mask)\n", + "\n", + " self.W_0 = torch.nn.Parameter(1e-3*torch.randn(input_dim,hidden_dim)*self.mask,requires_grad=True)\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim),requires_grad=True)\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", + " \n", + " def set_mask(self,mask):\n", + " mask = torch.nn.Parameter(mask)\n", + " self.mask.data = mask.data\n", + " self.W_0.data = self.mask.data*self.W_0.data\n", + "\n", + " def forward(self, x):\n", + " hidden = torch.tanh(x@(self.W_0*self.mask) + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the above code is a little bit different than a standard multilayer perceptron implementation.\n", + "\n", + "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", + "\n", + "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:\n", + "\n", + "\n", + "\n", + "**Answer: This MLP accomodates a \"mask\" object, which implements pruning by creating 0 entries in a respective linear layer's weight matrix**" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "n_epochs = 10\n", + "\n", + "input_dim = 784\n", + "hidden_dim = 64\n", + "output_dim = 10\n", + "\n", + "model = MultilayerPerceptron(input_dim,hidden_dim,output_dim)\n", + "model = model.to(device)\n", + "\n", + "criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy\n", + "lr_rate = 0.001\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And then training proceeds as normal." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.4186490774154663. Accuracy: 90.\n" + ] + } + ], + "source": [ + "iter = 0\n", + "for epoch in range(1):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_pretrained.h5')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pruning\n", + "\n", + "Certainly not a state of the art model, but also not a terrible one. Because we're hoping to do some weight pruning, let's inspect some of the weights directly (recall that we can act like they're images)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVtklEQVR4nO3dW4zc5XkG8Oed0x5mDz6tvWtjTi4hcUEQtEVRSCOqKBHhAshFqnARUSmqc5FIiZSLInIRbiqhqkmaiyqSU1AISokiJVG4QG0QQiG5aMSaOtjUNHaogbUXr4979Oyc3l7sgBbY7/mWOeyM/T0/abW78843881/5p3/7rzfwdwdInL1y3S7AyKyOZTsIolQsoskQskukgglu0gicpt5Z9mhoue2b93MuxRJSvX8RdQWl2y9WEvJbmb3APgBgCyAf3P3x9j1c9u3Yvzb32jlLkWEePsffxCMNf1nvJllAfwrgM8D2A/gQTPb3+ztiUhntfI/+50ATrj76+5eBvAzAPe3p1si0m6tJPseAG+t+X26cdl7mNkBM5sys6na4lILdycirWgl2df7EOADY2/d/aC7T7r7ZHao2MLdiUgrWkn2aQB71/x+DYDTrXVHRDqllWR/CcBNZnaDmRUAfAnAM+3ploi0W9OlN3evmtnXAfwnVktvT7j7q23r2dUkV+fxdauibRK77UrkCrH2sUmTrH29kw8cQCbcucxAlTatVyPnwXIk7h1+bE1oqc7u7s8CeLZNfRGRDtJwWZFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSsanz2a9a2UixORI3Ug8GgIGhFRrPZcN1/EKuRtvWI133SL24XOUvoWwm3LfLpTxtW69lI/FILZs8tvoKv23L87ERHnvOWx2/0AE6s4skQskukgglu0gilOwiiVCyiyRCyS6SCJXeNspIqSVShsn28fJXocCnW948Nkvj2/uWg7FdffO07bYcXyqsL1Oh8YtVvvpQqR4urw1nS7RtzHK9QOPHFsaDsRMXdtC2c3ODNB7dDjVWWmPl1mhZr7lztM7sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiHTq7K0seRyJZ/t5nTyf53X2j4ydpfGx/kUaPzE/FowdPvuBHbnew9j4AQDXjlyk8e19vE6/u28uGIvVyQ9dvJbGXz1+DY0PnAzX+Isz/HEPD/EXxPI4b1/dXabxXH94/EJloY+2bZbO7CKJULKLJELJLpIIJbtIIpTsIolQsoskQskukoh06uyxlX1jS/+SG6hl+WEcHQ7PNweAcp0vaxxz7VC4Fr7Uz2vZo3k+p/wjxbdp/Oa+GRr/aCE8hiAfGfwwFJnv/ualLTReObU1GPPIIc8v8L4VI8tYLxT4Mtm1XWSp6tgS2U1qKdnN7CSABQA1AFV3n2xHp0Sk/dpxZv8bdz/XhtsRkQ7S/+wiiWg12R3Ab8zskJkdWO8KZnbAzKbMbKq2yMdRi0jntPpn/F3uftrMdgJ4zsxec/cX117B3Q8COAgAfdddE12nT0Q6o6Uzu7ufbnyfBfArAHe2o1Mi0n5NJ7uZFc1s+J2fAXwOwNF2dUxE2quVP+N3AfiVmb1zO//u7v/Rll51QKyOnlvg73ts2nesRF+JbD1cyPD57guVfhrPkM6Va/wpnqmO0PhI7jKN30jq6ABwZGUiGHu9HJ6HDwAvnL2ZxkuRLZ+rY+HjWim29nFVrcjXMOgf48etkA+3n1/ij6tZTSe7u78O4LY29kVEOkilN5FEKNlFEqFkF0mEkl0kEUp2kUT01hTXyLLGqIdrXFbl9a/8HH9fK8zz9jWyum8lsuxwbMLiUpVPQ42pevixXSoN0LbnF/iWy6fneWnud9l9/PYvDAVj2WleUjRekYy+en0LuYHdfPrsSGRa8m07T9P4tQMXaPzV+XBJ8tCZ8DEDAHhzU2B1ZhdJhJJdJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUT0Vp098tbDpqlml3jjvou8NlmPHInKEBkDMBrefhcAyhV+45noftLcYjk8CODC4iBtuzLL47UFXvMtXOLHdcuF8GMbPEeWUwZQjkxDvcRnwMIGwtNId++4RNv+1Y43aPyTwydofCw3T+MVsnz44aG9tG1tobkpsDqziyRCyS6SCCW7SCKU7CKJULKLJELJLpIIJbtIInqrzh7ZqtbK4XiuxNtW+dRp1AZ5rbu6LVyzHRrlywaPDPC500sVPp89Fp8vhevsK5d5TTa2hPbIn2kYo6+v0HjhXHheuBf4y8/28Ln4sX2X+4vlYOxjW/lW1PsH+Xz12/t4PHYWHcqGj1s2xyfy18vk9UDmuuvMLpIIJbtIIpTsIolQsoskQskukgglu0gilOwiieipOrvF6uwk7hleJ3ey7jsA1CNThHPF8Jz1fJbXRS+X+Y3HtnSu1SJz9cn2v8VhXuMvOR+AsPU13j5/7E0ah5G+T2ynTctFvqZ9Ky6W+Tz+5Tp/wWzLtHaefKO0LRirVvjrASxMUijaYzN7wsxmzezomsu2mdlzZna88X1r7HZEpLs28vb0YwD3vO+yhwE87+43AXi+8buI9LBosrv7iwDev5fN/QCebPz8JIAH2twvEWmzZv/x2OXuMwDQ+L4zdEUzO2BmU2Y2VVtcavLuRKRVHf803t0Puvuku09mhzr3gYuIcM0m+xkzmwCAxvfZ9nVJRDqh2WR/BsBDjZ8fAvDr9nRHRDolWmc3s6cB3A1gh5lNA/gOgMcA/NzMvgLgTQBfbEdn2LrwMZkV3jYy9RmZMq/TlxfDtfJLZX7jTvaVB4BMntfpcwUev2Hn+WDs1lE+7/qpU5+i8cJx3r56LnzfAJAb3xWMlbdH1qwvRMZdRJbbv3wufPvHsuF+AcB4P1/3/eX+aRpfitTp31oKV6u9ys/BnmUPPByLJru7PxgIfSbWVkR6h4bLiiRCyS6SCCW7SCKU7CKJULKLJKKnprh6PlJLIWHP8zJNnldSUBni7ftPh0tvxitjUZUif9zF/e+fmvBe9+38YzC2v+8UbfsUeOmtvhReChoAsltGabyybyIYW9jLy1Oxacd9FyLLhw+Gz2X1Md722Nw4jQ9kb6XxOlnSGQBOzZHjFqspNklndpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSoWQXSURv1dnp1D0AZKfa6gBvO3AmctNzPO7kSFl4JedVkbfU5XFek906yLeEvrEQXjvk+lx422IAGJjm03MzoyM0Xt/B6+znbwlvu1zPRsY2XKrTeH6JP+eVYvjAL4/y6bVz/fy4Vev8Sb1c59ts0543P9Ob0pldJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUS0VN19hi6LXOkRF+PzHcfOM9rulYP34HxpqgMRJa5Nh7/i5GzND5ZCNeE/1ThT/HA2cgS2jcEd/YCAJTGeD25tC382Ebe4Adu6K0VGq8X+LlqeReZLx95zhaW+VbWxxf5cSlk+OCLUik8Wd8jW5c3S2d2kUQo2UUSoWQXSYSSXSQRSnaRRCjZRRKhZBdJxBVVZze29XHkbavGy6YoR9aNz5Pl07MrvFYdq/GXdvOa7FhhkcZ/XyoGY//y1mdp22pkDEBlmL9Eqv28fY4ct8EzFd72Ip/Hv7SPz7Vf2U6el8geBZkML8SvVPlxGYzMh+/vDz/2apnfdrOrykfP7Gb2hJnNmtnRNZc9amanzOxw4+veJu9fRDbJRv6M/zGAe9a5/Pvufnvj69n2dktE2i2a7O7+IgC+/5CI9LxWPqD7upm90vgzf2voSmZ2wMymzGyqtrjUwt2JSCuaTfYfAtgH4HYAMwC+G7qiux9090l3n8wOhT9IEpHOairZ3f2Mu9fcvQ7gRwDubG+3RKTdmkp2M1u7D+8XABwNXVdEekO0zm5mTwO4G8AOM5sG8B0Ad5vZ7Vgt+Z0E8NUO9vFdbD57dYRvkl5yvj56rY/Xi43MMc7xcjAqkf9exq87T+P9GV6Pnlq+MRg78fYYbTtSilRtY+sE5CLHjbRfGucbsM/dGPwoCABw4RbeucJEeHzCnhH++dF4cZ7G9w2do/HY/uwXh8Pr1v8fmesOAJXl5obHRFu5+4PrXPx4U/cmIl2j4bIiiVCyiyRCyS6SCCW7SCKU7CKJuKKmuNIyUB+fkljbxUtzmet5eWtslE8zZe7Y8RaPD71B4x8tzND4koeXc/7R0l/TtrnlyFTPWmT6buQVtLQ73H7xEyXadt8uXt66rXiJxgey4ed0bz+f7jGRv0jjWVZTBHC2OkzjR2x3MFZZ5stzI0de683PAheRq4WSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEXFl1dsKyke1/R/k81E/veZ3G79v6cjBWzPCthYvGa/ijkSmsu3Nk62EAy/WFYKwwzJc0Lo/w6ZSxl8j8Dbz1njvCYwTu2/0Kbbstx8c2ZCL7Llc83Peb+t6mbW+K3Dd/xoD/Ku2h8fPLZN5zuYVzMCn/68wukgglu0gilOwiiVCyiyRCyS6SCCW7SCKU7CKJuLLq7GSurkeW1/URvrTvWCFcqwaArIVrurE6+nCkjr4lw99z+4zXwmtkbvWtu0/TtkfHP0Lj5dHIls57+RiDT46Fxy/E6uhbsmS/ZwDlyPLg2zPh2y8aH3/AVz8ATlaHaPzFuY/S+PIKeU77Ive+wh93iM7sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiKumzh7dWrjO68XTJb498MvZ64OxsRyv0cfmXcfqyfH58uG663g/79t/38jn+a8s8Rr/8Bbe9+nSlmCsFjnX/OXANI2P5+ZovOThvp+s7KBtT1WrNH5iZZzGj8/zrbLLrM4eea02K3pmN7O9ZvaCmR0zs1fN7BuNy7eZ2XNmdrzxnWeLiHTVRv6MrwL4lrt/DMAnAHzNzPYDeBjA8+5+E4DnG7+LSI+KJru7z7j7y42fFwAcA7AHwP0Anmxc7UkAD3SqkyLSug/1AZ2ZXQ/g4wD+AGCXu88Aq28IAHYG2hwwsykzm6otLrXWWxFp2oaT3cyGAPwCwDfdfX6j7dz9oLtPuvtkdogssiciHbWhZDezPFYT/afu/svGxWfMbKIRnwAw25kuikg7REtvZmYAHgdwzN2/tyb0DICHADzW+P7rjvRwg6zGyxXL5wZp/LeVfTR+ZGQiGNtZ5FM168771k+2FgaAwRyfjnnHyJvB2MeKfIpr335+34fOX0vjs/N8qudrF3YFY6VRXtabKPAtmVsxV2vtr8ypueto/OS5bTReLZHUiy0l3WRlbiN19rsAfBnAETM73LjsEawm+c/N7CsA3gTwxea6ICKbIZrs7v57hN9LPtPe7ohIp2i4rEgilOwiiVCyiyRCyS6SCCW7SCKurCmuhFV58TG7xB9qbprXfM+T0X+zA9tp29hbqg/wpYO37OB1/In+8IDGz4/wbZHHcnww5OGL19D45QsDPO7h+MU5XuuOjT8YK/DjMpQLTw1+e2WEtp25PErj03M8XqvxJ51uMV7hS0V7ITKfO0BndpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSoWQXScRVU2ev9/HlmvPzvHZZnOa1S7IrMmLvmeXIdtEr23jfLlV4TfjIyO5grD+yXXSf8SWTK7XI9sCRdQQK58PtM6f52IbfLd5M49ki7ztTj4zLQGQNAmT468Ujx4XKN1dHj9GZXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEnHV1Nlj68Z75G0tf5nXNgfOhuvVlWFei64MRQ4zHyIAW+Gdn5kP1+EPGV/3fVdkS+edgzy+dA2vlV8shteVz8wWaNv8BX7cMjP8vhnP8ue7soU/KZ6L1MIjt0/HJ3Rmx2ad2UVSoWQXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBEb2Z99L4CfABjHakX4oLv/wMweBfD3AM42rvqIuz/bqY5GRcqa9cgjXdjL3/eq/eGacP8lvu77yEk+77o4wwurpa28b7Xj4b3ATznfJ/yNyDbltUgpu85L5ehjz0ts2nYk7pGp9nUyL7ze5Nrr78pHBkfEtDLfvUkbGVRTBfAtd3/ZzIYBHDKz5xqx77v7P3eueyLSLhvZn30GwEzj5wUzOwZgT6c7JiLt9aH+Zzez6wF8HMAfGhd93cxeMbMnzGxroM0BM5sys6na4lJLnRWR5m042c1sCMAvAHzT3ecB/BDAPgC3Y/XM/9312rn7QXefdPfJLNkvTUQ6a0PJbmZ5rCb6T939lwDg7mfcvebudQA/AnBn57opIq2KJruZGYDHARxz9++tuXxizdW+AOBo+7snIu2ykU/j7wLwZQBHzOxw47JHADxoZrdjtUByEsBXO9LDDfLI8ru1IV4qWRrkt18aC78vZi9HpmKW+W1nwzsLr7bnq0HTKZHZFX5c+s/SMLIV3r7ax0tIleFwvMp3e0Z1MPKcDvA4La/FTnOxyljlyhuispFP43+P9R9692rqIvKhXXlvTyLSFCW7SCKU7CKJULKLJELJLpIIJbtIIq6apaRjokv/RlQL4WmsVb6jclyka1aPFH3ZDNvI1sN8K2ogsqNzdElmz5LxDbFadmxX5dhyzSy8+TNMu05ndpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSoWQXSYS5t7ik7oe5M7OzAN5Yc9EOAOc2rQMfTq/2rVf7BahvzWpn365z97H1Apua7B+4c7Mpd5/sWgeIXu1br/YLUN+atVl905/xIolQsoskotvJfrDL98/0at96tV+A+tasTelbV/9nF5HN0+0zu4hsEiW7SCK6kuxmdo+Z/a+ZnTCzh7vRhxAzO2lmR8zssJlNdbkvT5jZrJkdXXPZNjN7zsyON76vu8del/r2qJmdahy7w2Z2b5f6ttfMXjCzY2b2qpl9o3F5V48d6demHLdN/5/dzLIA/gTgswCmAbwE4EF3/59N7UiAmZ0EMOnuXR+AYWafBrAI4Cfufkvjsn8CcMHdH2u8UW5193/okb49CmCx29t4N3Yrmli7zTiABwD8Hbp47Ei//habcNy6cWa/E8AJd3/d3csAfgbg/i70o+e5+4sALrzv4vsBPNn4+Umsvlg2XaBvPcHdZ9z95cbPCwDe2Wa8q8eO9GtTdCPZ9wB4a83v0+it/d4dwG/M7JCZHeh2Z9axy91ngNUXD4CdXe7P+0W38d5M79tmvGeOXTPbn7eqG8m+3upfvVT/u8vd7wDweQBfa/y5KhuzoW28N8s624z3hGa3P29VN5J9GsDeNb9fA+B0F/qxLnc/3fg+C+BX6L2tqM+8s4Nu4/tsl/vzrl7axnu9bcbRA8eum9ufdyPZXwJwk5ndYGYFAF8C8EwX+vEBZlZsfHACMysC+Bx6byvqZwA81Pj5IQC/7mJf3qNXtvEObTOOLh+7rm9/7u6b/gXgXqx+Iv9nAN/uRh8C/boRwB8bX692u28Ansbqn3UVrP5F9BUA2wE8D+B44/u2HurbUwCOAHgFq4k10aW+fQqr/xq+AuBw4+vebh870q9NOW4aLiuSCI2gE0mEkl0kEUp2kUQo2UUSoWQXSYSSXSQRSnaRRPw/6mByfzknyV4AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", + "\n", + "** Based on the above image, It could be reasonable to prune the weights along the margins **\n", + "\n", + "\n", + "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "def mask(mat):\n", + " mat = np.abs(mat)\n", + " std = np.std(mat)\n", + " mean = np.mean(mat)\n", + " threshold = np.abs(mean - 2.5*std)\n", + " print(\"Mean: {}, STD: {}, Threshold: {}\\n\".format(mean,std,threshold))\n", + " mask = np.zeros_like(mat)\n", + " mask[mat>threshold] = 1\n", + " print(\"Non-Null Weight Ratio : {}\".format( (np.count_nonzero(mask==0)/(np.size(mask)))*100 ))\n", + " return mask" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean: 0.03005061112344265, STD: 0.032245710492134094, Threshold: 0.050563665106892586\n", + "\n", + "Non-Null Weight Ratio : 75.25310905612244\n" + ] + } + ], + "source": [ + "W0_mask = mask(model.W_0.detach().cpu().numpy())\n", + "W0_mask = torch.Tensor(W0_mask)\n", + "W0_mask = W0_mask.to(device)\n", + "model.set_mask(W0_mask)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we have explicitly set some entries in one of the the weight matrices to zero, and ensured via the mask, that they will not be updated by gradient descent. Fine tune the model: " + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.28309938311576843. Accuracy: 91.\n", + "Iteration: 0. Loss: 0.27462732791900635. Accuracy: 92.\n", + "Iteration: 0. Loss: 0.20802313089370728. Accuracy: 92.\n", + "Iteration: 0. Loss: 0.2611789107322693. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.26243436336517334. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.22998295724391937. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.23924028873443604. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.19653643667697906. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.1895284205675125. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.19546085596084595. Accuracy: 94.\n" + ] + } + ], + "source": [ + "iter = 0\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_pruned.h5')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", + "\n", + "**I gained 2% accuracy after pruning about 86% of the weights.** \n", + "\n", + "\n", + "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?\n", + "\n", + "**After removing about 94% of the weights the accuracy began to degrade.**" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAD4CAYAAABxC1oQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAdAklEQVR4nO3dfZBdd33f8fdnV9LKlmxZsrBRJRnLIFpMk5iiGFoXQoIflEwHwQwudtOMmDEVZPA0NKGDA1PsMcOMCAWHmVKCsBVMBjDGhqBhVBRhoJA2OJKJakt2jIQQ8lpCQlrZlmQ97MO3f9yzzt378Ltn9969T+fzmjmz95zfefjt3Xu/+/ud83tQRGBmVgQDnc6AmVm7OOCZWWE44JlZYTjgmVlhOOCZWWHMaefFBhcsiLmXLGnnJc0KZfS5EcZPn1Yz57jptxfE8ZHxXPs+9vi5bRGxtpnrtVNTAU/SWuAzwCBwb0RsTO0/95IlXPGH/6WZS5pZwsHP3dP0OY6PjPP3267Ite/gsr1Lm75gG8044EkaBD4L3AAMAzskbYmIJ1uVOTNrvwAmmOh0NmZFMyW8a4F9EbEfQNIDwDrAAc+shwXBaOSr0vaaZgLecuCZsvVh4A2VO0naAGwAmLNocROXM7N2cQmvWq0bo1X91CJiE7AJYP7yle7HZtblgmC8T7ucNhPwhoGVZesrgEPNZcfMusFEddmlLzQT8HYAqyWtAp4FbgH+Q0tyZWYdE8C4A95UETEm6XZgG6VmKZsjYk/LcmZmHeMSXg0RsRXY2qK8mFkXCGDU9/DMrAiCcJXWzAoiYLw/450DnplNVepp0Z88WoqZVRDjOZdcZ5PWSnpa0j5Jd9RI/2NJT0p6XNIjkl5RlrZe0t5sWd/sb+YSnplNUXpo0dSAKy/J2ef+H4A1EfGipD8E/gx4l6QlwJ3Amixbj2XHnphpflzCM7MpSu3wWlbCe6nPfUScByb73P/T9SK+HxEvZqs/ptSJAeAmYHtEjGRBbjvQ1FBULuGZWZWJ/CW8pZJ2lq1vyrqTTsrV577MbcD/Shy7PG/GanHAM7MpJkt4OR2LiDWJ9Fx97gEk/UdK1dffmu6xeblKa2ZTBGKcgVxLDrn63Eu6HvgI8LaIODedY6fDAc/MqkyEci05vNTnXtI8Sn3ut5TvIOl1wOcpBbujZUnbgBslLZa0GLgx2zZjrtKa2RSBOB+DrTlXnT73ku4GdkbEFuCTwELg65IADkbE2yJiRNLHKAVNgLsjYqSZ/DjgmdkUpYbHrav81epzHxEfLXt9feLYzcDmVuXFAc/MqkzjoUVPccDrA6OrztZNmziT/hNrXpOdiF5In3/wxfolhYl56QdujWpVapD1uSf780s72yLEePTn7X0HPDOrMuESnpkVQemhRX+Ghv78rcxsxlr90KKbOOCZWZXxFg0e0G0c8MxsismeFv3IAc/Mqkz4Ka2ZFUFp8AAHPOuQiVefTqbP/emCNuWkmibS93rGE23tYk66Hd6852bvSzcxN50+MDprl+56gRhtUdeybuOAZ2ZTROCGx2ZWFHLDYzMrhsAlPDMrED+0MLNCCHIP7tlzHPDMbIrSNI39GRr687cysybkn2S71zjg9YDR54eS6QNX1R8Pb+7++cljzy1PNzibf3BeMn3BoXRbutEF9b84JxcmD23Y/nCgifaHjdrZDZ5p0L7wgqYmz+pqQWt7WkhaC3yG0hDv90bExor0NwN/Dvw6cEtEPFSWNg48ka0ejIi3NZOXpgKepAPASWAcGGswXZuZ9YhWlfAkDQKfBW6gNAvZDklbIuLJst0OAu8GPljjFGci4pqWZIbWlPB+OyKOteA8ZtYFItTKEt61wL6I2A8g6QFgHfBSwIuIA1lak8NvN9afz57NbMZKDy0Gcy3AUkk7y5YNFadbDjxTtj6cbctrfnbeH0t6e3O/WfMlvAD+RlIAn4+ITZU7ZG/ABoA5ixY3eTkzm33TmtPiWINbWbXqxtO5AXpFRBySdBXwPUlPRMTPpnH8FM0GvOuyzFwGbJf0jxHxw/IdsiC4CWD+8pX9e6fXrE+UHlq07CntMLCybH0FcCh3XiIOZT/3S/oB8DpgxgGvqSptWWaOAt+kVF83sx43zkCuJYcdwGpJqyTNA24BtuQ5UNJiSUPZ66XAdZTd+5uJGQc8SQskXTT5GrgR2N1MZsys8yZ7WuRZGp4rYgy4HdgGPAU8GBF7JN0t6W0Akn5T0jBwM/B5SXuyw18D7JT0/4DvAxsrnu5OWzNV2suBb0qaPM9XIuI7zWTGahs60ujPNPM/49Cz6YHhzq06l0wffXX6wdrAM/XbAS4YTo+5Nvpc58b56+d2dnm0chKfiNgKbK3Y9tGy1zsoVXUrj/u/wK+1LCM08U3JHjP/RgvzYmZdIAJGJ/qzAYd7WpjZFKUqrQOemRWE+9KaWSG0uFlKV3HAM7MKrtKaWYF4TgsrpHkNhodqxujC3m36cWGDvgKnV6bTNd66vLRa6Smtp2k0swLwEO9mViiu0ppZIfgprZkVip/SmlkhRIgxBzwzKwpXac2sEHwPz2wWLNqXTr/w6Fgy/fC/mb2P75In020Ej/9aOiCMLmowbNaZ+lXGOWeSh7aFA56ZFYLb4ZlZofRrO7z+fBRjZjMWAWMTA7mWPCStlfS0pH2S7qiR/mZJP5E0JumdFWnrJe3NlvXN/m4u4ZlZlVZVaSUNAp8FbqA0g9kOSVsq5qY4CLwb+GDFsUuAO4E1lJ6lPJYde2Km+XEJz8ymaOUkPpRmMtwXEfsj4jzwALBuyvUiDkTE40Dlk56bgO0RMZIFue3A2mZ+N5fwzKxK5C/hLZW0s2x9UzYX9aTlwDNl68PAG3Keu9axy/NmrBYHPDOrMo2HFsciYk0ivdaJ8o4L1syxNTngWZJecyqZPvTDi5Lpg+frfz5HL0xf+/lVs9jObk/6e3N2SYN2dpek29nF/PSAdxOpwxNt9NohoqXt8IaB8tEBVwANRhOccuxbKo79QTOZ8T08M6sgxicGci057ABWS1olaR5wC7AlZ0a2ATdKWixpMXBjtm3GHPDMrEqEci2NzxNjwO2UAtVTwIMRsUfS3ZLeBiDpNyUNAzcDn5e0Jzt2BPgYpaC5A7g72zZjrtKa2RSt7ksbEVuBrRXbPlr2egel6mqtYzcDm1uVFwc8M5sqSvfx+pEDnplV6deuZQ54ZjZFZA8t+pEDnplVcZXWrIb5J9Lt0ea+WP+b88IV6blPT13Z4NwvzLzaNfLa9LEXHUgfP+fS5gatmxhZ0NTxs20aPS16SsNyq6TNko5K2l22bYmk7dkIBtuzNjJm1gciWtcspdvkqah/keoOu3cAj0TEauCRbN3M+kQLBw/oKg0DXkT8EKhs7LcOuD97fT/w9hbny8w6KCLf0mtmeg/v8og4DBARhyVdVm9HSRuADQBzFrnma9btAjHRp09pZ/23iohNEbEmItYMLujuG7VmVhI5l14z04B3RNIygOzn0dZlycw6quAPLWrZAkyOL78e+FZrsmNmXaFPi3gN7+FJ+iqlMamWZiMa3AlsBB6UdBul8ehvns1M2sxd8Mv0f+EzL09/auf+XXq8u8WPpwevOH3VxXXTIt0Mj3kj6f/HMWf2vnGN8jYxnt7hooXpdnpnznT37Z1eLL3l0TDgRcStdZLe2uK8mFkXCGBioqABz8wKJoCilvDMrHh6sY1dHg54ZlatTwNef7YuNLMm5GuSkvfBhqS1kp6WtE9SVTdUSUOSvpalPyrpymz7lZLOSNqVLX/R7G/mEp6ZVWtRCU/SIPBZ4AZKs5DtkLQlIp4s2+024EREvErSLcAngHdlaT+LiGtakxsHvL5wydP1P50Tc9PHNhrk6PLHGuyRHsEJjdVPGzybPvb8os7Vq06tTKfH0aFk+nNn081W0kd3WEC07inttcC+iNgPIOkBSn3xywPeOuCu7PVDwP+QNCtPTVylNbMalHNhqaSdZcuGihMtB54pWx/OttXcJ5vl7Hng0ixtlaR/kPS/Jb2p2d/KJTwzq5a/cH0sItYk0muV1CrPXm+fw8AVEXFc0uuBv5b02oh4IXfuKriEZ2bVWte1bBgov0GwAjhUbx9Jc4BFwEhEnIuI4wAR8RjwM+DVM/p9Mg54ZjbVZMPjPEtjO4DVklZJmgfcQqkvfrnyvvnvBL4XESHpZdlDDyRdBawG9jfzq7lKa2ZVWtXwOCLGJN0ObAMGgc0RsUfS3cDOiNgC3Af8laR9lAYbviU7/M3A3ZLGgHHgfRGR7rzdgAOemVVrYV/aiNgKbK3Y9tGy12epMQBJRDwMPNyyjOCAZ2Y1qE97WjjgZQbPpv+jjc+v/wkYGkkfe27J7H56fvX6+mlDxxv9p07nbWIwfZv33MvTwxwduTbVHi197TkvpvM+MJpMntV2fBMLx5PpA0PpdGjQQLKTenSsuzwc8MysQu4HEj3HAc/MqrmEZ2aF0aDLYK9ywDOzqTwAqJkViZ/Smllx9GnAc9cyMysMl/AyFxxN/0t74VX100YXps+9YDidfnpFOv3cinSDs7kXnq+bNng4nbm5p9P3asYuTI/rNjA286JAozaCYxemj1/4bPrO+sii2bsPNXS40Vent79artKaWTEELe1a1k0c8Mysmkt4ZlYUrtKaWXE44JlZYTjgmVkRKFylNbMi8VPa/vbCVen00cX1J1iNK9Lt5M6smJdMHzqabuumk+k/0/zd9c9//pLm/lWPLky3Tdf4zM8/lh5Kr2F3ztMvb9Ruvk+LKW3QyhKepLXAZygN8X5vRGysSB8CvgS8HjgOvCsiDmRpf0ppou5x4D9HxLZm8tKwp4WkzZKOStpdtu0uSc9K2pUtv9dMJsysy7Ro1rJsEp7PAr8LXA3cKunqit1uA05ExKuAe4BPZMdeTWl+i9cCa4H/OTmpz0zl6Vr2xexile6JiGuyZWuNdDPrRfFP9/EaLTlcC+yLiP0RcR54AFhXsc864P7s9UPAWyUp2/5ANl3jz4F92flmrGHAi4gfUppJyMyKonXz0i4HnilbH8621dwnIsaA54FLcx47Lc0MHnC7pMezKu/iejtJ2iBpp6Sd46dPN3E5M2sXTeRbgKWT3+9s2VB5qhqnrwyV9fbJc+y0zDTgfQ54JXANcBj4VL0dI2JTRKyJiDWDCxrcpTazXnNs8vudLZsq0oeBlWXrK4BD9faRNAdYRKlWmefYaZlRwIuIIxExHhETwBdosl5tZl2mdVXaHcBqSaskzaP0EGJLxT5bgPXZ63cC34uIyLbfImlI0ipgNfD3TfxWM2uWImlZRBzOVt8B7E7tb2Y9pIUNjyNiTNLtwDZKzVI2R8QeSXcDOyNiC3Af8FeS9lEq2d2SHbtH0oPAk8AY8P6IaDT/ZVLDgCfpq8BbKNXVh4E7gbdIuoZSjD8AvLeZTHSDgfrN7AAYOpJ4q1JplP7KTRlIf/pOXVX/MzBvpLkxXhu1hTu3eObnP3dpejy7OQ3G6js3NONLN21i9YvJ9IULzibTX9y1pJXZab0WtsPLWnFsrdj20bLXZ4Gb6xz7ceDjrcpLw4AXEbfW2HxfqzJgZl2oT9tsu6eFmU0hXnoC23cc8MxsKg8eYGaF4oBnZoXhgGdmReEqrc3Ym256PJm+e+TlyfSL551Lpr9x6c/rpn3tW29OHjt2QTKZk1ekm52cW5z+ZgzWn0GyYbOTRs5dlm6SpbGZn39wSfo9b9Re5+Sp9BvbdFOl2eaAZ2aFEH5Ka2ZF4hKemRWF7+GZWXE44JlZIeQfCaXnOOCZ2RTCVVozKxAHvD531Zt+kUy/68pv1U27dmhu8tgfn023F3vsoiuT6dddsC+Zfjbq/xm/Rrod3pwzyWR4w/PJ5MHdFzc4QX3nlqXH5Bp8Pt1abf6l6cyfPTG/fuLcdLuLKy5LT+MydyD9N91/9NJketdzwDOzwnDAM7NC8GgpZlYofRrwmhv/28z60jSmaZz5NaQlkrZL2pv9rDndq6T12T57Ja0v2/4DSU9L2pUtlzW6pgOemVVR5FuadAfwSESsBh7J1qfmQ1pCaR6dN1CaHfHOisD4+xFxTbYcbXRBBzwzmyrvFI3NB7x1wP3Z6/uBt9fY5yZge0SMRMQJYDuwdqYXdMAzs2r5A95SSTvLlg3TuMrlk9O9Zj9rVUmXA8+UrQ9n2yb9ZVad/W+SGo4H5ocWmRdH5yXTv36i/lzjDzco288fGE2mj0f6/872iauT6SvnptuMNWOsiXZ2AANXn6yfNrwwfXCDj+/ZYw0G80u0tbtocXqaxVdd/Ktk+oGT6XZ2o8+n55Ds4AyTDU2zp8WxiFhT91zSd4FaAz5+ZBrZqTSZu9+PiGclXQQ8DPwB8KXUyRzwzKyKJlrzmDYirq97DemIpGURcVjSMqDWPbhhSvNiT1oB/CA797PZz5OSvkLpHl8y4LlKa2ZTte8e3hZg8qnreqBWd6ZtwI2SFmcPK24EtkmaI2kpgKS5wL8Ddje6oAOemVVp01PajcANkvYCN2TrSFoj6V6AiBgBPgbsyJa7s21DlALf48Au4FngC40u6CqtmVVrQ8PjiDgOvLXG9p3Ae8rWNwObK/Y5Dbx+utd0wDOzKu5aZmbF4YBnZoXgWcv636Hji5LpDw3Xv12gOQ0+Hc+nx8uLC9Jjqw1emB437uKF9ceFW/avDyWPPfx3/yyZ3qwzxy6sm3bp7gZzu74ife6JOelnbhdcdrpu2mCDOtvTz12eTD9+uv7vBTB0pHe/Wv084nHDp7SSVkr6vqSnJO2R9EfZ9lwdf82sB0XkW3pMnmYpY8CfRMRrgDcC75d0NTk6/ppZb2pTs5S2axjwIuJwRPwke30SeIpSX7Y8HX/NrNe0r+Fx203rRoOkK4HXAY9S0fG33lhUWWfiDQBzFrnWa9YL+vWhRe6eFpIWUuqg+4GIeCHvcRGxKSLWRMSawQULZpJHM2uzdgwA2gm5Al7WV+1h4MsR8Y1s85Gswy+Jjr9m1muCvn1o0bBKm40xdR/wVER8uixpsuPvRup3/O0ZA3sbNDOYzYufSP/fGTqRHrrq+VfUn47wxKL00FSvvu6ZZPov/s/KZHojQ0frT7V46ooGBzf6dzyY/sKdP1e/OdDoaHoKyKUL6zdpARh9It2Mqdf14gOJPPLcw7uO0jhTT0jalW37MKVA96Ck24CDwM2zk0Uza7uiBryI+FvqD8VY1fHXzHpbPzc87t3m4GY2OyJaNgBot3HAM7Nq/RnvHPDMrJqrtGZWDAH0aZXWQ7ybWbU2dC3LOwCJpO9Iek7Styu2r5L0aHb81ySl22/hEl5POLe4wScr0eJdA737n1rpUbPg4nQbw9cs/2XdtMvn158+EuBH2369wcX7W5uqtJMDkGyUdEe2/qEa+30SuBB4b8X2TwD3RMQDkv4CuA34XOqCLuGZWRVNRK6lSbkGIImIR4Ap/6GyDhG/AzzU6PhyLuGZ2VTTq64ulbSzbH1TRGzKeWyuAUjquBR4LiImR8cdpjSKU5IDnplNUWp4nDviHYuINXXPJX0XeHmNpI/MIGtTTl1jW8NMO+CZWbUWjYQSEdfXS5N0RNKyrHQ33QFIjgGXSJqTlfJWAOn5DPA9PDOrQRG5liZNDkAC0xyAJCIC+D7wzukc74BnZlO1b8TjjcANkvYCN2TrSFoj6d7JnST9CPg68FZJw5JuypI+BPyxpH2U7und1+iCrtKaWYX29KWNiOPUGIAkInYC7ylbf1Od4/cD107nmg54mVUPnUimTzz+j3XTBi5Mj6X38zt+Y0Z5ymvuqcR0hyfrj5UHsDfS0xE2bMnZQfMOpEcp/OmBVfXTWp2ZftODg3vm4YBnZlN5Im4zKxSX8MysMPoz3jngmVk1TfRnndYBz8ymClrW8LjbOOCZ2RSiJY2Ku5IDnplVc8Drb+ML02265vzzV9VNG1u6sNXZaZ0GA5utWDaSTD96sFa/b+t7DnhmVgi+h2dmReKntGZWEOEqrZkVROCAZ2YF0p81Wgc8M6vmdnhmVhxFDXiSVgJfojQRxwSlWYk+I+ku4D8Bv8p2/XBEbJ2tjM62g2vTY9rN/9WCumlzzqY/HJfuTqePzU+MZwecWp5OH79g5h/Oo4+6nZ1ViIDx/qzT5hnifQz4k4h4DfBG4P2Srs7S7omIa7KlZ4OdmVWIyLc0QdISSdsl7c1+Lq6z33ckPSfp2xXbvyjp55J2Zcs1ja7ZMOBFxOGI+En2+iTwFDnmfzSzHtaGgAfcATwSEauBR7L1Wj4J/EGdtP9aVuja1eiC05rER9KVwOuAR7NNt0t6XNLmRHTeIGmnpJ3jp09P53Jm1gkBTES+pTnrgPuz1/cDb6+ZnYhHgJPNXgymEfAkLQQeBj4QES8AnwNeCVwDHAY+Veu4iNgUEWsiYs3ggvr3wcysWwTERL4Flk4WaLJlwzQudHlEHIZSTRK4bAaZ/XhW6LpHUrpDPDmf0kqaSynYfTkivpFl8EhZ+heAb9c53Mx6STCdhxbHImJNvURJ36X0wLPSR2aQs0p/CvyS0lxTmyhN23h36oA8T2lFab7HpyLi02Xbl01GZ+AdwO4ZZtrMuk2LmqVExPX10iQdmYwjkpYBR6d57sn4c07SXwIfbHRMnhLedZRuGD4hafKm4IeBW7OnIgEcAN47ncz2mrMvm/kH4NTKdLOSxmavTdREg0+AxtPpg+fTv9vE3P5sz9X32tMObwuwntIE3OuBb03n4LJgKUr3/xoWuhoGvIj4W6DWp9rNUMz6UtsGD9gIPCjpNuAgcDOApDXA+yLiPdn6j4B/ASyUNAzcFhHbgC9Lehml+LQLeF+jC7qnhZlNFUAbhoeKiOPAW2ts3wm8p2z9TXWO/53pXtMBz8yqFbVrmZkVTf92LXPAM7OpAiIc8MysKJrvRdGVHPDMrJrv4Vk/Ghhr7ni3s+tDEW15StsJDnhmVs0lPDMrhiDGG3Sx6VEOeGY21eTwUH3IAc/MqrlZipkVQQDhEp6ZFUKES3hmVhz9+tBC0cbHz5J+BfyibNNS4FjbMjA93Zq3bs0XOG8z1cq8vSIiXtbMCSR9h1Ke8jgWEWubuV47tTXgVV1c2pkaHrqTujVv3ZovcN5mqpvz1m+mNWuZmVkvc8Azs8LodMDb1OHrp3Rr3ro1X+C8zVQ3562vdPQenplZO3W6hGdm1jYOeGZWGB0JeJLWSnpa0j5Jd3QiD/VIOiDpCUm7JO3scF42SzoqaXfZtiWStkvam/1c3EV5u0vSs9l7t0vS73UobyslfV/SU5L2SPqjbHtH37tEvrrifSuCtt/DkzQI/BS4ARgGdgC3RsSTbc1IHZIOAGsiouONVCW9GTgFfCki/mW27c+AkYjYmP2zWBwRH+qSvN0FnIqI/97u/FTkbRmwLCJ+Iuki4DFKEzW/mw6+d4l8/Xu64H0rgk6U8K4F9kXE/og4DzwArOtAPrpeRPwQGKnYvA64P3t9P6UvTNvVyVtXiIjDEfGT7PVJ4ClgOR1+7xL5sjbpRMBbDjxTtj5Md/3RA/gbSY9J2tDpzNRweUQchtIXCLisw/mpdLukx7Mqb0eq2+UkXQm8DniULnrvKvIFXfa+9atOBDzV2NZNbWOui4h/Bfwu8P6s6mb5fA54JXANcBj4VCczI2kh8DDwgYh4oZN5KVcjX131vvWzTgS8YWBl2foK4FAH8lFTRBzKfh4FvkmpCt5NjmT3gibvCR3tcH5eEhFHImI8SpOafoEOvneS5lIKKl+OiG9kmzv+3tXKVze9b/2uEwFvB7Ba0ipJ84BbgC0dyEcVSQuym8lIWgDcCOxOH9V2W4D12ev1wLc6mJcpJoNJ5h106L2TJOA+4KmI+HRZUkffu3r56pb3rQg60tMie+z+58AgsDkiPt72TNQg6SpKpToojRX4lU7mTdJXgbdQGqrnCHAn8NfAg8AVwEHg5oho+8ODOnl7C6VqWQAHgPdO3jNrc97+LfAj4AlgciTLD1O6X9ax9y6Rr1vpgvetCNy1zMwKwz0tzKwwHPDMrDAc8MysMBzwzKwwHPDMrDAc8MysMBzwzKww/j/8/keh/P7k4gAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.colorbar()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Quantization\n", + "\n", + "Now that we have a pruned model that appears to be performing well, let's see if we can make it even smaller by quantization. To do this, we'll need a slightly different neural network, one that corresponds to Figure 3 from the paper. Instead of having a matrix of float values, we'll have a matrix of integer labels (here called \"labels\") that correspond to entries in a (hopefully) small codebook of centroids (here called \"centroids\"). The way that I've coded it, there's still a mask that enforces our desired sparsity pattern." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "class MultilayerPerceptronQuantized(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim, hidden_dim,mask,labels,centroids):\n", + " super(MultilayerPerceptronQuantized, self).__init__()\n", + " self.mask = torch.nn.Parameter(mask,requires_grad=False)\n", + " self.labels = torch.nn.Parameter(labels,requires_grad=False)\n", + " self.centroids = torch.nn.Parameter(centroids,requires_grad=True)\n", + "\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim))\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim))\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim))\n", + "\n", + " def forward(self, x):\n", + " W_0 = self.mask*self.centroids[self.labels]\n", + " hidden = torch.tanh(x@W_0 + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice what is happening in the forward method: W_0 is being reconstructed by using a matrix (self.labels) to index into a vector (self.centroids). The beauty of automatic differentiation allows backpropogation through this sort of weird indexing operation, and thus gives us gradients of the objective function with respect to the centroid values!\n", + "\n", + "### Q6: However, before we are able to use this AD magic, we need to specify the static label matrix (and an initial guess for centroids). Use the k-means algorithm (or something else if you prefer) figure out the label matrix and centroid vectors. PROTIP1: I used scikit-learns implementation of k-means. PROTIP2: only cluster the non-zero entries" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.cluster import KMeans\n", + "# convert weight and mask matrices into numpy arrays\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "mask = model.mask.detach().cpu().numpy()\n", + "\n", + "# Figure out the indices of non-zero entries \n", + "inds = np.where(mask!=0)\n", + "# Figure out the values of non-zero entries\n", + "vals = np.expand_dims(W_0[inds],1)\n", + "num_cluster = 2\n", + "km = KMeans(n_clusters=num_cluster).fit(vals)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([784, 64])" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "labels = np.zeros_like(W_0)\n", + "labels[inds] = km.labels_\n", + "labels = torch.tensor(labels,dtype=torch.long,device=device).squeeze()\n", + "centers = km.cluster_centers_\n", + "centroids = torch.tensor(centers,device=device).squeeze()\n", + "labels.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can instantiate our quantized model and import the appropriate pre-trained weights for the other network layers. " + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,model.mask,labels,centroids)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "model_q = model_q.to(device)\n", + "\n", + "# Copy pre-trained weights from unquantized model for non-quantized layers\n", + "model_q.b_0.data = model.b_0.data\n", + "model_q.W_1.data = model.W_1.data\n", + "model_q.b_1.data = model.b_1.data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we can fine tune the quantized model. We'll adjust not only the centroids, but also the weights in the other layers." + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.2916867434978485. Accuracy: 92.\n", + "Iteration: 0. Loss: 0.205495223402977. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.20868274569511414. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.261547327041626. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.23044338822364807. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.22543968260288239. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.23166033625602722. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.24529825150966644. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.22417983412742615. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.21196895837783813. Accuracy: 94.\n" + ] + } + ], + "source": [ + "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", + "iter = 0\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model_q(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model_q(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_quantized.h5')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After retraining, we can, just for fun, reconstruct the pruned and quantized weights and plot them as images:" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAALMElEQVR4nO3dT4ic9R3H8c+nq/awetjUJmxjqFZzqBQayxKElGKRSswlerCYg6QgrAUFRQ8Ve9BjKDW2h2JZazAtVhFUzCG0hiAED4qrpPlj2iZKqmuWbCUF415s1m8P+6SscWZnnOd55nk23/cLlpl5Znbny+g7z+z8ZvZxRAjAxe9rTQ8AYDiIHUiC2IEkiB1IgtiBJC4Z5p2NXD4al4ytGuZdAqmc+88ZLXw6707XlYrd9mZJv5U0IukPEbFjudtfMrZK33rogTJ3CWAZpx7/TdfrBn4ab3tE0u8k3SrpeknbbF8/6M8DUK8yv7NvlHQiIt6PiM8kPS9pazVjAahamdjXSvpwyeWZYtsX2J60PW17emF+vsTdASijTOydXgT40ntvI2IqIiYiYmJkdLTE3QEoo0zsM5LWLbl8laRT5cYBUJcysb8lab3ta2xfJulOSXuqGQtA1QZeeouIc7bvk/RXLS697YqIo5VNBqBSpdbZI2KvpL0VzQKgRrxdFkiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSRKHbLZ9klJZyUtSDoXERNVDAWgeqViL/w4Ij6u4OcAqBFP44EkysYekl61/bbtyU43sD1pe9r29ML8fMm7AzCosk/jN0XEKdurJe2z/feIOLD0BhExJWlKkr6+bl2UvD8AAyq1Z4+IU8XpnKSXJW2sYigA1Rs4dtujtq84f17SLZKOVDUYgGqVeRq/RtLLts//nD9HxF8qmQpA5QaOPSLel/T9CmcBUCOW3oAkiB1IgtiBJIgdSILYgSSq+CAMLmLXPfjGstef2HnjkCZBWezZgSSIHUiC2IEkiB1IgtiBJIgdSILYgSRYZ8eyWEe/eLBnB5IgdiAJYgeSIHYgCWIHkiB2IAliB5IgdiAJYgeSIHYgCWIHkiB2IAliB5IgdiAJYgeS4PPsFXjvzt8ve/21z/+81PfXqddsuHj03LPb3mV7zvaRJdtW2d5n+3hxOlbvmADK6udp/DOSNl+w7WFJ+yNivaT9xWUALdYz9og4IOnMBZu3StpdnN8t6baK5wJQsUFfoFsTEbOSVJyu7nZD25O2p21PL8zPD3h3AMqq/dX4iJiKiImImBgZHa377gB0MWjsp22PS1JxOlfdSADqMGjseyRtL85vl/RKNeMAqEvPdXbbz0m6SdKVtmckPSpph6QXbN8t6QNJd9Q5ZHZNroWv5PcA1Dn7Snx/Qs/YI2Jbl6turngWADXi7bJAEsQOJEHsQBLEDiRB7EASfMS1AmWXYVbiMk4bNLksWPZjzU1gzw4kQexAEsQOJEHsQBLEDiRB7EASxA4kwTo7Vqwyf6K7jevgdWPPDiRB7EASxA4kQexAEsQOJEHsQBLEDiTBOntyK/lPRTf981ca9uxAEsQOJEHsQBLEDiRB7EASxA4kQexAEqyzJ1fmM+FV/HwMT889u+1dtudsH1my7THbH9k+WHxtqXdMAGX18zT+GUmbO2x/IiI2FF97qx0LQNV6xh4RBySdGcIsAGpU5gW6+2wfKp7mj3W7ke1J29O2pxfm50vcHYAyBo39SUnXStogaVbS491uGBFTETERERMjo6MD3h2AsgaKPSJOR8RCRHwu6SlJG6sdC0DVBord9viSi7dLOtLttgDaoec6u+3nJN0k6UrbM5IelXST7Q2SQtJJSffUOGPr1f2ZcNaqUYWesUfEtg6bn65hFgA14u2yQBLEDiRB7EASxA4kQexAEnzEdQVo8s894+LBnh1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgnX2CtT955jbjI/frhzs2YEkiB1IgtiBJIgdSILYgSSIHUiC2IEkWGcfgot5Lfq6B99Y9voTO28c0iTohT07kASxA0kQO5AEsQNJEDuQBLEDSRA7kATr7IUy68W9vheDYQ2/Wj337LbX2X7N9jHbR23fX2xfZXuf7ePF6Vj94wIYVD9P489JeigivivpRkn32r5e0sOS9kfEekn7i8sAWqpn7BExGxHvFOfPSjomaa2krZJ2FzfbLem2uoYEUN5XeoHO9tWSbpD0pqQ1ETErLf6DIGl1l++ZtD1te3phfr7ctAAG1nfsti+X9KKkByLik36/LyKmImIiIiZGRkcHmRFABfqK3falWgz92Yh4qdh82vZ4cf24pLl6RgRQhZ5Lb7Yt6WlJxyJi55Kr9kjaLmlHcfpKLRMOCcs4g6nzceO/SbX6WWffJOkuSYdtHyy2PaLFyF+wfbekDyTdUc+IAKrQM/aIeF2Su1x9c7XjAKgLb5cFkiB2IAliB5IgdiAJYgeS4COuFei1Htzmj8Cylp0He3YgCWIHkiB2IAliB5IgdiAJYgeSIHYgCdbZh6DJdXjW0XEee3YgCWIHkiB2IAliB5IgdiAJYgeSIHYgCdbZW4C1cAwDe3YgCWIHkiB2IAliB5IgdiAJYgeSIHYgiZ6x215n+zXbx2wftX1/sf0x2x/ZPlh8bal/XACD6udNNeckPRQR79i+QtLbtvcV1z0REb+ubzwAVenn+OyzkmaL82dtH5O0tu7BAFTrK/3ObvtqSTdIerPYdJ/tQ7Z32R7r8j2TtqdtTy/Mz5caFsDg+o7d9uWSXpT0QER8IulJSddK2qDFPf/jnb4vIqYiYiIiJkZGRysYGcAg+ord9qVaDP3ZiHhJkiLidEQsRMTnkp6StLG+MQGU1c+r8Zb0tKRjEbFzyfbxJTe7XdKR6scDUJV+Xo3fJOkuSYdtHyy2PSJpm+0NkkLSSUn31DIhgEr082r865Lc4aq91Y8DoC68gw5IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJBwRw7sz+9+S/rVk05WSPh7aAF9NW2dr61wSsw2qytm+HRHf7HTFUGP/0p3b0xEx0dgAy2jrbG2dS2K2QQ1rNp7GA0kQO5BE07FPNXz/y2nrbG2dS2K2QQ1ltkZ/ZwcwPE3v2QEMCbEDSTQSu+3Ntv9h+4Tth5uYoRvbJ20fLg5DPd3wLLtsz9k+smTbKtv7bB8vTjseY6+h2VpxGO9lDjPe6GPX9OHPh/47u+0RSf+U9BNJM5LekrQtIt4d6iBd2D4paSIiGn8Dhu0fSfpU0h8j4nvFtl9JOhMRO4p/KMci4hctme0xSZ82fRjv4mhF40sPMy7pNkk/U4OP3TJz/VRDeNya2LNvlHQiIt6PiM8kPS9pawNztF5EHJB05oLNWyXtLs7v1uL/LEPXZbZWiIjZiHinOH9W0vnDjDf62C0z11A0EftaSR8uuTyjdh3vPSS9avtt25NND9PBmoiYlRb/55G0uuF5LtTzMN7DdMFhxlvz2A1y+POymoi906Gk2rT+tykifiDpVkn3Fk9X0Z++DuM9LB0OM94Kgx7+vKwmYp+RtG7J5asknWpgjo4i4lRxOifpZbXvUNSnzx9Btzida3ie/2vTYbw7HWZcLXjsmjz8eROxvyVpve1rbF8m6U5JexqY40tsjxYvnMj2qKRb1L5DUe+RtL04v13SKw3O8gVtOYx3t8OMq+HHrvHDn0fE0L8kbdHiK/LvSfplEzN0mes7kv5WfB1tejZJz2nxad1/tfiM6G5J35C0X9Lx4nRVi2b7k6TDkg5pMazxhmb7oRZ/NTwk6WDxtaXpx26ZuYbyuPF2WSAJ3kEHJEHsQBLEDiRB7EASxA4kQexAEsQOJPE/ppWOB0CAZMMAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "W_0 = (model_q.mask*model_q.centroids[model_q.labels]).detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Certainly a much more parsimonious representation. The obvious question now becomes:\n", + "\n", + "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", + "\n", + "I got down to two centroids before things got wierd. \n", + "\n", + "\n", + "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? \n", + "\n", + "Somewhat surprisingly, the accuracy difference the between the compressed form of the model that was trained on one epoch and the compressed form of the model that was trained on 10 epochs is only about 3%." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data/MNIST/processed/test.pt b/data/MNIST/processed/test.pt new file mode 100644 index 0000000..a397d41 Binary files /dev/null and b/data/MNIST/processed/test.pt differ diff --git a/data/MNIST/processed/training.pt b/data/MNIST/processed/training.pt new file mode 100644 index 0000000..ccf1ca7 Binary files /dev/null and b/data/MNIST/processed/training.pt differ diff --git a/data/MNIST/raw/t10k-images-idx3-ubyte b/data/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/data/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/data/MNIST/raw/t10k-labels-idx1-ubyte b/data/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/data/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/data/MNIST/raw/train-images-idx3-ubyte b/data/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/data/MNIST/raw/train-images-idx3-ubyte differ diff --git a/data/MNIST/raw/train-labels-idx1-ubyte b/data/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/data/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/deep_compression_exercise.ipynb b/deep_compression_exercise.ipynb index f903c10..7b8e422 100644 --- a/deep_compression_exercise.ipynb +++ b/deep_compression_exercise.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -57,9 +57,9 @@ "\n", " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", - " \n", + " \n", " def set_mask(self,mask):\n", - " \n", + " mask = torch.nn.Parameter(mask)\n", " self.mask.data = mask.data\n", " self.W_0.data = self.mask.data*self.W_0.data\n", "\n", @@ -77,12 +77,16 @@ "\n", "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", "\n", - "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:" + "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:\n", + "\n", + "\n", + "\n", + "**Answer: This MLP accomodates a \"mask\" object, which implements pruning by creating 0 entries in a respective linear layer's weight matrix**" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -109,12 +113,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 37, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.4186490774154663. Accuracy: 90.\n" + ] + } + ], "source": [ "iter = 0\n", - "for epoch in range(n_epochs):\n", + "for epoch in range(1):\n", " for i, (images, labels) in enumerate(train_loader):\n", " images = images.view(-1, 28 * 28).to(device)\n", " labels = labels.to(device)\n", @@ -152,9 +164,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVtklEQVR4nO3dW4zc5XkG8Oed0x5mDz6tvWtjTi4hcUEQtEVRSCOqKBHhAshFqnARUSmqc5FIiZSLInIRbiqhqkmaiyqSU1AISokiJVG4QG0QQiG5aMSaOtjUNHaogbUXr4979Oyc3l7sgBbY7/mWOeyM/T0/abW78843881/5p3/7rzfwdwdInL1y3S7AyKyOZTsIolQsoskQskukgglu0gicpt5Z9mhoue2b93MuxRJSvX8RdQWl2y9WEvJbmb3APgBgCyAf3P3x9j1c9u3Yvzb32jlLkWEePsffxCMNf1nvJllAfwrgM8D2A/gQTPb3+ztiUhntfI/+50ATrj76+5eBvAzAPe3p1si0m6tJPseAG+t+X26cdl7mNkBM5sys6na4lILdycirWgl2df7EOADY2/d/aC7T7r7ZHao2MLdiUgrWkn2aQB71/x+DYDTrXVHRDqllWR/CcBNZnaDmRUAfAnAM+3ploi0W9OlN3evmtnXAfwnVktvT7j7q23r2dUkV+fxdauibRK77UrkCrH2sUmTrH29kw8cQCbcucxAlTatVyPnwXIk7h1+bE1oqc7u7s8CeLZNfRGRDtJwWZFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSsanz2a9a2UixORI3Ug8GgIGhFRrPZcN1/EKuRtvWI133SL24XOUvoWwm3LfLpTxtW69lI/FILZs8tvoKv23L87ERHnvOWx2/0AE6s4skQskukgglu0gilOwiiVCyiyRCyS6SCJXeNspIqSVShsn28fJXocCnW948Nkvj2/uWg7FdffO07bYcXyqsL1Oh8YtVvvpQqR4urw1nS7RtzHK9QOPHFsaDsRMXdtC2c3ODNB7dDjVWWmPl1mhZr7lztM7sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiHTq7K0seRyJZ/t5nTyf53X2j4ydpfGx/kUaPzE/FowdPvuBHbnew9j4AQDXjlyk8e19vE6/u28uGIvVyQ9dvJbGXz1+DY0PnAzX+Isz/HEPD/EXxPI4b1/dXabxXH94/EJloY+2bZbO7CKJULKLJELJLpIIJbtIIpTsIolQsoskQskukoh06uyxlX1jS/+SG6hl+WEcHQ7PNweAcp0vaxxz7VC4Fr7Uz2vZo3k+p/wjxbdp/Oa+GRr/aCE8hiAfGfwwFJnv/ualLTReObU1GPPIIc8v8L4VI8tYLxT4Mtm1XWSp6tgS2U1qKdnN7CSABQA1AFV3n2xHp0Sk/dpxZv8bdz/XhtsRkQ7S/+wiiWg12R3Ab8zskJkdWO8KZnbAzKbMbKq2yMdRi0jntPpn/F3uftrMdgJ4zsxec/cX117B3Q8COAgAfdddE12nT0Q6o6Uzu7ufbnyfBfArAHe2o1Mi0n5NJ7uZFc1s+J2fAXwOwNF2dUxE2quVP+N3AfiVmb1zO//u7v/Rll51QKyOnlvg73ts2nesRF+JbD1cyPD57guVfhrPkM6Va/wpnqmO0PhI7jKN30jq6ABwZGUiGHu9HJ6HDwAvnL2ZxkuRLZ+rY+HjWim29nFVrcjXMOgf48etkA+3n1/ij6tZTSe7u78O4LY29kVEOkilN5FEKNlFEqFkF0mEkl0kEUp2kUT01hTXyLLGqIdrXFbl9a/8HH9fK8zz9jWyum8lsuxwbMLiUpVPQ42pevixXSoN0LbnF/iWy6fneWnud9l9/PYvDAVj2WleUjRekYy+en0LuYHdfPrsSGRa8m07T9P4tQMXaPzV+XBJ8tCZ8DEDAHhzU2B1ZhdJhJJdJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUT0Vp098tbDpqlml3jjvou8NlmPHInKEBkDMBrefhcAyhV+45noftLcYjk8CODC4iBtuzLL47UFXvMtXOLHdcuF8GMbPEeWUwZQjkxDvcRnwMIGwtNId++4RNv+1Y43aPyTwydofCw3T+MVsnz44aG9tG1tobkpsDqziyRCyS6SCCW7SCKU7CKJULKLJELJLpIIJbtIInqrzh7ZqtbK4XiuxNtW+dRp1AZ5rbu6LVyzHRrlywaPDPC500sVPp89Fp8vhevsK5d5TTa2hPbIn2kYo6+v0HjhXHheuBf4y8/28Ln4sX2X+4vlYOxjW/lW1PsH+Xz12/t4PHYWHcqGj1s2xyfy18vk9UDmuuvMLpIIJbtIIpTsIolQsoskQskukgglu0gilOwiieipOrvF6uwk7hleJ3ey7jsA1CNThHPF8Jz1fJbXRS+X+Y3HtnSu1SJz9cn2v8VhXuMvOR+AsPU13j5/7E0ah5G+T2ynTctFvqZ9Ky6W+Tz+5Tp/wWzLtHaefKO0LRirVvjrASxMUijaYzN7wsxmzezomsu2mdlzZna88X1r7HZEpLs28vb0YwD3vO+yhwE87+43AXi+8buI9LBosrv7iwDev5fN/QCebPz8JIAH2twvEWmzZv/x2OXuMwDQ+L4zdEUzO2BmU2Y2VVtcavLuRKRVHf803t0Puvuku09mhzr3gYuIcM0m+xkzmwCAxvfZ9nVJRDqh2WR/BsBDjZ8fAvDr9nRHRDolWmc3s6cB3A1gh5lNA/gOgMcA/NzMvgLgTQBfbEdn2LrwMZkV3jYy9RmZMq/TlxfDtfJLZX7jTvaVB4BMntfpcwUev2Hn+WDs1lE+7/qpU5+i8cJx3r56LnzfAJAb3xWMlbdH1qwvRMZdRJbbv3wufPvHsuF+AcB4P1/3/eX+aRpfitTp31oKV6u9ys/BnmUPPByLJru7PxgIfSbWVkR6h4bLiiRCyS6SCCW7SCKU7CKJULKLJKKnprh6PlJLIWHP8zJNnldSUBni7ftPh0tvxitjUZUif9zF/e+fmvBe9+38YzC2v+8UbfsUeOmtvhReChoAsltGabyybyIYW9jLy1Oxacd9FyLLhw+Gz2X1Md722Nw4jQ9kb6XxOlnSGQBOzZHjFqspNklndpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSoWQXSURv1dnp1D0AZKfa6gBvO3AmctNzPO7kSFl4JedVkbfU5XFek906yLeEvrEQXjvk+lx422IAGJjm03MzoyM0Xt/B6+znbwlvu1zPRsY2XKrTeH6JP+eVYvjAL4/y6bVz/fy4Vev8Sb1c59ts0543P9Ob0pldJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUS0VN19hi6LXOkRF+PzHcfOM9rulYP34HxpqgMRJa5Nh7/i5GzND5ZCNeE/1ThT/HA2cgS2jcEd/YCAJTGeD25tC382Ebe4Adu6K0VGq8X+LlqeReZLx95zhaW+VbWxxf5cSlk+OCLUik8Wd8jW5c3S2d2kUQo2UUSoWQXSYSSXSQRSnaRRCjZRRKhZBdJxBVVZze29XHkbavGy6YoR9aNz5Pl07MrvFYdq/GXdvOa7FhhkcZ/XyoGY//y1mdp22pkDEBlmL9Eqv28fY4ct8EzFd72Ip/Hv7SPz7Vf2U6el8geBZkML8SvVPlxGYzMh+/vDz/2apnfdrOrykfP7Gb2hJnNmtnRNZc9amanzOxw4+veJu9fRDbJRv6M/zGAe9a5/Pvufnvj69n2dktE2i2a7O7+IgC+/5CI9LxWPqD7upm90vgzf2voSmZ2wMymzGyqtrjUwt2JSCuaTfYfAtgH4HYAMwC+G7qiux9090l3n8wOhT9IEpHOairZ3f2Mu9fcvQ7gRwDubG+3RKTdmkp2M1u7D+8XABwNXVdEekO0zm5mTwO4G8AOM5sG8B0Ad5vZ7Vgt+Z0E8NUO9vFdbD57dYRvkl5yvj56rY/Xi43MMc7xcjAqkf9exq87T+P9GV6Pnlq+MRg78fYYbTtSilRtY+sE5CLHjbRfGucbsM/dGPwoCABw4RbeucJEeHzCnhH++dF4cZ7G9w2do/HY/uwXh8Pr1v8fmesOAJXl5obHRFu5+4PrXPx4U/cmIl2j4bIiiVCyiyRCyS6SCCW7SCKU7CKJuKKmuNIyUB+fkljbxUtzmet5eWtslE8zZe7Y8RaPD71B4x8tzND4koeXc/7R0l/TtrnlyFTPWmT6buQVtLQ73H7xEyXadt8uXt66rXiJxgey4ed0bz+f7jGRv0jjWVZTBHC2OkzjR2x3MFZZ5stzI0de683PAheRq4WSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEXFl1dsKyke1/R/k81E/veZ3G79v6cjBWzPCthYvGa/ijkSmsu3Nk62EAy/WFYKwwzJc0Lo/w6ZSxl8j8Dbz1njvCYwTu2/0Kbbstx8c2ZCL7Llc83Peb+t6mbW+K3Dd/xoD/Ku2h8fPLZN5zuYVzMCn/68wukgglu0gilOwiiVCyiyRCyS6SCCW7SCKU7CKJuLLq7GSurkeW1/URvrTvWCFcqwaArIVrurE6+nCkjr4lw99z+4zXwmtkbvWtu0/TtkfHP0Lj5dHIls57+RiDT46Fxy/E6uhbsmS/ZwDlyPLg2zPh2y8aH3/AVz8ATlaHaPzFuY/S+PIKeU77Ive+wh93iM7sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiKumzh7dWrjO68XTJb498MvZ64OxsRyv0cfmXcfqyfH58uG663g/79t/38jn+a8s8Rr/8Bbe9+nSlmCsFjnX/OXANI2P5+ZovOThvp+s7KBtT1WrNH5iZZzGj8/zrbLLrM4eea02K3pmN7O9ZvaCmR0zs1fN7BuNy7eZ2XNmdrzxnWeLiHTVRv6MrwL4lrt/DMAnAHzNzPYDeBjA8+5+E4DnG7+LSI+KJru7z7j7y42fFwAcA7AHwP0Anmxc7UkAD3SqkyLSug/1AZ2ZXQ/g4wD+AGCXu88Aq28IAHYG2hwwsykzm6otLrXWWxFp2oaT3cyGAPwCwDfdfX6j7dz9oLtPuvtkdogssiciHbWhZDezPFYT/afu/svGxWfMbKIRnwAw25kuikg7REtvZmYAHgdwzN2/tyb0DICHADzW+P7rjvRwg6zGyxXL5wZp/LeVfTR+ZGQiGNtZ5FM168771k+2FgaAwRyfjnnHyJvB2MeKfIpr335+34fOX0vjs/N8qudrF3YFY6VRXtabKPAtmVsxV2vtr8ypueto/OS5bTReLZHUiy0l3WRlbiN19rsAfBnAETM73LjsEawm+c/N7CsA3gTwxea6ICKbIZrs7v57hN9LPtPe7ohIp2i4rEgilOwiiVCyiyRCyS6SCCW7SCKurCmuhFV58TG7xB9qbprXfM+T0X+zA9tp29hbqg/wpYO37OB1/In+8IDGz4/wbZHHcnww5OGL19D45QsDPO7h+MU5XuuOjT8YK/DjMpQLTw1+e2WEtp25PErj03M8XqvxJ51uMV7hS0V7ITKfO0BndpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSoWQXScRVU2ev9/HlmvPzvHZZnOa1S7IrMmLvmeXIdtEr23jfLlV4TfjIyO5grD+yXXSf8SWTK7XI9sCRdQQK58PtM6f52IbfLd5M49ki7ztTj4zLQGQNAmT468Ujx4XKN1dHj9GZXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEnHV1Nlj68Z75G0tf5nXNgfOhuvVlWFei64MRQ4zHyIAW+Gdn5kP1+EPGV/3fVdkS+edgzy+dA2vlV8shteVz8wWaNv8BX7cMjP8vhnP8ue7soU/KZ6L1MIjt0/HJ3Rmx2ad2UVSoWQXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBEb2Z99L4CfABjHakX4oLv/wMweBfD3AM42rvqIuz/bqY5GRcqa9cgjXdjL3/eq/eGacP8lvu77yEk+77o4wwurpa28b7Xj4b3ATznfJ/yNyDbltUgpu85L5ehjz0ts2nYk7pGp9nUyL7ze5Nrr78pHBkfEtDLfvUkbGVRTBfAtd3/ZzIYBHDKz5xqx77v7P3eueyLSLhvZn30GwEzj5wUzOwZgT6c7JiLt9aH+Zzez6wF8HMAfGhd93cxeMbMnzGxroM0BM5sys6na4lJLnRWR5m042c1sCMAvAHzT3ecB/BDAPgC3Y/XM/9312rn7QXefdPfJLNkvTUQ6a0PJbmZ5rCb6T939lwDg7mfcvebudQA/AnBn57opIq2KJruZGYDHARxz9++tuXxizdW+AOBo+7snIu2ykU/j7wLwZQBHzOxw47JHADxoZrdjtUByEsBXO9LDDfLI8ru1IV4qWRrkt18aC78vZi9HpmKW+W1nwzsLr7bnq0HTKZHZFX5c+s/SMLIV3r7ax0tIleFwvMp3e0Z1MPKcDvA4La/FTnOxyljlyhuispFP43+P9R9692rqIvKhXXlvTyLSFCW7SCKU7CKJULKLJELJLpIIJbtIIq6apaRjokv/RlQL4WmsVb6jclyka1aPFH3ZDNvI1sN8K2ogsqNzdElmz5LxDbFadmxX5dhyzSy8+TNMu05ndpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSoWQXSYS5t7ik7oe5M7OzAN5Yc9EOAOc2rQMfTq/2rVf7BahvzWpn365z97H1Apua7B+4c7Mpd5/sWgeIXu1br/YLUN+atVl905/xIolQsoskotvJfrDL98/0at96tV+A+tasTelbV/9nF5HN0+0zu4hsEiW7SCK6kuxmdo+Z/a+ZnTCzh7vRhxAzO2lmR8zssJlNdbkvT5jZrJkdXXPZNjN7zsyON76vu8del/r2qJmdahy7w2Z2b5f6ttfMXjCzY2b2qpl9o3F5V48d6demHLdN/5/dzLIA/gTgswCmAbwE4EF3/59N7UiAmZ0EMOnuXR+AYWafBrAI4Cfufkvjsn8CcMHdH2u8UW5193/okb49CmCx29t4N3Yrmli7zTiABwD8Hbp47Ei//habcNy6cWa/E8AJd3/d3csAfgbg/i70o+e5+4sALrzv4vsBPNn4+Umsvlg2XaBvPcHdZ9z95cbPCwDe2Wa8q8eO9GtTdCPZ9wB4a83v0+it/d4dwG/M7JCZHeh2Z9axy91ngNUXD4CdXe7P+0W38d5M79tmvGeOXTPbn7eqG8m+3upfvVT/u8vd7wDweQBfa/y5KhuzoW28N8s624z3hGa3P29VN5J9GsDeNb9fA+B0F/qxLnc/3fg+C+BX6L2tqM+8s4Nu4/tsl/vzrl7axnu9bcbRA8eum9ufdyPZXwJwk5ndYGYFAF8C8EwX+vEBZlZsfHACMysC+Bx6byvqZwA81Pj5IQC/7mJf3qNXtvEObTOOLh+7rm9/7u6b/gXgXqx+Iv9nAN/uRh8C/boRwB8bX692u28Ansbqn3UVrP5F9BUA2wE8D+B44/u2HurbUwCOAHgFq4k10aW+fQqr/xq+AuBw4+vebh870q9NOW4aLiuSCI2gE0mEkl0kEUp2kUQo2UUSoWQXSYSSXSQRSnaRRPw/6mByfzknyV4AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "import matplotlib.pyplot as plt\n", "W_0 = model.W_0.detach().cpu().numpy()\n", @@ -168,33 +193,58 @@ "source": [ "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", "\n", + "** Based on the above image, It could be reasonable to prune the weights along the margins **\n", "\n", - "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array. " + "\n", + "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ - "new_mask = model.mask" + "import numpy as np\n", + "def mask(mat):\n", + " mat = np.abs(mat)\n", + " std = np.std(mat)\n", + " mean = np.mean(mat)\n", + " threshold = np.abs(mean - 2.5*std)\n", + " print(\"Mean: {}, STD: {}, Threshold: {}\\n\".format(mean,std,threshold))\n", + " mask = np.zeros_like(mat)\n", + " mask[mat>threshold] = 1\n", + " print(\"Non-Null Weight Ratio : {}\".format( (np.count_nonzero(mask==0)/(np.size(mask)))*100 ))\n", + " return mask" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 41, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean: 0.03005061112344265, STD: 0.032245710492134094, Threshold: 0.050563665106892586\n", + "\n", + "Non-Null Weight Ratio : 75.25310905612244\n" + ] + } + ], "source": [ - "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" + "W0_mask = mask(model.W_0.detach().cpu().numpy())\n", + "W0_mask = torch.Tensor(W0_mask)\n", + "W0_mask = W0_mask.to(device)\n", + "model.set_mask(W0_mask)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "model.set_mask(new_mask)" + "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" ] }, { @@ -206,9 +256,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.28309938311576843. Accuracy: 91.\n", + "Iteration: 0. Loss: 0.27462732791900635. Accuracy: 92.\n", + "Iteration: 0. Loss: 0.20802313089370728. Accuracy: 92.\n", + "Iteration: 0. Loss: 0.2611789107322693. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.26243436336517334. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.22998295724391937. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.23924028873443604. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.19653643667697906. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.1895284205675125. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.19546085596084595. Accuracy: 94.\n" + ] + } + ], "source": [ "iter = 0\n", "for epoch in range(n_epochs):\n", @@ -244,17 +311,36 @@ "source": [ "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", "\n", - "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?" + "**I gained 2% accuracy after pruning about 86% of the weights.** \n", + "\n", + "\n", + "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?\n", + "\n", + "**After removing about 94% of the weights the accuracy began to degrade.**" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 43, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAD4CAYAAABxC1oQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAdAklEQVR4nO3dfZBdd33f8fdnV9LKlmxZsrBRJRnLIFpMk5iiGFoXQoIflEwHwQwudtOMmDEVZPA0NKGDA1PsMcOMCAWHmVKCsBVMBjDGhqBhVBRhoJA2OJKJakt2jIQQ8lpCQlrZlmQ97MO3f9yzzt378Ltn9969T+fzmjmz95zfefjt3Xu/+/ud83tQRGBmVgQDnc6AmVm7OOCZWWE44JlZYTjgmVlhOOCZWWHMaefFBhcsiLmXLGnnJc0KZfS5EcZPn1Yz57jptxfE8ZHxXPs+9vi5bRGxtpnrtVNTAU/SWuAzwCBwb0RsTO0/95IlXPGH/6WZS5pZwsHP3dP0OY6PjPP3267Ite/gsr1Lm75gG8044EkaBD4L3AAMAzskbYmIJ1uVOTNrvwAmmOh0NmZFMyW8a4F9EbEfQNIDwDrAAc+shwXBaOSr0vaaZgLecuCZsvVh4A2VO0naAGwAmLNocROXM7N2cQmvWq0bo1X91CJiE7AJYP7yle7HZtblgmC8T7ucNhPwhoGVZesrgEPNZcfMusFEddmlLzQT8HYAqyWtAp4FbgH+Q0tyZWYdE8C4A95UETEm6XZgG6VmKZsjYk/LcmZmHeMSXg0RsRXY2qK8mFkXCGDU9/DMrAiCcJXWzAoiYLw/450DnplNVepp0Z88WoqZVRDjOZdcZ5PWSnpa0j5Jd9RI/2NJT0p6XNIjkl5RlrZe0t5sWd/sb+YSnplNUXpo0dSAKy/J2ef+H4A1EfGipD8E/gx4l6QlwJ3Amixbj2XHnphpflzCM7MpSu3wWlbCe6nPfUScByb73P/T9SK+HxEvZqs/ptSJAeAmYHtEjGRBbjvQ1FBULuGZWZWJ/CW8pZJ2lq1vyrqTTsrV577MbcD/Shy7PG/GanHAM7MpJkt4OR2LiDWJ9Fx97gEk/UdK1dffmu6xeblKa2ZTBGKcgVxLDrn63Eu6HvgI8LaIODedY6fDAc/MqkyEci05vNTnXtI8Sn3ut5TvIOl1wOcpBbujZUnbgBslLZa0GLgx2zZjrtKa2RSBOB+DrTlXnT73ku4GdkbEFuCTwELg65IADkbE2yJiRNLHKAVNgLsjYqSZ/DjgmdkUpYbHrav81epzHxEfLXt9feLYzcDmVuXFAc/MqkzjoUVPccDrA6OrztZNmziT/hNrXpOdiF5In3/wxfolhYl56QdujWpVapD1uSf780s72yLEePTn7X0HPDOrMuESnpkVQemhRX+Ghv78rcxsxlr90KKbOOCZWZXxFg0e0G0c8MxsismeFv3IAc/Mqkz4Ka2ZFUFp8AAHPOuQiVefTqbP/emCNuWkmibS93rGE23tYk66Hd6852bvSzcxN50+MDprl+56gRhtUdeybuOAZ2ZTROCGx2ZWFHLDYzMrhsAlPDMrED+0MLNCCHIP7tlzHPDMbIrSNI39GRr687cysybkn2S71zjg9YDR54eS6QNX1R8Pb+7++cljzy1PNzibf3BeMn3BoXRbutEF9b84JxcmD23Y/nCgifaHjdrZDZ5p0L7wgqYmz+pqQWt7WkhaC3yG0hDv90bExor0NwN/Dvw6cEtEPFSWNg48ka0ejIi3NZOXpgKepAPASWAcGGswXZuZ9YhWlfAkDQKfBW6gNAvZDklbIuLJst0OAu8GPljjFGci4pqWZIbWlPB+OyKOteA8ZtYFItTKEt61wL6I2A8g6QFgHfBSwIuIA1lak8NvN9afz57NbMZKDy0Gcy3AUkk7y5YNFadbDjxTtj6cbctrfnbeH0t6e3O/WfMlvAD+RlIAn4+ITZU7ZG/ABoA5ixY3eTkzm33TmtPiWINbWbXqxtO5AXpFRBySdBXwPUlPRMTPpnH8FM0GvOuyzFwGbJf0jxHxw/IdsiC4CWD+8pX9e6fXrE+UHlq07CntMLCybH0FcCh3XiIOZT/3S/oB8DpgxgGvqSptWWaOAt+kVF83sx43zkCuJYcdwGpJqyTNA24BtuQ5UNJiSUPZ66XAdZTd+5uJGQc8SQskXTT5GrgR2N1MZsys8yZ7WuRZGp4rYgy4HdgGPAU8GBF7JN0t6W0Akn5T0jBwM/B5SXuyw18D7JT0/4DvAxsrnu5OWzNV2suBb0qaPM9XIuI7zWTGahs60ujPNPM/49Cz6YHhzq06l0wffXX6wdrAM/XbAS4YTo+5Nvpc58b56+d2dnm0chKfiNgKbK3Y9tGy1zsoVXUrj/u/wK+1LCM08U3JHjP/RgvzYmZdIAJGJ/qzAYd7WpjZFKUqrQOemRWE+9KaWSG0uFlKV3HAM7MKrtKaWYF4TgsrpHkNhodqxujC3m36cWGDvgKnV6bTNd66vLRa6Smtp2k0swLwEO9mViiu0ppZIfgprZkVip/SmlkhRIgxBzwzKwpXac2sEHwPz2wWLNqXTr/w6Fgy/fC/mb2P75In020Ej/9aOiCMLmowbNaZ+lXGOWeSh7aFA56ZFYLb4ZlZofRrO7z+fBRjZjMWAWMTA7mWPCStlfS0pH2S7qiR/mZJP5E0JumdFWnrJe3NlvXN/m4u4ZlZlVZVaSUNAp8FbqA0g9kOSVsq5qY4CLwb+GDFsUuAO4E1lJ6lPJYde2Km+XEJz8ymaOUkPpRmMtwXEfsj4jzwALBuyvUiDkTE40Dlk56bgO0RMZIFue3A2mZ+N5fwzKxK5C/hLZW0s2x9UzYX9aTlwDNl68PAG3Keu9axy/NmrBYHPDOrMo2HFsciYk0ivdaJ8o4L1syxNTngWZJecyqZPvTDi5Lpg+frfz5HL0xf+/lVs9jObk/6e3N2SYN2dpek29nF/PSAdxOpwxNt9NohoqXt8IaB8tEBVwANRhOccuxbKo79QTOZ8T08M6sgxicGci057ABWS1olaR5wC7AlZ0a2ATdKWixpMXBjtm3GHPDMrEqEci2NzxNjwO2UAtVTwIMRsUfS3ZLeBiDpNyUNAzcDn5e0Jzt2BPgYpaC5A7g72zZjrtKa2RSt7ksbEVuBrRXbPlr2egel6mqtYzcDm1uVFwc8M5sqSvfx+pEDnplV6deuZQ54ZjZFZA8t+pEDnplVcZXWrIb5J9Lt0ea+WP+b88IV6blPT13Z4NwvzLzaNfLa9LEXHUgfP+fS5gatmxhZ0NTxs20aPS16SsNyq6TNko5K2l22bYmk7dkIBtuzNjJm1gciWtcspdvkqah/keoOu3cAj0TEauCRbN3M+kQLBw/oKg0DXkT8EKhs7LcOuD97fT/w9hbny8w6KCLf0mtmeg/v8og4DBARhyVdVm9HSRuADQBzFrnma9btAjHRp09pZ/23iohNEbEmItYMLujuG7VmVhI5l14z04B3RNIygOzn0dZlycw6quAPLWrZAkyOL78e+FZrsmNmXaFPi3gN7+FJ+iqlMamWZiMa3AlsBB6UdBul8ehvns1M2sxd8Mv0f+EzL09/auf+XXq8u8WPpwevOH3VxXXTIt0Mj3kj6f/HMWf2vnGN8jYxnt7hooXpdnpnznT37Z1eLL3l0TDgRcStdZLe2uK8mFkXCGBioqABz8wKJoCilvDMrHh6sY1dHg54ZlatTwNef7YuNLMm5GuSkvfBhqS1kp6WtE9SVTdUSUOSvpalPyrpymz7lZLOSNqVLX/R7G/mEp6ZVWtRCU/SIPBZ4AZKs5DtkLQlIp4s2+024EREvErSLcAngHdlaT+LiGtakxsHvL5wydP1P50Tc9PHNhrk6PLHGuyRHsEJjdVPGzybPvb8os7Vq06tTKfH0aFk+nNn081W0kd3WEC07inttcC+iNgPIOkBSn3xywPeOuCu7PVDwP+QNCtPTVylNbMalHNhqaSdZcuGihMtB54pWx/OttXcJ5vl7Hng0ixtlaR/kPS/Jb2p2d/KJTwzq5a/cH0sItYk0muV1CrPXm+fw8AVEXFc0uuBv5b02oh4IXfuKriEZ2bVWte1bBgov0GwAjhUbx9Jc4BFwEhEnIuI4wAR8RjwM+DVM/p9Mg54ZjbVZMPjPEtjO4DVklZJmgfcQqkvfrnyvvnvBL4XESHpZdlDDyRdBawG9jfzq7lKa2ZVWtXwOCLGJN0ObAMGgc0RsUfS3cDOiNgC3Af8laR9lAYbviU7/M3A3ZLGgHHgfRGR7rzdgAOemVVrYV/aiNgKbK3Y9tGy12epMQBJRDwMPNyyjOCAZ2Y1qE97WjjgZQbPpv+jjc+v/wkYGkkfe27J7H56fvX6+mlDxxv9p07nbWIwfZv33MvTwxwduTbVHi197TkvpvM+MJpMntV2fBMLx5PpA0PpdGjQQLKTenSsuzwc8MysQu4HEj3HAc/MqrmEZ2aF0aDLYK9ywDOzqTwAqJkViZ/Smllx9GnAc9cyMysMl/AyFxxN/0t74VX100YXps+9YDidfnpFOv3cinSDs7kXnq+bNng4nbm5p9P3asYuTI/rNjA286JAozaCYxemj1/4bPrO+sii2bsPNXS40Vent79artKaWTEELe1a1k0c8Mysmkt4ZlYUrtKaWXE44JlZYTjgmVkRKFylNbMi8VPa/vbCVen00cX1J1iNK9Lt5M6smJdMHzqabuumk+k/0/zd9c9//pLm/lWPLky3Tdf4zM8/lh5Kr2F3ztMvb9Ruvk+LKW3QyhKepLXAZygN8X5vRGysSB8CvgS8HjgOvCsiDmRpf0ppou5x4D9HxLZm8tKwp4WkzZKOStpdtu0uSc9K2pUtv9dMJsysy7Ro1rJsEp7PAr8LXA3cKunqit1uA05ExKuAe4BPZMdeTWl+i9cCa4H/OTmpz0zl6Vr2xexile6JiGuyZWuNdDPrRfFP9/EaLTlcC+yLiP0RcR54AFhXsc864P7s9UPAWyUp2/5ANl3jz4F92flmrGHAi4gfUppJyMyKonXz0i4HnilbH8621dwnIsaA54FLcx47Lc0MHnC7pMezKu/iejtJ2iBpp6Sd46dPN3E5M2sXTeRbgKWT3+9s2VB5qhqnrwyV9fbJc+y0zDTgfQ54JXANcBj4VL0dI2JTRKyJiDWDCxrcpTazXnNs8vudLZsq0oeBlWXrK4BD9faRNAdYRKlWmefYaZlRwIuIIxExHhETwBdosl5tZl2mdVXaHcBqSaskzaP0EGJLxT5bgPXZ63cC34uIyLbfImlI0ipgNfD3TfxWM2uWImlZRBzOVt8B7E7tb2Y9pIUNjyNiTNLtwDZKzVI2R8QeSXcDOyNiC3Af8FeS9lEq2d2SHbtH0oPAk8AY8P6IaDT/ZVLDgCfpq8BbKNXVh4E7gbdIuoZSjD8AvLeZTHSDgfrN7AAYOpJ4q1JplP7KTRlIf/pOXVX/MzBvpLkxXhu1hTu3eObnP3dpejy7OQ3G6js3NONLN21i9YvJ9IULzibTX9y1pJXZab0WtsPLWnFsrdj20bLXZ4Gb6xz7ceDjrcpLw4AXEbfW2HxfqzJgZl2oT9tsu6eFmU0hXnoC23cc8MxsKg8eYGaF4oBnZoXhgGdmReEqrc3Ym256PJm+e+TlyfSL551Lpr9x6c/rpn3tW29OHjt2QTKZk1ekm52cW5z+ZgzWn0GyYbOTRs5dlm6SpbGZn39wSfo9b9Re5+Sp9BvbdFOl2eaAZ2aFEH5Ka2ZF4hKemRWF7+GZWXE44JlZIeQfCaXnOOCZ2RTCVVozKxAHvD531Zt+kUy/68pv1U27dmhu8tgfn023F3vsoiuT6dddsC+Zfjbq/xm/Rrod3pwzyWR4w/PJ5MHdFzc4QX3nlqXH5Bp8Pt1abf6l6cyfPTG/fuLcdLuLKy5LT+MydyD9N91/9NJketdzwDOzwnDAM7NC8GgpZlYofRrwmhv/28z60jSmaZz5NaQlkrZL2pv9rDndq6T12T57Ja0v2/4DSU9L2pUtlzW6pgOemVVR5FuadAfwSESsBh7J1qfmQ1pCaR6dN1CaHfHOisD4+xFxTbYcbXRBBzwzmyrvFI3NB7x1wP3Z6/uBt9fY5yZge0SMRMQJYDuwdqYXdMAzs2r5A95SSTvLlg3TuMrlk9O9Zj9rVUmXA8+UrQ9n2yb9ZVad/W+SGo4H5ocWmRdH5yXTv36i/lzjDzco288fGE2mj0f6/872iauT6SvnptuMNWOsiXZ2AANXn6yfNrwwfXCDj+/ZYw0G80u0tbtocXqaxVdd/Ktk+oGT6XZ2o8+n55Ds4AyTDU2zp8WxiFhT91zSd4FaAz5+ZBrZqTSZu9+PiGclXQQ8DPwB8KXUyRzwzKyKJlrzmDYirq97DemIpGURcVjSMqDWPbhhSvNiT1oB/CA797PZz5OSvkLpHl8y4LlKa2ZTte8e3hZg8qnreqBWd6ZtwI2SFmcPK24EtkmaI2kpgKS5wL8Ddje6oAOemVVp01PajcANkvYCN2TrSFoj6V6AiBgBPgbsyJa7s21DlALf48Au4FngC40u6CqtmVVrQ8PjiDgOvLXG9p3Ae8rWNwObK/Y5Dbx+utd0wDOzKu5aZmbF4YBnZoXgWcv636Hji5LpDw3Xv12gOQ0+Hc+nx8uLC9Jjqw1emB437uKF9ceFW/avDyWPPfx3/yyZ3qwzxy6sm3bp7gZzu74ife6JOelnbhdcdrpu2mCDOtvTz12eTD9+uv7vBTB0pHe/Wv084nHDp7SSVkr6vqSnJO2R9EfZ9lwdf82sB0XkW3pMnmYpY8CfRMRrgDcC75d0NTk6/ppZb2pTs5S2axjwIuJwRPwke30SeIpSX7Y8HX/NrNe0r+Fx203rRoOkK4HXAY9S0fG33lhUWWfiDQBzFrnWa9YL+vWhRe6eFpIWUuqg+4GIeCHvcRGxKSLWRMSawQULZpJHM2uzdgwA2gm5Al7WV+1h4MsR8Y1s85Gswy+Jjr9m1muCvn1o0bBKm40xdR/wVER8uixpsuPvRup3/O0ZA3sbNDOYzYufSP/fGTqRHrrq+VfUn47wxKL00FSvvu6ZZPov/s/KZHojQ0frT7V46ooGBzf6dzyY/sKdP1e/OdDoaHoKyKUL6zdpARh9It2Mqdf14gOJPPLcw7uO0jhTT0jalW37MKVA96Ck24CDwM2zk0Uza7uiBryI+FvqD8VY1fHXzHpbPzc87t3m4GY2OyJaNgBot3HAM7Nq/RnvHPDMrJqrtGZWDAH0aZXWQ7ybWbU2dC3LOwCJpO9Iek7Styu2r5L0aHb81ySl22/hEl5POLe4wScr0eJdA737n1rpUbPg4nQbw9cs/2XdtMvn158+EuBH2369wcX7W5uqtJMDkGyUdEe2/qEa+30SuBB4b8X2TwD3RMQDkv4CuA34XOqCLuGZWRVNRK6lSbkGIImIR4Ap/6GyDhG/AzzU6PhyLuGZ2VTTq64ulbSzbH1TRGzKeWyuAUjquBR4LiImR8cdpjSKU5IDnplNUWp4nDviHYuINXXPJX0XeHmNpI/MIGtTTl1jW8NMO+CZWbUWjYQSEdfXS5N0RNKyrHQ33QFIjgGXSJqTlfJWAOn5DPA9PDOrQRG5liZNDkAC0xyAJCIC+D7wzukc74BnZlO1b8TjjcANkvYCN2TrSFoj6d7JnST9CPg68FZJw5JuypI+BPyxpH2U7und1+iCrtKaWYX29KWNiOPUGIAkInYC7ylbf1Od4/cD107nmg54mVUPnUimTzz+j3XTBi5Mj6X38zt+Y0Z5ymvuqcR0hyfrj5UHsDfS0xE2bMnZQfMOpEcp/OmBVfXTWp2ZftODg3vm4YBnZlN5Im4zKxSX8MysMPoz3jngmVk1TfRnndYBz8ymClrW8LjbOOCZ2RSiJY2Ku5IDnplVc8Drb+ML02265vzzV9VNG1u6sNXZaZ0GA5utWDaSTD96sFa/b+t7DnhmVgi+h2dmReKntGZWEOEqrZkVROCAZ2YF0p81Wgc8M6vmdnhmVhxFDXiSVgJfojQRxwSlWYk+I+ku4D8Bv8p2/XBEbJ2tjM62g2vTY9rN/9WCumlzzqY/HJfuTqePzU+MZwecWp5OH79g5h/Oo4+6nZ1ViIDx/qzT5hnifQz4k4h4DfBG4P2Srs7S7omIa7KlZ4OdmVWIyLc0QdISSdsl7c1+Lq6z33ckPSfp2xXbvyjp55J2Zcs1ja7ZMOBFxOGI+En2+iTwFDnmfzSzHtaGgAfcATwSEauBR7L1Wj4J/EGdtP9aVuja1eiC05rER9KVwOuAR7NNt0t6XNLmRHTeIGmnpJ3jp09P53Jm1gkBTES+pTnrgPuz1/cDb6+ZnYhHgJPNXgymEfAkLQQeBj4QES8AnwNeCVwDHAY+Veu4iNgUEWsiYs3ggvr3wcysWwTERL4Flk4WaLJlwzQudHlEHIZSTRK4bAaZ/XhW6LpHUrpDPDmf0kqaSynYfTkivpFl8EhZ+heAb9c53Mx6STCdhxbHImJNvURJ36X0wLPSR2aQs0p/CvyS0lxTmyhN23h36oA8T2lFab7HpyLi02Xbl01GZ+AdwO4ZZtrMuk2LmqVExPX10iQdmYwjkpYBR6d57sn4c07SXwIfbHRMnhLedZRuGD4hafKm4IeBW7OnIgEcAN47ncz2mrMvm/kH4NTKdLOSxmavTdREg0+AxtPpg+fTv9vE3P5sz9X32tMObwuwntIE3OuBb03n4LJgKUr3/xoWuhoGvIj4W6DWp9rNUMz6UtsGD9gIPCjpNuAgcDOApDXA+yLiPdn6j4B/ASyUNAzcFhHbgC9Lehml+LQLeF+jC7qnhZlNFUAbhoeKiOPAW2ts3wm8p2z9TXWO/53pXtMBz8yqFbVrmZkVTf92LXPAM7OpAiIc8MysKJrvRdGVHPDMrJrv4Vk/Ghhr7ni3s+tDEW15StsJDnhmVs0lPDMrhiDGG3Sx6VEOeGY21eTwUH3IAc/MqrlZipkVQQDhEp6ZFUKES3hmVhz9+tBC0cbHz5J+BfyibNNS4FjbMjA93Zq3bs0XOG8z1cq8vSIiXtbMCSR9h1Ke8jgWEWubuV47tTXgVV1c2pkaHrqTujVv3ZovcN5mqpvz1m+mNWuZmVkvc8Azs8LodMDb1OHrp3Rr3ro1X+C8zVQ3562vdPQenplZO3W6hGdm1jYOeGZWGB0JeJLWSnpa0j5Jd3QiD/VIOiDpCUm7JO3scF42SzoqaXfZtiWStkvam/1c3EV5u0vSs9l7t0vS73UobyslfV/SU5L2SPqjbHtH37tEvrrifSuCtt/DkzQI/BS4ARgGdgC3RsSTbc1IHZIOAGsiouONVCW9GTgFfCki/mW27c+AkYjYmP2zWBwRH+qSvN0FnIqI/97u/FTkbRmwLCJ+Iuki4DFKEzW/mw6+d4l8/Xu64H0rgk6U8K4F9kXE/og4DzwArOtAPrpeRPwQGKnYvA64P3t9P6UvTNvVyVtXiIjDEfGT7PVJ4ClgOR1+7xL5sjbpRMBbDjxTtj5Md/3RA/gbSY9J2tDpzNRweUQchtIXCLisw/mpdLukx7Mqb0eq2+UkXQm8DniULnrvKvIFXfa+9atOBDzV2NZNbWOui4h/Bfwu8P6s6mb5fA54JXANcBj4VCczI2kh8DDwgYh4oZN5KVcjX131vvWzTgS8YWBl2foK4FAH8lFTRBzKfh4FvkmpCt5NjmT3gibvCR3tcH5eEhFHImI8SpOafoEOvneS5lIKKl+OiG9kmzv+3tXKVze9b/2uEwFvB7Ba0ipJ84BbgC0dyEcVSQuym8lIWgDcCOxOH9V2W4D12ev1wLc6mJcpJoNJ5h106L2TJOA+4KmI+HRZUkffu3r56pb3rQg60tMie+z+58AgsDkiPt72TNQg6SpKpToojRX4lU7mTdJXgbdQGqrnCHAn8NfAg8AVwEHg5oho+8ODOnl7C6VqWQAHgPdO3jNrc97+LfAj4AlgciTLD1O6X9ax9y6Rr1vpgvetCNy1zMwKwz0tzKwwHPDMrDAc8MysMBzwzKwwHPDMrDAc8MysMBzwzKww/j/8/keh/P7k4gAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "W_0 = model.W_0.detach().cpu().numpy()\n", "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.colorbar()\n", "plt.show()" ] }, @@ -269,7 +355,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -303,10 +389,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ + "from sklearn.cluster import KMeans\n", "# convert weight and mask matrices into numpy arrays\n", "W_0 = model.W_0.detach().cpu().numpy()\n", "mask = model.mask.detach().cpu().numpy()\n", @@ -314,13 +401,34 @@ "# Figure out the indices of non-zero entries \n", "inds = np.where(mask!=0)\n", "# Figure out the values of non-zero entries\n", - "vals = W_0[inds]\n", - "\n", - "### TODO: perform clustering on vals\n", - "\n", - "### TODO: turn the label matrix and centroids into a torch tensor\n", - "labels = torch.tensor(...,dtype=torch.long,device=device)\n", - "centroids = torch.tensor(...,device=device)" + "vals = np.expand_dims(W_0[inds],1)\n", + "num_cluster = 2\n", + "km = KMeans(n_clusters=num_cluster).fit(vals)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([784, 64])" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "labels = np.zeros_like(W_0)\n", + "labels[inds] = km.labels_\n", + "labels = torch.tensor(labels,dtype=torch.long,device=device).squeeze()\n", + "centers = km.cluster_centers_\n", + "centroids = torch.tensor(centers,device=device).squeeze()\n", + "labels.shape" ] }, { @@ -332,12 +440,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,model.mask,labels,centroids)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ - "# Instantiate quantized model\n", - "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,new_mask,labels,centroids)\n", "model_q = model_q.to(device)\n", "\n", "# Copy pre-trained weights from unquantized model for non-quantized layers\n", @@ -355,9 +470,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.2916867434978485. Accuracy: 92.\n", + "Iteration: 0. Loss: 0.205495223402977. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.20868274569511414. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.261547327041626. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.23044338822364807. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.22543968260288239. Accuracy: 93.\n", + "Iteration: 0. Loss: 0.23166033625602722. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.24529825150966644. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.22417983412742615. Accuracy: 94.\n", + "Iteration: 0. Loss: 0.21196895837783813. Accuracy: 94.\n" + ] + } + ], "source": [ "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", "iter = 0\n", @@ -397,9 +529,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAALMElEQVR4nO3dT4ic9R3H8c+nq/awetjUJmxjqFZzqBQayxKElGKRSswlerCYg6QgrAUFRQ8Ve9BjKDW2h2JZazAtVhFUzCG0hiAED4qrpPlj2iZKqmuWbCUF415s1m8P+6SscWZnnOd55nk23/cLlpl5Znbny+g7z+z8ZvZxRAjAxe9rTQ8AYDiIHUiC2IEkiB1IgtiBJC4Z5p2NXD4al4ytGuZdAqmc+88ZLXw6707XlYrd9mZJv5U0IukPEbFjudtfMrZK33rogTJ3CWAZpx7/TdfrBn4ab3tE0u8k3SrpeknbbF8/6M8DUK8yv7NvlHQiIt6PiM8kPS9pazVjAahamdjXSvpwyeWZYtsX2J60PW17emF+vsTdASijTOydXgT40ntvI2IqIiYiYmJkdLTE3QEoo0zsM5LWLbl8laRT5cYBUJcysb8lab3ta2xfJulOSXuqGQtA1QZeeouIc7bvk/RXLS697YqIo5VNBqBSpdbZI2KvpL0VzQKgRrxdFkiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSRKHbLZ9klJZyUtSDoXERNVDAWgeqViL/w4Ij6u4OcAqBFP44EkysYekl61/bbtyU43sD1pe9r29ML8fMm7AzCosk/jN0XEKdurJe2z/feIOLD0BhExJWlKkr6+bl2UvD8AAyq1Z4+IU8XpnKSXJW2sYigA1Rs4dtujtq84f17SLZKOVDUYgGqVeRq/RtLLts//nD9HxF8qmQpA5QaOPSLel/T9CmcBUCOW3oAkiB1IgtiBJIgdSILYgSSq+CAMLmLXPfjGstef2HnjkCZBWezZgSSIHUiC2IEkiB1IgtiBJIgdSILYgSRYZ8eyWEe/eLBnB5IgdiAJYgeSIHYgCWIHkiB2IAliB5IgdiAJYgeSIHYgCWIHkiB2IAliB5IgdiAJYgeS4PPsFXjvzt8ve/21z/+81PfXqddsuHj03LPb3mV7zvaRJdtW2d5n+3hxOlbvmADK6udp/DOSNl+w7WFJ+yNivaT9xWUALdYz9og4IOnMBZu3StpdnN8t6baK5wJQsUFfoFsTEbOSVJyu7nZD25O2p21PL8zPD3h3AMqq/dX4iJiKiImImBgZHa377gB0MWjsp22PS1JxOlfdSADqMGjseyRtL85vl/RKNeMAqEvPdXbbz0m6SdKVtmckPSpph6QXbN8t6QNJd9Q5ZHZNroWv5PcA1Dn7Snx/Qs/YI2Jbl6turngWADXi7bJAEsQOJEHsQBLEDiRB7EASfMS1AmWXYVbiMk4bNLksWPZjzU1gzw4kQexAEsQOJEHsQBLEDiRB7EASxA4kwTo7Vqwyf6K7jevgdWPPDiRB7EASxA4kQexAEsQOJEHsQBLEDiTBOntyK/lPRTf981ca9uxAEsQOJEHsQBLEDiRB7EASxA4kQexAEqyzJ1fmM+FV/HwMT889u+1dtudsH1my7THbH9k+WHxtqXdMAGX18zT+GUmbO2x/IiI2FF97qx0LQNV6xh4RBySdGcIsAGpU5gW6+2wfKp7mj3W7ke1J29O2pxfm50vcHYAyBo39SUnXStogaVbS491uGBFTETERERMjo6MD3h2AsgaKPSJOR8RCRHwu6SlJG6sdC0DVBord9viSi7dLOtLttgDaoec6u+3nJN0k6UrbM5IelXST7Q2SQtJJSffUOGPr1f2ZcNaqUYWesUfEtg6bn65hFgA14u2yQBLEDiRB7EASxA4kQexAEnzEdQVo8s894+LBnh1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgnX2CtT955jbjI/frhzs2YEkiB1IgtiBJIgdSILYgSSIHUiC2IEkWGcfgot5Lfq6B99Y9voTO28c0iTohT07kASxA0kQO5AEsQNJEDuQBLEDSRA7kATr7IUy68W9vheDYQ2/Wj337LbX2X7N9jHbR23fX2xfZXuf7ePF6Vj94wIYVD9P489JeigivivpRkn32r5e0sOS9kfEekn7i8sAWqpn7BExGxHvFOfPSjomaa2krZJ2FzfbLem2uoYEUN5XeoHO9tWSbpD0pqQ1ETErLf6DIGl1l++ZtD1te3phfr7ctAAG1nfsti+X9KKkByLik36/LyKmImIiIiZGRkcHmRFABfqK3falWgz92Yh4qdh82vZ4cf24pLl6RgRQhZ5Lb7Yt6WlJxyJi55Kr9kjaLmlHcfpKLRMOCcs4g6nzceO/SbX6WWffJOkuSYdtHyy2PaLFyF+wfbekDyTdUc+IAKrQM/aIeF2Su1x9c7XjAKgLb5cFkiB2IAliB5IgdiAJYgeS4COuFei1Htzmj8Cylp0He3YgCWIHkiB2IAliB5IgdiAJYgeSIHYgCdbZh6DJdXjW0XEee3YgCWIHkiB2IAliB5IgdiAJYgeSIHYgCdbZW4C1cAwDe3YgCWIHkiB2IAliB5IgdiAJYgeSIHYgiZ6x215n+zXbx2wftX1/sf0x2x/ZPlh8bal/XACD6udNNeckPRQR79i+QtLbtvcV1z0REb+ubzwAVenn+OyzkmaL82dtH5O0tu7BAFTrK/3ObvtqSTdIerPYdJ/tQ7Z32R7r8j2TtqdtTy/Mz5caFsDg+o7d9uWSXpT0QER8IulJSddK2qDFPf/jnb4vIqYiYiIiJkZGRysYGcAg+ord9qVaDP3ZiHhJkiLidEQsRMTnkp6StLG+MQGU1c+r8Zb0tKRjEbFzyfbxJTe7XdKR6scDUJV+Xo3fJOkuSYdtHyy2PSJpm+0NkkLSSUn31DIhgEr082r865Lc4aq91Y8DoC68gw5IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJIgdSILYgSSIHUiC2IEkiB1IgtiBJBwRw7sz+9+S/rVk05WSPh7aAF9NW2dr61wSsw2qytm+HRHf7HTFUGP/0p3b0xEx0dgAy2jrbG2dS2K2QQ1rNp7GA0kQO5BE07FPNXz/y2nrbG2dS2K2QQ1ltkZ/ZwcwPE3v2QEMCbEDSTQSu+3Ntv9h+4Tth5uYoRvbJ20fLg5DPd3wLLtsz9k+smTbKtv7bB8vTjseY6+h2VpxGO9lDjPe6GPX9OHPh/47u+0RSf+U9BNJM5LekrQtIt4d6iBd2D4paSIiGn8Dhu0fSfpU0h8j4nvFtl9JOhMRO4p/KMci4hctme0xSZ82fRjv4mhF40sPMy7pNkk/U4OP3TJz/VRDeNya2LNvlHQiIt6PiM8kPS9pawNztF5EHJB05oLNWyXtLs7v1uL/LEPXZbZWiIjZiHinOH9W0vnDjDf62C0z11A0EftaSR8uuTyjdh3vPSS9avtt25NND9PBmoiYlRb/55G0uuF5LtTzMN7DdMFhxlvz2A1y+POymoi906Gk2rT+tykifiDpVkn3Fk9X0Z++DuM9LB0OM94Kgx7+vKwmYp+RtG7J5asknWpgjo4i4lRxOifpZbXvUNSnzx9Btzida3ie/2vTYbw7HWZcLXjsmjz8eROxvyVpve1rbF8m6U5JexqY40tsjxYvnMj2qKRb1L5DUe+RtL04v13SKw3O8gVtOYx3t8OMq+HHrvHDn0fE0L8kbdHiK/LvSfplEzN0mes7kv5WfB1tejZJz2nxad1/tfiM6G5J35C0X9Lx4nRVi2b7k6TDkg5pMazxhmb7oRZ/NTwk6WDxtaXpx26ZuYbyuPF2WSAJ3kEHJEHsQBLEDiRB7EASxA4kQexAEsQOJPE/ppWOB0CAZMMAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "W_0 = (model_q.mask*model_q.centroids[model_q.labels]).detach().cpu().numpy()\n", "plt.imshow(W_0[:,1].reshape((28,28)))\n", @@ -414,7 +559,12 @@ "\n", "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", "\n", - "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? " + "I got down to two centroids before things got wierd. \n", + "\n", + "\n", + "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? \n", + "\n", + "Somewhat surprisingly, the accuracy difference the between the compressed form of the model that was trained on one epoch and the compressed form of the model that was trained on 10 epochs is only about 3%." ] }, { diff --git a/mnist_pretrained.h5 b/mnist_pretrained.h5 new file mode 100644 index 0000000..3f52777 Binary files /dev/null and b/mnist_pretrained.h5 differ diff --git a/mnist_pruned.h5 b/mnist_pruned.h5 new file mode 100644 index 0000000..3ea5e1c Binary files /dev/null and b/mnist_pruned.h5 differ diff --git a/mnist_quantized.h5 b/mnist_quantized.h5 new file mode 100644 index 0000000..428995f Binary files /dev/null and b/mnist_quantized.h5 differ