From 8c8c5a405ba940edcbde9f244c2c20e28d35c371 Mon Sep 17 00:00:00 2001 From: tmayer868 Date: Wed, 1 Apr 2020 12:48:23 -0600 Subject: [PATCH 1/3] Add files via upload --- deep_compression_exercise .ipynb | 1825 ++++++++++++++++++++++++++++++ 1 file changed, 1825 insertions(+) create mode 100644 deep_compression_exercise .ipynb diff --git a/deep_compression_exercise .ipynb b/deep_compression_exercise .ipynb new file mode 100644 index 0000000..81451f0 --- /dev/null +++ b/deep_compression_exercise .ipynb @@ -0,0 +1,1825 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "colab": { + "name": "deep_compression_exercise.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "be09c445b75f468e8dcfa18358ce4a36": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_5aac198c9a3d4eea9be7fb8beeb227ce", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_03e6a001de424248a4122d33143e75f0", + "IPY_MODEL_c005a5ba50db480caba161765274b3f9" + ] + } + }, + "5aac198c9a3d4eea9be7fb8beeb227ce": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "03e6a001de424248a4122d33143e75f0": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_60e876032db7469e846f94c3ac3d5942", + "_dom_classes": [], + "description": "", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_00e93303f3b447a1bf56f08dfde7de4d" + } + }, + "c005a5ba50db480caba161765274b3f9": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_4403339d5717448687dcb5155c9e1239", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 9920512/? [00:20<00:00, 1487669.97it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_ac8b2caa8abb45189966110e612433c1" + } + }, + "60e876032db7469e846f94c3ac3d5942": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "00e93303f3b447a1bf56f08dfde7de4d": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "4403339d5717448687dcb5155c9e1239": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "ac8b2caa8abb45189966110e612433c1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "fc589a1423bd406581f252ecf8ea97bd": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_e7f77c796bc14d9ab6f97137424b925f", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_0e847e7b236c41b79dc518091bf7067f", + "IPY_MODEL_3e2bf6a25441443c92dc5c3c2d1fef22" + ] + } + }, + "e7f77c796bc14d9ab6f97137424b925f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "0e847e7b236c41b79dc518091bf7067f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_ac1223dd567b41589efcc1044c8fa962", + "_dom_classes": [], + "description": " 0%", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 0, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_092573b76819403ca18edea829d1bb31" + } + }, + "3e2bf6a25441443c92dc5c3c2d1fef22": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_7a47e9e824d345fa916984efef7b2f3d", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 0/28881 [00:00<?, ?it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_08c5f10c348649a9a927c221cf2aa965" + } + }, + "ac1223dd567b41589efcc1044c8fa962": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "092573b76819403ca18edea829d1bb31": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "7a47e9e824d345fa916984efef7b2f3d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "08c5f10c348649a9a927c221cf2aa965": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "5315c5e02993418a9d70d8568a24574c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_7d973a7b57544afcb7deeed203361142", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_6ebfc1cfcbca4b41aa3214b8ab36e790", + "IPY_MODEL_fc64f32f6bf741ffb46025054188bff0" + ] + } + }, + "7d973a7b57544afcb7deeed203361142": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "6ebfc1cfcbca4b41aa3214b8ab36e790": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_9a43b17aa8bc46508b7fa9fabaf1bf32", + "_dom_classes": [], + "description": "", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_438ac4bcdab249bcbbaccbe207981ffd" + } + }, + "fc64f32f6bf741ffb46025054188bff0": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_279398b30dc6479b99e352bbc4651b76", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 1654784/? [00:18<00:00, 511954.24it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_7c8a5703523f4c0f80b8fee415f3fa83" + } + }, + "9a43b17aa8bc46508b7fa9fabaf1bf32": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "438ac4bcdab249bcbbaccbe207981ffd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "279398b30dc6479b99e352bbc4651b76": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "7c8a5703523f4c0f80b8fee415f3fa83": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "e4c516162b844d6fbd33cf29584aae20": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_6779960e36614cef9c512a615a2980bf", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_9b09958b45a441d29688f6eeba9707d3", + "IPY_MODEL_62a69547013748de9d80c4a992748698" + ] + } + }, + "6779960e36614cef9c512a615a2980bf": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "9b09958b45a441d29688f6eeba9707d3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_f49a8270a4364784b39bc0eca205a8eb", + "_dom_classes": [], + "description": " 0%", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 0, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_bca6c5f7f93942e98c6ef7412e311dfb" + } + }, + "62a69547013748de9d80c4a992748698": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_5a1aef0316fd4ebbbd436bd2ecf2383c", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 0/4542 [00:00<?, ?it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_00f7f24316dc48b1adeb59b3d0dadc4b" + } + }, + "f49a8270a4364784b39bc0eca205a8eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "bca6c5f7f93942e98c6ef7412e311dfb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "5a1aef0316fd4ebbbd436bd2ecf2383c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "00f7f24316dc48b1adeb59b3d0dadc4b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "qe_xWm7z9YM6", + "colab_type": "text" + }, + "source": [ + "# Exercise Week 9: Pruning and Quantization\n", + "This week, we will explore some of the ideas discussed in Han, Mao, and Dally's Deep Compression. In particular, we will implement weight pruning with fine tuning, as well as k-means weight quantization. **Note that we will unfortunately not be doing this in a way that will actually lead to substantial efficiency gains: that would involve the use of sparse matrices which are not currently well-supported in pytorch.** \n", + "\n", + "## Training an MNIST classifier\n", + "For this example, we'll work with a basic multilayer perceptron with a single hidden layer. We will train it on the MNIST dataset so that it can classify handwritten digits. As usual we load the data:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Cysg2vAY9YNQ", + "colab_type": "code", + "outputId": "cbea6932-d4ff-497e-9ff3-39fa462cdcba", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 313, + "referenced_widgets": [ + "be09c445b75f468e8dcfa18358ce4a36", + "5aac198c9a3d4eea9be7fb8beeb227ce", + "03e6a001de424248a4122d33143e75f0", + "c005a5ba50db480caba161765274b3f9", + "60e876032db7469e846f94c3ac3d5942", + "00e93303f3b447a1bf56f08dfde7de4d", + "4403339d5717448687dcb5155c9e1239", + "ac8b2caa8abb45189966110e612433c1", + "fc589a1423bd406581f252ecf8ea97bd", + "e7f77c796bc14d9ab6f97137424b925f", + "0e847e7b236c41b79dc518091bf7067f", + "3e2bf6a25441443c92dc5c3c2d1fef22", + "ac1223dd567b41589efcc1044c8fa962", + "092573b76819403ca18edea829d1bb31", + "7a47e9e824d345fa916984efef7b2f3d", + "08c5f10c348649a9a927c221cf2aa965", + "5315c5e02993418a9d70d8568a24574c", + "7d973a7b57544afcb7deeed203361142", + "6ebfc1cfcbca4b41aa3214b8ab36e790", + "fc64f32f6bf741ffb46025054188bff0", + "9a43b17aa8bc46508b7fa9fabaf1bf32", + "438ac4bcdab249bcbbaccbe207981ffd", + "279398b30dc6479b99e352bbc4651b76", + "7c8a5703523f4c0f80b8fee415f3fa83", + "e4c516162b844d6fbd33cf29584aae20", + "6779960e36614cef9c512a615a2980bf", + "9b09958b45a441d29688f6eeba9707d3", + "62a69547013748de9d80c4a992748698", + "f49a8270a4364784b39bc0eca205a8eb", + "bca6c5f7f93942e98c6ef7412e311dfb", + "5a1aef0316fd4ebbbd436bd2ecf2383c", + "00f7f24316dc48b1adeb59b3d0dadc4b" + ] + } + }, + "source": [ + "import torch\n", + "import torchvision.transforms as transforms\n", + "import torchvision.datasets as datasets\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", + "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n", + "\n", + "batch_size = 300\n", + "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be09c445b75f468e8dcfa18358ce4a36", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fc589a1423bd406581f252ecf8ea97bd", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5315c5e02993418a9d70d8568a24574c", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e4c516162b844d6fbd33cf29584aae20", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "Processing...\n", + "Done!\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8kdx9cAu9YOU", + "colab_type": "text" + }, + "source": [ + "Then define a model:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "UInyoax99YOf", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MultilayerPerceptron(torch.nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim,mask=None):\n", + " super(MultilayerPerceptron, self).__init__()\n", + " if not mask:\n", + " self.mask = torch.nn.Parameter(torch.ones(input_dim,hidden_dim),requires_grad=False)\n", + " else:\n", + " self.mask = torch.nn.Parameter(mask)\n", + "\n", + " self.W_0 = torch.nn.Parameter(1e-3*torch.randn(input_dim,hidden_dim)*self.mask,requires_grad=True)\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim),requires_grad=True)\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", + " \n", + " def set_mask(self,mask):\n", + " \n", + " self.mask.data = mask.data\n", + " self.W_0.data = self.mask.data*self.W_0.data\n", + "\n", + " def forward(self, x):\n", + " hidden = torch.tanh(x@(self.W_0*self.mask) + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GAM40Rnb9YPH", + "colab_type": "text" + }, + "source": [ + "Note that the above code is a little bit different than a standard multilayer perceptron implementation.\n", + "\n", + "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", + "\n", + "This model can \"mask\" parameters that we want to ignore, by\n", + "multiplying that paramter by 0. This has the same affect\n", + "as prunning the parameters even though we still technically\n", + "do multiplications and additions with those zeroed out parameters.\n", + "\n", + "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "IUZhjRk79YPV", + "colab_type": "code", + "colab": {} + }, + "source": [ + "n_epochs = 10\n", + "\n", + "input_dim = 784\n", + "hidden_dim = 64\n", + "output_dim = 10\n", + "\n", + "model = MultilayerPerceptron(input_dim,hidden_dim,output_dim)\n", + "model = model.to(device)\n", + "\n", + "criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy\n", + "lr_rate = 0.001\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-3)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "44vk9p1D9YPv", + "colab_type": "text" + }, + "source": [ + "And then training proceeds as normal." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tinBThTV9YP4", + "colab_type": "code", + "outputId": "e857ad65-f505-42c4-bd7e-6ce69c6d56fb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 185 + } + }, + "source": [ + "iter = 10\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_pretrained.h5')\n" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.4008619785308838. Accuracy: 90.\n", + "Iteration: 1. Loss: 0.331943154335022. Accuracy: 92.\n", + "Iteration: 2. Loss: 0.2710331082344055. Accuracy: 93.\n", + "Iteration: 3. Loss: 0.21628224849700928. Accuracy: 93.\n", + "Iteration: 4. Loss: 0.18730275332927704. Accuracy: 94.\n", + "Iteration: 5. Loss: 0.18389032781124115. Accuracy: 95.\n", + "Iteration: 6. Loss: 0.1158808246254921. Accuracy: 95.\n", + "Iteration: 7. Loss: 0.13385257124900818. Accuracy: 95.\n", + "Iteration: 8. Loss: 0.18437758088111877. Accuracy: 95.\n", + "Iteration: 9. Loss: 0.1472378671169281. Accuracy: 95.\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jka2xOtp9YQO", + "colab_type": "text" + }, + "source": [ + "## Pruning\n", + "\n", + "Certainly not a state of the art model, but also not a terrible one. Because we're hoping to do some weight pruning, let's inspect some of the weights directly (recall that we can act like they're images)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "IZVDikFr9YQU", + "colab_type": "code", + "outputId": "830c517c-54b2-4f0a-fad6-d014c12f0242", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + } + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAVv0lEQVR4nO3dXWyk1XkH8P8z4/H4e73eD7MsCwub\nTQui7UIs0iQ0Io2IgF6QXBSFi5RKqJuLICVRLhrRi3CJqnyIiyrSpqCQKiWKlERwQdoAQkGpWoSX\nbvhayC4U1rvxx3rN7nr8MZ6PpxcekAGf/zHz4Rlz/j/Jsj2Pz/ueeWeeecfzvOccc3eIyEdfpt0d\nEJHNoWQXSYSSXSQRSnaRRCjZRRLRtZk7y/b3e9fIyGbuUiQp5bk5VBYWbL1YQ8luZrcAeABAFsC/\nuvv97O+7RkZw2de/2cguRYQ4/cAPgrG638abWRbAvwC4FcA1AO40s2vq3Z6ItFYj/7PfAOCku7/h\n7isAfgbg9uZ0S0SarZFk3wtgYs3vp2u3vYeZHTazcTMbrxYWGtidiDSi5Z/Gu/sRdx9z97HMQH+r\ndyciAY0k+xkA+9b8flntNhHpQI0k+3MADprZlWbWDeDLAB5rTrdEpNnqLr25e9nM7gHwn1gtvT3k\n7i83rWdbSWTgYKYUaR57FCLbt3I4ll1Zt+T6rkqeb7zcz+PZZb79DNl/Nce3bbH7XeHxSp415m0/\nihqqs7v74wAeb1JfRKSFdLmsSCKU7CKJULKLJELJLpIIJbtIIpTsIonY1PHsbVWNxBt42YvVgzOl\nSFF3hYdzBR7PLpO2i5HORcIr2/iB8Wxk86R5NcePi8dq4ZG4kcfcM/yOx+5XNcfjlX7+hDPynGDX\nJjRCZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEpFO6S3yshYrxbChnFbhpRJWGgOAriUezxV437qW\nw/Ftx+dpWzv+Oo1ndu6gcWT5gS2PDgdjhf185qKlEX5cS4OREhU5bKVtvGk1UnpjZT0AQLnzxtDq\nzC6SCCW7SCKU7CKJULKLJELJLpIIJbtIIpTsIonYWnV2Vm6OlVyzjQ1pLG0LF1a9i2+7GOlbdp7v\nvHeab6B/Mrz/8jCbTxnIR+roXuDjayvnL9B4drkYjHWN7qdtV4Ybe3qWBsPHJTaFtkemuUbk2grv\njgxxXQqfZ2NTZMeeqyE6s4skQskukgglu0gilOwiiVCyiyRCyS6SCCW7SCK2Vp2diNXRY1MmV3sj\nxU1SS8/187mgPzY6S+On3t5O44X+yLjv68P7nynx1/Puty6n8Wo3DSN3IbJkMzms5R6+7WqkFh67\n/oBNyRwbC78yzOvk1T7+fMks8mI4Gw9fbVFWNrRZM3sTwDyACoCyu481o1Mi0nzNeA35nLvzU5eI\ntJ3+ZxdJRKPJ7gB+Y2ZHzezwen9gZofNbNzMxquFhQZ3JyL1avRt/I3ufsbMdgN4wsxedfdn1v6B\nux8BcAQA8vv2RT4mE5FWaejM7u5nat9nAPwKwA3N6JSINF/dyW5m/WY2+M7PAL4A4KVmdUxEmquR\nt/GjAH5lZu9s59/d/T+a0qs6ZCKDxmM1W0Tq9NnecjgWaTtxPjx3OgD87YH/pfFnt++n8QxZM/qv\ndpykbT95I583/lM94fHoAPDfy3y8/G8LfxqMVSOTEDw5+Sc0/sdJfn0CiuFat5HHEwBwka/JbGV+\nnmTrDKxuIByKjbWPjXcPqTvZ3f0NAH9Rb3sR2VwqvYkkQskukgglu0gilOwiiVCyiyRiaw1xZeWK\n3khpLTadcx8vxVyxey4YO7/Ex2ru6ueXCb9WGKXxwgovb12/cyIYy0XqNNfled/y1kvjF6v8vrP9\nH7t4GW27Lc/Xuu67fIbGS5Vw6W2+yMfuLvbxY16c6aPxGCen2dhy0FatbzlondlFEqFkF0mEkl0k\nEUp2kUQo2UUSoWQXSYSSXSQRW6rO7plwLT06lfRwiYYvGblI41cNngvGlvr4cMiLJV6LPja5l8Z7\n83yq6qnloWDsf6b307a/PH2IxvNd/PqDtxd5Hf78RHh4r5GpngHgkqt5HX1unk+xvTwb7ltmkD8f\nunKRcaSRsEVG0Dq761ZfHT1GZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0lEZ9XZI+N4Wf3R\nu3njbBeP5zI8fvLizmDszCyfKtrI9QEAMNDHx21nI+2fffWq8L4jSwfnz/F47xTf98owrwn3NjAH\nwfkFXsM3MoU2APSPhsfql0r8fnvkso3cPD9PDp/g7ef3hQ9MeSAylTR5qrJDojO7SCKU7CKJULKL\nJELJLpIIJbtIIpTsIolQsoskoqPq7JGyKa/Dx2r0EYslPiZ9aSUcr0SW783l+eDmvUN8LP0fpnfR\nePdUuG+9U7wOvvsonzc+c/RVGi9+7s9p/PxV4b6V6KBuYLnIH5MrR8NzDADAfDE893sBfF74pTcH\naXzXcf5kzS3yeDXPJo6nTVElh4Ud0uiZ3cweMrMZM3tpzW0jZvaEmZ2ofY8slC0i7baRt/E/BnDL\n+277NoCn3P0ggKdqv4tIB4smu7s/A+D9ax/dDuDh2s8PA/hik/slIk1W7wd0o+4+Wft5CkBwsTIz\nO2xm42Y2Xi3w/w9FpHUa/jTe3R1A8NMIdz/i7mPuPpYZ4BMEikjr1Jvs02a2BwBq3/k0oCLSdvUm\n+2MA7qr9fBeAR5vTHRFplWid3cweAXATgJ1mdhrAdwDcD+DnZnY3gLcA3NGMzjgfYkz+WUD0ZWto\nYInGF2PrdZ8ZCO+6yAuj134yvH46AFy/jccHuoo0/mxxfzBWyPN68qVPLtK47eFrx5+7htfCq+Sw\nLu/kF0dUCnzbJ4u7adwK4ad3/wR/sg3xw4Lh1+ZpvNLHU8uq4bH6bH2ERkST3d3vDIQ+3+S+iEgL\n6XJZkUQo2UUSoWQXSYSSXSQRSnaRRHTUENcYZ6MCK7z8tbjMS2sr0300PngyXKop7uClkt09BRp/\n+uzHaXwpMvx27+7zwdjEAh8eu3QFnwY7sxIpj/HKHpZHSfvYsOTIY4oyL5/1nwrHe2b5Y9b7Nl+T\nObPIl9FeuoRPg10aYPNB06Z8uWdCZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0nElqqzU5Ga\nbXGeF4S7lvjrXpmU4VdG+M6fmTjAtx2Zinqwjw9x3Tf0djA2kQkvNQ0Ak5/mT4GBUzSMLO8aVR3m\nU2zH6uz9r0em/94drqV389m7kZvndfaYxV38GoDKYOS+M3S+6PB91pldJBFKdpFEKNlFEqFkF0mE\nkl0kEUp2kUQo2UUSsaXq7HRJ5yqvyVoXr4V3v83bl4bIziPjixcn+PK/XQt8AxcQnsYaAOb6RsLB\nbn6/Y2Oji8P8D5YuiWx/JDzue9cOPh3z3Cv8GoElNlYeQO5i+FyWK/C2S7siU2Tnhmi8uxCZDpo8\nX2NzM1iZxMkDqjO7SCKU7CKJULKLJELJLpIIJbtIIpTsIolQsoskYkvV2anIMreZs3ze+NwC33xp\nG2l7nr9mVviu0bXE66qxWnhpe3jstfXycdmlyPUJ1W5+32676SiNv1EI18qvHpqibR99kdfZyz38\nMa/mw/GLV/L7lSnRMLZF5k8o5/lxzV0Ij3cvDUfmrGdzL5BDEj2zm9lDZjZjZi+tue0+MztjZsdq\nX7fFtiMi7bWRt/E/BnDLOrf/wN0P1b4eb263RKTZosnu7s8AmNuEvohICzXyAd09ZvZC7W3+9tAf\nmdlhMxs3s/FqIfKPsYi0TL3J/kMABwAcAjAJ4HuhP3T3I+4+5u5jmYH+OncnIo2qK9ndfdrdK+5e\nBfAjADc0t1si0mx1JbuZ7Vnz65cAvBT6WxHpDNE6u5k9AuAmADvN7DSA7wC4ycwOYbWq9yaAr7aw\nj+9i9ebYGOBskcfz5yM121y4/fwBPgd4pshfUz3D+9Y3FRkbTR7GA389QVuO5PnnKP/1Op/z/rmz\nl9P4aF94bfoMnaAA6L86PB8+AJyf4/8WVsl8/pUuXsuunODbLg7yx6zSy+PlQbL/LD8unqtvboVo\nsrv7nevc/GCsnYh0Fl0uK5IIJbtIIpTsIolQsoskQskukoiPzhDXWHUqotzD410L4R10FfjyvFVW\nKgHQfZ7ve+ityPK+Hn4YT0zvok2/+Wcv0PjvB/bS+NRk8Erp1Xg1HH+5b08wBgD5Hj7O1CLDmrty\n4eNWjQzt9ciw5HIfb1/khwVOhh7bfCQt2fDaRoa4ishHg5JdJBFKdpFEKNlFEqFkF0mEkl0kEUp2\nkURsqTo7GxHpkTp7bDrn+at4PLtIlsIt8p0PnOKvqSPHizSeP8WHemZWwkXdSg9f7vm7lZtp3Kf4\nBQhDE/y+VcnKxyuDfFnkxe15Gr/0Y2dpfHdfeEno1+f4NNXzfXyu6MIVkSm4+V2DLZBrMyJDf8ll\nFXSIq87sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiM6qs0eWwc2uhGOV7ZHaJB9yjmqktgkP\nvy52X4gsz1uIjGc/y6dz9jN8aeOecnhs9A7wenL3hV4aZ3VbAOib5mPOPRvewMIl/EFZXObxs7v4\nNQTFcvjpvVJqYMw4gAqbChqIrrNtJXLdBokBkWnTNZ5dRJTsIolQsoskQskukgglu0gilOwiiVCy\niySio+rsGV6yRZX1tsF543MX+etephyubWb5cPRYyRXLe3i9uOf/IhcJWHgH3XNLtOnwMp+Tvmvm\nIt/3ucik9zvDY+37TvFJBmY/wSdfP3eQDxrftXs23HaOH3PvijyhYvHIE9KWNz/1omd2M9tnZk+b\n2Stm9rKZfb12+4iZPWFmJ2rfI9Pii0g7beRtfBnAt9z9GgB/CeBrZnYNgG8DeMrdDwJ4qva7iHSo\naLK7+6S7P1/7eR7AcQB7AdwO4OHanz0M4Iut6qSINO5DfUBnZvsBXAfgWQCj7j5ZC00BGA20OWxm\n42Y2Xi3wa8BFpHU2nOxmNgDgFwC+4e7v+dTG3R2BTyTc/Yi7j7n7WGagv6HOikj9NpTsZpbDaqL/\n1N1/Wbt52sz21OJ7AMy0posi0gzRz//NzAA8COC4u39/TegxAHcBuL/2/dFGO1PlMwejkg+XM7KL\nkdJZZERi3ySvj/XNkA2Q0hcALO7kfStcyktI3Vfvp3EUw+Wzhf28xNR9npferLDI9x15t1bcOxyM\nLezhpbf5K/muD+6bpvEMGbac7+V13mo+UpLs4k+opSl+3LPL4edMbInvem2k2PcZAF8B8KKZHavd\ndi9Wk/znZnY3gLcA3NGSHopIU0ST3d1/h/AUBp9vbndEpFV0uaxIIpTsIolQsoskQskukgglu0gi\nOmqIa0MiL1uVSO2STXkMAMWh8A4W9vK2lV6+7/w53n7q04M03n0xvH2LXF+QnyXzcwOYvZmvZV24\nLHKNwYHw9gd28OGzn9ozQeMH+/h1XLlMuFZ+eT9fBvu56ctpfH6RXxTiXXwuas+En09smmkAdZ+i\ndWYXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEbKk6e7YYqT8S1W5e6164lMcrA+G6aW4Xn665\n9HYPjcfWky4PRdYPHgqPze46zevBs5/g8Y9fe4rGPzF4jsbv3vlMMDYYmTs8F5mO+bdL/BqAt4rh\n5ar/uLiNtr0wz5eyLhd56lg5ch5ld63+pzmlM7tIIpTsIolQsoskQskukgglu0gilOwiiVCyiyRi\nS9XZ2dLHZIrw1bbdvFZd7ucDv3N94XHZ2wZ4nX02UpPFHK+zew/v2xWXzAVjxV1823+z92Ua/7vh\ncRo/UeL16vlq+BqDB2c/S9u+dmE3jc8t9NH48kp4Pv7iIp+r35f4Y5ZZiqxTEBuT3qJaOqMzu0gi\nlOwiiVCyiyRCyS6SCCW7SCKU7CKJULKLJGIj67PvA/ATAKNYHYV7xN0fMLP7APwDgLO1P73X3R9v\nVUeBeC2dYkV6AKjyjZcuhMd9nyvxWnYustZ3dT+/BuDWg6/S+Md6w/OnZ4xv+1APH6/+3PKlNP70\nhatp/OjsZcHYzOwQbVuN1LptmZ+rnKwVYCu8bZYPtY9NQdCRNnJRTRnAt9z9eTMbBHDUzJ6oxX7g\n7t9tXfdEpFk2sj77JIDJ2s/zZnYcwN5Wd0xEmutD/c9uZvsBXAfg2dpN95jZC2b2kJltD7Q5bGbj\nZjZeLSw01FkRqd+Gk93MBgD8AsA33P0igB8COADgEFbP/N9br527H3H3MXcfywz0N6HLIlKPDSW7\nmeWwmug/dfdfAoC7T7t7xd2rAH4E4IbWdVNEGhVNdjMzAA8COO7u319z+541f/YlAC81v3si0iwb\n+TT+MwC+AuBFMztWu+1eAHea2SGsluPeBPDVlvSwSbIF/rpm1cgwU1K58yXethRbLjoy/PbXL15L\n4yiS+xZb/XeA15i8EtnAPB8q2giL7DtaiiUVz2jbSKk2Mgt2R9rIp/G/w/pPmZbW1EWkuXQFnUgi\nlOwiiVCyiyRCyS6SCCW7SCKU7CKJ2FJTSTfCqo3N3cvqsrYS2XYsvhC7BoA3b2T132rk+oPYMyQ2\ncjgTu+9MdH7wSC2cLfEd6Vbsfm1FOrOLJELJLpIIJbtIIpTsIolQsoskQskukgglu0gizL2R+Zk/\n5M7MzgJ4a81NOwHMbloHPpxO7Vun9gtQ3+rVzL5d4e671gtsarJ/YOdm4+4+1rYOEJ3at07tF6C+\n1Wuz+qa38SKJULKLJKLdyX6kzftnOrVvndovQH2r16b0ra3/s4vI5mn3mV1ENomSXSQRbUl2M7vF\nzF4zs5Nm9u129CHEzN40sxfN7JiZjbe5Lw+Z2YyZvbTmthEze8LMTtS+r7vGXpv6dp+Znakdu2Nm\ndlub+rbPzJ42s1fM7GUz+3rt9rYeO9KvTTlum/4/u5llAfwBwM0ATgN4DsCd7v7KpnYkwMzeBDDm\n7m2/AMPMPgugAOAn7n5t7bZ/BjDn7vfXXii3u/s/dkjf7gNQaPcy3rXVivasXWYcwBcB/D3aeOxI\nv+7AJhy3dpzZbwBw0t3fcPcVAD8DcHsb+tHx3P0ZAHPvu/l2AA/Xfn4Yq0+WTRfoW0dw90l3f772\n8zyAd5YZb+uxI/3aFO1I9r0AJtb8fhqdtd67A/iNmR01s8Pt7sw6Rt19svbzFIDRdnZmHdFlvDfT\n+5YZ75hjV8/y543SB3QfdKO7Xw/gVgBfq71d7Ui++j9YJ9VON7SM92ZZZ5nxd7Xz2NW7/Hmj2pHs\nZwDsW/P7ZbXbOoK7n6l9nwHwK3TeUtTT76ygW/s+0+b+vKuTlvFeb5lxdMCxa+fy5+1I9ucAHDSz\nK82sG8CXATzWhn58gJn11z44gZn1A/gCOm8p6scA3FX7+S4Aj7axL+/RKct4h5YZR5uPXduXP3f3\nTf8CcBtWP5F/HcA/taMPgX5dBeD3ta+X2903AI9g9W1dCaufbdwNYAeApwCcAPAkgJEO6tu/AXgR\nwAtYTaw9berbjVh9i/4CgGO1r9vafexIvzbluOlyWZFE6AM6kUQo2UUSoWQXSYSSXSQRSnaRRCjZ\nRRKhZBdJxP8DiQuUv16KLFsAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jtGIcba89YQm", + "colab_type": "text" + }, + "source": [ + "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", + "\n", + "The weights that are very purple. \n", + "\n", + "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4uh6-dKQ9YQs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "new_mask = model.mask\n", + "ratio_threshold = .2\n", + "#drop the values that are less than ratio_threshold the size of the max weight in absolute vaue.\n", + "mask = torch.abs(model.W_0.data)/torch.abs(torch.max(model.W_0.data)) > ratio_threshold\n", + "new_mask.data = (new_mask.data)*(mask.float())\n", + "model.set_mask(new_mask)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eGuQ3t-E9YRB", + "colab_type": "text" + }, + "source": [ + "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OocvWnxK9YRW", + "colab_type": "text" + }, + "source": [ + "Now, we have explicitly set some entries in one of the the weight matrices to zero, and ensured via the mask, that they will not be updated by gradient descent. Fine tune the model: " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WgdWUpuD9YRa", + "colab_type": "code", + "outputId": "a4ca20e0-b690-4489-d6f7-af33980dbef6", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 101 + } + }, + "source": [ + "iter = 0\n", + "n_epochs = 5\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_pruned.h5')" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.1554800122976303. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.16709843277931213. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.1133299395442009. Accuracy: 96.\n", + "Iteration: 0. Loss: 0.14041125774383545. Accuracy: 96.\n", + "Iteration: 0. Loss: 0.12037742882966995. Accuracy: 96.\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vHGmJW2P9YRs", + "colab_type": "text" + }, + "source": [ + "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", + "\n", + "Not much I pruned the bottom weights less than 20% of the max weight in absoloute value and the accuracy degraded a couple percent but quickly recovered after\n", + "degrade at all. droppping 60 percent of weights dropped the accuracy\n", + "to 70% but it recovered into the 80s.\n", + "\n", + "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "QrhV1kL99YRw", + "colab_type": "code", + "outputId": "0ea77a0f-d18c-492d-9a47-604db53c8be1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + } + }, + "source": [ + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOKUlEQVR4nO3dbYxc5XnG8evyxga85sWGZOOA05DU\naoKa1qQbhwpUkRAooCqGSiCsgoyK6lSCKCiRUkQ+xF9aWW1DGlVVKqcgDAQILbGgLYI4DipK1FIW\n5BgDSSDG1F6M18S8mA3xy/ruhz1GG7PzzHreztj3/yetZubc58y5NfLlM3OeOfM4IgTg2Der7gYA\n9AZhB5Ig7EAShB1IgrADSbynlzsbGByM2acs6OUugVT2v75bE+Pjnq7WVthtXyzpm5IGJP1LRKwu\nrT/7lAX64F9+qZ1dAij4v3++pWGt5bfxtgck/ZOkSySdJWm57bNafT4A3dXOZ/alkl6IiC0RsU/S\nvZKWdaYtAJ3WTthPl7RtyuPt1bLfYHul7RHbIxPj423sDkA7un42PiLWRMRwRAwPDA52e3cAGmgn\n7KOSFk15fEa1DEAfaifsT0habPtM23MkXSXpwc60BaDTWh56i4gDtm+Q9Igmh95ui4hnOtYZgI5q\na5w9Ih6S9FCHegHQRXxdFkiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k\nQdiBJHr6U9LovTPv312sH9z002J9YP788g4GmhwvFpzSsLTlz4bK26KjOLIDSRB2IAnCDiRB2IEk\nCDuQBGEHkiDsQBKMsx/jdp5bHic/6YxPFutzX3y9WD/4wkvF+qyJg4Uq4+y9xJEdSIKwA0kQdiAJ\nwg4kQdiBJAg7kARhB5JgnP0Y96uFTervn11e4VPvbbKHZnX0i7bCbnurpD2SJiQdiIjhTjQFoPM6\ncWT/dES82oHnAdBFfGYHkmg37CHp+7aftL1yuhVsr7Q9YntkYny8zd0BaFW7b+PPi4hR2++TtN72\nTyPisakrRMQaSWsk6fjTF0Wb+wPQoraO7BExWt2OSVonaWknmgLQeS2H3fag7RMP3Zd0kaTNnWoM\nQGe18zZ+SNI624ee5+6IeLgjXSVz+zX/WKxfe+cXWn7u1cvvbHlbSZrrvcX6d1/9VLH+2Jbfblib\n2H1ccds5vxwo1uUmnwrD5XoyLYc9IrZI+v0O9gKgixh6A5Ig7EAShB1IgrADSRB2IAkuce0Dowea\nTIvcxKqr7mlY2/z2GcVtzxv8ebHerLfX9p1QrB93/P6GtfFZc4rbNh1awxHhyA4kQdiBJAg7kARh\nB5Ig7EAShB1IgrADSTDOPkOlsey5s8qXgX7l7muL9a/ec3UrLb1j/GDjS0XvWveZ4rZ3+dNt7bup\nwmWmxzGO3lMc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZZ2jVvcvr23mT8ei/u+9PW3/uZj+3\nzM81HzM4sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyzJzd7T7m+6KFfFusxq3y8ePGKBYWNGaPv\npaZHdtu32R6zvXnKsgW219t+vrptb5YDAF03k7fxt0u6+LBlN0naEBGLJW2oHgPoY03DHhGPSdp9\n2OJlktZW99dKuqzDfQHosFZP0A1FxI7q/iuShhqtaHul7RHbIxPj4y3uDkC72j4bHxEhqeHVEhGx\nJiKGI2J4YHCw3d0BaFGrYd9pe6EkVbdjnWsJQDe0GvYHJa2o7q+Q9EBn2gHQLU3H2W3fI+l8SafZ\n3i7pa5JWS7rP9nWSXpJ0ZTebPNbdt+KWYn3jr8tzrP/Nd1t/+Ztdrr7/1PJHr22fLc/Pfqz64CPl\n80/7Ti7PPf/KObM72c6MNA17RDT61YYLOtwLgC7i67JAEoQdSIKwA0kQdiAJwg4kwSWuPXDjFeWv\nIVy59ks96mQaTYbeZu2f6E0fR5lZ+w8W63UMrTXDkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCc\nvQf+4V+X1d1CQ/tOLtdfumReezsoXUPbxz8l/YEf7yvWvXd/jzrpHI7sQBKEHUiCsANJEHYgCcIO\nJEHYgSQIO5AE4+zANF4+t/xT0It+ePT9hDZHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2tGX4\nwmeL9ZEffKxh7eI/+d/itg//5yfLO6/xevhtnzkGx9lt32Z7zPbmKctW2R61vbH6u7S7bQJo10ze\nxt8u6eJpln8jIpZUfw91ti0AndY07BHxmKTdPegFQBe1c4LuBtubqrf58xutZHul7RHbIxPj423s\nDkA7Wg37tyR9RNISSTskfb3RihGxJiKGI2J4YHCwxd0BaFdLYY+InRExEREHJX1b0tLOtgWg01oK\nu+2FUx5eLmlzo3UB9Iem4+y275F0vqTTbG+X9DVJ59teosnZvbdK+nwXe0QX/d4FPyvWN234nWJ9\nZP1ZLe977NcntrwtjlzTsEfE8mkW39qFXgB0EV+XBZIg7EAShB1IgrADSRB2IAkucT0KvOdX5fqB\nua0/9wdOeKNY39T6Uzf11IaPdvHZcTiO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsfeCU5w8W\n654obz+wPxrWXv34QHHbh/+jxt8dqfGnoDPiyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3geO\nf608kD6rMI4uSXN2vd2wNu/F8r63fu7k8gpNDI0cKNZ3DvNPrF9wZAeSIOxAEoQdSIKwA0kQdiAJ\nwg4kQdiBJBgE7QN7Typfcz7v5b3F+v4FxzculofoNfREeZx87va3inWP7irWP/w/heL8k4rbbrn6\nfcU618MfmaZHdtuLbD9q+1nbz9j+YrV8ge31tp+vbud3v10ArZrJ2/gDkr4cEWdJOkfS9bbPknST\npA0RsVjShuoxgD7VNOwRsSMinqru75H0nKTTJS2TtLZaba2ky7rVJID2HdEJOtsfknS2pMclDUXE\njqr0iqShBtustD1ie2RifLyNVgG0Y8Zhtz1P0v2SboyIN6fWIiLU4FRQRKyJiOGIGB4YHGyrWQCt\nm1HYbc/WZNC/ExHfqxbvtL2wqi+UNNadFgF0QtOhN9uWdKuk5yLilimlByWtkLS6un2gKx0m8NpH\ny//nzttWHj+bNdH4EtltF51Q3Pb9j5eH3ma9Xh56iybDZ2997NSGtbE/KA85MrTWWTMZZz9X0jWS\nnra9sVp2syZDfp/t6yS9JOnK7rQIoBOahj0ifiSp0X+xF3S2HQDdwtdlgSQIO5AEYQeSIOxAEoQd\nSIJLXI8C2y6c28bWTa5xbeIXf356W9sXNRlH/+vldxXrS457uVg/rvD0T+0tXz77lbuvLdaPRhzZ\ngSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTe2Xp7PIKTYbpv3DFvxfr15+yrWFtb+wvbvtfb5e/\nX/Bvb55drE9E42PZHevyXbDJkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCc/VjX5Jrxqy//YbH+\nuZM2Fuu7Jsqz/Nz+ZuPrxu/cfk5x29Efn1Gs48hwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGYy\nP/siSXdIGtLk1c1rIuKbtldJ+gtJu6pVb46Ih7rVaGZ/+MdPF+uL5441rL12oHxN+Gw3nttdkta9\n8Yli/SdvlH9X/tlHFxfr6J2ZfKnmgKQvR8RTtk+U9KTt9VXtGxHx991rD0CnzGR+9h2SdlT399h+\nTlIXpwkB0A1H9Jnd9ocknS3p8WrRDbY32b7N9vwG26y0PWJ7ZGJ8vK1mAbRuxmG3PU/S/ZJujIg3\nJX1L0kckLdHkkf/r020XEWsiYjgihgcGy9+jBtA9Mwq77dmaDPp3IuJ7khQROyNiIiIOSvq2pKXd\naxNAu5qG3bYl3SrpuYi4ZcryhVNWu1zS5s63B6BTZnI2/lxJ10h62vah6x1vlrTc9hJNDsdtlfT5\nrnQI/fcjHy/Xe9QHjm4zORv/I0nTXRTNmDpwFOEbdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEH\nkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQcEb3bmb1L0ktTFp0m6dWeNXBk+rW3fu1LordWdbK334qI\n905X6GnY37VzeyQihmtroKBfe+vXviR6a1WveuNtPJAEYQeSqDvsa2ref0m/9tavfUn01qqe9Fbr\nZ3YAvVP3kR1AjxB2IIlawm77Yts/s/2C7Zvq6KER21ttP217o+2Rmnu5zfaY7c1Tli2wvd7289Xt\ntHPs1dTbKtuj1Wu30falNfW2yPajtp+1/YztL1bLa33tCn315HXr+Wd22wOSfi7pQknbJT0haXlE\nPNvTRhqwvVXScETU/gUM238k6S1Jd0TE71bL/lbS7ohYXf1HOT8i/qpPelsl6a26p/GuZitaOHWa\ncUmXSbpWNb52hb6uVA9etzqO7EslvRARWyJin6R7JS2roY++FxGPSdp92OJlktZW99dq8h9LzzXo\nrS9ExI6IeKq6v0fSoWnGa33tCn31RB1hP13StimPt6u/5nsPSd+3/aTtlXU3M42hiNhR3X9F0lCd\nzUyj6TTevXTYNON989q1Mv15uzhB927nRcQnJF0i6frq7WpfisnPYP00djqjabx7ZZppxt9R52vX\n6vTn7aoj7KOSFk15fEa1rC9ExGh1OyZpnfpvKuqdh2bQrW7Hau7nHf00jfd004yrD167Oqc/ryPs\nT0habPtM23MkXSXpwRr6eBfbg9WJE9kelHSR+m8q6gclrajur5D0QI29/IZ+mca70TTjqvm1q336\n84jo+Z+kSzV5Rv4Xkr5aRw8N+vqwpJ9Uf8/U3ZukezT5tm6/Js9tXCfpVEkbJD0v6QeSFvRRb3dK\nelrSJk0Ga2FNvZ2nybfomyRtrP4urfu1K/TVk9eNr8sCSXCCDkiCsANJEHYgCcIOJEHYgSQIO5AE\nYQeS+H9oSiuxioM8OwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Z3f8XNK9YR_", + "colab_type": "text" + }, + "source": [ + "## Quantization\n", + "\n", + "Now that we have a pruned model that appears to be performing well, let's see if we can make it even smaller by quantization. To do this, we'll need a slightly different neural network, one that corresponds to Figure 3 from the paper. Instead of having a matrix of float values, we'll have a matrix of integer labels (here called \"labels\") that correspond to entries in a (hopefully) small codebook of centroids (here called \"centroids\"). The way that I've coded it, there's still a mask that enforces our desired sparsity pattern." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2Of2GpCU9YSD", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MultilayerPerceptronQuantized(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim, hidden_dim,mask,labels,centroids):\n", + " super(MultilayerPerceptronQuantized, self).__init__()\n", + " self.mask = torch.nn.Parameter(mask,requires_grad=False)\n", + " self.labels = torch.nn.Parameter(labels,requires_grad=False)\n", + " self.centroids = torch.nn.Parameter(centroids,requires_grad=True)\n", + "\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim))\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim))\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim))\n", + "\n", + " def forward(self, x):\n", + " W_0 = self.mask*(self.centroids[self.labels].reshape(784,64))\n", + " hidden = torch.tanh(x@W_0 + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnBF8PwJ9YSQ", + "colab_type": "text" + }, + "source": [ + "Notice what is happening in the forward method: W_0 is being reconstructed by using a matrix (self.labels) to index into a vector (self.centroids). The beauty of automatic differentiation allows backpropogation through this sort of weird indexing operation, and thus gives us gradients of the objective function with respect to the centroid values!\n", + "\n", + "### Q6: However, before we are able to use this AD magic, we need to specify the static label matrix (and an initial guess for centroids). Use the k-means algorithm (or something else if you prefer) figure out the label matrix and centroid vectors. PROTIP1: I used scikit-learns implementation of k-means. PROTIP2: only cluster the non-zero entries" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bn9LC7zY9YSV", + "colab_type": "code", + "colab": {} + }, + "source": [ + "from sklearn.cluster import KMeans\n", + "import numpy as np\n", + "# convert weight and mask matrices into numpy arrays\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "mask = model.mask.detach().cpu().numpy()\n", + "\n", + "# Figure out the indices of non-zero entries \n", + "inds = np.where(mask!=0)\n", + "# Figure out the values of non-zero entries\n", + "vals = W_0[inds]\n", + "\n", + "### TODO: perform clustering on vals\n", + "kmean = KMeans(n_clusters=2)\n", + "clusters = kmean.fit(vals.reshape(len(vals),1))\n", + "centroids = kmean.cluster_centers_\n", + "labels = []\n", + "\n", + "for val in W_0.reshape(784*64):\n", + " label = torch.argmin(torch.abs(val - torch.from_numpy(centroids)))\n", + " labels.append(label.data)\n", + "\n", + " \n", + "\n", + "### TODO: turn the label matrix and centroids into a torch tensor\n", + "labels = torch.tensor(labels,dtype=torch.long,device=device)\n", + "centroids = torch.tensor(centroids,device=device)\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y4SE1iks9YSk", + "colab_type": "text" + }, + "source": [ + "Now, we can instantiate our quantized model and import the appropriate pre-trained weights for the other network layers. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SLJS3aTV9YSn", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Instantiate quantized model\n", + "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,new_mask,labels,centroids)\n", + "model_q = model_q.to(device)\n", + "\n", + "# Copy pre-trained weights from unquantized model for non-quantized layers\n", + "model_q.b_0.data = model.b_0.data\n", + "model_q.W_1.data = model.W_1.data\n", + "model_q.b_1.data = model.b_1.data" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ijPdAe149YSy", + "colab_type": "text" + }, + "source": [ + "Finally, we can fine tune the quantized model. We'll adjust not only the centroids, but also the weights in the other layers." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lm5n1IfC9YS1", + "colab_type": "code", + "outputId": "e80a11c4-992f-461e-af3c-7ea68cc72808", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 101 + } + }, + "source": [ + "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", + "iter = 0\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model_q(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model_q(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_quantized.h5')" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.14308680593967438. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.18871541321277618. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.16591843962669373. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.12692667543888092. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.14464326202869415. Accuracy: 95.\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AZvRY6vQIffp", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0ksmxQvc9YTA", + "colab_type": "text" + }, + "source": [ + "After retraining, we can, just for fun, reconstruct the pruned and quantized weights and plot them as images:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bbEWvCON9YTC", + "colab_type": "code", + "outputId": "41a66a18-6b1e-4051-9936-b848b75ef4a7", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + } + }, + "source": [ + "W_0 = (model_q.mask*model_q.centroids[model_q.labels].reshape(784,64)).detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAALm0lEQVR4nO3dT4ic9R3H8c+nq/aw6yFb2yWNabWa\nSyg0liUISrFIJeYSvQRzkBSka0HBgIWKPZhjKP49FGWtwVisIqiYQ2hNgxA8KK6SxsS0TZSISdes\nkoPZXGzWbw/7RNZkd2ec53nmeXa/7xcsM/M8M/N8ffCT3zPPd575OSIEYPn7TtMFAOgPwg4kQdiB\nJAg7kARhB5K4pJ8bGxgajEuGh/u5SSCVc6dPa2b6rOdbVyrstjdIekLSgKQ/R8SOxZ5/yfCwfvi7\nbWU2CWAR/3348QXX9XwYb3tA0p8k3SppraQtttf2+n4A6lXmM/t6Scci4qOI+FLSi5I2VVMWgKqV\nCfsqSZ/MeXyiWPYNtsdsT9iemJk+W2JzAMqo/Wx8RIxHxGhEjA4MDda9OQALKBP2k5JWz3l8ZbEM\nQAuVCfs7ktbYvtr2ZZLukLS7mrIAVK3n1ltEnLN9r6S/a7b1tjMiDldWGYBKleqzR8QeSXsqqgVA\njfi6LJAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfT1p6TRf9du\ne6vpEhZ07PHrmy4hFUZ2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE\nYQeSIOxAEoQdSILr2Zc5rhnHeaXCbvu4pDOSZiSdi4jRKooCUL0qRvZfRsTnFbwPgBrxmR1IomzY\nQ9Lrtt+1PTbfE2yP2Z6wPTEzfbbk5gD0quxh/I0RcdL2DyTttf2viNg/9wkRMS5pXJK++6PVUXJ7\nAHpUamSPiJPF7ZSkVyWtr6IoANXrOey2B21ffv6+pFskHaqqMADVKnMYPyLpVdvn3+evEfG3SqpK\n5sPNTy26/pqXflvbe7dZmf9uXKznsEfER5J+VmEtAGpE6w1IgrADSRB2IAnCDiRB2IEkuMR1GVjK\n7TX0DyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRBn71LZXrZnS7VrPNSTi4TxXmM7EAShB1IgrAD\nSRB2IAnCDiRB2IEkCDuQBH32LrW5X93m2tAejOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kAR99uSu\n3fZWre9/7PHra31/dK/jyG57p+0p24fmLBu2vdf20eJ2Rb1lAiirm8P4ZyVtuGDZA5L2RcQaSfuK\nxwBarGPYI2K/pNMXLN4kaVdxf5ek2yquC0DFej1BNxIRk8X9TyWNLPRE22O2J2xPzEyf7XFzAMoq\nfTY+IkJSLLJ+PCJGI2J0YGiw7OYA9KjXsJ+yvVKSitup6koCUIdew75b0tbi/lZJr1VTDoC6dOyz\n235B0k2SrrB9QtJDknZIesn2XZI+lrS5ziKXu7Lzqzd5PXvWPnrZ7yc0sd86hj0itiyw6uaKawFQ\nI74uCyRB2IEkCDuQBGEHkiDsQBJc4toHnVpr/BT08tPGliQjO5AEYQeSIOxAEoQdSIKwA0kQdiAJ\nwg4kQZ+9D9rcR29jP7gN6v6J7SYwsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvTZgXl0+v7BUuzD\nM7IDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL02VFKmd/EX8q/p78Ufweg48hue6ftKduH5izbbvuk\n7QPF38Z6ywRQVjeH8c9K2jDP8sciYl3xt6fasgBUrWPYI2K/pNN9qAVAjcqcoLvX9sHiMH/FQk+y\nPWZ7wvbEzPTZEpsDUEavYX9S0jWS1kmalPTIQk+MiPGIGI2I0YGhwR43B6CsnsIeEaciYiYivpL0\ntKT11ZYFoGo9hd32yjkPb5d0aKHnAmiHjn122y9IuknSFbZPSHpI0k2210kKSccl3V1jjahR2V53\nm3vh+KaOYY+ILfMsfqaGWgDUiK/LAkkQdiAJwg4kQdiBJAg7kASXuC4BnX62eClebinRtus3RnYg\nCcIOJEHYgSQIO5AEYQeSIOxAEoQdSII+ewuUnf53sdd36sHT686DkR1IgrADSRB2IAnCDiRB2IEk\nCDuQBGEHkqDPvszVfS38cr3WfjliZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJOizJ1f2Wvo6358e\nfbU6juy2V9t+w/YHtg/bvq9YPmx7r+2jxe2K+ssF0KtuDuPPSbo/ItZKul7SPbbXSnpA0r6IWCNp\nX/EYQEt1DHtETEbEe8X9M5KOSFolaZOkXcXTdkm6ra4iAZT3rU7Q2b5K0nWS3pY0EhGTxapPJY0s\n8Jox2xO2J2amz5YoFUAZXYfd9pCklyVti4gv5q6LiJAU870uIsYjYjQiRgeGBksVC6B3XYXd9qWa\nDfrzEfFKsfiU7ZXF+pWSpuopEUAVOrbebFvSM5KORMSjc1btlrRV0o7i9rVaKkygU4upzvZV3a23\nxdBa669u+uw3SLpT0vu2DxTLHtRsyF+yfZekjyVtrqdEAFXoGPaIeFOSF1h9c7XlAKgLX5cFkiDs\nQBKEHUiCsANJEHYgCS5xXQKa7Ec3ue0PNz/V2LaX41TWjOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7\nkAR99uTK9tGb7IWXsRz76J0wsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvTZk2tznzxjL7xOjOxA\nEoQdSIKwA0kQdiAJwg4kQdiBJAg7kEQ387OvlvScpBFJIWk8Ip6wvV3SbyR9Vjz1wYjYU1ehmbW5\nF94JvfL26OZLNeck3R8R79m+XNK7tvcW6x6LiIfrKw9AVbqZn31S0mRx/4ztI5JW1V0YgGp9q8/s\ntq+SdJ2kt4tF99o+aHun7RULvGbM9oTtiZnps6WKBdC7rsNue0jSy5K2RcQXkp6UdI2kdZod+R+Z\n73URMR4RoxExOjA0WEHJAHrRVdhtX6rZoD8fEa9IUkScioiZiPhK0tOS1tdXJoCyOobdtiU9I+lI\nRDw6Z/nKOU+7XdKh6ssDUJVuzsbfIOlOSe/bPlAse1DSFtvrNNuOOy7p7loqBO0rVKKbs/FvSvI8\nq+ipA0sI36ADkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k\n4Yjo38bszyR9PGfRFZI+71sB305ba2trXRK19arK2n4cEd+fb0Vfw37Rxu2JiBhtrIBFtLW2ttYl\nUVuv+lUbh/FAEoQdSKLpsI83vP3FtLW2ttYlUVuv+lJbo5/ZAfRP0yM7gD4h7EASjYTd9gbb/7Z9\nzPYDTdSwENvHbb9v+4DtiYZr2Wl7yvahOcuGbe+1fbS4nXeOvYZq2277ZLHvDtje2FBtq22/YfsD\n24dt31csb3TfLVJXX/Zb3z+z2x6Q9B9Jv5J0QtI7krZExAd9LWQBto9LGo2Ixr+AYfsXkqYlPRcR\nPy2W/VHS6YjYUfxDuSIift+S2rZLmm56Gu9itqKVc6cZl3SbpF+rwX23SF2b1Yf91sTIvl7SsYj4\nKCK+lPSipE0N1NF6EbFf0ukLFm+StKu4v0uz/7P03QK1tUJETEbEe8X9M5LOTzPe6L5bpK6+aCLs\nqyR9MufxCbVrvveQ9Lrtd22PNV3MPEYiYrK4/6mkkSaLmUfHabz76YJpxluz73qZ/rwsTtBd7MaI\n+LmkWyXdUxyutlLMfgZrU++0q2m8+2Weaca/1uS+63X687KaCPtJSavnPL6yWNYKEXGyuJ2S9Kra\nNxX1qfMz6Ba3Uw3X87U2TeM93zTjasG+a3L68ybC/o6kNbavtn2ZpDsk7W6gjovYHixOnMj2oKRb\n1L6pqHdL2lrc3yrptQZr+Ya2TOO90DTjanjfNT79eUT0/U/SRs2ekf9Q0h+aqGGBun4i6Z/F3+Gm\na5P0gmYP6/6n2XMbd0n6nqR9ko5K+oek4RbV9hdJ70s6qNlgrWyoths1e4h+UNKB4m9j0/tukbr6\nst/4uiyQBCfogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNPDLQpJGSWBgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_6ar0vcD9YTK", + "colab_type": "text" + }, + "source": [ + "Certainly a much more parsimonious representation. The obvious question now becomes:\n", + "\n", + "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", + "\n", + "I got great results all the way down to two. Though this may because the \n", + "bias and W_1 values weren't restricted. It would be interesting to see what would happen if we restricted those weights to a small code book as well. \n", + "\n", + "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? \n", + "\n", + "Less accurate, but not by a large amount. A final accuracy of 94% as opposed to 96%." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GGUePifO9YTM", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file From 7de5910115d4e98a6453001d7d52ca08cb5b646e Mon Sep 17 00:00:00 2001 From: tmayer868 Date: Wed, 1 Apr 2020 12:49:50 -0600 Subject: [PATCH 2/3] Delete deep_compression_exercise .ipynb --- deep_compression_exercise .ipynb | 1825 ------------------------------ 1 file changed, 1825 deletions(-) delete mode 100644 deep_compression_exercise .ipynb diff --git a/deep_compression_exercise .ipynb b/deep_compression_exercise .ipynb deleted file mode 100644 index 81451f0..0000000 --- a/deep_compression_exercise .ipynb +++ /dev/null @@ -1,1825 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.1" - }, - "colab": { - "name": "deep_compression_exercise.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "accelerator": "GPU", - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "be09c445b75f468e8dcfa18358ce4a36": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_5aac198c9a3d4eea9be7fb8beeb227ce", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_03e6a001de424248a4122d33143e75f0", - "IPY_MODEL_c005a5ba50db480caba161765274b3f9" - ] - } - }, - "5aac198c9a3d4eea9be7fb8beeb227ce": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "03e6a001de424248a4122d33143e75f0": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntProgressModel", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_60e876032db7469e846f94c3ac3d5942", - "_dom_classes": [], - "description": "", - "_model_name": "IntProgressModel", - "bar_style": "info", - "max": 1, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 1, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_00e93303f3b447a1bf56f08dfde7de4d" - } - }, - "c005a5ba50db480caba161765274b3f9": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_4403339d5717448687dcb5155c9e1239", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 9920512/? [00:20<00:00, 1487669.97it/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_ac8b2caa8abb45189966110e612433c1" - } - }, - "60e876032db7469e846f94c3ac3d5942": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "00e93303f3b447a1bf56f08dfde7de4d": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "4403339d5717448687dcb5155c9e1239": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "ac8b2caa8abb45189966110e612433c1": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "fc589a1423bd406581f252ecf8ea97bd": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_e7f77c796bc14d9ab6f97137424b925f", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_0e847e7b236c41b79dc518091bf7067f", - "IPY_MODEL_3e2bf6a25441443c92dc5c3c2d1fef22" - ] - } - }, - "e7f77c796bc14d9ab6f97137424b925f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "0e847e7b236c41b79dc518091bf7067f": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntProgressModel", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_ac1223dd567b41589efcc1044c8fa962", - "_dom_classes": [], - "description": " 0%", - "_model_name": "IntProgressModel", - "bar_style": "info", - "max": 1, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 0, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_092573b76819403ca18edea829d1bb31" - } - }, - "3e2bf6a25441443c92dc5c3c2d1fef22": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_7a47e9e824d345fa916984efef7b2f3d", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 0/28881 [00:00<?, ?it/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_08c5f10c348649a9a927c221cf2aa965" - } - }, - "ac1223dd567b41589efcc1044c8fa962": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "092573b76819403ca18edea829d1bb31": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "7a47e9e824d345fa916984efef7b2f3d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "08c5f10c348649a9a927c221cf2aa965": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "5315c5e02993418a9d70d8568a24574c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_7d973a7b57544afcb7deeed203361142", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_6ebfc1cfcbca4b41aa3214b8ab36e790", - "IPY_MODEL_fc64f32f6bf741ffb46025054188bff0" - ] - } - }, - "7d973a7b57544afcb7deeed203361142": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "6ebfc1cfcbca4b41aa3214b8ab36e790": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntProgressModel", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_9a43b17aa8bc46508b7fa9fabaf1bf32", - "_dom_classes": [], - "description": "", - "_model_name": "IntProgressModel", - "bar_style": "info", - "max": 1, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 1, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_438ac4bcdab249bcbbaccbe207981ffd" - } - }, - "fc64f32f6bf741ffb46025054188bff0": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_279398b30dc6479b99e352bbc4651b76", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 1654784/? [00:18<00:00, 511954.24it/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_7c8a5703523f4c0f80b8fee415f3fa83" - } - }, - "9a43b17aa8bc46508b7fa9fabaf1bf32": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "438ac4bcdab249bcbbaccbe207981ffd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "279398b30dc6479b99e352bbc4651b76": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "7c8a5703523f4c0f80b8fee415f3fa83": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "e4c516162b844d6fbd33cf29584aae20": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_6779960e36614cef9c512a615a2980bf", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_9b09958b45a441d29688f6eeba9707d3", - "IPY_MODEL_62a69547013748de9d80c4a992748698" - ] - } - }, - "6779960e36614cef9c512a615a2980bf": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "9b09958b45a441d29688f6eeba9707d3": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntProgressModel", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_f49a8270a4364784b39bc0eca205a8eb", - "_dom_classes": [], - "description": " 0%", - "_model_name": "IntProgressModel", - "bar_style": "info", - "max": 1, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 0, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_bca6c5f7f93942e98c6ef7412e311dfb" - } - }, - "62a69547013748de9d80c4a992748698": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_5a1aef0316fd4ebbbd436bd2ecf2383c", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 0/4542 [00:00<?, ?it/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_00f7f24316dc48b1adeb59b3d0dadc4b" - } - }, - "f49a8270a4364784b39bc0eca205a8eb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "bca6c5f7f93942e98c6ef7412e311dfb": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "5a1aef0316fd4ebbbd436bd2ecf2383c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "00f7f24316dc48b1adeb59b3d0dadc4b": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - } - } - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "qe_xWm7z9YM6", - "colab_type": "text" - }, - "source": [ - "# Exercise Week 9: Pruning and Quantization\n", - "This week, we will explore some of the ideas discussed in Han, Mao, and Dally's Deep Compression. In particular, we will implement weight pruning with fine tuning, as well as k-means weight quantization. **Note that we will unfortunately not be doing this in a way that will actually lead to substantial efficiency gains: that would involve the use of sparse matrices which are not currently well-supported in pytorch.** \n", - "\n", - "## Training an MNIST classifier\n", - "For this example, we'll work with a basic multilayer perceptron with a single hidden layer. We will train it on the MNIST dataset so that it can classify handwritten digits. As usual we load the data:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Cysg2vAY9YNQ", - "colab_type": "code", - "outputId": "cbea6932-d4ff-497e-9ff3-39fa462cdcba", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 313, - "referenced_widgets": [ - "be09c445b75f468e8dcfa18358ce4a36", - "5aac198c9a3d4eea9be7fb8beeb227ce", - "03e6a001de424248a4122d33143e75f0", - "c005a5ba50db480caba161765274b3f9", - "60e876032db7469e846f94c3ac3d5942", - "00e93303f3b447a1bf56f08dfde7de4d", - "4403339d5717448687dcb5155c9e1239", - "ac8b2caa8abb45189966110e612433c1", - "fc589a1423bd406581f252ecf8ea97bd", - "e7f77c796bc14d9ab6f97137424b925f", - "0e847e7b236c41b79dc518091bf7067f", - "3e2bf6a25441443c92dc5c3c2d1fef22", - "ac1223dd567b41589efcc1044c8fa962", - "092573b76819403ca18edea829d1bb31", - "7a47e9e824d345fa916984efef7b2f3d", - "08c5f10c348649a9a927c221cf2aa965", - "5315c5e02993418a9d70d8568a24574c", - "7d973a7b57544afcb7deeed203361142", - "6ebfc1cfcbca4b41aa3214b8ab36e790", - "fc64f32f6bf741ffb46025054188bff0", - "9a43b17aa8bc46508b7fa9fabaf1bf32", - "438ac4bcdab249bcbbaccbe207981ffd", - "279398b30dc6479b99e352bbc4651b76", - "7c8a5703523f4c0f80b8fee415f3fa83", - "e4c516162b844d6fbd33cf29584aae20", - "6779960e36614cef9c512a615a2980bf", - "9b09958b45a441d29688f6eeba9707d3", - "62a69547013748de9d80c4a992748698", - "f49a8270a4364784b39bc0eca205a8eb", - "bca6c5f7f93942e98c6ef7412e311dfb", - "5a1aef0316fd4ebbbd436bd2ecf2383c", - "00f7f24316dc48b1adeb59b3d0dadc4b" - ] - } - }, - "source": [ - "import torch\n", - "import torchvision.transforms as transforms\n", - "import torchvision.datasets as datasets\n", - "\n", - "device = torch.device('cuda' if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", - "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n", - "\n", - "batch_size = 300\n", - "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", - "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "be09c445b75f468e8dcfa18358ce4a36", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fc589a1423bd406581f252ecf8ea97bd", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5315c5e02993418a9d70d8568a24574c", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e4c516162b844d6fbd33cf29584aae20", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", - "Processing...\n", - "Done!\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8kdx9cAu9YOU", - "colab_type": "text" - }, - "source": [ - "Then define a model:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "UInyoax99YOf", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MultilayerPerceptron(torch.nn.Module):\n", - " def __init__(self, input_dim, hidden_dim, output_dim,mask=None):\n", - " super(MultilayerPerceptron, self).__init__()\n", - " if not mask:\n", - " self.mask = torch.nn.Parameter(torch.ones(input_dim,hidden_dim),requires_grad=False)\n", - " else:\n", - " self.mask = torch.nn.Parameter(mask)\n", - "\n", - " self.W_0 = torch.nn.Parameter(1e-3*torch.randn(input_dim,hidden_dim)*self.mask,requires_grad=True)\n", - " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim),requires_grad=True)\n", - "\n", - " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", - " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", - " \n", - " def set_mask(self,mask):\n", - " \n", - " self.mask.data = mask.data\n", - " self.W_0.data = self.mask.data*self.W_0.data\n", - "\n", - " def forward(self, x):\n", - " hidden = torch.tanh(x@(self.W_0*self.mask) + self.b_0)\n", - " outputs = hidden@self.W_1 + self.b_1\n", - " return outputs\n" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GAM40Rnb9YPH", - "colab_type": "text" - }, - "source": [ - "Note that the above code is a little bit different than a standard multilayer perceptron implementation.\n", - "\n", - "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", - "\n", - "This model can \"mask\" parameters that we want to ignore, by\n", - "multiplying that paramter by 0. This has the same affect\n", - "as prunning the parameters even though we still technically\n", - "do multiplications and additions with those zeroed out parameters.\n", - "\n", - "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IUZhjRk79YPV", - "colab_type": "code", - "colab": {} - }, - "source": [ - "n_epochs = 10\n", - "\n", - "input_dim = 784\n", - "hidden_dim = 64\n", - "output_dim = 10\n", - "\n", - "model = MultilayerPerceptron(input_dim,hidden_dim,output_dim)\n", - "model = model.to(device)\n", - "\n", - "criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy\n", - "lr_rate = 0.001\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-3)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "44vk9p1D9YPv", - "colab_type": "text" - }, - "source": [ - "And then training proceeds as normal." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "tinBThTV9YP4", - "colab_type": "code", - "outputId": "e857ad65-f505-42c4-bd7e-6ce69c6d56fb", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 185 - } - }, - "source": [ - "iter = 10\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_pretrained.h5')\n" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Iteration: 0. Loss: 0.4008619785308838. Accuracy: 90.\n", - "Iteration: 1. Loss: 0.331943154335022. Accuracy: 92.\n", - "Iteration: 2. Loss: 0.2710331082344055. Accuracy: 93.\n", - "Iteration: 3. Loss: 0.21628224849700928. Accuracy: 93.\n", - "Iteration: 4. Loss: 0.18730275332927704. Accuracy: 94.\n", - "Iteration: 5. Loss: 0.18389032781124115. Accuracy: 95.\n", - "Iteration: 6. Loss: 0.1158808246254921. Accuracy: 95.\n", - "Iteration: 7. Loss: 0.13385257124900818. Accuracy: 95.\n", - "Iteration: 8. Loss: 0.18437758088111877. Accuracy: 95.\n", - "Iteration: 9. Loss: 0.1472378671169281. Accuracy: 95.\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jka2xOtp9YQO", - "colab_type": "text" - }, - "source": [ - "## Pruning\n", - "\n", - "Certainly not a state of the art model, but also not a terrible one. Because we're hoping to do some weight pruning, let's inspect some of the weights directly (recall that we can act like they're images)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IZVDikFr9YQU", - "colab_type": "code", - "outputId": "830c517c-54b2-4f0a-fad6-d014c12f0242", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 265 - } - }, - "source": [ - "import matplotlib.pyplot as plt\n", - "W_0 = model.W_0.detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAVv0lEQVR4nO3dXWyk1XkH8P8z4/H4e73eD7MsCwub\nTQui7UIs0iQ0Io2IgF6QXBSFi5RKqJuLICVRLhrRi3CJqnyIiyrSpqCQKiWKlERwQdoAQkGpWoSX\nbvhayC4U1rvxx3rN7nr8MZ6PpxcekAGf/zHz4Rlz/j/Jsj2Pz/ueeWeeecfzvOccc3eIyEdfpt0d\nEJHNoWQXSYSSXSQRSnaRRCjZRRLRtZk7y/b3e9fIyGbuUiQp5bk5VBYWbL1YQ8luZrcAeABAFsC/\nuvv97O+7RkZw2de/2cguRYQ4/cAPgrG638abWRbAvwC4FcA1AO40s2vq3Z6ItFYj/7PfAOCku7/h\n7isAfgbg9uZ0S0SarZFk3wtgYs3vp2u3vYeZHTazcTMbrxYWGtidiDSi5Z/Gu/sRdx9z97HMQH+r\ndyciAY0k+xkA+9b8flntNhHpQI0k+3MADprZlWbWDeDLAB5rTrdEpNnqLr25e9nM7gHwn1gtvT3k\n7i83rWdbSWTgYKYUaR57FCLbt3I4ll1Zt+T6rkqeb7zcz+PZZb79DNl/Nce3bbH7XeHxSp415m0/\nihqqs7v74wAeb1JfRKSFdLmsSCKU7CKJULKLJELJLpIIJbtIIpTsIonY1PHsbVWNxBt42YvVgzOl\nSFF3hYdzBR7PLpO2i5HORcIr2/iB8Wxk86R5NcePi8dq4ZG4kcfcM/yOx+5XNcfjlX7+hDPynGDX\nJjRCZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEpFO6S3yshYrxbChnFbhpRJWGgOAriUezxV437qW\nw/Ftx+dpWzv+Oo1ndu6gcWT5gS2PDgdjhf185qKlEX5cS4OREhU5bKVtvGk1UnpjZT0AQLnzxtDq\nzC6SCCW7SCKU7CKJULKLJELJLpIIJbtIIpTsIonYWnV2Vm6OlVyzjQ1pLG0LF1a9i2+7GOlbdp7v\nvHeab6B/Mrz/8jCbTxnIR+roXuDjayvnL9B4drkYjHWN7qdtV4Ybe3qWBsPHJTaFtkemuUbk2grv\njgxxXQqfZ2NTZMeeqyE6s4skQskukgglu0gilOwiiVCyiyRCyS6SCCW7SCK2Vp2diNXRY1MmV3sj\nxU1SS8/187mgPzY6S+On3t5O44X+yLjv68P7nynx1/Puty6n8Wo3DSN3IbJkMzms5R6+7WqkFh67\n/oBNyRwbC78yzOvk1T7+fMks8mI4Gw9fbVFWNrRZM3sTwDyACoCyu481o1Mi0nzNeA35nLvzU5eI\ntJ3+ZxdJRKPJ7gB+Y2ZHzezwen9gZofNbNzMxquFhQZ3JyL1avRt/I3ufsbMdgN4wsxedfdn1v6B\nux8BcAQA8vv2RT4mE5FWaejM7u5nat9nAPwKwA3N6JSINF/dyW5m/WY2+M7PAL4A4KVmdUxEmquR\nt/GjAH5lZu9s59/d/T+a0qs6ZCKDxmM1W0Tq9NnecjgWaTtxPjx3OgD87YH/pfFnt++n8QxZM/qv\ndpykbT95I583/lM94fHoAPDfy3y8/G8LfxqMVSOTEDw5+Sc0/sdJfn0CiuFat5HHEwBwka/JbGV+\nnmTrDKxuIByKjbWPjXcPqTvZ3f0NAH9Rb3sR2VwqvYkkQskukgglu0gilOwiiVCyiyRiaw1xZeWK\n3khpLTadcx8vxVyxey4YO7/Ex2ru6ueXCb9WGKXxwgovb12/cyIYy0XqNNfled/y1kvjF6v8vrP9\nH7t4GW27Lc/Xuu67fIbGS5Vw6W2+yMfuLvbxY16c6aPxGCen2dhy0FatbzlondlFEqFkF0mEkl0k\nEUp2kUQo2UUSoWQXSYSSXSQRW6rO7plwLT06lfRwiYYvGblI41cNngvGlvr4cMiLJV6LPja5l8Z7\n83yq6qnloWDsf6b307a/PH2IxvNd/PqDtxd5Hf78RHh4r5GpngHgkqt5HX1unk+xvTwb7ltmkD8f\nunKRcaSRsEVG0Dq761ZfHT1GZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0lEZ9XZI+N4Wf3R\nu3njbBeP5zI8fvLizmDszCyfKtrI9QEAMNDHx21nI+2fffWq8L4jSwfnz/F47xTf98owrwn3NjAH\nwfkFXsM3MoU2APSPhsfql0r8fnvkso3cPD9PDp/g7ef3hQ9MeSAylTR5qrJDojO7SCKU7CKJULKL\nJELJLpIIJbtIIpTsIolQsoskoqPq7JGyKa/Dx2r0EYslPiZ9aSUcr0SW783l+eDmvUN8LP0fpnfR\nePdUuG+9U7wOvvsonzc+c/RVGi9+7s9p/PxV4b6V6KBuYLnIH5MrR8NzDADAfDE893sBfF74pTcH\naXzXcf5kzS3yeDXPJo6nTVElh4Ud0uiZ3cweMrMZM3tpzW0jZvaEmZ2ofY8slC0i7baRt/E/BnDL\n+277NoCn3P0ggKdqv4tIB4smu7s/A+D9ax/dDuDh2s8PA/hik/slIk1W7wd0o+4+Wft5CkBwsTIz\nO2xm42Y2Xi3w/w9FpHUa/jTe3R1A8NMIdz/i7mPuPpYZ4BMEikjr1Jvs02a2BwBq3/k0oCLSdvUm\n+2MA7qr9fBeAR5vTHRFplWid3cweAXATgJ1mdhrAdwDcD+DnZnY3gLcA3NGMzjgfYkz+WUD0ZWto\nYInGF2PrdZ8ZCO+6yAuj134yvH46AFy/jccHuoo0/mxxfzBWyPN68qVPLtK47eFrx5+7htfCq+Sw\nLu/kF0dUCnzbJ4u7adwK4ad3/wR/sg3xw4Lh1+ZpvNLHU8uq4bH6bH2ERkST3d3vDIQ+3+S+iEgL\n6XJZkUQo2UUSoWQXSYSSXSQRSnaRRHTUENcYZ6MCK7z8tbjMS2sr0300PngyXKop7uClkt09BRp/\n+uzHaXwpMvx27+7zwdjEAh8eu3QFnwY7sxIpj/HKHpZHSfvYsOTIY4oyL5/1nwrHe2b5Y9b7Nl+T\nObPIl9FeuoRPg10aYPNB06Z8uWdCZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0nElqqzU5Ga\nbXGeF4S7lvjrXpmU4VdG+M6fmTjAtx2Zinqwjw9x3Tf0djA2kQkvNQ0Ak5/mT4GBUzSMLO8aVR3m\nU2zH6uz9r0em/94drqV389m7kZvndfaYxV38GoDKYOS+M3S+6PB91pldJBFKdpFEKNlFEqFkF0mE\nkl0kEUp2kUQo2UUSsaXq7HRJ5yqvyVoXr4V3v83bl4bIziPjixcn+PK/XQt8AxcQnsYaAOb6RsLB\nbn6/Y2Oji8P8D5YuiWx/JDzue9cOPh3z3Cv8GoElNlYeQO5i+FyWK/C2S7siU2Tnhmi8uxCZDpo8\nX2NzM1iZxMkDqjO7SCKU7CKJULKLJELJLpIIJbtIIpTsIolQsoskYkvV2anIMreZs3ze+NwC33xp\nG2l7nr9mVviu0bXE66qxWnhpe3jstfXycdmlyPUJ1W5+32676SiNv1EI18qvHpqibR99kdfZyz38\nMa/mw/GLV/L7lSnRMLZF5k8o5/lxzV0Ij3cvDUfmrGdzL5BDEj2zm9lDZjZjZi+tue0+MztjZsdq\nX7fFtiMi7bWRt/E/BnDLOrf/wN0P1b4eb263RKTZosnu7s8AmNuEvohICzXyAd09ZvZC7W3+9tAf\nmdlhMxs3s/FqIfKPsYi0TL3J/kMABwAcAjAJ4HuhP3T3I+4+5u5jmYH+OncnIo2qK9ndfdrdK+5e\nBfAjADc0t1si0mx1JbuZ7Vnz65cAvBT6WxHpDNE6u5k9AuAmADvN7DSA7wC4ycwOYbWq9yaAr7aw\nj+9i9ebYGOBskcfz5yM121y4/fwBPgd4pshfUz3D+9Y3FRkbTR7GA389QVuO5PnnKP/1Op/z/rmz\nl9P4aF94bfoMnaAA6L86PB8+AJyf4/8WVsl8/pUuXsuunODbLg7yx6zSy+PlQbL/LD8unqtvboVo\nsrv7nevc/GCsnYh0Fl0uK5IIJbtIIpTsIolQsoskQskukoiPzhDXWHUqotzD410L4R10FfjyvFVW\nKgHQfZ7ve+ityPK+Hn4YT0zvok2/+Wcv0PjvB/bS+NRk8Erp1Xg1HH+5b08wBgD5Hj7O1CLDmrty\n4eNWjQzt9ciw5HIfb1/khwVOhh7bfCQt2fDaRoa4ishHg5JdJBFKdpFEKNlFEqFkF0mEkl0kEUp2\nkURsqTo7GxHpkTp7bDrn+at4PLtIlsIt8p0PnOKvqSPHizSeP8WHemZWwkXdSg9f7vm7lZtp3Kf4\nBQhDE/y+VcnKxyuDfFnkxe15Gr/0Y2dpfHdfeEno1+f4NNXzfXyu6MIVkSm4+V2DLZBrMyJDf8ll\nFXSIq87sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiM6qs0eWwc2uhGOV7ZHaJB9yjmqktgkP\nvy52X4gsz1uIjGc/y6dz9jN8aeOecnhs9A7wenL3hV4aZ3VbAOib5mPOPRvewMIl/EFZXObxs7v4\nNQTFcvjpvVJqYMw4gAqbChqIrrNtJXLdBokBkWnTNZ5dRJTsIolQsoskQskukgglu0gilOwiiVCy\niySio+rsGV6yRZX1tsF543MX+etephyubWb5cPRYyRXLe3i9uOf/IhcJWHgH3XNLtOnwMp+Tvmvm\nIt/3ucik9zvDY+37TvFJBmY/wSdfP3eQDxrftXs23HaOH3PvijyhYvHIE9KWNz/1omd2M9tnZk+b\n2Stm9rKZfb12+4iZPWFmJ2rfI9Pii0g7beRtfBnAt9z9GgB/CeBrZnYNgG8DeMrdDwJ4qva7iHSo\naLK7+6S7P1/7eR7AcQB7AdwO4OHanz0M4Iut6qSINO5DfUBnZvsBXAfgWQCj7j5ZC00BGA20OWxm\n42Y2Xi3wa8BFpHU2nOxmNgDgFwC+4e7v+dTG3R2BTyTc/Yi7j7n7WGagv6HOikj9NpTsZpbDaqL/\n1N1/Wbt52sz21OJ7AMy0posi0gzRz//NzAA8COC4u39/TegxAHcBuL/2/dFGO1PlMwejkg+XM7KL\nkdJZZERi3ySvj/XNkA2Q0hcALO7kfStcyktI3Vfvp3EUw+Wzhf28xNR9npferLDI9x15t1bcOxyM\nLezhpbf5K/muD+6bpvEMGbac7+V13mo+UpLs4k+opSl+3LPL4edMbInvem2k2PcZAF8B8KKZHavd\ndi9Wk/znZnY3gLcA3NGSHopIU0ST3d1/h/AUBp9vbndEpFV0uaxIIpTsIolQsoskQskukgglu0gi\nOmqIa0MiL1uVSO2STXkMAMWh8A4W9vK2lV6+7/w53n7q04M03n0xvH2LXF+QnyXzcwOYvZmvZV24\nLHKNwYHw9gd28OGzn9ozQeMH+/h1XLlMuFZ+eT9fBvu56ctpfH6RXxTiXXwuas+En09smmkAdZ+i\ndWYXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEbKk6e7YYqT8S1W5e6164lMcrA+G6aW4Xn665\n9HYPjcfWky4PRdYPHgqPze46zevBs5/g8Y9fe4rGPzF4jsbv3vlMMDYYmTs8F5mO+bdL/BqAt4rh\n5ar/uLiNtr0wz5eyLhd56lg5ch5ld63+pzmlM7tIIpTsIolQsoskQskukgglu0gilOwiiVCyiyRi\nS9XZ2dLHZIrw1bbdvFZd7ucDv3N94XHZ2wZ4nX02UpPFHK+zew/v2xWXzAVjxV1823+z92Ua/7vh\ncRo/UeL16vlq+BqDB2c/S9u+dmE3jc8t9NH48kp4Pv7iIp+r35f4Y5ZZiqxTEBuT3qJaOqMzu0gi\nlOwiiVCyiyRCyS6SCCW7SCKU7CKJULKLJGIj67PvA/ATAKNYHYV7xN0fMLP7APwDgLO1P73X3R9v\nVUeBeC2dYkV6AKjyjZcuhMd9nyvxWnYustZ3dT+/BuDWg6/S+Md6w/OnZ4xv+1APH6/+3PKlNP70\nhatp/OjsZcHYzOwQbVuN1LptmZ+rnKwVYCu8bZYPtY9NQdCRNnJRTRnAt9z9eTMbBHDUzJ6oxX7g\n7t9tXfdEpFk2sj77JIDJ2s/zZnYcwN5Wd0xEmutD/c9uZvsBXAfg2dpN95jZC2b2kJltD7Q5bGbj\nZjZeLSw01FkRqd+Gk93MBgD8AsA33P0igB8COADgEFbP/N9br527H3H3MXcfywz0N6HLIlKPDSW7\nmeWwmug/dfdfAoC7T7t7xd2rAH4E4IbWdVNEGhVNdjMzAA8COO7u319z+541f/YlAC81v3si0iwb\n+TT+MwC+AuBFMztWu+1eAHea2SGsluPeBPDVlvSwSbIF/rpm1cgwU1K58yXethRbLjoy/PbXL15L\n4yiS+xZb/XeA15i8EtnAPB8q2giL7DtaiiUVz2jbSKk2Mgt2R9rIp/G/w/pPmZbW1EWkuXQFnUgi\nlOwiiVCyiyRCyS6SCCW7SCKU7CKJ2FJTSTfCqo3N3cvqsrYS2XYsvhC7BoA3b2T132rk+oPYMyQ2\ncjgTu+9MdH7wSC2cLfEd6Vbsfm1FOrOLJELJLpIIJbtIIpTsIolQsoskQskukgglu0gizL2R+Zk/\n5M7MzgJ4a81NOwHMbloHPpxO7Vun9gtQ3+rVzL5d4e671gtsarJ/YOdm4+4+1rYOEJ3at07tF6C+\n1Wuz+qa38SKJULKLJKLdyX6kzftnOrVvndovQH2r16b0ra3/s4vI5mn3mV1ENomSXSQRbUl2M7vF\nzF4zs5Nm9u129CHEzN40sxfN7JiZjbe5Lw+Z2YyZvbTmthEze8LMTtS+r7vGXpv6dp+Znakdu2Nm\ndlub+rbPzJ42s1fM7GUz+3rt9rYeO9KvTTlum/4/u5llAfwBwM0ATgN4DsCd7v7KpnYkwMzeBDDm\n7m2/AMPMPgugAOAn7n5t7bZ/BjDn7vfXXii3u/s/dkjf7gNQaPcy3rXVivasXWYcwBcB/D3aeOxI\nv+7AJhy3dpzZbwBw0t3fcPcVAD8DcHsb+tHx3P0ZAHPvu/l2AA/Xfn4Yq0+WTRfoW0dw90l3f772\n8zyAd5YZb+uxI/3aFO1I9r0AJtb8fhqdtd67A/iNmR01s8Pt7sw6Rt19svbzFIDRdnZmHdFlvDfT\n+5YZ75hjV8/y543SB3QfdKO7Xw/gVgBfq71d7Ui++j9YJ9VON7SM92ZZZ5nxd7Xz2NW7/Hmj2pHs\nZwDsW/P7ZbXbOoK7n6l9nwHwK3TeUtTT76ygW/s+0+b+vKuTlvFeb5lxdMCxa+fy5+1I9ucAHDSz\nK82sG8CXATzWhn58gJn11z44gZn1A/gCOm8p6scA3FX7+S4Aj7axL+/RKct4h5YZR5uPXduXP3f3\nTf8CcBtWP5F/HcA/taMPgX5dBeD3ta+X2903AI9g9W1dCaufbdwNYAeApwCcAPAkgJEO6tu/AXgR\nwAtYTaw9berbjVh9i/4CgGO1r9vafexIvzbluOlyWZFE6AM6kUQo2UUSoWQXSYSSXSQRSnaRRCjZ\nRRKhZBdJxP8DiQuUv16KLFsAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jtGIcba89YQm", - "colab_type": "text" - }, - "source": [ - "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", - "\n", - "The weights that are very purple. \n", - "\n", - "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4uh6-dKQ9YQs", - "colab_type": "code", - "colab": {} - }, - "source": [ - "new_mask = model.mask\n", - "ratio_threshold = .2\n", - "#drop the values that are less than ratio_threshold the size of the max weight in absolute vaue.\n", - "mask = torch.abs(model.W_0.data)/torch.abs(torch.max(model.W_0.data)) > ratio_threshold\n", - "new_mask.data = (new_mask.data)*(mask.float())\n", - "model.set_mask(new_mask)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eGuQ3t-E9YRB", - "colab_type": "text" - }, - "source": [ - "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OocvWnxK9YRW", - "colab_type": "text" - }, - "source": [ - "Now, we have explicitly set some entries in one of the the weight matrices to zero, and ensured via the mask, that they will not be updated by gradient descent. Fine tune the model: " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WgdWUpuD9YRa", - "colab_type": "code", - "outputId": "a4ca20e0-b690-4489-d6f7-af33980dbef6", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 101 - } - }, - "source": [ - "iter = 0\n", - "n_epochs = 5\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_pruned.h5')" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Iteration: 0. Loss: 0.1554800122976303. Accuracy: 95.\n", - "Iteration: 0. Loss: 0.16709843277931213. Accuracy: 95.\n", - "Iteration: 0. Loss: 0.1133299395442009. Accuracy: 96.\n", - "Iteration: 0. Loss: 0.14041125774383545. Accuracy: 96.\n", - "Iteration: 0. Loss: 0.12037742882966995. Accuracy: 96.\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vHGmJW2P9YRs", - "colab_type": "text" - }, - "source": [ - "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", - "\n", - "Not much I pruned the bottom weights less than 20% of the max weight in absoloute value and the accuracy degraded a couple percent but quickly recovered after\n", - "degrade at all. droppping 60 percent of weights dropped the accuracy\n", - "to 70% but it recovered into the 80s.\n", - "\n", - "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "QrhV1kL99YRw", - "colab_type": "code", - "outputId": "0ea77a0f-d18c-492d-9a47-604db53c8be1", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 265 - } - }, - "source": [ - "W_0 = model.W_0.detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOKUlEQVR4nO3dbYxc5XnG8evyxga85sWGZOOA05DU\naoKa1qQbhwpUkRAooCqGSiCsgoyK6lSCKCiRUkQ+xF9aWW1DGlVVKqcgDAQILbGgLYI4DipK1FIW\n5BgDSSDG1F6M18S8mA3xy/ruhz1GG7PzzHreztj3/yetZubc58y5NfLlM3OeOfM4IgTg2Der7gYA\n9AZhB5Ig7EAShB1IgrADSbynlzsbGByM2acs6OUugVT2v75bE+Pjnq7WVthtXyzpm5IGJP1LRKwu\nrT/7lAX64F9+qZ1dAij4v3++pWGt5bfxtgck/ZOkSySdJWm57bNafT4A3dXOZ/alkl6IiC0RsU/S\nvZKWdaYtAJ3WTthPl7RtyuPt1bLfYHul7RHbIxPj423sDkA7un42PiLWRMRwRAwPDA52e3cAGmgn\n7KOSFk15fEa1DEAfaifsT0habPtM23MkXSXpwc60BaDTWh56i4gDtm+Q9Igmh95ui4hnOtYZgI5q\na5w9Ih6S9FCHegHQRXxdFkiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k\nQdiBJHr6U9LovTPv312sH9z002J9YP788g4GmhwvFpzSsLTlz4bK26KjOLIDSRB2IAnCDiRB2IEk\nCDuQBGEHkiDsQBKMsx/jdp5bHic/6YxPFutzX3y9WD/4wkvF+qyJg4Uq4+y9xJEdSIKwA0kQdiAJ\nwg4kQdiBJAg7kARhB5JgnP0Y96uFTervn11e4VPvbbKHZnX0i7bCbnurpD2SJiQdiIjhTjQFoPM6\ncWT/dES82oHnAdBFfGYHkmg37CHp+7aftL1yuhVsr7Q9YntkYny8zd0BaFW7b+PPi4hR2++TtN72\nTyPisakrRMQaSWsk6fjTF0Wb+wPQoraO7BExWt2OSVonaWknmgLQeS2H3fag7RMP3Zd0kaTNnWoM\nQGe18zZ+SNI624ee5+6IeLgjXSVz+zX/WKxfe+cXWn7u1cvvbHlbSZrrvcX6d1/9VLH+2Jbfblib\n2H1ccds5vxwo1uUmnwrD5XoyLYc9IrZI+v0O9gKgixh6A5Ig7EAShB1IgrADSRB2IAkuce0Dowea\nTIvcxKqr7mlY2/z2GcVtzxv8ebHerLfX9p1QrB93/P6GtfFZc4rbNh1awxHhyA4kQdiBJAg7kARh\nB5Ig7EAShB1IgrADSTDOPkOlsey5s8qXgX7l7muL9a/ec3UrLb1j/GDjS0XvWveZ4rZ3+dNt7bup\nwmWmxzGO3lMc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZZ2jVvcvr23mT8ei/u+9PW3/uZj+3\nzM81HzM4sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyzJzd7T7m+6KFfFusxq3y8ePGKBYWNGaPv\npaZHdtu32R6zvXnKsgW219t+vrptb5YDAF03k7fxt0u6+LBlN0naEBGLJW2oHgPoY03DHhGPSdp9\n2OJlktZW99dKuqzDfQHosFZP0A1FxI7q/iuShhqtaHul7RHbIxPj4y3uDkC72j4bHxEhqeHVEhGx\nJiKGI2J4YHCw3d0BaFGrYd9pe6EkVbdjnWsJQDe0GvYHJa2o7q+Q9EBn2gHQLU3H2W3fI+l8SafZ\n3i7pa5JWS7rP9nWSXpJ0ZTebPNbdt+KWYn3jr8tzrP/Nd1t/+Ztdrr7/1PJHr22fLc/Pfqz64CPl\n80/7Ti7PPf/KObM72c6MNA17RDT61YYLOtwLgC7i67JAEoQdSIKwA0kQdiAJwg4kwSWuPXDjFeWv\nIVy59ks96mQaTYbeZu2f6E0fR5lZ+w8W63UMrTXDkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCc\nvQf+4V+X1d1CQ/tOLtdfumReezsoXUPbxz8l/YEf7yvWvXd/jzrpHI7sQBKEHUiCsANJEHYgCcIO\nJEHYgSQIO5AE4+zANF4+t/xT0It+ePT9hDZHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2tGX4\nwmeL9ZEffKxh7eI/+d/itg//5yfLO6/xevhtnzkGx9lt32Z7zPbmKctW2R61vbH6u7S7bQJo10ze\nxt8u6eJpln8jIpZUfw91ti0AndY07BHxmKTdPegFQBe1c4LuBtubqrf58xutZHul7RHbIxPj423s\nDkA7Wg37tyR9RNISSTskfb3RihGxJiKGI2J4YHCwxd0BaFdLYY+InRExEREHJX1b0tLOtgWg01oK\nu+2FUx5eLmlzo3UB9Iem4+y275F0vqTTbG+X9DVJ59teosnZvbdK+nwXe0QX/d4FPyvWN234nWJ9\nZP1ZLe977NcntrwtjlzTsEfE8mkW39qFXgB0EV+XBZIg7EAShB1IgrADSRB2IAkucT0KvOdX5fqB\nua0/9wdOeKNY39T6Uzf11IaPdvHZcTiO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsfeCU5w8W\n654obz+wPxrWXv34QHHbh/+jxt8dqfGnoDPiyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3geO\nf608kD6rMI4uSXN2vd2wNu/F8r63fu7k8gpNDI0cKNZ3DvNPrF9wZAeSIOxAEoQdSIKwA0kQdiAJ\nwg4kQdiBJBgE7QN7Typfcz7v5b3F+v4FxzculofoNfREeZx87va3inWP7irWP/w/heL8k4rbbrn6\nfcU618MfmaZHdtuLbD9q+1nbz9j+YrV8ge31tp+vbud3v10ArZrJ2/gDkr4cEWdJOkfS9bbPknST\npA0RsVjShuoxgD7VNOwRsSMinqru75H0nKTTJS2TtLZaba2ky7rVJID2HdEJOtsfknS2pMclDUXE\njqr0iqShBtustD1ie2RifLyNVgG0Y8Zhtz1P0v2SboyIN6fWIiLU4FRQRKyJiOGIGB4YHGyrWQCt\nm1HYbc/WZNC/ExHfqxbvtL2wqi+UNNadFgF0QtOhN9uWdKuk5yLilimlByWtkLS6un2gKx0m8NpH\ny//nzttWHj+bNdH4EtltF51Q3Pb9j5eH3ma9Xh56iybDZ2997NSGtbE/KA85MrTWWTMZZz9X0jWS\nnra9sVp2syZDfp/t6yS9JOnK7rQIoBOahj0ifiSp0X+xF3S2HQDdwtdlgSQIO5AEYQeSIOxAEoQd\nSIJLXI8C2y6c28bWTa5xbeIXf356W9sXNRlH/+vldxXrS457uVg/rvD0T+0tXz77lbuvLdaPRhzZ\ngSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTe2Xp7PIKTYbpv3DFvxfr15+yrWFtb+wvbvtfb5e/\nX/Bvb55drE9E42PZHevyXbDJkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCc/VjX5Jrxqy//YbH+\nuZM2Fuu7Jsqz/Nz+ZuPrxu/cfk5x29Efn1Gs48hwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGYy\nP/siSXdIGtLk1c1rIuKbtldJ+gtJu6pVb46Ih7rVaGZ/+MdPF+uL5441rL12oHxN+Gw3nttdkta9\n8Yli/SdvlH9X/tlHFxfr6J2ZfKnmgKQvR8RTtk+U9KTt9VXtGxHx991rD0CnzGR+9h2SdlT399h+\nTlIXpwkB0A1H9Jnd9ocknS3p8WrRDbY32b7N9vwG26y0PWJ7ZGJ8vK1mAbRuxmG3PU/S/ZJujIg3\nJX1L0kckLdHkkf/r020XEWsiYjgihgcGy9+jBtA9Mwq77dmaDPp3IuJ7khQROyNiIiIOSvq2pKXd\naxNAu5qG3bYl3SrpuYi4ZcryhVNWu1zS5s63B6BTZnI2/lxJ10h62vah6x1vlrTc9hJNDsdtlfT5\nrnQI/fcjHy/Xe9QHjm4zORv/I0nTXRTNmDpwFOEbdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEH\nkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQcEb3bmb1L0ktTFp0m6dWeNXBk+rW3fu1LordWdbK334qI\n905X6GnY37VzeyQihmtroKBfe+vXviR6a1WveuNtPJAEYQeSqDvsa2ref0m/9tavfUn01qqe9Fbr\nZ3YAvVP3kR1AjxB2IIlawm77Yts/s/2C7Zvq6KER21ttP217o+2Rmnu5zfaY7c1Tli2wvd7289Xt\ntHPs1dTbKtuj1Wu30falNfW2yPajtp+1/YztL1bLa33tCn315HXr+Wd22wOSfi7pQknbJT0haXlE\nPNvTRhqwvVXScETU/gUM238k6S1Jd0TE71bL/lbS7ohYXf1HOT8i/qpPelsl6a26p/GuZitaOHWa\ncUmXSbpWNb52hb6uVA9etzqO7EslvRARWyJin6R7JS2roY++FxGPSdp92OJlktZW99dq8h9LzzXo\nrS9ExI6IeKq6v0fSoWnGa33tCn31RB1hP13StimPt6u/5nsPSd+3/aTtlXU3M42hiNhR3X9F0lCd\nzUyj6TTevXTYNON989q1Mv15uzhB927nRcQnJF0i6frq7WpfisnPYP00djqjabx7ZZppxt9R52vX\n6vTn7aoj7KOSFk15fEa1rC9ExGh1OyZpnfpvKuqdh2bQrW7Hau7nHf00jfd004yrD167Oqc/ryPs\nT0habPtM23MkXSXpwRr6eBfbg9WJE9kelHSR+m8q6gclrajur5D0QI29/IZ+mca70TTjqvm1q336\n84jo+Z+kSzV5Rv4Xkr5aRw8N+vqwpJ9Uf8/U3ZukezT5tm6/Js9tXCfpVEkbJD0v6QeSFvRRb3dK\nelrSJk0Ga2FNvZ2nybfomyRtrP4urfu1K/TVk9eNr8sCSXCCDkiCsANJEHYgCcIOJEHYgSQIO5AE\nYQeS+H9oSiuxioM8OwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8Z3f8XNK9YR_", - "colab_type": "text" - }, - "source": [ - "## Quantization\n", - "\n", - "Now that we have a pruned model that appears to be performing well, let's see if we can make it even smaller by quantization. To do this, we'll need a slightly different neural network, one that corresponds to Figure 3 from the paper. Instead of having a matrix of float values, we'll have a matrix of integer labels (here called \"labels\") that correspond to entries in a (hopefully) small codebook of centroids (here called \"centroids\"). The way that I've coded it, there's still a mask that enforces our desired sparsity pattern." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "2Of2GpCU9YSD", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MultilayerPerceptronQuantized(torch.nn.Module):\n", - " def __init__(self, input_dim, output_dim, hidden_dim,mask,labels,centroids):\n", - " super(MultilayerPerceptronQuantized, self).__init__()\n", - " self.mask = torch.nn.Parameter(mask,requires_grad=False)\n", - " self.labels = torch.nn.Parameter(labels,requires_grad=False)\n", - " self.centroids = torch.nn.Parameter(centroids,requires_grad=True)\n", - "\n", - " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim))\n", - "\n", - " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim))\n", - " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim))\n", - "\n", - " def forward(self, x):\n", - " W_0 = self.mask*(self.centroids[self.labels].reshape(784,64))\n", - " hidden = torch.tanh(x@W_0 + self.b_0)\n", - " outputs = hidden@self.W_1 + self.b_1\n", - " return outputs" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DnBF8PwJ9YSQ", - "colab_type": "text" - }, - "source": [ - "Notice what is happening in the forward method: W_0 is being reconstructed by using a matrix (self.labels) to index into a vector (self.centroids). The beauty of automatic differentiation allows backpropogation through this sort of weird indexing operation, and thus gives us gradients of the objective function with respect to the centroid values!\n", - "\n", - "### Q6: However, before we are able to use this AD magic, we need to specify the static label matrix (and an initial guess for centroids). Use the k-means algorithm (or something else if you prefer) figure out the label matrix and centroid vectors. PROTIP1: I used scikit-learns implementation of k-means. PROTIP2: only cluster the non-zero entries" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "bn9LC7zY9YSV", - "colab_type": "code", - "colab": {} - }, - "source": [ - "from sklearn.cluster import KMeans\n", - "import numpy as np\n", - "# convert weight and mask matrices into numpy arrays\n", - "W_0 = model.W_0.detach().cpu().numpy()\n", - "mask = model.mask.detach().cpu().numpy()\n", - "\n", - "# Figure out the indices of non-zero entries \n", - "inds = np.where(mask!=0)\n", - "# Figure out the values of non-zero entries\n", - "vals = W_0[inds]\n", - "\n", - "### TODO: perform clustering on vals\n", - "kmean = KMeans(n_clusters=2)\n", - "clusters = kmean.fit(vals.reshape(len(vals),1))\n", - "centroids = kmean.cluster_centers_\n", - "labels = []\n", - "\n", - "for val in W_0.reshape(784*64):\n", - " label = torch.argmin(torch.abs(val - torch.from_numpy(centroids)))\n", - " labels.append(label.data)\n", - "\n", - " \n", - "\n", - "### TODO: turn the label matrix and centroids into a torch tensor\n", - "labels = torch.tensor(labels,dtype=torch.long,device=device)\n", - "centroids = torch.tensor(centroids,device=device)\n" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y4SE1iks9YSk", - "colab_type": "text" - }, - "source": [ - "Now, we can instantiate our quantized model and import the appropriate pre-trained weights for the other network layers. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "SLJS3aTV9YSn", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Instantiate quantized model\n", - "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,new_mask,labels,centroids)\n", - "model_q = model_q.to(device)\n", - "\n", - "# Copy pre-trained weights from unquantized model for non-quantized layers\n", - "model_q.b_0.data = model.b_0.data\n", - "model_q.W_1.data = model.W_1.data\n", - "model_q.b_1.data = model.b_1.data" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ijPdAe149YSy", - "colab_type": "text" - }, - "source": [ - "Finally, we can fine tune the quantized model. We'll adjust not only the centroids, but also the weights in the other layers." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lm5n1IfC9YS1", - "colab_type": "code", - "outputId": "e80a11c4-992f-461e-af3c-7ea68cc72808", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 101 - } - }, - "source": [ - "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", - "iter = 0\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model_q(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model_q(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_quantized.h5')" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Iteration: 0. Loss: 0.14308680593967438. Accuracy: 95.\n", - "Iteration: 0. Loss: 0.18871541321277618. Accuracy: 95.\n", - "Iteration: 0. Loss: 0.16591843962669373. Accuracy: 95.\n", - "Iteration: 0. Loss: 0.12692667543888092. Accuracy: 95.\n", - "Iteration: 0. Loss: 0.14464326202869415. Accuracy: 95.\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "AZvRY6vQIffp", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0ksmxQvc9YTA", - "colab_type": "text" - }, - "source": [ - "After retraining, we can, just for fun, reconstruct the pruned and quantized weights and plot them as images:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "bbEWvCON9YTC", - "colab_type": "code", - "outputId": "41a66a18-6b1e-4051-9936-b848b75ef4a7", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 265 - } - }, - "source": [ - "W_0 = (model_q.mask*model_q.centroids[model_q.labels].reshape(784,64)).detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAALm0lEQVR4nO3dT4ic9R3H8c+nq/aw6yFb2yWNabWa\nSyg0liUISrFIJeYSvQRzkBSka0HBgIWKPZhjKP49FGWtwVisIqiYQ2hNgxA8KK6SxsS0TZSISdes\nkoPZXGzWbw/7RNZkd2ec53nmeXa/7xcsM/M8M/N8ffCT3zPPd575OSIEYPn7TtMFAOgPwg4kQdiB\nJAg7kARhB5K4pJ8bGxgajEuGh/u5SSCVc6dPa2b6rOdbVyrstjdIekLSgKQ/R8SOxZ5/yfCwfvi7\nbWU2CWAR/3348QXX9XwYb3tA0p8k3SppraQtttf2+n4A6lXmM/t6Scci4qOI+FLSi5I2VVMWgKqV\nCfsqSZ/MeXyiWPYNtsdsT9iemJk+W2JzAMqo/Wx8RIxHxGhEjA4MDda9OQALKBP2k5JWz3l8ZbEM\nQAuVCfs7ktbYvtr2ZZLukLS7mrIAVK3n1ltEnLN9r6S/a7b1tjMiDldWGYBKleqzR8QeSXsqqgVA\njfi6LJAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfT1p6TRf9du\ne6vpEhZ07PHrmy4hFUZ2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE\nYQeSIOxAEoQdSILr2Zc5rhnHeaXCbvu4pDOSZiSdi4jRKooCUL0qRvZfRsTnFbwPgBrxmR1IomzY\nQ9Lrtt+1PTbfE2yP2Z6wPTEzfbbk5gD0quxh/I0RcdL2DyTttf2viNg/9wkRMS5pXJK++6PVUXJ7\nAHpUamSPiJPF7ZSkVyWtr6IoANXrOey2B21ffv6+pFskHaqqMADVKnMYPyLpVdvn3+evEfG3SqpK\n5sPNTy26/pqXflvbe7dZmf9uXKznsEfER5J+VmEtAGpE6w1IgrADSRB2IAnCDiRB2IEkuMR1GVjK\n7TX0DyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRBn71LZXrZnS7VrPNSTi4TxXmM7EAShB1IgrAD\nSRB2IAnCDiRB2IEkCDuQBH32LrW5X93m2tAejOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kAR99uSu\n3fZWre9/7PHra31/dK/jyG57p+0p24fmLBu2vdf20eJ2Rb1lAiirm8P4ZyVtuGDZA5L2RcQaSfuK\nxwBarGPYI2K/pNMXLN4kaVdxf5ek2yquC0DFej1BNxIRk8X9TyWNLPRE22O2J2xPzEyf7XFzAMoq\nfTY+IkJSLLJ+PCJGI2J0YGiw7OYA9KjXsJ+yvVKSitup6koCUIdew75b0tbi/lZJr1VTDoC6dOyz\n235B0k2SrrB9QtJDknZIesn2XZI+lrS5ziKXu7Lzqzd5PXvWPnrZ7yc0sd86hj0itiyw6uaKawFQ\nI74uCyRB2IEkCDuQBGEHkiDsQBJc4toHnVpr/BT08tPGliQjO5AEYQeSIOxAEoQdSIKwA0kQdiAJ\nwg4kQZ+9D9rcR29jP7gN6v6J7SYwsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvTZgXl0+v7BUuzD\nM7IDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL02VFKmd/EX8q/p78Ufweg48hue6ftKduH5izbbvuk\n7QPF38Z6ywRQVjeH8c9K2jDP8sciYl3xt6fasgBUrWPYI2K/pNN9qAVAjcqcoLvX9sHiMH/FQk+y\nPWZ7wvbEzPTZEpsDUEavYX9S0jWS1kmalPTIQk+MiPGIGI2I0YGhwR43B6CsnsIeEaciYiYivpL0\ntKT11ZYFoGo9hd32yjkPb5d0aKHnAmiHjn122y9IuknSFbZPSHpI0k2210kKSccl3V1jjahR2V53\nm3vh+KaOYY+ILfMsfqaGWgDUiK/LAkkQdiAJwg4kQdiBJAg7kASXuC4BnX62eClebinRtus3RnYg\nCcIOJEHYgSQIO5AEYQeSIOxAEoQdSII+ewuUnf53sdd36sHT686DkR1IgrADSRB2IAnCDiRB2IEk\nCDuQBGEHkqDPvszVfS38cr3WfjliZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJOizJ1f2Wvo6358e\nfbU6juy2V9t+w/YHtg/bvq9YPmx7r+2jxe2K+ssF0KtuDuPPSbo/ItZKul7SPbbXSnpA0r6IWCNp\nX/EYQEt1DHtETEbEe8X9M5KOSFolaZOkXcXTdkm6ra4iAZT3rU7Q2b5K0nWS3pY0EhGTxapPJY0s\n8Jox2xO2J2amz5YoFUAZXYfd9pCklyVti4gv5q6LiJAU870uIsYjYjQiRgeGBksVC6B3XYXd9qWa\nDfrzEfFKsfiU7ZXF+pWSpuopEUAVOrbebFvSM5KORMSjc1btlrRV0o7i9rVaKkygU4upzvZV3a23\nxdBa669u+uw3SLpT0vu2DxTLHtRsyF+yfZekjyVtrqdEAFXoGPaIeFOSF1h9c7XlAKgLX5cFkiDs\nQBKEHUiCsANJEHYgCS5xXQKa7Ec3ue0PNz/V2LaX41TWjOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7\nkAR99uTK9tGb7IWXsRz76J0wsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvTZk2tznzxjL7xOjOxA\nEoQdSIKwA0kQdiAJwg4kQdiBJAg7kEQ387OvlvScpBFJIWk8Ip6wvV3SbyR9Vjz1wYjYU1ehmbW5\nF94JvfL26OZLNeck3R8R79m+XNK7tvcW6x6LiIfrKw9AVbqZn31S0mRx/4ztI5JW1V0YgGp9q8/s\ntq+SdJ2kt4tF99o+aHun7RULvGbM9oTtiZnps6WKBdC7rsNue0jSy5K2RcQXkp6UdI2kdZod+R+Z\n73URMR4RoxExOjA0WEHJAHrRVdhtX6rZoD8fEa9IUkScioiZiPhK0tOS1tdXJoCyOobdtiU9I+lI\nRDw6Z/nKOU+7XdKh6ssDUJVuzsbfIOlOSe/bPlAse1DSFtvrNNuOOy7p7loqBO0rVKKbs/FvSvI8\nq+ipA0sI36ADkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k\n4Yjo38bszyR9PGfRFZI+71sB305ba2trXRK19arK2n4cEd+fb0Vfw37Rxu2JiBhtrIBFtLW2ttYl\nUVuv+lUbh/FAEoQdSKLpsI83vP3FtLW2ttYlUVuv+lJbo5/ZAfRP0yM7gD4h7EASjYTd9gbb/7Z9\nzPYDTdSwENvHbb9v+4DtiYZr2Wl7yvahOcuGbe+1fbS4nXeOvYZq2277ZLHvDtje2FBtq22/YfsD\n24dt31csb3TfLVJXX/Zb3z+z2x6Q9B9Jv5J0QtI7krZExAd9LWQBto9LGo2Ixr+AYfsXkqYlPRcR\nPy2W/VHS6YjYUfxDuSIift+S2rZLmm56Gu9itqKVc6cZl3SbpF+rwX23SF2b1Yf91sTIvl7SsYj4\nKCK+lPSipE0N1NF6EbFf0ukLFm+StKu4v0uz/7P03QK1tUJETEbEe8X9M5LOTzPe6L5bpK6+aCLs\nqyR9MufxCbVrvveQ9Lrtd22PNV3MPEYiYrK4/6mkkSaLmUfHabz76YJpxluz73qZ/rwsTtBd7MaI\n+LmkWyXdUxyutlLMfgZrU++0q2m8+2Weaca/1uS+63X687KaCPtJSavnPL6yWNYKEXGyuJ2S9Kra\nNxX1qfMz6Ba3Uw3X87U2TeM93zTjasG+a3L68ybC/o6kNbavtn2ZpDsk7W6gjovYHixOnMj2oKRb\n1L6pqHdL2lrc3yrptQZr+Ya2TOO90DTjanjfNT79eUT0/U/SRs2ekf9Q0h+aqGGBun4i6Z/F3+Gm\na5P0gmYP6/6n2XMbd0n6nqR9ko5K+oek4RbV9hdJ70s6qNlgrWyoths1e4h+UNKB4m9j0/tukbr6\nst/4uiyQBCfogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNPDLQpJGSWBgAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_6ar0vcD9YTK", - "colab_type": "text" - }, - "source": [ - "Certainly a much more parsimonious representation. The obvious question now becomes:\n", - "\n", - "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", - "\n", - "I got great results all the way down to two. Though this may because the \n", - "bias and W_1 values weren't restricted. It would be interesting to see what would happen if we restricted those weights to a small code book as well. \n", - "\n", - "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? \n", - "\n", - "Less accurate, but not by a large amount. A final accuracy of 94% as opposed to 96%." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "GGUePifO9YTM", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} \ No newline at end of file From c9b74db515c3237e054cfe7e734224f3275e0658 Mon Sep 17 00:00:00 2001 From: tmayer868 Date: Wed, 1 Apr 2020 12:50:43 -0600 Subject: [PATCH 3/3] Add files via upload --- deep_compression_exercise.ipynb | 2270 +++++++++++++++++++++++++------ 1 file changed, 1823 insertions(+), 447 deletions(-) diff --git a/deep_compression_exercise.ipynb b/deep_compression_exercise.ipynb index f903c10..81451f0 100644 --- a/deep_compression_exercise.ipynb +++ b/deep_compression_exercise.ipynb @@ -1,449 +1,1825 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Exercise Week 9: Pruning and Quantization\n", - "This week, we will explore some of the ideas discussed in Han, Mao, and Dally's Deep Compression. In particular, we will implement weight pruning with fine tuning, as well as k-means weight quantization. **Note that we will unfortunately not be doing this in a way that will actually lead to substantial efficiency gains: that would involve the use of sparse matrices which are not currently well-supported in pytorch.** \n", - "\n", - "## Training an MNIST classifier\n", - "For this example, we'll work with a basic multilayer perceptron with a single hidden layer. We will train it on the MNIST dataset so that it can classify handwritten digits. As usual we load the data:" - ] + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "colab": { + "name": "deep_compression_exercise.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "be09c445b75f468e8dcfa18358ce4a36": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_5aac198c9a3d4eea9be7fb8beeb227ce", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_03e6a001de424248a4122d33143e75f0", + "IPY_MODEL_c005a5ba50db480caba161765274b3f9" + ] + } + }, + "5aac198c9a3d4eea9be7fb8beeb227ce": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "03e6a001de424248a4122d33143e75f0": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_60e876032db7469e846f94c3ac3d5942", + "_dom_classes": [], + "description": "", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_00e93303f3b447a1bf56f08dfde7de4d" + } + }, + "c005a5ba50db480caba161765274b3f9": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_4403339d5717448687dcb5155c9e1239", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 9920512/? [00:20<00:00, 1487669.97it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_ac8b2caa8abb45189966110e612433c1" + } + }, + "60e876032db7469e846f94c3ac3d5942": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "00e93303f3b447a1bf56f08dfde7de4d": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "4403339d5717448687dcb5155c9e1239": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "ac8b2caa8abb45189966110e612433c1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "fc589a1423bd406581f252ecf8ea97bd": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_e7f77c796bc14d9ab6f97137424b925f", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_0e847e7b236c41b79dc518091bf7067f", + "IPY_MODEL_3e2bf6a25441443c92dc5c3c2d1fef22" + ] + } + }, + "e7f77c796bc14d9ab6f97137424b925f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "0e847e7b236c41b79dc518091bf7067f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_ac1223dd567b41589efcc1044c8fa962", + "_dom_classes": [], + "description": " 0%", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 0, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_092573b76819403ca18edea829d1bb31" + } + }, + "3e2bf6a25441443c92dc5c3c2d1fef22": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_7a47e9e824d345fa916984efef7b2f3d", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 0/28881 [00:00<?, ?it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_08c5f10c348649a9a927c221cf2aa965" + } + }, + "ac1223dd567b41589efcc1044c8fa962": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "092573b76819403ca18edea829d1bb31": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "7a47e9e824d345fa916984efef7b2f3d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "08c5f10c348649a9a927c221cf2aa965": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "5315c5e02993418a9d70d8568a24574c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_7d973a7b57544afcb7deeed203361142", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_6ebfc1cfcbca4b41aa3214b8ab36e790", + "IPY_MODEL_fc64f32f6bf741ffb46025054188bff0" + ] + } + }, + "7d973a7b57544afcb7deeed203361142": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "6ebfc1cfcbca4b41aa3214b8ab36e790": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_9a43b17aa8bc46508b7fa9fabaf1bf32", + "_dom_classes": [], + "description": "", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_438ac4bcdab249bcbbaccbe207981ffd" + } + }, + "fc64f32f6bf741ffb46025054188bff0": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_279398b30dc6479b99e352bbc4651b76", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 1654784/? [00:18<00:00, 511954.24it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_7c8a5703523f4c0f80b8fee415f3fa83" + } + }, + "9a43b17aa8bc46508b7fa9fabaf1bf32": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "438ac4bcdab249bcbbaccbe207981ffd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "279398b30dc6479b99e352bbc4651b76": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "7c8a5703523f4c0f80b8fee415f3fa83": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "e4c516162b844d6fbd33cf29584aae20": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_6779960e36614cef9c512a615a2980bf", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_9b09958b45a441d29688f6eeba9707d3", + "IPY_MODEL_62a69547013748de9d80c4a992748698" + ] + } + }, + "6779960e36614cef9c512a615a2980bf": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "9b09958b45a441d29688f6eeba9707d3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_f49a8270a4364784b39bc0eca205a8eb", + "_dom_classes": [], + "description": " 0%", + "_model_name": "IntProgressModel", + "bar_style": "info", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 0, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_bca6c5f7f93942e98c6ef7412e311dfb" + } + }, + "62a69547013748de9d80c4a992748698": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_5a1aef0316fd4ebbbd436bd2ecf2383c", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 0/4542 [00:00<?, ?it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_00f7f24316dc48b1adeb59b3d0dadc4b" + } + }, + "f49a8270a4364784b39bc0eca205a8eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "bca6c5f7f93942e98c6ef7412e311dfb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "5a1aef0316fd4ebbbd436bd2ecf2383c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "00f7f24316dc48b1adeb59b3d0dadc4b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torchvision.transforms as transforms\n", - "import torchvision.datasets as datasets\n", - "\n", - "device = torch.device('cuda' if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", - "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n", - "\n", - "batch_size = 300\n", - "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", - "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then define a model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MultilayerPerceptron(torch.nn.Module):\n", - " def __init__(self, input_dim, hidden_dim, output_dim,mask=None):\n", - " super(MultilayerPerceptron, self).__init__()\n", - " if not mask:\n", - " self.mask = torch.nn.Parameter(torch.ones(input_dim,hidden_dim),requires_grad=False)\n", - " else:\n", - " self.mask = torch.nn.Parameter(mask)\n", - "\n", - " self.W_0 = torch.nn.Parameter(1e-3*torch.randn(input_dim,hidden_dim)*self.mask,requires_grad=True)\n", - " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim),requires_grad=True)\n", - "\n", - " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", - " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", - " \n", - " def set_mask(self,mask):\n", - " \n", - " self.mask.data = mask.data\n", - " self.W_0.data = self.mask.data*self.W_0.data\n", - "\n", - " def forward(self, x):\n", - " hidden = torch.tanh(x@(self.W_0*self.mask) + self.b_0)\n", - " outputs = hidden@self.W_1 + self.b_1\n", - " return outputs\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that the above code is a little bit different than a standard multilayer perceptron implementation.\n", - "\n", - "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", - "\n", - "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_epochs = 10\n", - "\n", - "input_dim = 784\n", - "hidden_dim = 64\n", - "output_dim = 10\n", - "\n", - "model = MultilayerPerceptron(input_dim,hidden_dim,output_dim)\n", - "model = model.to(device)\n", - "\n", - "criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy\n", - "lr_rate = 0.001\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And then training proceeds as normal." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "iter = 0\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_pretrained.h5')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Pruning\n", - "\n", - "Certainly not a state of the art model, but also not a terrible one. Because we're hoping to do some weight pruning, let's inspect some of the weights directly (recall that we can act like they're images)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "W_0 = model.W_0.detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", - "\n", - "\n", - "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "new_mask = model.mask" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model.set_mask(new_mask)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we have explicitly set some entries in one of the the weight matrices to zero, and ensured via the mask, that they will not be updated by gradient descent. Fine tune the model: " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "iter = 0\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_pruned.h5')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", - "\n", - "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "W_0 = model.W_0.detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Quantization\n", - "\n", - "Now that we have a pruned model that appears to be performing well, let's see if we can make it even smaller by quantization. To do this, we'll need a slightly different neural network, one that corresponds to Figure 3 from the paper. Instead of having a matrix of float values, we'll have a matrix of integer labels (here called \"labels\") that correspond to entries in a (hopefully) small codebook of centroids (here called \"centroids\"). The way that I've coded it, there's still a mask that enforces our desired sparsity pattern." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MultilayerPerceptronQuantized(torch.nn.Module):\n", - " def __init__(self, input_dim, output_dim, hidden_dim,mask,labels,centroids):\n", - " super(MultilayerPerceptronQuantized, self).__init__()\n", - " self.mask = torch.nn.Parameter(mask,requires_grad=False)\n", - " self.labels = torch.nn.Parameter(labels,requires_grad=False)\n", - " self.centroids = torch.nn.Parameter(centroids,requires_grad=True)\n", - "\n", - " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim))\n", - "\n", - " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim))\n", - " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim))\n", - "\n", - " def forward(self, x):\n", - " W_0 = self.mask*self.centroids[self.labels]\n", - " hidden = torch.tanh(x@W_0 + self.b_0)\n", - " outputs = hidden@self.W_1 + self.b_1\n", - " return outputs" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Notice what is happening in the forward method: W_0 is being reconstructed by using a matrix (self.labels) to index into a vector (self.centroids). The beauty of automatic differentiation allows backpropogation through this sort of weird indexing operation, and thus gives us gradients of the objective function with respect to the centroid values!\n", - "\n", - "### Q6: However, before we are able to use this AD magic, we need to specify the static label matrix (and an initial guess for centroids). Use the k-means algorithm (or something else if you prefer) figure out the label matrix and centroid vectors. PROTIP1: I used scikit-learns implementation of k-means. PROTIP2: only cluster the non-zero entries" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# convert weight and mask matrices into numpy arrays\n", - "W_0 = model.W_0.detach().cpu().numpy()\n", - "mask = model.mask.detach().cpu().numpy()\n", - "\n", - "# Figure out the indices of non-zero entries \n", - "inds = np.where(mask!=0)\n", - "# Figure out the values of non-zero entries\n", - "vals = W_0[inds]\n", - "\n", - "### TODO: perform clustering on vals\n", - "\n", - "### TODO: turn the label matrix and centroids into a torch tensor\n", - "labels = torch.tensor(...,dtype=torch.long,device=device)\n", - "centroids = torch.tensor(...,device=device)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can instantiate our quantized model and import the appropriate pre-trained weights for the other network layers. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Instantiate quantized model\n", - "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,new_mask,labels,centroids)\n", - "model_q = model_q.to(device)\n", - "\n", - "# Copy pre-trained weights from unquantized model for non-quantized layers\n", - "model_q.b_0.data = model.b_0.data\n", - "model_q.W_1.data = model.W_1.data\n", - "model_q.b_1.data = model.b_1.data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we can fine tune the quantized model. We'll adjust not only the centroids, but also the weights in the other layers." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", - "iter = 0\n", - "for epoch in range(n_epochs):\n", - " for i, (images, labels) in enumerate(train_loader):\n", - " images = images.view(-1, 28 * 28).to(device)\n", - " labels = labels.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model_q(images)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # calculate Accuracy\n", - " correct = 0\n", - " total = 0\n", - " for images, labels in test_loader:\n", - " images = images.view(-1, 28*28).to(device)\n", - " labels = labels.to(device)\n", - " outputs = model_q(images)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total+= labels.size(0)\n", - " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", - " correct+= (predicted == labels).sum()\n", - " accuracy = 100 * correct/total\n", - " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", - "torch.save(model.state_dict(),'mnist_quantized.h5')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "After retraining, we can, just for fun, reconstruct the pruned and quantized weights and plot them as images:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "W_0 = (model_q.mask*model_q.centroids[model_q.labels]).detach().cpu().numpy()\n", - "plt.imshow(W_0[:,1].reshape((28,28)))\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Certainly a much more parsimonious representation. The obvious question now becomes:\n", - "\n", - "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", - "\n", - "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "qe_xWm7z9YM6", + "colab_type": "text" + }, + "source": [ + "# Exercise Week 9: Pruning and Quantization\n", + "This week, we will explore some of the ideas discussed in Han, Mao, and Dally's Deep Compression. In particular, we will implement weight pruning with fine tuning, as well as k-means weight quantization. **Note that we will unfortunately not be doing this in a way that will actually lead to substantial efficiency gains: that would involve the use of sparse matrices which are not currently well-supported in pytorch.** \n", + "\n", + "## Training an MNIST classifier\n", + "For this example, we'll work with a basic multilayer perceptron with a single hidden layer. We will train it on the MNIST dataset so that it can classify handwritten digits. As usual we load the data:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Cysg2vAY9YNQ", + "colab_type": "code", + "outputId": "cbea6932-d4ff-497e-9ff3-39fa462cdcba", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 313, + "referenced_widgets": [ + "be09c445b75f468e8dcfa18358ce4a36", + "5aac198c9a3d4eea9be7fb8beeb227ce", + "03e6a001de424248a4122d33143e75f0", + "c005a5ba50db480caba161765274b3f9", + "60e876032db7469e846f94c3ac3d5942", + "00e93303f3b447a1bf56f08dfde7de4d", + "4403339d5717448687dcb5155c9e1239", + "ac8b2caa8abb45189966110e612433c1", + "fc589a1423bd406581f252ecf8ea97bd", + "e7f77c796bc14d9ab6f97137424b925f", + "0e847e7b236c41b79dc518091bf7067f", + "3e2bf6a25441443c92dc5c3c2d1fef22", + "ac1223dd567b41589efcc1044c8fa962", + "092573b76819403ca18edea829d1bb31", + "7a47e9e824d345fa916984efef7b2f3d", + "08c5f10c348649a9a927c221cf2aa965", + "5315c5e02993418a9d70d8568a24574c", + "7d973a7b57544afcb7deeed203361142", + "6ebfc1cfcbca4b41aa3214b8ab36e790", + "fc64f32f6bf741ffb46025054188bff0", + "9a43b17aa8bc46508b7fa9fabaf1bf32", + "438ac4bcdab249bcbbaccbe207981ffd", + "279398b30dc6479b99e352bbc4651b76", + "7c8a5703523f4c0f80b8fee415f3fa83", + "e4c516162b844d6fbd33cf29584aae20", + "6779960e36614cef9c512a615a2980bf", + "9b09958b45a441d29688f6eeba9707d3", + "62a69547013748de9d80c4a992748698", + "f49a8270a4364784b39bc0eca205a8eb", + "bca6c5f7f93942e98c6ef7412e311dfb", + "5a1aef0316fd4ebbbd436bd2ecf2383c", + "00f7f24316dc48b1adeb59b3d0dadc4b" + ] + } + }, + "source": [ + "import torch\n", + "import torchvision.transforms as transforms\n", + "import torchvision.datasets as datasets\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", + "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n", + "\n", + "batch_size = 300\n", + "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be09c445b75f468e8dcfa18358ce4a36", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fc589a1423bd406581f252ecf8ea97bd", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5315c5e02993418a9d70d8568a24574c", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e4c516162b844d6fbd33cf29584aae20", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "Processing...\n", + "Done!\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8kdx9cAu9YOU", + "colab_type": "text" + }, + "source": [ + "Then define a model:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "UInyoax99YOf", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MultilayerPerceptron(torch.nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim,mask=None):\n", + " super(MultilayerPerceptron, self).__init__()\n", + " if not mask:\n", + " self.mask = torch.nn.Parameter(torch.ones(input_dim,hidden_dim),requires_grad=False)\n", + " else:\n", + " self.mask = torch.nn.Parameter(mask)\n", + "\n", + " self.W_0 = torch.nn.Parameter(1e-3*torch.randn(input_dim,hidden_dim)*self.mask,requires_grad=True)\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim),requires_grad=True)\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim),requires_grad=True)\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim),requires_grad=True)\n", + " \n", + " def set_mask(self,mask):\n", + " \n", + " self.mask.data = mask.data\n", + " self.W_0.data = self.mask.data*self.W_0.data\n", + "\n", + " def forward(self, x):\n", + " hidden = torch.tanh(x@(self.W_0*self.mask) + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GAM40Rnb9YPH", + "colab_type": "text" + }, + "source": [ + "Note that the above code is a little bit different than a standard multilayer perceptron implementation.\n", + "\n", + "### Q1: What does this model have the capability of doing that a \"Vanilla\" MLP does not. Why might we want this functionality for studying pruning?\n", + "\n", + "This model can \"mask\" parameters that we want to ignore, by\n", + "multiplying that paramter by 0. This has the same affect\n", + "as prunning the parameters even though we still technically\n", + "do multiplications and additions with those zeroed out parameters.\n", + "\n", + "Let's first train this model without utilizing this extra functionality. You can set the hidden layer size to whatever you'd like when instantiating the model:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "IUZhjRk79YPV", + "colab_type": "code", + "colab": {} + }, + "source": [ + "n_epochs = 10\n", + "\n", + "input_dim = 784\n", + "hidden_dim = 64\n", + "output_dim = 10\n", + "\n", + "model = MultilayerPerceptron(input_dim,hidden_dim,output_dim)\n", + "model = model.to(device)\n", + "\n", + "criterion = torch.nn.CrossEntropyLoss() # computes softmax and then the cross entropy\n", + "lr_rate = 0.001\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-3)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "44vk9p1D9YPv", + "colab_type": "text" + }, + "source": [ + "And then training proceeds as normal." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tinBThTV9YP4", + "colab_type": "code", + "outputId": "e857ad65-f505-42c4-bd7e-6ce69c6d56fb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 185 + } + }, + "source": [ + "iter = 10\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(epoch, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_pretrained.h5')\n" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.4008619785308838. Accuracy: 90.\n", + "Iteration: 1. Loss: 0.331943154335022. Accuracy: 92.\n", + "Iteration: 2. Loss: 0.2710331082344055. Accuracy: 93.\n", + "Iteration: 3. Loss: 0.21628224849700928. Accuracy: 93.\n", + "Iteration: 4. Loss: 0.18730275332927704. Accuracy: 94.\n", + "Iteration: 5. Loss: 0.18389032781124115. Accuracy: 95.\n", + "Iteration: 6. Loss: 0.1158808246254921. Accuracy: 95.\n", + "Iteration: 7. Loss: 0.13385257124900818. Accuracy: 95.\n", + "Iteration: 8. Loss: 0.18437758088111877. Accuracy: 95.\n", + "Iteration: 9. Loss: 0.1472378671169281. Accuracy: 95.\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jka2xOtp9YQO", + "colab_type": "text" + }, + "source": [ + "## Pruning\n", + "\n", + "Certainly not a state of the art model, but also not a terrible one. Because we're hoping to do some weight pruning, let's inspect some of the weights directly (recall that we can act like they're images)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "IZVDikFr9YQU", + "colab_type": "code", + "outputId": "830c517c-54b2-4f0a-fad6-d014c12f0242", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + } + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAVv0lEQVR4nO3dXWyk1XkH8P8z4/H4e73eD7MsCwub\nTQui7UIs0iQ0Io2IgF6QXBSFi5RKqJuLICVRLhrRi3CJqnyIiyrSpqCQKiWKlERwQdoAQkGpWoSX\nbvhayC4U1rvxx3rN7nr8MZ6PpxcekAGf/zHz4Rlz/j/Jsj2Pz/ueeWeeecfzvOccc3eIyEdfpt0d\nEJHNoWQXSYSSXSQRSnaRRCjZRRLRtZk7y/b3e9fIyGbuUiQp5bk5VBYWbL1YQ8luZrcAeABAFsC/\nuvv97O+7RkZw2de/2cguRYQ4/cAPgrG638abWRbAvwC4FcA1AO40s2vq3Z6ItFYj/7PfAOCku7/h\n7isAfgbg9uZ0S0SarZFk3wtgYs3vp2u3vYeZHTazcTMbrxYWGtidiDSi5Z/Gu/sRdx9z97HMQH+r\ndyciAY0k+xkA+9b8flntNhHpQI0k+3MADprZlWbWDeDLAB5rTrdEpNnqLr25e9nM7gHwn1gtvT3k\n7i83rWdbSWTgYKYUaR57FCLbt3I4ll1Zt+T6rkqeb7zcz+PZZb79DNl/Nce3bbH7XeHxSp415m0/\nihqqs7v74wAeb1JfRKSFdLmsSCKU7CKJULKLJELJLpIIJbtIIpTsIonY1PHsbVWNxBt42YvVgzOl\nSFF3hYdzBR7PLpO2i5HORcIr2/iB8Wxk86R5NcePi8dq4ZG4kcfcM/yOx+5XNcfjlX7+hDPynGDX\nJjRCZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEpFO6S3yshYrxbChnFbhpRJWGgOAriUezxV437qW\nw/Ftx+dpWzv+Oo1ndu6gcWT5gS2PDgdjhf185qKlEX5cS4OREhU5bKVtvGk1UnpjZT0AQLnzxtDq\nzC6SCCW7SCKU7CKJULKLJELJLpIIJbtIIpTsIonYWnV2Vm6OlVyzjQ1pLG0LF1a9i2+7GOlbdp7v\nvHeab6B/Mrz/8jCbTxnIR+roXuDjayvnL9B4drkYjHWN7qdtV4Ybe3qWBsPHJTaFtkemuUbk2grv\njgxxXQqfZ2NTZMeeqyE6s4skQskukgglu0gilOwiiVCyiyRCyS6SCCW7SCK2Vp2diNXRY1MmV3sj\nxU1SS8/187mgPzY6S+On3t5O44X+yLjv68P7nynx1/Puty6n8Wo3DSN3IbJkMzms5R6+7WqkFh67\n/oBNyRwbC78yzOvk1T7+fMks8mI4Gw9fbVFWNrRZM3sTwDyACoCyu481o1Mi0nzNeA35nLvzU5eI\ntJ3+ZxdJRKPJ7gB+Y2ZHzezwen9gZofNbNzMxquFhQZ3JyL1avRt/I3ufsbMdgN4wsxedfdn1v6B\nux8BcAQA8vv2RT4mE5FWaejM7u5nat9nAPwKwA3N6JSINF/dyW5m/WY2+M7PAL4A4KVmdUxEmquR\nt/GjAH5lZu9s59/d/T+a0qs6ZCKDxmM1W0Tq9NnecjgWaTtxPjx3OgD87YH/pfFnt++n8QxZM/qv\ndpykbT95I583/lM94fHoAPDfy3y8/G8LfxqMVSOTEDw5+Sc0/sdJfn0CiuFat5HHEwBwka/JbGV+\nnmTrDKxuIByKjbWPjXcPqTvZ3f0NAH9Rb3sR2VwqvYkkQskukgglu0gilOwiiVCyiyRiaw1xZeWK\n3khpLTadcx8vxVyxey4YO7/Ex2ru6ueXCb9WGKXxwgovb12/cyIYy0XqNNfled/y1kvjF6v8vrP9\nH7t4GW27Lc/Xuu67fIbGS5Vw6W2+yMfuLvbxY16c6aPxGCen2dhy0FatbzlondlFEqFkF0mEkl0k\nEUp2kUQo2UUSoWQXSYSSXSQRW6rO7plwLT06lfRwiYYvGblI41cNngvGlvr4cMiLJV6LPja5l8Z7\n83yq6qnloWDsf6b307a/PH2IxvNd/PqDtxd5Hf78RHh4r5GpngHgkqt5HX1unk+xvTwb7ltmkD8f\nunKRcaSRsEVG0Dq761ZfHT1GZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0lEZ9XZI+N4Wf3R\nu3njbBeP5zI8fvLizmDszCyfKtrI9QEAMNDHx21nI+2fffWq8L4jSwfnz/F47xTf98owrwn3NjAH\nwfkFXsM3MoU2APSPhsfql0r8fnvkso3cPD9PDp/g7ef3hQ9MeSAylTR5qrJDojO7SCKU7CKJULKL\nJELJLpIIJbtIIpTsIolQsoskoqPq7JGyKa/Dx2r0EYslPiZ9aSUcr0SW783l+eDmvUN8LP0fpnfR\nePdUuG+9U7wOvvsonzc+c/RVGi9+7s9p/PxV4b6V6KBuYLnIH5MrR8NzDADAfDE893sBfF74pTcH\naXzXcf5kzS3yeDXPJo6nTVElh4Ud0uiZ3cweMrMZM3tpzW0jZvaEmZ2ofY8slC0i7baRt/E/BnDL\n+277NoCn3P0ggKdqv4tIB4smu7s/A+D9ax/dDuDh2s8PA/hik/slIk1W7wd0o+4+Wft5CkBwsTIz\nO2xm42Y2Xi3w/w9FpHUa/jTe3R1A8NMIdz/i7mPuPpYZ4BMEikjr1Jvs02a2BwBq3/k0oCLSdvUm\n+2MA7qr9fBeAR5vTHRFplWid3cweAXATgJ1mdhrAdwDcD+DnZnY3gLcA3NGMzjgfYkz+WUD0ZWto\nYInGF2PrdZ8ZCO+6yAuj134yvH46AFy/jccHuoo0/mxxfzBWyPN68qVPLtK47eFrx5+7htfCq+Sw\nLu/kF0dUCnzbJ4u7adwK4ad3/wR/sg3xw4Lh1+ZpvNLHU8uq4bH6bH2ERkST3d3vDIQ+3+S+iEgL\n6XJZkUQo2UUSoWQXSYSSXSQRSnaRRHTUENcYZ6MCK7z8tbjMS2sr0300PngyXKop7uClkt09BRp/\n+uzHaXwpMvx27+7zwdjEAh8eu3QFnwY7sxIpj/HKHpZHSfvYsOTIY4oyL5/1nwrHe2b5Y9b7Nl+T\nObPIl9FeuoRPg10aYPNB06Z8uWdCZ3aRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0nElqqzU5Ga\nbXGeF4S7lvjrXpmU4VdG+M6fmTjAtx2Zinqwjw9x3Tf0djA2kQkvNQ0Ak5/mT4GBUzSMLO8aVR3m\nU2zH6uz9r0em/94drqV389m7kZvndfaYxV38GoDKYOS+M3S+6PB91pldJBFKdpFEKNlFEqFkF0mE\nkl0kEUp2kUQo2UUSsaXq7HRJ5yqvyVoXr4V3v83bl4bIziPjixcn+PK/XQt8AxcQnsYaAOb6RsLB\nbn6/Y2Oji8P8D5YuiWx/JDzue9cOPh3z3Cv8GoElNlYeQO5i+FyWK/C2S7siU2Tnhmi8uxCZDpo8\nX2NzM1iZxMkDqjO7SCKU7CKJULKLJELJLpIIJbtIIpTsIolQsoskYkvV2anIMreZs3ze+NwC33xp\nG2l7nr9mVviu0bXE66qxWnhpe3jstfXycdmlyPUJ1W5+32676SiNv1EI18qvHpqibR99kdfZyz38\nMa/mw/GLV/L7lSnRMLZF5k8o5/lxzV0Ij3cvDUfmrGdzL5BDEj2zm9lDZjZjZi+tue0+MztjZsdq\nX7fFtiMi7bWRt/E/BnDLOrf/wN0P1b4eb263RKTZosnu7s8AmNuEvohICzXyAd09ZvZC7W3+9tAf\nmdlhMxs3s/FqIfKPsYi0TL3J/kMABwAcAjAJ4HuhP3T3I+4+5u5jmYH+OncnIo2qK9ndfdrdK+5e\nBfAjADc0t1si0mx1JbuZ7Vnz65cAvBT6WxHpDNE6u5k9AuAmADvN7DSA7wC4ycwOYbWq9yaAr7aw\nj+9i9ebYGOBskcfz5yM121y4/fwBPgd4pshfUz3D+9Y3FRkbTR7GA389QVuO5PnnKP/1Op/z/rmz\nl9P4aF94bfoMnaAA6L86PB8+AJyf4/8WVsl8/pUuXsuunODbLg7yx6zSy+PlQbL/LD8unqtvboVo\nsrv7nevc/GCsnYh0Fl0uK5IIJbtIIpTsIolQsoskQskukoiPzhDXWHUqotzD410L4R10FfjyvFVW\nKgHQfZ7ve+ityPK+Hn4YT0zvok2/+Wcv0PjvB/bS+NRk8Erp1Xg1HH+5b08wBgD5Hj7O1CLDmrty\n4eNWjQzt9ciw5HIfb1/khwVOhh7bfCQt2fDaRoa4ishHg5JdJBFKdpFEKNlFEqFkF0mEkl0kEUp2\nkURsqTo7GxHpkTp7bDrn+at4PLtIlsIt8p0PnOKvqSPHizSeP8WHemZWwkXdSg9f7vm7lZtp3Kf4\nBQhDE/y+VcnKxyuDfFnkxe15Gr/0Y2dpfHdfeEno1+f4NNXzfXyu6MIVkSm4+V2DLZBrMyJDf8ll\nFXSIq87sIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SiM6qs0eWwc2uhGOV7ZHaJB9yjmqktgkP\nvy52X4gsz1uIjGc/y6dz9jN8aeOecnhs9A7wenL3hV4aZ3VbAOib5mPOPRvewMIl/EFZXObxs7v4\nNQTFcvjpvVJqYMw4gAqbChqIrrNtJXLdBokBkWnTNZ5dRJTsIolQsoskQskukgglu0gilOwiiVCy\niySio+rsGV6yRZX1tsF543MX+etephyubWb5cPRYyRXLe3i9uOf/IhcJWHgH3XNLtOnwMp+Tvmvm\nIt/3ucik9zvDY+37TvFJBmY/wSdfP3eQDxrftXs23HaOH3PvijyhYvHIE9KWNz/1omd2M9tnZk+b\n2Stm9rKZfb12+4iZPWFmJ2rfI9Pii0g7beRtfBnAt9z9GgB/CeBrZnYNgG8DeMrdDwJ4qva7iHSo\naLK7+6S7P1/7eR7AcQB7AdwO4OHanz0M4Iut6qSINO5DfUBnZvsBXAfgWQCj7j5ZC00BGA20OWxm\n42Y2Xi3wa8BFpHU2nOxmNgDgFwC+4e7v+dTG3R2BTyTc/Yi7j7n7WGagv6HOikj9NpTsZpbDaqL/\n1N1/Wbt52sz21OJ7AMy0posi0gzRz//NzAA8COC4u39/TegxAHcBuL/2/dFGO1PlMwejkg+XM7KL\nkdJZZERi3ySvj/XNkA2Q0hcALO7kfStcyktI3Vfvp3EUw+Wzhf28xNR9npferLDI9x15t1bcOxyM\nLezhpbf5K/muD+6bpvEMGbac7+V13mo+UpLs4k+opSl+3LPL4edMbInvem2k2PcZAF8B8KKZHavd\ndi9Wk/znZnY3gLcA3NGSHopIU0ST3d1/h/AUBp9vbndEpFV0uaxIIpTsIolQsoskQskukgglu0gi\nOmqIa0MiL1uVSO2STXkMAMWh8A4W9vK2lV6+7/w53n7q04M03n0xvH2LXF+QnyXzcwOYvZmvZV24\nLHKNwYHw9gd28OGzn9ozQeMH+/h1XLlMuFZ+eT9fBvu56ctpfH6RXxTiXXwuas+En09smmkAdZ+i\ndWYXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEbKk6e7YYqT8S1W5e6164lMcrA+G6aW4Xn665\n9HYPjcfWky4PRdYPHgqPze46zevBs5/g8Y9fe4rGPzF4jsbv3vlMMDYYmTs8F5mO+bdL/BqAt4rh\n5ar/uLiNtr0wz5eyLhd56lg5ch5ld63+pzmlM7tIIpTsIolQsoskQskukgglu0gilOwiiVCyiyRi\nS9XZ2dLHZIrw1bbdvFZd7ucDv3N94XHZ2wZ4nX02UpPFHK+zew/v2xWXzAVjxV1823+z92Ua/7vh\ncRo/UeL16vlq+BqDB2c/S9u+dmE3jc8t9NH48kp4Pv7iIp+r35f4Y5ZZiqxTEBuT3qJaOqMzu0gi\nlOwiiVCyiyRCyS6SCCW7SCKU7CKJULKLJGIj67PvA/ATAKNYHYV7xN0fMLP7APwDgLO1P73X3R9v\nVUeBeC2dYkV6AKjyjZcuhMd9nyvxWnYustZ3dT+/BuDWg6/S+Md6w/OnZ4xv+1APH6/+3PKlNP70\nhatp/OjsZcHYzOwQbVuN1LptmZ+rnKwVYCu8bZYPtY9NQdCRNnJRTRnAt9z9eTMbBHDUzJ6oxX7g\n7t9tXfdEpFk2sj77JIDJ2s/zZnYcwN5Wd0xEmutD/c9uZvsBXAfg2dpN95jZC2b2kJltD7Q5bGbj\nZjZeLSw01FkRqd+Gk93MBgD8AsA33P0igB8COADgEFbP/N9br527H3H3MXcfywz0N6HLIlKPDSW7\nmeWwmug/dfdfAoC7T7t7xd2rAH4E4IbWdVNEGhVNdjMzAA8COO7u319z+541f/YlAC81v3si0iwb\n+TT+MwC+AuBFMztWu+1eAHea2SGsluPeBPDVlvSwSbIF/rpm1cgwU1K58yXethRbLjoy/PbXL15L\n4yiS+xZb/XeA15i8EtnAPB8q2giL7DtaiiUVz2jbSKk2Mgt2R9rIp/G/w/pPmZbW1EWkuXQFnUgi\nlOwiiVCyiyRCyS6SCCW7SCKU7CKJ2FJTSTfCqo3N3cvqsrYS2XYsvhC7BoA3b2T132rk+oPYMyQ2\ncjgTu+9MdH7wSC2cLfEd6Vbsfm1FOrOLJELJLpIIJbtIIpTsIolQsoskQskukgglu0gizL2R+Zk/\n5M7MzgJ4a81NOwHMbloHPpxO7Vun9gtQ3+rVzL5d4e671gtsarJ/YOdm4+4+1rYOEJ3at07tF6C+\n1Wuz+qa38SKJULKLJKLdyX6kzftnOrVvndovQH2r16b0ra3/s4vI5mn3mV1ENomSXSQRbUl2M7vF\nzF4zs5Nm9u129CHEzN40sxfN7JiZjbe5Lw+Z2YyZvbTmthEze8LMTtS+r7vGXpv6dp+Znakdu2Nm\ndlub+rbPzJ42s1fM7GUz+3rt9rYeO9KvTTlum/4/u5llAfwBwM0ATgN4DsCd7v7KpnYkwMzeBDDm\n7m2/AMPMPgugAOAn7n5t7bZ/BjDn7vfXXii3u/s/dkjf7gNQaPcy3rXVivasXWYcwBcB/D3aeOxI\nv+7AJhy3dpzZbwBw0t3fcPcVAD8DcHsb+tHx3P0ZAHPvu/l2AA/Xfn4Yq0+WTRfoW0dw90l3f772\n8zyAd5YZb+uxI/3aFO1I9r0AJtb8fhqdtd67A/iNmR01s8Pt7sw6Rt19svbzFIDRdnZmHdFlvDfT\n+5YZ75hjV8/y543SB3QfdKO7Xw/gVgBfq71d7Ui++j9YJ9VON7SM92ZZZ5nxd7Xz2NW7/Hmj2pHs\nZwDsW/P7ZbXbOoK7n6l9nwHwK3TeUtTT76ygW/s+0+b+vKuTlvFeb5lxdMCxa+fy5+1I9ucAHDSz\nK82sG8CXATzWhn58gJn11z44gZn1A/gCOm8p6scA3FX7+S4Aj7axL+/RKct4h5YZR5uPXduXP3f3\nTf8CcBtWP5F/HcA/taMPgX5dBeD3ta+X2903AI9g9W1dCaufbdwNYAeApwCcAPAkgJEO6tu/AXgR\nwAtYTaw9berbjVh9i/4CgGO1r9vafexIvzbluOlyWZFE6AM6kUQo2UUSoWQXSYSSXSQRSnaRRCjZ\nRRKhZBdJxP8DiQuUv16KLFsAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jtGIcba89YQm", + "colab_type": "text" + }, + "source": [ + "### Q2: Based on the above image, what weights might reasonably be pruned (i.e. explicitly forced to be zero)?\n", + "\n", + "The weights that are very purple. \n", + "\n", + "### Q3: Implement some means of establishing a threshold for the (absolute value of the) weights, below which they are set to zero. Using this method, create a mask array. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4uh6-dKQ9YQs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "new_mask = model.mask\n", + "ratio_threshold = .2\n", + "#drop the values that are less than ratio_threshold the size of the max weight in absolute vaue.\n", + "mask = torch.abs(model.W_0.data)/torch.abs(torch.max(model.W_0.data)) > ratio_threshold\n", + "new_mask.data = (new_mask.data)*(mask.float())\n", + "model.set_mask(new_mask)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eGuQ3t-E9YRB", + "colab_type": "text" + }, + "source": [ + "Now that we have a mask that explicitly establishes a sparsity pattern for our model, let's update our model with this mask:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OocvWnxK9YRW", + "colab_type": "text" + }, + "source": [ + "Now, we have explicitly set some entries in one of the the weight matrices to zero, and ensured via the mask, that they will not be updated by gradient descent. Fine tune the model: " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WgdWUpuD9YRa", + "colab_type": "code", + "outputId": "a4ca20e0-b690-4489-d6f7-af33980dbef6", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 101 + } + }, + "source": [ + "iter = 0\n", + "n_epochs = 5\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_pruned.h5')" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.1554800122976303. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.16709843277931213. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.1133299395442009. Accuracy: 96.\n", + "Iteration: 0. Loss: 0.14041125774383545. Accuracy: 96.\n", + "Iteration: 0. Loss: 0.12037742882966995. Accuracy: 96.\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vHGmJW2P9YRs", + "colab_type": "text" + }, + "source": [ + "### Q4: How much accuracy did you lose by pruning the model? How much \"compression\" did you achieve (here defined as total entries in W_0 divided by number of non-zero entries)? \n", + "\n", + "Not much I pruned the bottom weights less than 20% of the max weight in absoloute value and the accuracy degraded a couple percent but quickly recovered after\n", + "degrade at all. droppping 60 percent of weights dropped the accuracy\n", + "to 70% but it recovered into the 80s.\n", + "\n", + "### Q5: Explore a few different thresholds: approximately how many weights can you prune before accuracy starts to degrade?" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "QrhV1kL99YRw", + "colab_type": "code", + "outputId": "0ea77a0f-d18c-492d-9a47-604db53c8be1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + } + }, + "source": [ + "W_0 = model.W_0.detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOKUlEQVR4nO3dbYxc5XnG8evyxga85sWGZOOA05DU\naoKa1qQbhwpUkRAooCqGSiCsgoyK6lSCKCiRUkQ+xF9aWW1DGlVVKqcgDAQILbGgLYI4DipK1FIW\n5BgDSSDG1F6M18S8mA3xy/ruhz1GG7PzzHreztj3/yetZubc58y5NfLlM3OeOfM4IgTg2Der7gYA\n9AZhB5Ig7EAShB1IgrADSbynlzsbGByM2acs6OUugVT2v75bE+Pjnq7WVthtXyzpm5IGJP1LRKwu\nrT/7lAX64F9+qZ1dAij4v3++pWGt5bfxtgck/ZOkSySdJWm57bNafT4A3dXOZ/alkl6IiC0RsU/S\nvZKWdaYtAJ3WTthPl7RtyuPt1bLfYHul7RHbIxPj423sDkA7un42PiLWRMRwRAwPDA52e3cAGmgn\n7KOSFk15fEa1DEAfaifsT0habPtM23MkXSXpwc60BaDTWh56i4gDtm+Q9Igmh95ui4hnOtYZgI5q\na5w9Ih6S9FCHegHQRXxdFkiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k\nQdiBJHr6U9LovTPv312sH9z002J9YP788g4GmhwvFpzSsLTlz4bK26KjOLIDSRB2IAnCDiRB2IEk\nCDuQBGEHkiDsQBKMsx/jdp5bHic/6YxPFutzX3y9WD/4wkvF+qyJg4Uq4+y9xJEdSIKwA0kQdiAJ\nwg4kQdiBJAg7kARhB5JgnP0Y96uFTervn11e4VPvbbKHZnX0i7bCbnurpD2SJiQdiIjhTjQFoPM6\ncWT/dES82oHnAdBFfGYHkmg37CHp+7aftL1yuhVsr7Q9YntkYny8zd0BaFW7b+PPi4hR2++TtN72\nTyPisakrRMQaSWsk6fjTF0Wb+wPQoraO7BExWt2OSVonaWknmgLQeS2H3fag7RMP3Zd0kaTNnWoM\nQGe18zZ+SNI624ee5+6IeLgjXSVz+zX/WKxfe+cXWn7u1cvvbHlbSZrrvcX6d1/9VLH+2Jbfblib\n2H1ccds5vxwo1uUmnwrD5XoyLYc9IrZI+v0O9gKgixh6A5Ig7EAShB1IgrADSRB2IAkuce0Dowea\nTIvcxKqr7mlY2/z2GcVtzxv8ebHerLfX9p1QrB93/P6GtfFZc4rbNh1awxHhyA4kQdiBJAg7kARh\nB5Ig7EAShB1IgrADSTDOPkOlsey5s8qXgX7l7muL9a/ec3UrLb1j/GDjS0XvWveZ4rZ3+dNt7bup\nwmWmxzGO3lMc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZZ2jVvcvr23mT8ei/u+9PW3/uZj+3\nzM81HzM4sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyzJzd7T7m+6KFfFusxq3y8ePGKBYWNGaPv\npaZHdtu32R6zvXnKsgW219t+vrptb5YDAF03k7fxt0u6+LBlN0naEBGLJW2oHgPoY03DHhGPSdp9\n2OJlktZW99dKuqzDfQHosFZP0A1FxI7q/iuShhqtaHul7RHbIxPj4y3uDkC72j4bHxEhqeHVEhGx\nJiKGI2J4YHCw3d0BaFGrYd9pe6EkVbdjnWsJQDe0GvYHJa2o7q+Q9EBn2gHQLU3H2W3fI+l8SafZ\n3i7pa5JWS7rP9nWSXpJ0ZTebPNbdt+KWYn3jr8tzrP/Nd1t/+Ztdrr7/1PJHr22fLc/Pfqz64CPl\n80/7Ti7PPf/KObM72c6MNA17RDT61YYLOtwLgC7i67JAEoQdSIKwA0kQdiAJwg4kwSWuPXDjFeWv\nIVy59ks96mQaTYbeZu2f6E0fR5lZ+w8W63UMrTXDkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCc\nvQf+4V+X1d1CQ/tOLtdfumReezsoXUPbxz8l/YEf7yvWvXd/jzrpHI7sQBKEHUiCsANJEHYgCcIO\nJEHYgSQIO5AE4+zANF4+t/xT0It+ePT9hDZHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2tGX4\nwmeL9ZEffKxh7eI/+d/itg//5yfLO6/xevhtnzkGx9lt32Z7zPbmKctW2R61vbH6u7S7bQJo10ze\nxt8u6eJpln8jIpZUfw91ti0AndY07BHxmKTdPegFQBe1c4LuBtubqrf58xutZHul7RHbIxPj423s\nDkA7Wg37tyR9RNISSTskfb3RihGxJiKGI2J4YHCwxd0BaFdLYY+InRExEREHJX1b0tLOtgWg01oK\nu+2FUx5eLmlzo3UB9Iem4+y275F0vqTTbG+X9DVJ59teosnZvbdK+nwXe0QX/d4FPyvWN234nWJ9\nZP1ZLe977NcntrwtjlzTsEfE8mkW39qFXgB0EV+XBZIg7EAShB1IgrADSRB2IAkucT0KvOdX5fqB\nua0/9wdOeKNY39T6Uzf11IaPdvHZcTiO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsfeCU5w8W\n654obz+wPxrWXv34QHHbh/+jxt8dqfGnoDPiyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3geO\nf608kD6rMI4uSXN2vd2wNu/F8r63fu7k8gpNDI0cKNZ3DvNPrF9wZAeSIOxAEoQdSIKwA0kQdiAJ\nwg4kQdiBJBgE7QN7Typfcz7v5b3F+v4FxzculofoNfREeZx87va3inWP7irWP/w/heL8k4rbbrn6\nfcU618MfmaZHdtuLbD9q+1nbz9j+YrV8ge31tp+vbud3v10ArZrJ2/gDkr4cEWdJOkfS9bbPknST\npA0RsVjShuoxgD7VNOwRsSMinqru75H0nKTTJS2TtLZaba2ky7rVJID2HdEJOtsfknS2pMclDUXE\njqr0iqShBtustD1ie2RifLyNVgG0Y8Zhtz1P0v2SboyIN6fWIiLU4FRQRKyJiOGIGB4YHGyrWQCt\nm1HYbc/WZNC/ExHfqxbvtL2wqi+UNNadFgF0QtOhN9uWdKuk5yLilimlByWtkLS6un2gKx0m8NpH\ny//nzttWHj+bNdH4EtltF51Q3Pb9j5eH3ma9Xh56iybDZ2997NSGtbE/KA85MrTWWTMZZz9X0jWS\nnra9sVp2syZDfp/t6yS9JOnK7rQIoBOahj0ifiSp0X+xF3S2HQDdwtdlgSQIO5AEYQeSIOxAEoQd\nSIJLXI8C2y6c28bWTa5xbeIXf356W9sXNRlH/+vldxXrS457uVg/rvD0T+0tXz77lbuvLdaPRhzZ\ngSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTe2Xp7PIKTYbpv3DFvxfr15+yrWFtb+wvbvtfb5e/\nX/Bvb55drE9E42PZHevyXbDJkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCc/VjX5Jrxqy//YbH+\nuZM2Fuu7Jsqz/Nz+ZuPrxu/cfk5x29Efn1Gs48hwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGYy\nP/siSXdIGtLk1c1rIuKbtldJ+gtJu6pVb46Ih7rVaGZ/+MdPF+uL5441rL12oHxN+Gw3nttdkta9\n8Yli/SdvlH9X/tlHFxfr6J2ZfKnmgKQvR8RTtk+U9KTt9VXtGxHx991rD0CnzGR+9h2SdlT399h+\nTlIXpwkB0A1H9Jnd9ocknS3p8WrRDbY32b7N9vwG26y0PWJ7ZGJ8vK1mAbRuxmG3PU/S/ZJujIg3\nJX1L0kckLdHkkf/r020XEWsiYjgihgcGy9+jBtA9Mwq77dmaDPp3IuJ7khQROyNiIiIOSvq2pKXd\naxNAu5qG3bYl3SrpuYi4ZcryhVNWu1zS5s63B6BTZnI2/lxJ10h62vah6x1vlrTc9hJNDsdtlfT5\nrnQI/fcjHy/Xe9QHjm4zORv/I0nTXRTNmDpwFOEbdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEH\nkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQcEb3bmb1L0ktTFp0m6dWeNXBk+rW3fu1LordWdbK334qI\n905X6GnY37VzeyQihmtroKBfe+vXviR6a1WveuNtPJAEYQeSqDvsa2ref0m/9tavfUn01qqe9Fbr\nZ3YAvVP3kR1AjxB2IIlawm77Yts/s/2C7Zvq6KER21ttP217o+2Rmnu5zfaY7c1Tli2wvd7289Xt\ntHPs1dTbKtuj1Wu30falNfW2yPajtp+1/YztL1bLa33tCn315HXr+Wd22wOSfi7pQknbJT0haXlE\nPNvTRhqwvVXScETU/gUM238k6S1Jd0TE71bL/lbS7ohYXf1HOT8i/qpPelsl6a26p/GuZitaOHWa\ncUmXSbpWNb52hb6uVA9etzqO7EslvRARWyJin6R7JS2roY++FxGPSdp92OJlktZW99dq8h9LzzXo\nrS9ExI6IeKq6v0fSoWnGa33tCn31RB1hP13StimPt6u/5nsPSd+3/aTtlXU3M42hiNhR3X9F0lCd\nzUyj6TTevXTYNON989q1Mv15uzhB927nRcQnJF0i6frq7WpfisnPYP00djqjabx7ZZppxt9R52vX\n6vTn7aoj7KOSFk15fEa1rC9ExGh1OyZpnfpvKuqdh2bQrW7Hau7nHf00jfd004yrD167Oqc/ryPs\nT0habPtM23MkXSXpwRr6eBfbg9WJE9kelHSR+m8q6gclrajur5D0QI29/IZ+mca70TTjqvm1q336\n84jo+Z+kSzV5Rv4Xkr5aRw8N+vqwpJ9Uf8/U3ZukezT5tm6/Js9tXCfpVEkbJD0v6QeSFvRRb3dK\nelrSJk0Ga2FNvZ2nybfomyRtrP4urfu1K/TVk9eNr8sCSXCCDkiCsANJEHYgCcIOJEHYgSQIO5AE\nYQeS+H9oSiuxioM8OwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Z3f8XNK9YR_", + "colab_type": "text" + }, + "source": [ + "## Quantization\n", + "\n", + "Now that we have a pruned model that appears to be performing well, let's see if we can make it even smaller by quantization. To do this, we'll need a slightly different neural network, one that corresponds to Figure 3 from the paper. Instead of having a matrix of float values, we'll have a matrix of integer labels (here called \"labels\") that correspond to entries in a (hopefully) small codebook of centroids (here called \"centroids\"). The way that I've coded it, there's still a mask that enforces our desired sparsity pattern." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2Of2GpCU9YSD", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MultilayerPerceptronQuantized(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim, hidden_dim,mask,labels,centroids):\n", + " super(MultilayerPerceptronQuantized, self).__init__()\n", + " self.mask = torch.nn.Parameter(mask,requires_grad=False)\n", + " self.labels = torch.nn.Parameter(labels,requires_grad=False)\n", + " self.centroids = torch.nn.Parameter(centroids,requires_grad=True)\n", + "\n", + " self.b_0 = torch.nn.Parameter(torch.zeros(hidden_dim))\n", + "\n", + " self.W_1 = torch.nn.Parameter(1e-3*torch.randn(hidden_dim,output_dim))\n", + " self.b_1 = torch.nn.Parameter(torch.zeros(output_dim))\n", + "\n", + " def forward(self, x):\n", + " W_0 = self.mask*(self.centroids[self.labels].reshape(784,64))\n", + " hidden = torch.tanh(x@W_0 + self.b_0)\n", + " outputs = hidden@self.W_1 + self.b_1\n", + " return outputs" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnBF8PwJ9YSQ", + "colab_type": "text" + }, + "source": [ + "Notice what is happening in the forward method: W_0 is being reconstructed by using a matrix (self.labels) to index into a vector (self.centroids). The beauty of automatic differentiation allows backpropogation through this sort of weird indexing operation, and thus gives us gradients of the objective function with respect to the centroid values!\n", + "\n", + "### Q6: However, before we are able to use this AD magic, we need to specify the static label matrix (and an initial guess for centroids). Use the k-means algorithm (or something else if you prefer) figure out the label matrix and centroid vectors. PROTIP1: I used scikit-learns implementation of k-means. PROTIP2: only cluster the non-zero entries" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bn9LC7zY9YSV", + "colab_type": "code", + "colab": {} + }, + "source": [ + "from sklearn.cluster import KMeans\n", + "import numpy as np\n", + "# convert weight and mask matrices into numpy arrays\n", + "W_0 = model.W_0.detach().cpu().numpy()\n", + "mask = model.mask.detach().cpu().numpy()\n", + "\n", + "# Figure out the indices of non-zero entries \n", + "inds = np.where(mask!=0)\n", + "# Figure out the values of non-zero entries\n", + "vals = W_0[inds]\n", + "\n", + "### TODO: perform clustering on vals\n", + "kmean = KMeans(n_clusters=2)\n", + "clusters = kmean.fit(vals.reshape(len(vals),1))\n", + "centroids = kmean.cluster_centers_\n", + "labels = []\n", + "\n", + "for val in W_0.reshape(784*64):\n", + " label = torch.argmin(torch.abs(val - torch.from_numpy(centroids)))\n", + " labels.append(label.data)\n", + "\n", + " \n", + "\n", + "### TODO: turn the label matrix and centroids into a torch tensor\n", + "labels = torch.tensor(labels,dtype=torch.long,device=device)\n", + "centroids = torch.tensor(centroids,device=device)\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y4SE1iks9YSk", + "colab_type": "text" + }, + "source": [ + "Now, we can instantiate our quantized model and import the appropriate pre-trained weights for the other network layers. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SLJS3aTV9YSn", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Instantiate quantized model\n", + "model_q = MultilayerPerceptronQuantized(input_dim,output_dim,hidden_dim,new_mask,labels,centroids)\n", + "model_q = model_q.to(device)\n", + "\n", + "# Copy pre-trained weights from unquantized model for non-quantized layers\n", + "model_q.b_0.data = model.b_0.data\n", + "model_q.W_1.data = model.W_1.data\n", + "model_q.b_1.data = model.b_1.data" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ijPdAe149YSy", + "colab_type": "text" + }, + "source": [ + "Finally, we can fine tune the quantized model. We'll adjust not only the centroids, but also the weights in the other layers." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lm5n1IfC9YS1", + "colab_type": "code", + "outputId": "e80a11c4-992f-461e-af3c-7ea68cc72808", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 101 + } + }, + "source": [ + "optimizer = torch.optim.Adam(model_q.parameters(), lr=lr_rate, weight_decay=1e-3)\n", + "iter = 0\n", + "for epoch in range(n_epochs):\n", + " for i, (images, labels) in enumerate(train_loader):\n", + " images = images.view(-1, 28 * 28).to(device)\n", + " labels = labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model_q(images)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate Accuracy\n", + " correct = 0\n", + " total = 0\n", + " for images, labels in test_loader:\n", + " images = images.view(-1, 28*28).to(device)\n", + " labels = labels.to(device)\n", + " outputs = model_q(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total+= labels.size(0)\n", + " # for gpu, bring the predicted and labels back to cpu fro python operations to work\n", + " correct+= (predicted == labels).sum()\n", + " accuracy = 100 * correct/total\n", + " print(\"Iteration: {}. Loss: {}. Accuracy: {}.\".format(iter, loss.item(), accuracy))\n", + "torch.save(model.state_dict(),'mnist_quantized.h5')" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Iteration: 0. Loss: 0.14308680593967438. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.18871541321277618. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.16591843962669373. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.12692667543888092. Accuracy: 95.\n", + "Iteration: 0. Loss: 0.14464326202869415. Accuracy: 95.\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AZvRY6vQIffp", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0ksmxQvc9YTA", + "colab_type": "text" + }, + "source": [ + "After retraining, we can, just for fun, reconstruct the pruned and quantized weights and plot them as images:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bbEWvCON9YTC", + "colab_type": "code", + "outputId": "41a66a18-6b1e-4051-9936-b848b75ef4a7", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + } + }, + "source": [ + "W_0 = (model_q.mask*model_q.centroids[model_q.labels].reshape(784,64)).detach().cpu().numpy()\n", + "plt.imshow(W_0[:,1].reshape((28,28)))\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAALm0lEQVR4nO3dT4ic9R3H8c+nq/aw6yFb2yWNabWa\nSyg0liUISrFIJeYSvQRzkBSka0HBgIWKPZhjKP49FGWtwVisIqiYQ2hNgxA8KK6SxsS0TZSISdes\nkoPZXGzWbw/7RNZkd2ec53nmeXa/7xcsM/M8M/N8ffCT3zPPd575OSIEYPn7TtMFAOgPwg4kQdiB\nJAg7kARhB5K4pJ8bGxgajEuGh/u5SSCVc6dPa2b6rOdbVyrstjdIekLSgKQ/R8SOxZ5/yfCwfvi7\nbWU2CWAR/3348QXX9XwYb3tA0p8k3SppraQtttf2+n4A6lXmM/t6Scci4qOI+FLSi5I2VVMWgKqV\nCfsqSZ/MeXyiWPYNtsdsT9iemJk+W2JzAMqo/Wx8RIxHxGhEjA4MDda9OQALKBP2k5JWz3l8ZbEM\nQAuVCfs7ktbYvtr2ZZLukLS7mrIAVK3n1ltEnLN9r6S/a7b1tjMiDldWGYBKleqzR8QeSXsqqgVA\njfi6LJAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfT1p6TRf9du\ne6vpEhZ07PHrmy4hFUZ2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE\nYQeSIOxAEoQdSILr2Zc5rhnHeaXCbvu4pDOSZiSdi4jRKooCUL0qRvZfRsTnFbwPgBrxmR1IomzY\nQ9Lrtt+1PTbfE2yP2Z6wPTEzfbbk5gD0quxh/I0RcdL2DyTttf2viNg/9wkRMS5pXJK++6PVUXJ7\nAHpUamSPiJPF7ZSkVyWtr6IoANXrOey2B21ffv6+pFskHaqqMADVKnMYPyLpVdvn3+evEfG3SqpK\n5sPNTy26/pqXflvbe7dZmf9uXKznsEfER5J+VmEtAGpE6w1IgrADSRB2IAnCDiRB2IEkuMR1GVjK\n7TX0DyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRBn71LZXrZnS7VrPNSTi4TxXmM7EAShB1IgrAD\nSRB2IAnCDiRB2IEkCDuQBH32LrW5X93m2tAejOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kAR99uSu\n3fZWre9/7PHra31/dK/jyG57p+0p24fmLBu2vdf20eJ2Rb1lAiirm8P4ZyVtuGDZA5L2RcQaSfuK\nxwBarGPYI2K/pNMXLN4kaVdxf5ek2yquC0DFej1BNxIRk8X9TyWNLPRE22O2J2xPzEyf7XFzAMoq\nfTY+IkJSLLJ+PCJGI2J0YGiw7OYA9KjXsJ+yvVKSitup6koCUIdew75b0tbi/lZJr1VTDoC6dOyz\n235B0k2SrrB9QtJDknZIesn2XZI+lrS5ziKXu7Lzqzd5PXvWPnrZ7yc0sd86hj0itiyw6uaKawFQ\nI74uCyRB2IEkCDuQBGEHkiDsQBJc4toHnVpr/BT08tPGliQjO5AEYQeSIOxAEoQdSIKwA0kQdiAJ\nwg4kQZ+9D9rcR29jP7gN6v6J7SYwsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvTZgXl0+v7BUuzD\nM7IDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL02VFKmd/EX8q/p78Ufweg48hue6ftKduH5izbbvuk\n7QPF38Z6ywRQVjeH8c9K2jDP8sciYl3xt6fasgBUrWPYI2K/pNN9qAVAjcqcoLvX9sHiMH/FQk+y\nPWZ7wvbEzPTZEpsDUEavYX9S0jWS1kmalPTIQk+MiPGIGI2I0YGhwR43B6CsnsIeEaciYiYivpL0\ntKT11ZYFoGo9hd32yjkPb5d0aKHnAmiHjn122y9IuknSFbZPSHpI0k2210kKSccl3V1jjahR2V53\nm3vh+KaOYY+ILfMsfqaGWgDUiK/LAkkQdiAJwg4kQdiBJAg7kASXuC4BnX62eClebinRtus3RnYg\nCcIOJEHYgSQIO5AEYQeSIOxAEoQdSII+ewuUnf53sdd36sHT686DkR1IgrADSRB2IAnCDiRB2IEk\nCDuQBGEHkqDPvszVfS38cr3WfjliZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJOizJ1f2Wvo6358e\nfbU6juy2V9t+w/YHtg/bvq9YPmx7r+2jxe2K+ssF0KtuDuPPSbo/ItZKul7SPbbXSnpA0r6IWCNp\nX/EYQEt1DHtETEbEe8X9M5KOSFolaZOkXcXTdkm6ra4iAZT3rU7Q2b5K0nWS3pY0EhGTxapPJY0s\n8Jox2xO2J2amz5YoFUAZXYfd9pCklyVti4gv5q6LiJAU870uIsYjYjQiRgeGBksVC6B3XYXd9qWa\nDfrzEfFKsfiU7ZXF+pWSpuopEUAVOrbebFvSM5KORMSjc1btlrRV0o7i9rVaKkygU4upzvZV3a23\nxdBa669u+uw3SLpT0vu2DxTLHtRsyF+yfZekjyVtrqdEAFXoGPaIeFOSF1h9c7XlAKgLX5cFkiDs\nQBKEHUiCsANJEHYgCS5xXQKa7Ec3ue0PNz/V2LaX41TWjOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7\nkAR99uTK9tGb7IWXsRz76J0wsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvTZk2tznzxjL7xOjOxA\nEoQdSIKwA0kQdiAJwg4kQdiBJAg7kEQ387OvlvScpBFJIWk8Ip6wvV3SbyR9Vjz1wYjYU1ehmbW5\nF94JvfL26OZLNeck3R8R79m+XNK7tvcW6x6LiIfrKw9AVbqZn31S0mRx/4ztI5JW1V0YgGp9q8/s\ntq+SdJ2kt4tF99o+aHun7RULvGbM9oTtiZnps6WKBdC7rsNue0jSy5K2RcQXkp6UdI2kdZod+R+Z\n73URMR4RoxExOjA0WEHJAHrRVdhtX6rZoD8fEa9IUkScioiZiPhK0tOS1tdXJoCyOobdtiU9I+lI\nRDw6Z/nKOU+7XdKh6ssDUJVuzsbfIOlOSe/bPlAse1DSFtvrNNuOOy7p7loqBO0rVKKbs/FvSvI8\nq+ipA0sI36ADkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k\n4Yjo38bszyR9PGfRFZI+71sB305ba2trXRK19arK2n4cEd+fb0Vfw37Rxu2JiBhtrIBFtLW2ttYl\nUVuv+lUbh/FAEoQdSKLpsI83vP3FtLW2ttYlUVuv+lJbo5/ZAfRP0yM7gD4h7EASjYTd9gbb/7Z9\nzPYDTdSwENvHbb9v+4DtiYZr2Wl7yvahOcuGbe+1fbS4nXeOvYZq2277ZLHvDtje2FBtq22/YfsD\n24dt31csb3TfLVJXX/Zb3z+z2x6Q9B9Jv5J0QtI7krZExAd9LWQBto9LGo2Ixr+AYfsXkqYlPRcR\nPy2W/VHS6YjYUfxDuSIift+S2rZLmm56Gu9itqKVc6cZl3SbpF+rwX23SF2b1Yf91sTIvl7SsYj4\nKCK+lPSipE0N1NF6EbFf0ukLFm+StKu4v0uz/7P03QK1tUJETEbEe8X9M5LOTzPe6L5bpK6+aCLs\nqyR9MufxCbVrvveQ9Lrtd22PNV3MPEYiYrK4/6mkkSaLmUfHabz76YJpxluz73qZ/rwsTtBd7MaI\n+LmkWyXdUxyutlLMfgZrU++0q2m8+2Weaca/1uS+63X687KaCPtJSavnPL6yWNYKEXGyuJ2S9Kra\nNxX1qfMz6Ba3Uw3X87U2TeM93zTjasG+a3L68ybC/o6kNbavtn2ZpDsk7W6gjovYHixOnMj2oKRb\n1L6pqHdL2lrc3yrptQZr+Ya2TOO90DTjanjfNT79eUT0/U/SRs2ekf9Q0h+aqGGBun4i6Z/F3+Gm\na5P0gmYP6/6n2XMbd0n6nqR9ko5K+oek4RbV9hdJ70s6qNlgrWyoths1e4h+UNKB4m9j0/tukbr6\nst/4uiyQBCfogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNPDLQpJGSWBgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_6ar0vcD9YTK", + "colab_type": "text" + }, + "source": [ + "Certainly a much more parsimonious representation. The obvious question now becomes:\n", + "\n", + "### Q7: How low can you go? How small can the centroid codebook be before we see a substantial degradation in test set accuracy?\n", + "\n", + "I got great results all the way down to two. Though this may because the \n", + "bias and W_1 values weren't restricted. It would be interesting to see what would happen if we restricted those weights to a small code book as well. \n", + "\n", + "### Bonus question: Try establishing the sparsity pattern using a model that's only been trained for a single epoch, then fine tune the pruned model and quantize as normal. How does this compare to pruning a model that has been fully trained? \n", + "\n", + "Less accurate, but not by a large amount. A final accuracy of 94% as opposed to 96%." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GGUePifO9YTM", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file