-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
302 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "Wk_g0iEgXk3g" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"c:\\python36\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", | ||
" from ._conv import register_converters as _register_converters\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import matplotlib.image as mpimg\n", | ||
"from utils import *\n", | ||
"from model import *" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "DCyCx7aqXk4j", | ||
"outputId": "252cc6da-0254-4620-bc78-8c97b04229c2" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"(?, 100, 100, 3)\n", | ||
"(?, 86, 86, 32)\n", | ||
"(?, 43, 43, 32)\n", | ||
"(?, 36, 36, 64)\n", | ||
"(?, 12, 12, 64)\n", | ||
"(?, 8, 8, 256)\n", | ||
"(?, 4, 4, 256)\n", | ||
"(?, 4096)\n", | ||
"(?, 64)\n", | ||
"(?, 100, 100, 3)\n", | ||
"(?, 86, 86, 32)\n", | ||
"(?, 43, 43, 32)\n", | ||
"(?, 36, 36, 64)\n", | ||
"(?, 12, 12, 64)\n", | ||
"(?, 8, 8, 256)\n", | ||
"(?, 4, 4, 256)\n", | ||
"(?, 4096)\n", | ||
"(?, 64)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"left = tf.placeholder(tf.float32, [None, 100, 100, 3], name='left')\n", | ||
"right = tf.placeholder(tf.float32, [None, 100, 100, 3], name='right')\n", | ||
"\n", | ||
"label = tf.placeholder(tf.int32, [None, 1], name='label') # 1 if same, 0 if different\n", | ||
"label = tf.to_float(label)\n", | ||
"\n", | ||
"margin = 1\n", | ||
"\n", | ||
"left_output = mynet(left, reuse=False)\n", | ||
"right_output = mynet(right, reuse=True)\n", | ||
"\n", | ||
"loss = contrastive_loss(left_output, right_output, label, margin)\n", | ||
"optim = tf.train.AdamOptimizer(0.0005).minimize(loss)\n", | ||
"\n", | ||
"init_op = tf.global_variables_initializer()\n", | ||
"sess = tf.Session()\n", | ||
"sess.run(init_op)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "WkXfLemtXk40", | ||
"outputId": "3645b9b9-d2af-4c42-821e-e652da940353" | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"2500" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"test_image_paths = []\n", | ||
"for category in os.listdir(photo_path):\n", | ||
" category_path = os.path.join(photo_path, category + '/')\n", | ||
" image_paths = np.random.choice(os.listdir(category_path), size=20, replace=False)\n", | ||
" for i in range(20):\n", | ||
" test_image_paths.append(category_path + image_paths[i])\n", | ||
"np.random.shuffle(test_image_paths)\n", | ||
"len(np.unique(test_image_paths))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "zsubqiv0Xk5D", | ||
"outputId": "f54ac24e-77aa-4fd8-a98d-35d68641fefb" | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"'Sketchy/sketch/bell/n03028596_5727-4.png'" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"test_sketch_path = sketch_path + np.random.choice(os.listdir(sketch_path)) \n", | ||
"test_sketch_path = test_sketch_path + '/' + np.random.choice(os.listdir(test_sketch_path))\n", | ||
"test_sketch_path" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "98rVKbjJXk5d", | ||
"outputId": "fae4936d-6af0-416c-9509-e2ab79a3d9b5" | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(1, 100, 100, 3)" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"test_sketch = load_img(test_sketch_path)\n", | ||
"test_sketch = np.expand_dims(test_sketch, 0)\n", | ||
"test_sketch.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "1LzbcZNsXk54" | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(2500, 100, 100, 3)" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"test_images = []\n", | ||
"for path in test_image_paths:\n", | ||
" test_images.append(load_img(path))\n", | ||
"test_images = np.array(test_images)\n", | ||
"test_images.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "qOCZmxEYXk6M" | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAAQI0lEQVR4nO3dW2gcVRgH8O/M3pLdkN1cNl7SBG1t09rUJL0oSo0WidQHFSrVItIX7YPoixQD+hAQwUAsJYgv5kEQIkUQfRGx9UGDaMEmEmuESBWNLRa1Nknt2s1lZ3z4mmG6STeX7pxzZs7/9yDjZtP9kpz/fGfuwnEcAjCVpboAAJUQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNEQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNEQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNEQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAATbPffck8lk6urqstlsfX19TU3NwYMHVRcVJMJxHNU1wFrs37//ww8/XPJLlmUVCgXJ9QQUOkBQuaNfXIuIbNsWQgwODiotMBjQAQIpGo3yOn7xn2/z5s0//fQTEQkhbNtWUFygoAMEEo/+pqamxV8aHx//5JNPiMhxnN9//112ZUGDDhA8sVhsfn6ellr9uyzLchwHTWBZ6ADBw6M/m82WeM/zzz9PRI7jvPrqq5LKCiZ0gIBJp9OXLl2ikqt/xk0gkUjk83kppQUSOkDA8OivrKzk/21sbLQsSwhhWdaRI0e876ytrSWi2dlZ+UUGCDpAkLS2tv7444+0sPrPZrMXLlzwvqHor8l7RfEnLgEdIEjGx8dpYVgTEY9+IcTOnTv5lWQyqaq2gIqqLgBWgXfp3H777UQUjUbJs7Ofu8GVK1fUVhg46ABBwpOZvXv30sKhAHfd//fff/PCwMCAouoCCdsAQeLO6ZPJ5JUrV4p28/Nun1gs5m74YhtgWegAgcR7NjOZjPfFiooKWjhKACuEAATP4OAgr9QvXrzofX3Lli2E9f0qIQDBc+jQIfLsC3JVV1erKCfYEIDgmZmZIaJYLFb0+uJIwLIQgKC68847i14pmvzs27dPYjlBhQAEDw/0vr6+otctyyJPH+CTotEWSkMAgqqrq6voFT75x+0DvDsIGwal4ThAkHhX54v/cJFIhC+GtG27ubn57NmzS74NvNABQqKhoYEPij3xxBNEdO7cOVo4XQJKQAcIkhIdgL/E94Pg48SL3wOLYQ0RPIu3axOJBC8UCoVMJsOjPx6Py64sgBCA4Cm6zHfDhg188k86nebNACISQvDhAigNU6AgWfLkNn7R+yXcGGvlsBGsnZMnT1ZXV1tL4TcUvcj3w3Icxw2GbdvuVxOJxNtvv63up9EdOoAW2traxsbGvIO4vDKZzOTkpB//ctChAygzODiYSqV4PX369Gnbtr2jXwgRjUZjsVgkEokuiMVi7jKv+/nFWCy2bt06Z8HExERDQwO/jadGU1NTQoiqqip1P66m0AEUmJqaqq+vL5qm850dWltbR0dHy/tx7n0UiSgSieCCAS90ANnq6upqamrcEWlZViQSOXTokG3b8/PzZR/9RDQ/P+84TiQSIaJCoeBuSwChA8jn3rQwm83++eefMj+6ra3t9OnTRJRMJnO5nMyP1hYCINXY2Ni2bdtI3THaldxX1CjohlLxOfoKT1Gem5vjhe7ublU1aAUdQCqe/6jdEuUasGOUoQPI09vby6ubPXv2KCyDzxGanp5WWIM+0AHkqayszOfzyu/Zf/jw4aNHjxI2A4gIAZCJz1TTYU88b4Rs3759ZGREbSXKYQokD69rdLh/LQfgzJkzqgtRDwGQhwNw2223qS7kKnePkMkQAEn4tm1ExIeidIBDwoQASMMPbMRNSnSDAEjC8x+tAoD9H4QASMNXLWo160AACAGQhvf9p1Ip1YXANRAASXh1y0830gSfIG04BEASnv1rdU0WpkCEAEjDo829gY8O1J6RoQkEQCqtVrpabZGrgl+BVFoFQKtiVEEApMJxAN0gAFIpPw/US6s0qoIASLVr1y7VJcA1EAAZ+vv7eeHIkSNKC7mK1/04DkAIgBz8uC598Oxfq/mYKgiADHrucUcHIARADj33uOsZS8l0/MOEj55360cACAGQQ8/JxuJnzRsIAZDhn3/+UV3CEnAgjBAAOfj+C7odeMIUiBAAOfixjbptCqMDEAIgBw813Z5bqltHUgIBkEefDuA+VVt1IerhVyCPPlMOrgTbAIQAyKRPAMCFAMij25wbUyBCAGTSLQCYAhECYDJMyQgBkAmnH2sIAZBHt3Nv0AEIAZBJt3NC9TxFTzIEwFzoAIQAyKTbbkfd9kopodefBGRCByAEQCbd9gKhAxACIJNuUyB0AEIAZNJtwOlWjxIIgDy6TTl0q0cJBEAe3QYcjgMQAiCTblMOnAxHCIAcul2BxfXotldKCV3+JOGm27qfJz9zc3OqC1EPAZAnmUyqLuGqbDarugRdIAAycAfo6OhQXchV69evJ/36khIIgDzHjx9XXcJVe/bsUV2CLgRWAxLwRqdWv2ou6b777vv6669V16ISAuC7jo6O0dFRIYRWux0ty3IcR7eq5MMUyHe//vqr6hKW8NJLL5FmTUkJdADfxePxubk5Dde1PAvasWPH8PCw6lqUQQB8x5ONRCKRz+dV13INLiwajZp8QABTIN/xKubBBx9UXUix2tpaMv54MALgr56eHl747LPP1Fay2IULF3hh+/btaitRCFMgf911110//PAD6bq5ybMgy7J0u2OFNOgA/hofHyf9ToR2bdy4kcw+LRQdwF889Gtqai5evKi6lqVxhZWVlf/995/qWhRAB/DRQw89xAvajn4iqqiooIWHOBkIHcBHQTnaavKWADqAj3jlsnv3btWFLKO5uZlM3RJAB/BLJpOZnp7Wf/XPeEtAw6N1fkMA/MLzikgkEogjTRUVFTMzM6Tr7lr/YArki7GxMR5Jvb29qmtZEXfFH41G1VYiGTqAL6qqqnK5XFDmPyybzfKx4QMHDhw7dkx1OZIgAL7g+U8ymczlcqprWYVIJGLbdrBye4MwBSq/np4eXq0Ea/QT0ffff09EjuNs27ZNdS2SoAOUX1B2/y8p0MWvATpAmcXjcV6n7N+/X3Uta8Hjnvdfqa5FBgSgnN58802+uCQej3/wwQeqy1mj6upqIrJtm8+SCDdMgcopNPOHWCzGhy/6+vpefvll1eX4CAEoG772l4gOHjz43nvvqS7nRhlywSSmQOXR2dnJAyUWi4Vg9BPRunXriGh+fn5oaEh1LT5CByiDw4cPHz16lIhCMPnxCs2MrgR0gBu1adMmHv0UuhMqv/jiCyLiDHR3d6suxxfoADeksbHxjz/+oNCt+13uhk1Yf0B0gLVbv359uEc/Ec3OzvKC4zjxeFxtMX5AANaorq6O73kY4tHP3Lu6z83N7du3T20xZYcp0Fokk0m+iDb0o5/x1jCF8edFB1g1y7KMGv20cM0kETmOc/LkSbXFlBc6wCqMjIzs3LmTl4NyqVcZcR+Ix+N87Vg4oAOs1N69e93R39zcbNroJ6JMJkOezeJwQABWZOPGje4Djk6cODExMaG2HiXcuxsdOHBAbSVlZNYFoGvT0tLy888/k9xJ/+joaHt7e9GL6XR6enpaTgEl8N1OwwHbAMvYtGnTmTNnaKnR/+ijjx4/ftynuRAffy16hVZ2m1EhhBBi69at3d3dDQ0NRNTV1VWWqngzIBaLhWci5MD1ueNGCMGvfPTRRwG9UkQIEYvF2tvbb+QXkkqlyPPbCAF0gFLc/d/XIxaUeKd3nW3btmVZ/F93GJFnSBER35+QY2ZZlvd7+UvRaNRxnMWf6PYo969buuympqbffvtt2V+C17vvvvvss88S0RtvvPHKK6+s6nv1hACUsuSwFkJUVFT4dC/laDRaKBSut7Gx5setdnd39/f3z8/PL/njdHZ2fvnllyv8p/h3ks1m//rrr9WWoSHsBSplYGCAF9LptLtatW3bvzuJ87hvaWkp7z/b19c3Oztr27bjOCdOnEin025fchxnaGiIU73yf3Bqaqq8FaqCAJTy3HPP8YJ7BMBvvIZ+7bXX/PuIrq6uqakpDsPWrVvdJMzMzLjTuc8///x6315ZWUkherIYArAiku8b/uSTT8r5oLGxMU5CIpFwX3Qc5+GHH7Ys6+mnn178LZ9++imF6BaiCMCKhGaFdz35fN5xnG+++ebmm292tzSOHTt2xx13FL3zgQce4IVbbrlFdpU+QABWJJlMqi5Bhnvvvff8+fO2be/evZtj8Msvv/D+Iu/b+Evh2AjGXqBlSN7pUXo/z5r3Aq0N3yqUl6uqqv79919e3rFjx3fffSezEv+gAyyD98e7j9SV49Zbb1384jPPPCOzBiIqFAp8wIGILl++bFmWZVk33XTTyMgIv2HDhg2SSyo7dIBltLe3u7eMlfBx7jW4iz9O4T0aEonEkuc+hOCJMugAyxgdHeUF70FZ/7jjrOjj3ENyTz31lIQyiszMzDiOU1tbW1TV448/Lr+Y8kIHWN7dd9996tQpknURTGdn51dffUVEQoiamprZ2dlcLsd/pkwmMzk56XcBpbkPU5qcnOQrBAINAVgR916ZqVTq8uXLfn9cU1PTuXPnil5saWnh584rlEql+Ch4aG6ZiCnQiszNzfEemFwuZ1mWuxXok7Nnz27evJnnG3xodnh4WO3ov//++y3L4tFvWVY4Rj+hA6yKd7fgrl27vv32W7X1SMOn6PFyyC6GRgdYhUKh0NbWxsunTp0K/QMVh4aG+LRtHv1CiImJiTCNfkIHWBvvadKWZUk+U0gO71qfiNLpdGjOAPVCB1gL27bdM2H4sYqWZT322GNqqyqXhoYG71q/tbXVcZxQjn5CB7hB1dXV7gkCRCSEqKysDNzDIdk777zzwgsv8Mmh/ErgHvO6BghAGWQymUuXLnl/k0KIRCLBN5DTXz6fT6VS3gPMIdvSLQFToDJwry9xD5Q6jpPP54UQ0Wh0y5YtbW1t7e3tr7/+uto6vXp7ezs6OmKxGHctHv1CiGQyOTw8bMjoJ3QAP7z44osDAwPX21Mej8cfeeSRSCTS2Nj41ltv+V3MyMjI+++/Pz09zXvxo9Hoxx9/7J22uXQ4zCwfAuAj9ybSJazkPj9rtuwflzffe3t7w/0oyBIQAKnckz3VMmeKvywEQKWenp7+/n5vl3A3JK7etkkIuvY6GHfZe1sHvq0Q/5cW7jJERPX19efPn5f9UwUKAgBGw14gMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNH+BwDwycoaTZzFAAAAAElFTkSuQmCC\n", | ||
"text/plain": [ | ||
"<PIL.PngImagePlugin.PngImageFile image mode=RGB size=256x256 at 0x12651C50>" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from PIL import Image\n", | ||
"Image.open(test_sketch_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "wOw35XSMXk6e" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"for epoch in range(1000000):\n", | ||
" photo_dictionary, sketch_dictionary = get_dict()\n", | ||
" p_batch, s_batch, lab_batch = get_batch(photo_dictionary, sketch_dictionary)\n", | ||
" lab_batch = lab_batch.reshape(-1, 1)\n", | ||
" [_, loss_] = sess.run([optim,loss], {left: p_batch, right: s_batch, label:lab_batch})\n", | ||
" \n", | ||
" \n", | ||
" if epoch%100==0:\n", | ||
" \n", | ||
" sketch_repr = sess.run([left_output], {left: test_sketch})\n", | ||
" sketch_repr = np.squeeze(np.array(sketch_repr), 1)\n", | ||
" print(sketch_repr.shape)\n", | ||
" sketch_representations = np.tile(sketch_repr, 2496).reshape(2496, 64)\n", | ||
" print(sketch_representations.shape)\n", | ||
" \n", | ||
" batch_size = 8\n", | ||
" n_batches = len(test_images) // batch_size\n", | ||
" image_representations = []\n", | ||
"\n", | ||
" for i in range(n_batches):\n", | ||
" img_repr = sess.run([left_output], {left: test_images[i*batch_size : (i+1)*batch_size]})\n", | ||
" img_repr = np.squeeze(np.array(img_repr), 0)\n", | ||
" image_representations.append(img_repr)\n", | ||
" image_representations = np.vstack(image_representations)\n", | ||
"\n", | ||
" diff = np.sqrt(np.mean((sketch_representations - image_representations)**2, -1))\n", | ||
" top_k = np.argsort(diff)[:5]\n", | ||
"\n", | ||
" print ('##' + str(epoch) + ' : loss == ' + str(loss_))\n", | ||
"\n", | ||
" plt.figure(figsize=(20, 20))\n", | ||
" for i in range(5): \n", | ||
" img = mpimg.imread(test_image_paths[top_k[i]])\n", | ||
" plt.subplot(1, 5, i+1)\n", | ||
" plt.imshow(img)\n", | ||
" plt.xticks([])\n", | ||
" plt.yticks([])\n", | ||
" plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"name": "SketchRetrieval.ipynb", | ||
"provenance": [], | ||
"version": "0.3.2" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |