diff --git a/deep_compression_exercise.ipynb b/deep_compression_exercise.ipynb index f903c10..0c2129b 100644 --- a/deep_compression_exercise.ipynb +++ b/deep_compression_exercise.ipynb @@ -1,449 +1,732 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Exercise Week 9: Pruning and Quantization\n", - "This week, we will explore some of the ideas discussed in Han, Mao, and Dally's Deep Compression. In particular, we will implement weight pruning with fine tuning, as well as k-means weight quantization. **Note that we will unfortunately not be doing this in a way that will actually lead to substantial efficiency gains: that would involve the use of sparse matrices which are not currently well-supported in pytorch.** \n", - "\n", - "## Training an MNIST classifier\n", - "For this example, we'll work with a basic multilayer perceptron with a single hidden layer. We will train it on the MNIST dataset so that it can classify handwritten digits. As usual we load the data:" - ] + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "colab": { + "name": "deep_compression_exercise.ipynb", + "provenance": [], + "machine_shape": "hm" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torchvision.transforms as transforms\n", - "import torchvision.datasets as datasets\n", - "\n", - "device = torch.device('cuda' if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", - "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n", - "\n", - "batch_size = 300\n", - "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", - "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then define a model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MultilayerPerceptron(torch.nn.Module):\n", - " def __init__(self, input_dim, hidden_dim, output_dim,mask=None):\n", - " super(MultilayerPerceptron, self).__init__()\n", - " if not mask:\n", - " self.mask = torch.nn.Parameter(torch.ones(input_dim,hidden_dim),requires_grad=False)\n", - " else:\n", - " self.mask = torch.nn.Parameter(mask)\n", - "\n", - " self.W_0 = torch.nn.Parameter(1e-3*torch.randn(input_dim,hidden_dim)*self.mask,requires_grad=True)\n", - " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim),requires_grad=True)\n", - "\n", - " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", - " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", - " \n", - " def set_mask(self,mask):\n", - " \n", - " self.mask.data = mask.data\n", - " self.W_0.data = self.mask.data*self.W_0.data\n", - "\n", - " def forward(self, x):\n", - " hidden = torch.tanh(x@(self.W_0*self.mask) + self.b_0)\n", - " outputs = hidden@self.W_1 + self.b_1\n", - " return outputs\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that the above code is a little bit different than a standard multilayer perceptron implementation.\n", - "\n", - "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", - "\n", - "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_epochs = 10\n", - "\n", - "input_dim = 784\n", - "hidden_dim = 64\n", - "output_dim = 10\n", - "\n", - "model = MultilayerPerceptron(input_dim,hidden_dim,output_dim)\n", - "model = model.to(device)\n", - "\n", - "criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy\n", - "lr_rate = 0.001\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And then training proceeds as normal." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "iter = 0\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_pretrained.h5')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Pruning\n", - "\n", - "Certainly not a state of the art model, but also not a terrible one. Because we're hoping to do some weight pruning, let's inspect some of the weights directly (recall that we can act like they're images)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "W_0 = model.W_0.detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", - "\n", - "\n", - "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "new_mask = model.mask" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model.set_mask(new_mask)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we have explicitly set some entries in one of the the weight matrices to zero, and ensured via the mask, that they will not be updated by gradient descent. Fine tune the model: " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "iter = 0\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_pruned.h5')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", - "\n", - "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "W_0 = model.W_0.detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Quantization\n", - "\n", - "Now that we have a pruned model that appears to be performing well, let's see if we can make it even smaller by quantization. To do this, we'll need a slightly different neural network, one that corresponds to Figure 3 from the paper. Instead of having a matrix of float values, we'll have a matrix of integer labels (here called \"labels\") that correspond to entries in a (hopefully) small codebook of centroids (here called \"centroids\"). The way that I've coded it, there's still a mask that enforces our desired sparsity pattern." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MultilayerPerceptronQuantized(torch.nn.Module):\n", - " def __init__(self, input_dim, output_dim, hidden_dim,mask,labels,centroids):\n", - " super(MultilayerPerceptronQuantized, self).__init__()\n", - " self.mask = torch.nn.Parameter(mask,requires_grad=False)\n", - " self.labels = torch.nn.Parameter(labels,requires_grad=False)\n", - " self.centroids = torch.nn.Parameter(centroids,requires_grad=True)\n", - "\n", - " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim))\n", - "\n", - " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim))\n", - " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim))\n", - "\n", - " def forward(self, x):\n", - " W_0 = self.mask*self.centroids[self.labels]\n", - " hidden = torch.tanh(x@W_0 + self.b_0)\n", - " outputs = hidden@self.W_1 + self.b_1\n", - " return outputs" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Notice what is happening in the forward method: W_0 is being reconstructed by using a matrix (self.labels) to index into a vector (self.centroids). The beauty of automatic differentiation allows backpropogation through this sort of weird indexing operation, and thus gives us gradients of the objective function with respect to the centroid values!\n", - "\n", - "### Q6: However, before we are able to use this AD magic, we need to specify the static label matrix (and an initial guess for centroids). Use the k-means algorithm (or something else if you prefer) figure out the label matrix and centroid vectors. PROTIP1: I used scikit-learns implementation of k-means. PROTIP2: only cluster the non-zero entries" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# convert weight and mask matrices into numpy arrays\n", - "W_0 = model.W_0.detach().cpu().numpy()\n", - "mask = model.mask.detach().cpu().numpy()\n", - "\n", - "# Figure out the indices of non-zero entries \n", - "inds = np.where(mask!=0)\n", - "# Figure out the values of non-zero entries\n", - "vals = W_0[inds]\n", - "\n", - "### TODO: perform clustering on vals\n", - "\n", - "### TODO: turn the label matrix and centroids into a torch tensor\n", - "labels = torch.tensor(...,dtype=torch.long,device=device)\n", - "centroids = torch.tensor(...,device=device)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can instantiate our quantized model and import the appropriate pre-trained weights for the other network layers. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Instantiate quantized model\n", - "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,new_mask,labels,centroids)\n", - "model_q = model_q.to(device)\n", - "\n", - "# Copy pre-trained weights from unquantized model for non-quantized layers\n", - "model_q.b_0.data = model.b_0.data\n", - "model_q.W_1.data = model.W_1.data\n", - "model_q.b_1.data = model.b_1.data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we can fine tune the quantized model. We'll adjust not only the centroids, but also the weights in the other layers." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", - "iter = 0\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model_q(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model_q(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_quantized.h5')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "After retraining, we can, just for fun, reconstruct the pruned and quantized weights and plot them as images:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "W_0 = (model_q.mask*model_q.centroids[model_q.labels]).detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Certainly a much more parsimonious representation. The obvious question now becomes:\n", - "\n", - "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", - "\n", - "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7T8MKDWtAMw_", + "colab_type": "text" + }, + "source": [ + "# Exercise Week 9: Pruning and Quantization\n", + "This week, we will explore some of the ideas discussed in Han, Mao, and Dally's Deep Compression. In particular, we will implement weight pruning with fine tuning, as well as k-means weight quantization. **Note that we will unfortunately not be doing this in a way that will actually lead to substantial efficiency gains: that would involve the use of sparse matrices which are not currently well-supported in pytorch.** \n", + "\n", + "## Training an MNIST classifier\n", + "For this example, we'll work with a basic multilayer perceptron with a single hidden layer. We will train it on the MNIST dataset so that it can classify handwritten digits. As usual we load the data:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uA2lbAk_AMxD", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import torch\n", + "import torchvision.transforms as transforms\n", + "import torchvision.datasets as datasets\n", + "import numpy as np\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", + "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n", + "\n", + "batch_size = 300\n", + "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jh79ItoQAMxg", + "colab_type": "text" + }, + "source": [ + "Then define a model:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3KrAiyIPAMxi", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MultilayerPerceptron(torch.nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim,mask=None):\n", + " super(MultilayerPerceptron, self).__init__()\n", + " if not mask:\n", + " self.mask = torch.nn.Parameter(torch.ones(input_dim,hidden_dim),requires_grad=False)\n", + " else:\n", + " self.mask = torch.nn.Parameter(mask)\n", + "\n", + " self.W_0 = torch.nn.Parameter(1e-3*torch.randn(input_dim,hidden_dim)*self.mask,requires_grad=True)\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim),requires_grad=True)\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", + " \n", + " def set_mask(self,mask):\n", + " \n", + " self.mask.data = mask.data\n", + " self.W_0.data = self.mask.data*self.W_0.data\n", + "\n", + " def forward(self, x):\n", + " hidden = torch.tanh(x@(self.W_0*self.mask) + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ec6tMW9jAMx6", + "colab_type": "text" + }, + "source": [ + "Note that the above code is a little bit different than a standard multilayer perceptron implementation.\n", + "\n", + "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", + "\n", + "This model has a mask. The mask prevents the model from updating weights that were pruned out
\n", + "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "W0QRXfI_AMyA", + "colab_type": "code", + "colab": {} + }, + "source": [ + "n_epochs = 10\n", + "\n", + "input_dim = 784\n", + "hidden_dim = 64\n", + "output_dim = 10\n", + "\n", + "model = MultilayerPerceptron(input_dim,hidden_dim,output_dim)\n", + "model = model.to(device)\n", + "\n", + "criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy\n", + "lr_rate = 0.001\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-3)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UQlnDzOGAMyl", + "colab_type": "text" + }, + "source": [ + "And then training proceeds as normal." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "a-BNPj6TAMys", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 442 + }, + "outputId": "b3656eac-fd7c-49e1-a363-42b06c9f6337" + }, + "source": [ + "iter = 0\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_pretrained.h5')\n" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.4250503182411194. Accuracy: 90.\n", + "Iteration: 1. Loss: 0.3562028110027313. Accuracy: 92.\n", + "Iteration: 2. Loss: 0.2239571362733841. Accuracy: 93.\n", + "Iteration: 3. Loss: 0.18697282671928406. Accuracy: 94.\n" + ], + "name": "stdout" + }, + { + "output_type": "error", + "ename": "KeyboardInterrupt", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0miter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 344\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 345\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 346\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 385\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 386\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 97\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, pic)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mConverted\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \"\"\"\n\u001b[0;32m--> 101\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py\u001b[0m in \u001b[0;36mto_tensor\u001b[0;34m(pic)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;31m# put it from HWC to CHW format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;31m# yikes, this transpose takes 80% of the loading time/CPU\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 99\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mByteTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m255\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RYdDk972AMzH", + "colab_type": "text" + }, + "source": [ + "## Pruning\n", + "\n", + "Certainly not a state of the art model, but also not a terrible one. Because we're hoping to do some weight pruning, let's inspect some of the weights directly (recall that we can act like they're images)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "wsotDexkAMzQ", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + }, + "outputId": "01957ebd-8834-46fb-f6bf-fd963d20540e" + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAVfElEQVR4nO3dXWyc5ZUH8P+ZD9szYydO7DgxIYVA\nQyFLtXTXilYtWrFC21JuoDeoXFRUQk0vilS0vVjEXpRLtNrS7cWqUrqgprtdqq5aBJUQW0qREKrE\nYlCAAEsDLPlw4o/EdjIe2/N59sID64Kfc8x8O8//J1m258z7vs+875x5Z+a8z/OIqoKILn+JbjeA\niDqDyU4UCSY7USSY7ESRYLITRSLVyY0lczlN7dzZyU0SRaUyP49qoSAbxZpKdhG5DcCPACQB/Kuq\nPmzdP7VzJ/b+3f3NbJI+Jd3wsP8/qTp3SLA0u5VMPfLPwVjDb+NFJAngXwB8FcBBAHeLyMFG10dE\n7dXMZ/ZDAN5V1fdVtQTgFwDuaE2ziKjVmkn2vQBOr/v/TP22PyEih0VkUkQmq4VCE5sjoma0/dt4\nVT2iqhOqOpHM5dq9OSIKaCbZpwDsW/f/lfXbiKgHNZPsLwM4ICL7RaQPwNcBPNWaZhFRqzVcelPV\niojcB+C/sFZ6e0xV32xZy3pNLVyiEqc65Za/nOWlZsdhLJ8s2huvpeyN19JO45uozLn7Leks7+wX\ntU5l3j5vsuJYS/deybKpOruqPg3g6Ra1hYjaiJfLEkWCyU4UCSY7USSY7ESRYLITRYLJThSJjvZn\nbycp2/Vgr2bbTFdOqTqrdurByRW77cmis7wRT5Tsx6Uppw6ftrftxa1rDBJlZ1nnmNWcZ6+1vFej\n9x5XdcC7SMAOw7l8oR14ZieKBJOdKBJMdqJIMNmJIsFkJ4oEk50oEpdP6c27g1NqqWa8Uko4rhV7\n6+m8Ha/2O91M+8wwVsbCMa+rplea80pICafkWcmGV5DySo6rzrYrdhxG3Fu3p9Znt720zdkvOeMJ\n2aayHM/sRJFgshNFgslOFAkmO1EkmOxEkWCyE0WCyU4Uicumzu4N3evFEzvtfqQDA+H+mDVjmGkA\n2Ja1i7p7cnkz7lmthg9j1RxPGShX7X6kp86OmPHUtHMRgLF5p2koD3rdc+3lUxvPXAzA75bsanIo\nam+Ib4vZtdfYLs/sRJFgshNFgslOFAkmO1EkmOxEkWCyE0WCyU4Uicunzu7VXHfZte6Jz5wy42MD\n4Vr47OqQuex8MWvGZ5YHzbinVAnXytWZL7pUsXecVu3ly2POeNDF8PkkM2Nvu9rfXMfuVCEc82r0\nlQE77g9z7VwjYDw0bxhrs4hvrLepZBeRDwDkAVQBVFR1opn1EVH7tOLM/jeqer4F6yGiNuJndqJI\nNJvsCuC3IvKKiBze6A4iclhEJkVkslowPkQRUVs1+zb+ZlWdEpExAM+KyP+o6gvr76CqRwAcAYD+\nffsan1CNiJrS1JldVafqv2cBPAHgUCsaRUSt13Cyi0hORIY+/BvAlwEcb1XDiKi1mnkbvxvAEyLy\n4Xr+Q1WfaUmrAqzapObsQcQzmZIZX6r0m/FRXQrG3pnfZS67uJgz4+l+u+3lkn2YastG3Olrn8g5\ndXKnzp6ad9pmjIlfsS8/QCVnf+pLLdltW91ljPXvnOasvvAAMHDBXl6TTtuMYQK8sRca1XCyq+r7\nAP68hW0hojZi6Y0oEkx2okgw2YkiwWQnigSTnSgSW6qLq9WtUJwuhXuG7OGaU87Ywm8ujgdj82e3\nm8smVuz+kMU+u09jctl+TRZrcWe/iPdyn3CWd6bCThrTMnvlr4Qz3PLqHvuYaV+4cYll+5h4bRt7\n1S6XDswum/HpL4afMyt77G1Xne63ITyzE0WCyU4UCSY7USSY7ESRYLITRYLJThQJJjtRJLZUnd0a\nQledrpy7MuEuqoA97TEALJfDxey+YXu658qK3Zez/7xd863avW+RmQ8/dm9Y4mLVnnI5OepMZX3Q\nHmpsaTH82MvOMROnxq9OT9B0Jtx9t6z2Tu1bsHdc9v0Fe+Mzc2Y4+Zfb7OXbgGd2okgw2YkiwWQn\nigSTnSgSTHaiSDDZiSLBZCeKxNaqsxtjSUvS7li9WMqY8WzKHmq6UAzXo0sLdgfjwSn7NTW1YoaR\nWrYLyrmZcD15dYddw8+X7fiO6y6Z8d/82b+b8VdL4emsn160BydeLNvXJ5wpDJvxitEp/f38bnNZ\nr5++rDrXVly095t5zJ1tN4pndqJIMNmJIsFkJ4oEk50oEkx2okgw2YkiwWQnisSWqrMb3dnd8c2n\nLtpju48O2v2y8/PhaZdH/9uuVeem7WmRk6t2YTVRssdHL28z+tov2eve9Zq97pOf3WnGz15v90m/\nNRNe/1DiJXPZ3+VvNOPDaXts9mMLV4aDFWdK5vP286k6dc6Me6yx32vO+AWNcs/sIvKYiMyKyPF1\nt+0UkWdF5ET99472NI+IWmUzb+N/CuC2j932AIDnVPUAgOfq/xNRD3OTXVVfADD/sZvvAHC0/vdR\nAHe2uF1E1GKNfkG3W1U//NAyDSB4obGIHBaRSRGZrBbsz8VE1D5Nfxuvqgog+G2Gqh5R1QlVnUjm\nwl9yEVF7NZrsMyIyDgD137OtaxIRtUOjyf4UgHvqf98D4MnWNIeI2sWts4vI4wBuATAqImcAfB/A\nwwB+KSL3AjgJ4K52NvKjthgl4VrVft0qV+1a+Hunx8z4rhfCtezRP9hvbMrj9hjhhT12YXVprx1f\nHQvXhHOn7Hry+H+eMOND1x8w4698cZ8Zz9dmgrHXi1eZy86Ww33hN6NYCT+9B6btp/7QaXv+da3Y\n8dR++7GtjIaPS3XAvvYhUbaPabBN3h1U9e5A6NaGtkhEXcHLZYkiwWQnigSTnSgSTHaiSDDZiSKx\npbq4WmTenmJ32ewfCyQu2Mtv/9/VcNCZO3j+c/ZQ08tX2KWU0mftsaZrpXBZUafsKZkxbJe3lsft\nLrLPXPi8GX9vMDxk86kVp/tswe6WXHSm2Z46NRKMjZ6yj1lmyr60O7HD7ui5dKM9VPXyVUbpzhtK\nusHSG8/sRJFgshNFgslOFAkmO1EkmOxEkWCyE0WCyU4Uicunzm6XTc1aNACkCnbtMr8v3M00PbLL\nXHZ1xKmjb7cLq7UV+zANnAlfI5Cdtdc9f8ju2lvdYXflnDz1GTM+NzYYjN2wfdpctj9lj2x0esGe\nsjlRMK4/sJ8OSOTtYaql375+YXXY3oAMhKcI11WncQ3imZ0oEkx2okgw2YkiwWQnigSTnSgSTHai\nSDDZiSKxtersRi29mrXryamcPW1yLW3viktXGa+LzktmJWtfBNB30V7BwHt2PDsTHno4YZfJsbzb\nqemWnCG61R4HoFILL59J2sfkc0PhYagBIGE9IQC8Np8JxkrGNNcAULjevnYitWwP91zJmmF7CISk\nc9FIg3hmJ4oEk50oEkx2okgw2YkiwWQnigSTnSgSTHaiSPRWnd0rL1rdwp2htLPZohlfutJeQdUI\n64Ldtzm5Yr+m5qbsBz50yq5HJyrh5StZu45etIduR2LI3vb46EUzPjFyyt6A4djClWY8Kc4A66nw\nflndZe/z2SE7NTIz9n5dGbOfT2KEE1lnuuh8Y/3d3TO7iDwmIrMicnzdbQ+JyJSIHKv/3N7Q1omo\nYzbzNv6nAG7b4PYfqupN9Z+nW9ssImo1N9lV9QUA8x1oCxG1UTNf0N0nIq/X3+YHJ74SkcMiMiki\nk9WCPX8WEbVPo8n+YwDXArgJwDkAPwjdUVWPqOqEqk4kc/YAgkTUPg0lu6rOqGpVVWsAfgLgUGub\nRUSt1lCyi8j4un+/BuB46L5E1BvcOruIPA7gFgCjInIGwPcB3CIiN2GtMv4BgG+3ojHaxkt89u+w\nv2OcG7A/YlxcDveNXnGuD0jPhJcFAKk6c8cbdXQAKA+F666X9tmHuHbDkhn/yjV/NONen/JrB2bN\nuCW/zZ7X/sWp/fYKrHnMr7bHhS8W7f1WzdjXVpRH7OsThobD2y+X7W0XU8a2jYfsJruq3r3BzY96\nyxFRb+HlskSRYLITRYLJThQJJjtRJJjsRJHoqS6u3rTL2kQX1yuydlfM05e2m/HCXHhsYHGGW047\nVwln5+xhiTXpdJc0Fi8GL2Re85Vr37bjw2+Y8eMr+8z4Hy5eG4yVavbT7+Qlu/H5ufB00ADQPxNe\nf9Ho/goAqYzdzbQ8apfWPKlEuHvuUsEu60nNebIH8MxOFAkmO1EkmOxEkWCyE0WCyU4UCSY7USSY\n7ESR6Kk6u8t6aeq3a9Xni3YX1kzarqsmVsLdSLNn7dfM/vnmpuD1usAmi+GabWrFXrc3bXLS6cL6\nm6nPm/G5i+FaeHnFnjZZvXqyM7VxaUd4v0jaHoa6WrKHa05m7OdbdcmZAnzJ6PZ8yV7Wuq7COlw8\nsxNFgslOFAkmO1EkmOxEkWCyE0WCyU4UCSY7USS2VJ09sRquu9a22TXZ5YrdR3h2YciMJ41tJ0rm\nokgv2zXdRMmOpwpOLfxCeDjokfSIuezzZw+Y8d/rdWZ84R17zueEMZyz9DvXH+y0d2yq3+lzvmoM\nRW3NwQ0gPW0/Xyrb7GPmjc2QToeL5at9zsLWENkGntmJIsFkJ4oEk50oEkx2okgw2YkiwWQnigST\nnSgSW6rObo0NL0m77rlSsftOlwt2PGmMM+6N614ctl9Tq/1OzTfvjFGeDh/GzMm8uejiM7vMeNm+\n/ABD9nD8yO8PH5fagHP9gdNfvd+ps8vIajBWPm9Po9130T4m1QE7Xhu2j1nRGBs+sdqec7C7VhHZ\nJyLPi8hbIvKmiHy3fvtOEXlWRE7UfzvTERBRN23mJaQC4HuqehDAXwH4jogcBPAAgOdU9QCA5+r/\nE1GPcpNdVc+p6qv1v/MA3gawF8AdAI7W73YUwJ3taiQRNe9TfTgQkasBfAHASwB2q+q5emgawO7A\nModFZFJEJqsFZ9IzImqbTSe7iAwC+BWA+1X10vqYqioCQ92p6hFVnVDViWTOHvSRiNpnU8kuImms\nJfrPVfXX9ZtnRGS8Hh8HMNueJhJRK7ilNxERAI8CeFtVH1kXegrAPQAerv9+si0tXKeWNspfFft1\na2renpI54QxFranwrkoWzUXd0prYFSQU9hpdNQHkrOGDy/bjGnnLbvzSXrurZ+EKpwRlTW3sDBWd\nzdlt25YJl9YA4OyFcPfbVMF+vlSydtlPjeciAKBoD0VtdlO1K5IN20yd/UsAvgHgDRE5Vr/tQawl\n+S9F5F4AJwHc1Z4mElEruMmuqi8ifDnLra1tDhG1Cy+XJYoEk50oEkx2okgw2YkiwWQnisTW6uJq\nkFW7rlnUfjOezNrFbqvLYsGpRfct2vXk0pAdX63Yj600lA3GctPOVNQVZ9rjQa/7rlOPNo5LYtDu\nBlou24+72m+fq/bsXQjGLgzaV3N625Y5+/kkFXu/qdVdu7kZvoN4ZieKBJOdKBJMdqJIMNmJIsFk\nJ4oEk50oEkx2okhsqTq7WX8sOnVNuyyKpDMUdXY43Lc670z/m5uy6/BVp23i9G/WJl6yKxl74fSy\nXfTNnnP2+1z4KZa/zl731VfY46Fs67P7s5/ODwdjI9vtIdKmp8PLAkDaeb5VnemoG5t0uTk8sxNF\ngslOFAkmO1EkmOxEkWCyE0WCyU4UCSY7USS2VJ3d4tWi4U7/a/etHsktB2PZPnvZuYujZjwz60zp\nbJfpkTDGIK8O2Ou2+lUD/n5Nluz9ujwSjiWcMQROTI2ZcfUab/Cmg4YzPkLNyZxEybn+wBt3vg14\nZieKBJOdKBJMdqJIMNmJIsFkJ4oEk50oEkx2okhsZn72fQB+BmA3AAVwRFV/JCIPAfgWgLn6XR9U\n1afb1VCXU3OVZfuh5kuDZnxpIBOMZQbtecTHbpgz44Vr7EJ6uZg244sL4fnbFw+aiyJZsOvJXj04\nfcmpJxur73svvE8BoDpgb7va54xZb8RrzhgE6eUme5w7ZfTOV9k3d1FNBcD3VPVVERkC8IqIPFuP\n/VBV/6l9zSOiVtnM/OznAJyr/50XkbcB7G13w4iotT7VZ3YRuRrAFwC8VL/pPhF5XUQeE5EdgWUO\ni8ikiExWC/ZQQETUPptOdhEZBPArAPer6iUAPwZwLYCbsHbm/8FGy6nqEVWdUNWJZM6eX4uI2mdT\nyS4iaawl+s9V9dcAoKozqlpV1RqAnwA41L5mElGz3GQXEQHwKIC3VfWRdbePr7vb1wAcb33ziKhV\nNvNt/JcAfAPAGyJyrH7bgwDuFpGbsFZF+ADAt9vSws1K2MWMxKo3ha5TgiqFXxdX8nZpbDnlfHzx\nuls6j02c7pSW6rDdzdRTSjrnC6PpUvPKds7jdspnsLrnOvMiV8PVzDXOIWvXtMvN2My38S9i42Gu\nu1dTJ6JPjVfQEUWCyU4UCSY7USSY7ESRYLITRYLJThSJy2Yo6WZ5ddFmatnaxLLtJkv29QXt7Ivp\n7nOvju6tv4nl3SG2e7CO7uGZnSgSTHaiSDDZiSLBZCeKBJOdKBJMdqJIMNmJIiGqnSsYisgcgJPr\nbhoFcL5jDfh0erVtvdougG1rVCvbdpWq7too0NFk/8TGRSZVdaJrDTD0att6tV0A29aoTrWNb+OJ\nIsFkJ4pEt5P9SJe3b+nVtvVquwC2rVEdaVtXP7MTUed0+8xORB3CZCeKRFeSXURuE5F3RORdEXmg\nG20IEZEPROQNETkmIpNdbstjIjIrIsfX3bZTRJ4VkRP13xvOsdeltj0kIlP1fXdMRG7vUtv2icjz\nIvKWiLwpIt+t397VfWe0qyP7reOf2UUkCeCPAP4WwBkALwO4W1Xf6mhDAkTkAwATqtr1CzBE5K8B\nLAH4mareWL/tHwHMq+rD9RfKHar69z3StocALHV7Gu/6bEXj66cZB3AngG+ii/vOaNdd6MB+68aZ\n/RCAd1X1fVUtAfgFgDu60I6ep6ovAJj/2M13ADha//so1p4sHRdoW09Q1XOq+mr97zyAD6cZ7+q+\nM9rVEd1I9r0ATq/7/wx6a753BfBbEXlFRA53uzEb2K2q5+p/TwPY3c3GbMCdxruTPjbNeM/su0am\nP28Wv6D7pJtV9S8AfBXAd+pvV3uSrn0G66Xa6aam8e6UDaYZ/0g3912j0583qxvJPgVg37r/r6zf\n1hNUdar+exbAE+i9qahnPpxBt/57tsvt+UgvTeO90TTj6IF9183pz7uR7C8DOCAi+0WkD8DXATzV\nhXZ8gojk6l+cQERyAL6M3puK+ikA99T/vgfAk11sy5/olWm8Q9OMo8v7ruvTn6tqx38A3I61b+Tf\nA/AP3WhDoF3XAHit/vNmt9sG4HGsva0rY+27jXsBjAB4DsAJAL8DsLOH2vZvAN4A8DrWEmu8S227\nGWtv0V8HcKz+c3u3953Rro7sN14uSxQJfkFHFAkmO1EkmOxEkWCyE0WCyU4UCSY7USSY7ESR+D/m\nhJk4Rjvq3AAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XHDUX74dAMzy", + "colab_type": "text" + }, + "source": [ + "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", + "all of the ones that are on the outside of the circle
\n", + "\n", + "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WZ40V0vIAMz5", + "colab_type": "code", + "colab": {} + }, + "source": [ + "weight_threshold = .01\n", + "new_mask = model.mask\n", + "new_mask = torch.tensor(np.abs(W_0) >= weight_threshold)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KE_hID9yAM0e", + "colab_type": "text" + }, + "source": [ + "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "MSFbM3B6AM0m", + "colab_type": "code", + "colab": {} + }, + "source": [ + "model.set_mask(new_mask)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZU7ozjKyAM1N", + "colab_type": "text" + }, + "source": [ + "Now, we have explicitly set some entries in one of the the weight matrices to zero, and ensured via the mask, that they will not be updated by gradient descent. Fine tune the model: " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "p0o0NafEAM1W", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 408 + }, + "outputId": "2f29e4ee-2148-485c-fac6-215ac33f9f32" + }, + "source": [ + "iter = 0\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " W_0 = model.W_0.data\n", + " nonzero_elements = np.count_nonzero(W_0)\n", + " elements = np.prod(W_0.shape)\n", + " compression = elements / nonzero_elements\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}. compression: {}\".format(iter, loss.item(), accuracy, compression))\n", + "torch.save(model.state_dict(),'mnist_pruned.h5')" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.1628478467464447. Accuracy: 94. compression: 1.0\n", + "Iteration: 0. Loss: 0.19111663103103638. Accuracy: 95. compression: 1.0\n" + ], + "name": "stdout" + }, + { + "output_type": "error", + "ename": "KeyboardInterrupt", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0miter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 344\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 345\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 346\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 385\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 386\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# doing this so that it is consistent with all other datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;31m# to return a PIL Image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfromarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'L'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Yl3nqNMmAM2A", + "colab_type": "text" + }, + "source": [ + "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", + "\n", + "with a weight threshold of .001, I got the same accuracy as the unpruned model of 96%, along with with a compression rate of 1.01\n", + "\n", + "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?\n", + "\n", + "weight_threshold: .01. accuracy: 96. compression: 1.16
\n", + "weight_threshold: .1. accuracy: 95. compression: 1.17
\n", + "weight_threshold: .5. accuracy: 40. compression: 1.36\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zVJaQKeRbZDG", + "colab_type": "text" + }, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "JfwmvLd4AM2J", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + }, + "outputId": "f5ae9d85-98b6-408a-bceb-300b05d59130" + }, + "source": [ + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAUwklEQVR4nO3dbWxk5XUH8P+ZGY/tmbG9ftk1Znez\nLLBNWVICyKKRgiqqqBFBrSCqhMKHiEiomw9BIlKkFtEP4SOKGiIqVZE2BbGpUqKoCWJVoTabbVQa\nokZr0IZdXnezL7DGa6/t9ct4xvN6+sGX1ICf8wzzHj//n7Syd87ce4/v+Pjac+7zPKKqIKLtL9bp\nBIioPVjsRIFgsRMFgsVOFAgWO1EgEu08WDyT1sTISDsPSRSU8uIiKtk12SrWULGLyN0AngIQB/DP\nqvqE9fzEyAgm/vaRRg5JRIaZ7zzljNX9a7yIxAH8E4AvATgI4AEROVjv/oiotRr5m/0OAGdV9Zyq\nFgH8GMC9zUmLiJqtkWLfDeC9Tf+/FD32ISJySESmRGSqks02cDgiakTL341X1cOqOqmqk/FMptWH\nIyKHRop9GsDeTf/fEz1GRF2okWI/AeCAiOwXkSSArwA42py0iKjZ6m69qWpZRB4G8J/YaL09o6qv\nNy0zaopYacuW6/+r2vFqb7WJ2VAnNdRnV9UXAbzYpFyIqIV4uyxRIFjsRIFgsRMFgsVOFAgWO1Eg\nWOxEgWjrePZQSdnuZft64bGyvX81No8XPH12z+TClZLneuC7XBj7T+Ts3OJ5e9eVXk+8331wjdvb\nqufrqvT/4d1/wCs7USBY7ESBYLETBYLFThQIFjtRIFjsRIFg6y2iSbsHJUVPC8tQTVfsJ+TsPpCv\nTVQ1ci97JgfSPruFJJ7WWyxnx60hslL17NvTkoyXzDDEGL6b8LT1qp7KKKfsF6Wcsr+fKqn2t+54\nZScKBIudKBAsdqJAsNiJAsFiJwoEi50oECx2okCwzx6JZeym7dBgzhlLxO2e6VhqzYxnegpmfKXY\nZ8YTMffxY2L3e4sVu1/89vkJM1713EIAa/jtut1HL+6wcy+n7fPeu+D+2oo95qao9nruu/AMW7bu\nffDy3dJR5655ZScKBIudKBAsdqJAsNiJAsFiJwoEi50oECx2okCE02cfsvvoN+25bMZvHLjijC0W\nU+a2i4W0GT+7OGbGkwm7mS1GL73H6MEDQLaQNONeg/Y811pwX0+SS3aPXyp2wzmxam+fXHHHSp5x\n/r4evybt86pxT5/eGqvfQIve0lCxi8gFAKsAKgDKqjrZjKSIqPmacWX/c1Wdb8J+iKiF+Dc7USAa\nLXYF8HMReUVEDm31BBE5JCJTIjJVyWYbPBwR1avRX+PvVNVpEdkF4JiIvKWqL21+gqoeBnAYAHo/\ntbdFbz0QkU9DV3ZVnY4+zgF4HsAdzUiKiJqv7mIXkbSIDHzwOYAvAjjdrMSIqLka+TV+HMDzIvLB\nfv5VVf+jKVnVQTN2L7o/XTTjZc8c5r3Guslnlnaa285fHTDjmfR6Q9tXcsbL6PnDqXeHfWxJeOY3\nz3oGhhv3ABQH7U3Xr7Vf0/R5u89eHHLHKn32ifEtdS15zzLbnu2t45czrZlTvu5iV9VzAD7bxFyI\nqIXYeiMKBIudKBAsdqJAsNiJAsFiJwrEthni6msRXTe6aMZ90zmfXxt1xubm7R6SLtnDSJeW7bhv\nqKcV1YTdYorFPPEe+7zGVuzrhRjds5hnGurUu3ZrbW2/bx5rg+9eTk985DX76+5fsHNbutH9tZU8\nw2t901i78MpOFAgWO1EgWOxEgWCxEwWCxU4UCBY7USBY7ESB2DZ9dvX0olMJe4hrttRrxnMldy98\naMi9nDMALF21++h9s/bLUBjzDN+dNvrRnpZsvmxPc53cZX9tqZvt+xeuzrjvQSjtsK812ufpo3u+\ntlive/uqZ2iu796GwfP291PfjD0FW350xBkrDNv3F/iWqnbhlZ0oECx2okCw2IkCwWInCgSLnSgQ\nLHaiQLDYiQKxbfrsULsverVgL6s8lMyb8dmse43fpSW7V53IefrJnn5x5pzddx286O4nl1K+8eb2\nvlP77amm//2WZ834iYPueQBeXv0jc9tsxb734eyqPYV3xZge/Oz79rZVa3puAMkle/4DPf+eGY9N\nuvvsvu+HevHKThQIFjtRIFjsRIFgsRMFgsVOFAgWO1EgWOxEgdg+fXaPd+fcfU0AmBhdNuPLq+4+\nff9bfea2SXvXiBftecITOTu+bowLj5fsYw+/Y48Znx0eM+Mv3rDfjH9tcM4Z++OeX5vb/tvKbWZ8\nV3LVjP/3lQNm3JK8YpdG7HfTZlzj9v0L66PuZnol1cB8+AbvlV1EnhGRORE5vemxERE5JiJnoo/D\nLcmOiJqmll/jnwVw90ceexTAcVU9AOB49H8i6mLeYlfVlwB8dO6hewEciT4/AuC+JudFRE1W7xt0\n46o6E31+GcC464kickhEpkRkqpK15+UiotZp+N14VVUYy+Cp6mFVnVTVyXjGPZiEiFqr3mKfFZEJ\nAIg+ut9yJaKuUG+xHwXwYPT5gwBeaE46RNQq3j67iDwH4C4AYyJyCcC3ATwB4Cci8hCAiwDub2WS\nNSnZg4B7BstmfHpuhxkf/N9+Z2zkTXts8/wt9rjs4pCde2HU7rNXk+55xBNZ++f5gX+8aMaLA3Yf\nffEv7T/NXl6/7Iy9VbjR3rdnTvtC1f72Hehxj8XXRfs1GbBPCypXr5rxxJ7dZrxqLyXQEt5iV9UH\nHKEvNDkXImoh3i5LFAgWO1EgWOxEgWCxEwWCxU4UiG0zxFUK9s+t9TVPr2PZXsJ3x1n3Er1StVtj\nuWvseGXAXoJ3/6dnzHjJmA76/dPOO5kBANWddsvxyp/auZ1c3WPG16vu8/pGdsLcdqnobncCQNUz\n5/LMinu56P7L9vfL4EV7Seb4oHvfAJD7zLV2fMJ9XrXfPueSr+8azSs7USBY7ESBYLETBYLFThQI\nFjtRIFjsRIFgsRMFYtv02RsV9/Qus7uN6ZoLdh/d+yPVs/nl5QEzvv6uOz7yut2LXj5o99m1zx4a\nfOLSPjOeHXcPJd2TWjK3nV4bsvddsIepLr/r3r4nZZ/05IK9hLek7SXA86P2VNKaMc4rl2wmokaw\n2IkCwWInCgSLnSgQLHaiQLDYiQLBYicKxLbps2ufPQa4t99eu7jYb49nz13j7ptWE3ZjtNJvL8Hr\nm+45+V92v3l41r3/WMk+L2vG1wUAkrC3r1Tsr3215F7Oeqlkj1e/fmDBjCeG7PN6ourObTVvL+G9\ncqN9b0Nyl91nLw54muVGbr77LurFKztRIFjsRIFgsRMFgsVOFAgWO1EgWOxEgWCxEwVi2/TZfcQ3\nRnjI7sMXxtzLMlfn7XHVvr7pwHk7Pvy2e+lhAIjn3LlX0vb9A3O3298C6UH72Dfvci/JDADJmHvc\n9vy6vdzzfM5esvlPRu359LNr7h5/xXNfxuJB+/6DvgX7OulbKwAxd7wnbc9ZX163709wHtL3BBF5\nRkTmROT0psceF5FpETkZ/bunrqMTUdvU8mv8swDu3uLx76nqrdG/F5ubFhE1m7fYVfUlAIttyIWI\nWqiRN+geFpHXol/zh11PEpFDIjIlIlOVbLaBwxFRI+ot9u8DuAHArQBmAHzX9URVPayqk6o6Gc/Y\nb8gQUevUVeyqOquqFVWtAvgBgDuamxYRNVtdxS4im9fa/TKA067nElF38PbZReQ5AHcBGBORSwC+\nDeAuEbkVGx3kCwC+3sIca+Ppo+8fs8dG50r2+u25krtffdU+NJKn7LHPyVW755tYdvf4AbuXvrLP\nvgcgffu8Gf/rfb8142+t2eu/70+5z/vNGbtPHhP7vPxi9iYzXsoZ9xhk7LHwBU8fvjhiXycTO+15\n51M97uNXq/a+7Zn83bzFrqoPbPHw03Uej4g6hLfLEgWCxU4UCBY7USBY7ESBYLETBWL7DHG1OyUY\n718142cKO834lcvu6ZxlzR4OmVwxw0jP2EMaq/32y1TOuFtMywfsYz9501Ezfk3cTn65Yg+3vJR3\n3kmNktrXmqonfnHBvW8A6LvobqcWR+xvGN1hD3lGr926K2XtVm7vaM4Zyy26h+YC9a/ozCs7USBY\n7ESBYLETBYLFThQIFjtRIFjsRIFgsRMFYtv02WNpe+BfvmJPqZyI2X1Xybl76ZkLnmmHF+19l1P2\n9pqxX6bigPtndjllT2l8fcKeXnBn3N7+5dnrzfjKunuIrW8oZ9zzmiST9mueHXP3wnXAM1C04Mkt\nY29fKdjd8FLJ/ZpLvjXXYF7ZiQLBYicKBIudKBAsdqJAsNiJAsFiJwoEi50oENumz97j6blaSwcD\nwOzyQN3HFntoM3rynsH2nnC8ZD8hs+Iee50fs8eb/0/+RjN+YmW/GX//jD0PQCN23WBP/+3rw68m\njXsEKnYfXIr2dbCyat+3IUV7//G4O3e18q5h3y68shMFgsVOFAgWO1EgWOxEgWCxEwWCxU4UCBY7\nUSC2TZ/dJ1uyly5eb2Subk/bs5ixf6bG7ENj4NyaGZequy87/I79Ej9x/K/MuMbsnm/qfXssfv5a\n900IPbvsZY0LJTv3vqQ9t3vfqHv/pXfT5rap9+3XLLvPvrnCN14+t+J+0evto/t4r+wisldEfiki\nb4jI6yLySPT4iIgcE5Ez0Ud7xn4i6qhafo0vA/iWqh4E8DkA3xCRgwAeBXBcVQ8AOB79n4i6lLfY\nVXVGVV+NPl8F8CaA3QDuBXAketoRAPe1KkkiatwneoNORK4DcBuA3wAYV9WZKHQZwLhjm0MiMiUi\nU5VstoFUiagRNRe7iGQA/BTAN1X1Q6v9qaoC2PKdHFU9rKqTqjoZz2QaSpaI6ldTsYtIDzYK/Ueq\n+rPo4VkRmYjiEwDmWpMiETWDt/UmIgLgaQBvquqTm0JHATwI4Ino4wstybBG1Yr9c+vC0oi9A8+U\nyZb+ec8YVY/eq3abprTDbhv2Xnb/eZTI2/se/7W97/yo3VpbHzPDGNjrXvI5l7OPLWK/JnFPvHA5\n5Q56hpGWjU1rUrDPW6vaa5Za+uyfB/BVAKdE5GT02GPYKPKfiMhDAC4CuL81KRJRM3iLXVV/Bfdt\nI19objpE1Cq8XZYoECx2okCw2IkCwWInCgSLnSgQ22aIa6Vs/9xaXvE0Tn19diM89zl724Ezds+1\nMJg046L2/pM7d7hjWfsegHjR3nfJc9NjcdQe6lkxeun9qYK977L97ZlM2Mceu969HPXCov2FrRvL\nYAMAVuzcOtFH9+GVnSgQLHaiQLDYiQLBYicKBIudKBAsdqJAsNiJArFt+ux61e5VY8iedjietPvR\n1R3u7TVnn8ZE3u5lV3s8PVnPLQDW9vF1++sq93tyz9nHTi7Y9xDoknvJ6MIBO7fb9lwy4ytFew7u\nK2vuXnoqY/f4swv2fRnd10X345WdKBAsdqJAsNiJAsFiJwoEi50oECx2okCw2IkCsW367D7xHrun\n29tn9+EH+tedsUrV/pk5X7TnrE9fsnvVpbTd1U3NuBvxGrO3jRfsJn7MPi2IFe14frd7zPlgr73z\nE+f2mfHqumdMea/72Ikeeyw8qn+InXQbr+xEgWCxEwWCxU4UCBY7USBY7ESBYLETBYLFThSIWtZn\n3wvghwDGsTGy+rCqPiUijwP4GwBXoqc+pqovtirRRvnmlS+V7F73QsE9NnpoIG9ue8tnL5jx85+y\n+/DFJXtsdTnd44wtf9oz3jzhmy/f7kfH8/Z5TWTdx18/5Z7vHgB0yL43wtcJt76y8pL7nNWyby/1\n7MGztnwr1HJTTRnAt1T1VREZAPCKiByLYt9T1X9oXXpE1Cy1rM8+A2Am+nxVRN4EsLvViRFRc32i\nv9lF5DoAtwH4TfTQwyLymog8IyLDjm0OiciUiExVstmGkiWi+tVc7CKSAfBTAN9U1RUA3wdwA4Bb\nsXHl/+5W26nqYVWdVNXJeMazcBgRtUxNxS4iPdgo9B+p6s8AQFVnVbWiqlUAPwBwR+vSJKJGeYtd\nRATA0wDeVNUnNz0+selpXwZwuvnpEVGz1PJu/OcBfBXAKRE5GT32GIAHRORWbHQ4LgD4eksybBZP\nq6UEO25ZWLSnNJ6P2y0mxDxtGE8Xp5qsv40jw/YY1XjcMxX1Fftrt1Tq37Qm4mkLtvbg7W+t+dTy\nbvyvsPW3W9f21Ino43gHHVEgWOxEgWCxEwWCxU4UCBY7USBY7ESBCGYq6ZbyzEosFd+Ayc5NW6yL\n9lLX5TblQa3HKztRIFjsRIFgsRMFgsVOFAgWO1EgWOxEgWCxEwVCVNs37lZErgC4uOmhMQDzbUvg\nk+nW3Lo1L4C51auZue1T1Z1bBdpa7B87uMiUqk52LAFDt+bWrXkBzK1e7cqNv8YTBYLFThSIThf7\n4Q4f39KtuXVrXgBzq1dbcuvo3+xE1D6dvrITUZuw2IkC0ZFiF5G7ReRtETkrIo92IgcXEbkgIqdE\n5KSITHU4l2dEZE5ETm96bEREjonImejjlmvsdSi3x0VkOjp3J0Xkng7ltldEfikib4jI6yLySPR4\nR8+dkVdbzlvb/2YXkTiAdwD8BYBLAE4AeEBV32hrIg4icgHApKp2/AYMEfkzAFkAP1TVz0SPfQfA\noqo+Ef2gHFbVv+uS3B4HkO30Mt7RakUTm5cZB3AfgK+hg+fOyOt+tOG8deLKfgeAs6p6TlWLAH4M\n4N4O5NH1VPUlAIsfefheAEeiz49g45ul7Ry5dQVVnVHVV6PPVwF8sMx4R8+dkVdbdKLYdwN4b9P/\nL6G71ntXAD8XkVdE5FCnk9nCuKrORJ9fBjDeyWS24F3Gu50+ssx415y7epY/bxTfoPu4O1X1dgBf\nAvCN6NfVrqQbf4N1U++0pmW822WLZcZ/r5Pnrt7lzxvViWKfBrB30//3RI91BVWdjj7OAXge3bcU\n9ewHK+hGH+c6nM/vddMy3lstM44uOHedXP68E8V+AsABEdkvIkkAXwFwtAN5fIyIpKM3TiAiaQBf\nRPctRX0UwIPR5w8CeKGDuXxItyzj7VpmHB0+dx1f/lxV2/4PwD3YeEf+dwD+vhM5OPK6HsBvo3+v\ndzo3AM9h49e6Ejbe23gIwCiA4wDOAPgFgJEuyu1fAJwC8Bo2CmuiQ7ndiY1f0V8DcDL6d0+nz52R\nV1vOG2+XJQoE36AjCgSLnSgQLHaiQLDYiQLBYicKBIudKBAsdqJA/B91ElPYo2ktdgAAAABJRU5E\nrkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6msW8VO6AM2v", + "colab_type": "text" + }, + "source": [ + "## Quantization\n", + "\n", + "Now that we have a pruned model that appears to be performing well, let's see if we can make it even smaller by quantization. To do this, we'll need a slightly different neural network, one that corresponds to Figure 3 from the paper. Instead of having a matrix of float values, we'll have a matrix of integer labels (here called \"labels\") that correspond to entries in a (hopefully) small codebook of centroids (here called \"centroids\"). The way that I've coded it, there's still a mask that enforces our desired sparsity pattern." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bECtREWOAM2z", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 231 + }, + "outputId": "de24b1ae-88d0-4956-c4c7-85c6b6d5d42b" + }, + "source": [ + "class MultilayerPerceptronQuantized(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim, hidden_dim,mask,labels,centroids):\n", + " super(MultilayerPerceptronQuantized, self).__init__()\n", + " self.mask = torch.nn.Parameter(mask,requires_grad=False)\n", + " self.labels = torch.nn.Parameter(labels,requires_grad=False)\n", + " self.centroids = torch.nn.Parameter(centroids,requires_grad=True)\n", + "\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim))\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim))\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim))\n", + "\n", + " def forward(self, x):\n", + " W_0 = self.mask*self.centroids[self.labels]\n", + " hidden = torch.tanh(x@W_0 + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mMultilayerPerceptronQuantized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_dim\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcentroids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mMultilayerPerceptronQuantized\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mParameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mParameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ydx5nUSxAM3Z", + "colab_type": "text" + }, + "source": [ + "Notice what is happening in the forward method: W_0 is being reconstructed by using a matrix (self.labels) to index into a vector (self.centroids). The beauty of automatic differentiation allows backpropogation through this sort of weird indexing operation, and thus gives us gradients of the objective function with respect to the centroid values!\n", + "\n", + "### Q6: However, before we are able to use this AD magic, we need to specify the static label matrix (and an initial guess for centroids). Use the k-means algorithm (or something else if you prefer) figure out the label matrix and centroid vectors. PROTIP1: I used scikit-learns implementation of k-means. PROTIP2: only cluster the non-zero entries" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CSyaxK6gAM3i", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 231 + }, + "outputId": "093bbe7f-29b8-41f1-d30a-74c1229e0fea" + }, + "source": [ + "from sklearn.cluster import KMeans\n", + "# convert weight and mask matrices into numpy arrays\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "mask = model.mask.detach().cpu().numpy()\n", + "# Figure out the indices of non-zero entries \n", + "inds = np.where(mask!=0)\n", + "# Figure out the values of non-zero entries\n", + "vals = W_0[inds]\n", + "### TODO: perform clustering on vals\n", + "vals = vals.reshape(-1, 1)\n", + "kmeans = KMeans(n_clusters = 20, random_state=0).fit(vals)\n", + "### TODO: turn the label matrix and centroids into a torch tensor\n", + "labels = torch.tensor(kmeans.predict(W_0.reshape(-1, 1)),dtype=torch.long,device=device)\n", + "centroids = torch.tensor(kmeans.cluster_centers_,device=device)\n", + "labels = labels.reshape(-1, 64, 1)\n", + "print(labels.shape)\n", + "print(centroids.shape)" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcluster\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mKMeans\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# convert weight and mask matrices into numpy arrays\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mW_0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mW_0\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mmask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m# Figure out the indices of non-zero entries\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RQDN_xh2AM37", + "colab_type": "text" + }, + "source": [ + "Now, we can instantiate our quantized model and import the appropriate pre-trained weights for the other network layers. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "6VBiUQwgAM4C", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Instantiate quantized model\n", + "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,new_mask,labels,centroids)\n", + "model_q = model_q.to(device)\n", + "\n", + "# Copy pre-trained weights from unquantized model for non-quantized layers\n", + "model_q.b_0.data = model.b_0.data\n", + "model_q.W_1.data = model.W_1.data\n", + "model_q.b_1.data = model.b_1.data" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GteClYZgAM4l", + "colab_type": "text" + }, + "source": [ + "Finally, we can fine tune the quantized model. We'll adjust not only the centroids, but also the weights in the other layers." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dEuiZ990AM4s", + "colab_type": "code", + "colab": {} + }, + "source": [ + "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", + "iter = 0\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model_q(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model_q(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_quantized.h5')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k2OQuQDOAM5S", + "colab_type": "text" + }, + "source": [ + "After retraining, we can, just for fun, reconstruct the pruned and quantized weights and plot them as images:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FPOAp6XvAM5Z", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 197 + }, + "outputId": "de2d5269-5b54-40a5-e4ef-fb7c13acd74e" + }, + "source": [ + "W_0 = (model_q.mask*model_q.centroids[model_q.labels]).detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 288, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mW_0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmodel_q\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mmodel_q\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcentroids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmodel_q\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mW_0\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m28\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (784) must match the size of tensor b (64) at non-singleton dimension 1" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R214KN6iAM6E", + "colab_type": "text" + }, + "source": [ + "Certainly a much more parsimonious representation. The obvious question now becomes:\n", + "\n", + "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", + "\n", + "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "jDlhi-dvAM6N", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file