Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ArkaJU authored Oct 13, 2018
1 parent 67fb8c9 commit 4eeecfa
Showing 1 changed file with 302 additions and 0 deletions.
302 changes: 302 additions & 0 deletions SketchRetrieval.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Wk_g0iEgXk3g"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\python36\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"from utils import *\n",
"from model import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DCyCx7aqXk4j",
"outputId": "252cc6da-0254-4620-bc78-8c97b04229c2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(?, 100, 100, 3)\n",
"(?, 86, 86, 32)\n",
"(?, 43, 43, 32)\n",
"(?, 36, 36, 64)\n",
"(?, 12, 12, 64)\n",
"(?, 8, 8, 256)\n",
"(?, 4, 4, 256)\n",
"(?, 4096)\n",
"(?, 64)\n",
"(?, 100, 100, 3)\n",
"(?, 86, 86, 32)\n",
"(?, 43, 43, 32)\n",
"(?, 36, 36, 64)\n",
"(?, 12, 12, 64)\n",
"(?, 8, 8, 256)\n",
"(?, 4, 4, 256)\n",
"(?, 4096)\n",
"(?, 64)\n"
]
}
],
"source": [
"left = tf.placeholder(tf.float32, [None, 100, 100, 3], name='left')\n",
"right = tf.placeholder(tf.float32, [None, 100, 100, 3], name='right')\n",
"\n",
"label = tf.placeholder(tf.int32, [None, 1], name='label') # 1 if same, 0 if different\n",
"label = tf.to_float(label)\n",
"\n",
"margin = 1\n",
"\n",
"left_output = mynet(left, reuse=False)\n",
"right_output = mynet(right, reuse=True)\n",
"\n",
"loss = contrastive_loss(left_output, right_output, label, margin)\n",
"optim = tf.train.AdamOptimizer(0.0005).minimize(loss)\n",
"\n",
"init_op = tf.global_variables_initializer()\n",
"sess = tf.Session()\n",
"sess.run(init_op)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "WkXfLemtXk40",
"outputId": "3645b9b9-d2af-4c42-821e-e652da940353"
},
"outputs": [
{
"data": {
"text/plain": [
"2500"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_image_paths = []\n",
"for category in os.listdir(photo_path):\n",
" category_path = os.path.join(photo_path, category + '/')\n",
" image_paths = np.random.choice(os.listdir(category_path), size=20, replace=False)\n",
" for i in range(20):\n",
" test_image_paths.append(category_path + image_paths[i])\n",
"np.random.shuffle(test_image_paths)\n",
"len(np.unique(test_image_paths))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "zsubqiv0Xk5D",
"outputId": "f54ac24e-77aa-4fd8-a98d-35d68641fefb"
},
"outputs": [
{
"data": {
"text/plain": [
"'Sketchy/sketch/bell/n03028596_5727-4.png'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_sketch_path = sketch_path + np.random.choice(os.listdir(sketch_path)) \n",
"test_sketch_path = test_sketch_path + '/' + np.random.choice(os.listdir(test_sketch_path))\n",
"test_sketch_path"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "98rVKbjJXk5d",
"outputId": "fae4936d-6af0-416c-9509-e2ab79a3d9b5"
},
"outputs": [
{
"data": {
"text/plain": [
"(1, 100, 100, 3)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_sketch = load_img(test_sketch_path)\n",
"test_sketch = np.expand_dims(test_sketch, 0)\n",
"test_sketch.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "1LzbcZNsXk54"
},
"outputs": [
{
"data": {
"text/plain": [
"(2500, 100, 100, 3)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_images = []\n",
"for path in test_image_paths:\n",
" test_images.append(load_img(path))\n",
"test_images = np.array(test_images)\n",
"test_images.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "qOCZmxEYXk6M"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAAQI0lEQVR4nO3dW2gcVRgH8O/M3pLdkN1cNl7SBG1t09rUJL0oSo0WidQHFSrVItIX7YPoixQD+hAQwUAsJYgv5kEQIkUQfRGx9UGDaMEmEmuESBWNLRa1Nknt2s1lZ3z4mmG6STeX7pxzZs7/9yDjZtP9kpz/fGfuwnEcAjCVpboAAJUQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNEQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNEQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNEQADAaAgBGQwDAaAgAGA0BAKMhAGA0BACMhgCA0RAAMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAATbPffck8lk6urqstlsfX19TU3NwYMHVRcVJMJxHNU1wFrs37//ww8/XPJLlmUVCgXJ9QQUOkBQuaNfXIuIbNsWQgwODiotMBjQAQIpGo3yOn7xn2/z5s0//fQTEQkhbNtWUFygoAMEEo/+pqamxV8aHx//5JNPiMhxnN9//112ZUGDDhA8sVhsfn6ellr9uyzLchwHTWBZ6ADBw6M/m82WeM/zzz9PRI7jvPrqq5LKCiZ0gIBJp9OXLl2ikqt/xk0gkUjk83kppQUSOkDA8OivrKzk/21sbLQsSwhhWdaRI0e876ytrSWi2dlZ+UUGCDpAkLS2tv7444+0sPrPZrMXLlzwvqHor8l7RfEnLgEdIEjGx8dpYVgTEY9+IcTOnTv5lWQyqaq2gIqqLgBWgXfp3H777UQUjUbJs7Ofu8GVK1fUVhg46ABBwpOZvXv30sKhAHfd//fff/PCwMCAouoCCdsAQeLO6ZPJ5JUrV4p28/Nun1gs5m74YhtgWegAgcR7NjOZjPfFiooKWjhKACuEAATP4OAgr9QvXrzofX3Lli2E9f0qIQDBc+jQIfLsC3JVV1erKCfYEIDgmZmZIaJYLFb0+uJIwLIQgKC68847i14pmvzs27dPYjlBhQAEDw/0vr6+otctyyJPH+CTotEWSkMAgqqrq6voFT75x+0DvDsIGwal4ThAkHhX54v/cJFIhC+GtG27ubn57NmzS74NvNABQqKhoYEPij3xxBNEdO7cOVo4XQJKQAcIkhIdgL/E94Pg48SL3wOLYQ0RPIu3axOJBC8UCoVMJsOjPx6Py64sgBCA4Cm6zHfDhg188k86nebNACISQvDhAigNU6AgWfLkNn7R+yXcGGvlsBGsnZMnT1ZXV1tL4TcUvcj3w3Icxw2GbdvuVxOJxNtvv63up9EdOoAW2traxsbGvIO4vDKZzOTkpB//ctChAygzODiYSqV4PX369Gnbtr2jXwgRjUZjsVgkEokuiMVi7jKv+/nFWCy2bt06Z8HExERDQwO/jadGU1NTQoiqqip1P66m0AEUmJqaqq+vL5qm850dWltbR0dHy/tx7n0UiSgSieCCAS90ANnq6upqamrcEWlZViQSOXTokG3b8/PzZR/9RDQ/P+84TiQSIaJCoeBuSwChA8jn3rQwm83++eefMj+6ra3t9OnTRJRMJnO5nMyP1hYCINXY2Ni2bdtI3THaldxX1CjohlLxOfoKT1Gem5vjhe7ublU1aAUdQCqe/6jdEuUasGOUoQPI09vby6ubPXv2KCyDzxGanp5WWIM+0AHkqayszOfzyu/Zf/jw4aNHjxI2A4gIAZCJz1TTYU88b4Rs3759ZGREbSXKYQokD69rdLh/LQfgzJkzqgtRDwGQhwNw2223qS7kKnePkMkQAEn4tm1ExIeidIBDwoQASMMPbMRNSnSDAEjC8x+tAoD9H4QASMNXLWo160AACAGQhvf9p1Ip1YXANRAASXh1y0830gSfIG04BEASnv1rdU0WpkCEAEjDo829gY8O1J6RoQkEQCqtVrpabZGrgl+BVFoFQKtiVEEApMJxAN0gAFIpPw/US6s0qoIASLVr1y7VJcA1EAAZ+vv7eeHIkSNKC7mK1/04DkAIgBz8uC598Oxfq/mYKgiADHrucUcHIARADj33uOsZS8l0/MOEj55360cACAGQQ8/JxuJnzRsIAZDhn3/+UV3CEnAgjBAAOfj+C7odeMIUiBAAOfixjbptCqMDEAIgBw813Z5bqltHUgIBkEefDuA+VVt1IerhVyCPPlMOrgTbAIQAyKRPAMCFAMij25wbUyBCAGTSLQCYAhECYDJMyQgBkAmnH2sIAZBHt3Nv0AEIAZBJt3NC9TxFTzIEwFzoAIQAyKTbbkfd9kopodefBGRCByAEQCbd9gKhAxACIJNuUyB0AEIAZNJtwOlWjxIIgDy6TTl0q0cJBEAe3QYcjgMQAiCTblMOnAxHCIAcul2BxfXotldKCV3+JOGm27qfJz9zc3OqC1EPAZAnmUyqLuGqbDarugRdIAAycAfo6OhQXchV69evJ/36khIIgDzHjx9XXcJVe/bsUV2CLgRWAxLwRqdWv2ou6b777vv6669V16ISAuC7jo6O0dFRIYRWux0ty3IcR7eq5MMUyHe//vqr6hKW8NJLL5FmTUkJdADfxePxubk5Dde1PAvasWPH8PCw6lqUQQB8x5ONRCKRz+dV13INLiwajZp8QABTIN/xKubBBx9UXUix2tpaMv54MALgr56eHl747LPP1Fay2IULF3hh+/btaitRCFMgf911110//PAD6bq5ybMgy7J0u2OFNOgA/hofHyf9ToR2bdy4kcw+LRQdwF889Gtqai5evKi6lqVxhZWVlf/995/qWhRAB/DRQw89xAvajn4iqqiooIWHOBkIHcBHQTnaavKWADqAj3jlsnv3btWFLKO5uZlM3RJAB/BLJpOZnp7Wf/XPeEtAw6N1fkMA/MLzikgkEogjTRUVFTMzM6Tr7lr/YArki7GxMR5Jvb29qmtZEXfFH41G1VYiGTqAL6qqqnK5XFDmPyybzfKx4QMHDhw7dkx1OZIgAL7g+U8ymczlcqprWYVIJGLbdrBye4MwBSq/np4eXq0Ea/QT0ffff09EjuNs27ZNdS2SoAOUX1B2/y8p0MWvATpAmcXjcV6n7N+/X3Uta8Hjnvdfqa5FBgSgnN58802+uCQej3/wwQeqy1mj6upqIrJtm8+SCDdMgcopNPOHWCzGhy/6+vpefvll1eX4CAEoG772l4gOHjz43nvvqS7nRhlywSSmQOXR2dnJAyUWi4Vg9BPRunXriGh+fn5oaEh1LT5CByiDw4cPHz16lIhCMPnxCs2MrgR0gBu1adMmHv0UuhMqv/jiCyLiDHR3d6suxxfoADeksbHxjz/+oNCt+13uhk1Yf0B0gLVbv359uEc/Ec3OzvKC4zjxeFxtMX5AANaorq6O73kY4tHP3Lu6z83N7du3T20xZYcp0Fokk0m+iDb0o5/x1jCF8edFB1g1y7KMGv20cM0kETmOc/LkSbXFlBc6wCqMjIzs3LmTl4NyqVcZcR+Ix+N87Vg4oAOs1N69e93R39zcbNroJ6JMJkOezeJwQABWZOPGje4Djk6cODExMaG2HiXcuxsdOHBAbSVlZNYFoGvT0tLy888/k9xJ/+joaHt7e9GL6XR6enpaTgEl8N1OwwHbAMvYtGnTmTNnaKnR/+ijjx4/ftynuRAffy16hVZ2m1EhhBBi69at3d3dDQ0NRNTV1VWWqngzIBaLhWci5MD1ueNGCMGvfPTRRwG9UkQIEYvF2tvbb+QXkkqlyPPbCAF0gFLc/d/XIxaUeKd3nW3btmVZ/F93GJFnSBER35+QY2ZZlvd7+UvRaNRxnMWf6PYo969buuympqbffvtt2V+C17vvvvvss88S0RtvvPHKK6+s6nv1hACUsuSwFkJUVFT4dC/laDRaKBSut7Gx5setdnd39/f3z8/PL/njdHZ2fvnllyv8p/h3ks1m//rrr9WWoSHsBSplYGCAF9LptLtatW3bvzuJ87hvaWkp7z/b19c3Oztr27bjOCdOnEin025fchxnaGiIU73yf3Bqaqq8FaqCAJTy3HPP8YJ7BMBvvIZ+7bXX/PuIrq6uqakpDsPWrVvdJMzMzLjTuc8///x6315ZWUkherIYArAiku8b/uSTT8r5oLGxMU5CIpFwX3Qc5+GHH7Ys6+mnn178LZ9++imF6BaiCMCKhGaFdz35fN5xnG+++ebmm292tzSOHTt2xx13FL3zgQce4IVbbrlFdpU+QABWJJlMqi5Bhnvvvff8+fO2be/evZtj8Msvv/D+Iu/b+Evh2AjGXqBlSN7pUXo/z5r3Aq0N3yqUl6uqqv79919e3rFjx3fffSezEv+gAyyD98e7j9SV49Zbb1384jPPPCOzBiIqFAp8wIGILl++bFmWZVk33XTTyMgIv2HDhg2SSyo7dIBltLe3u7eMlfBx7jW4iz9O4T0aEonEkuc+hOCJMugAyxgdHeUF70FZ/7jjrOjj3ENyTz31lIQyiszMzDiOU1tbW1TV448/Lr+Y8kIHWN7dd9996tQpknURTGdn51dffUVEQoiamprZ2dlcLsd/pkwmMzk56XcBpbkPU5qcnOQrBAINAVgR916ZqVTq8uXLfn9cU1PTuXPnil5saWnh584rlEql+Ch4aG6ZiCnQiszNzfEemFwuZ1mWuxXok7Nnz27evJnnG3xodnh4WO3ov//++y3L4tFvWVY4Rj+hA6yKd7fgrl27vv32W7X1SMOn6PFyyC6GRgdYhUKh0NbWxsunTp0K/QMVh4aG+LRtHv1CiImJiTCNfkIHWBvvadKWZUk+U0gO71qfiNLpdGjOAPVCB1gL27bdM2H4sYqWZT322GNqqyqXhoYG71q/tbXVcZxQjn5CB7hB1dXV7gkCRCSEqKysDNzDIdk777zzwgsv8Mmh/ErgHvO6BghAGWQymUuXLnl/k0KIRCLBN5DTXz6fT6VS3gPMIdvSLQFToDJwry9xD5Q6jpPP54UQ0Wh0y5YtbW1t7e3tr7/+uto6vXp7ezs6OmKxGHctHv1CiGQyOTw8bMjoJ3QAP7z44osDAwPX21Mej8cfeeSRSCTS2Nj41ltv+V3MyMjI+++/Pz09zXvxo9Hoxx9/7J22uXQ4zCwfAuAj9ybSJazkPj9rtuwflzffe3t7w/0oyBIQAKnckz3VMmeKvywEQKWenp7+/n5vl3A3JK7etkkIuvY6GHfZe1sHvq0Q/5cW7jJERPX19efPn5f9UwUKAgBGw14gMBoCAEZDAMBoCAAYDQEAoyEAYDQEAIyGAIDREAAwGgIARkMAwGgIABgNAQCjIQBgNAQAjIYAgNH+BwDwycoaTZzFAAAAAElFTkSuQmCC\n",
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGB size=256x256 at 0x12651C50>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from PIL import Image\n",
"Image.open(test_sketch_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "wOw35XSMXk6e"
},
"outputs": [],
"source": [
"for epoch in range(1000000):\n",
" photo_dictionary, sketch_dictionary = get_dict()\n",
" p_batch, s_batch, lab_batch = get_batch(photo_dictionary, sketch_dictionary)\n",
" lab_batch = lab_batch.reshape(-1, 1)\n",
" [_, loss_] = sess.run([optim,loss], {left: p_batch, right: s_batch, label:lab_batch})\n",
" \n",
" \n",
" if epoch%100==0:\n",
" \n",
" sketch_repr = sess.run([left_output], {left: test_sketch})\n",
" sketch_repr = np.squeeze(np.array(sketch_repr), 1)\n",
" print(sketch_repr.shape)\n",
" sketch_representations = np.tile(sketch_repr, 2496).reshape(2496, 64)\n",
" print(sketch_representations.shape)\n",
" \n",
" batch_size = 8\n",
" n_batches = len(test_images) // batch_size\n",
" image_representations = []\n",
"\n",
" for i in range(n_batches):\n",
" img_repr = sess.run([left_output], {left: test_images[i*batch_size : (i+1)*batch_size]})\n",
" img_repr = np.squeeze(np.array(img_repr), 0)\n",
" image_representations.append(img_repr)\n",
" image_representations = np.vstack(image_representations)\n",
"\n",
" diff = np.sqrt(np.mean((sketch_representations - image_representations)**2, -1))\n",
" top_k = np.argsort(diff)[:5]\n",
"\n",
" print ('##' + str(epoch) + ' : loss == ' + str(loss_))\n",
"\n",
" plt.figure(figsize=(20, 20))\n",
" for i in range(5): \n",
" img = mpimg.imread(test_image_paths[top_k[i]])\n",
" plt.subplot(1, 5, i+1)\n",
" plt.imshow(img)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.show()"
]
}
],
"metadata": {
"colab": {
"name": "SketchRetrieval.ipynb",
"provenance": [],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

0 comments on commit 4eeecfa

Please sign in to comment.