Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 149 additions & 72 deletions DeepImageSearch/DeepImageSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from PIL import ImageOps
import math
import faiss
from typing import Callable, Optional


class Load_Data:
"""A class for loading data from single/multiple folders or a CSV file"""
Expand All @@ -21,7 +23,7 @@ def __init__(self):
Initializes an instance of LoadData class
"""
pass

def from_folder(self, folder_list: list):
"""
Adds images from the specified folders to the image_list.
Expand All @@ -36,7 +38,7 @@ def from_folder(self, folder_list: list):
for folder in self.folder_list:
for root, dirs, files in os.walk(folder):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
image_path.append(os.path.join(root, file))
return image_path

Expand All @@ -55,60 +57,101 @@ def from_csv(self, csv_file_path: str, images_column_name: str):
self.images_column_name = images_column_name
return pd.read_csv(self.csv_file_path)[self.images_column_name].to_list()


class Search_Setup:
""" A class for setting up and running image similarity search."""
def __init__(self, image_list: list, model_name='vgg19', pretrained=True, image_count: int = None):
"""A class for setting up and running image similarity search."""

def __init__(
self,
image_list: list,
model_name: str = "vgg19",
pretrained: bool = True,
image_count: Optional[int] = None,
custom_feature_extractor: Optional[Callable] = None,
custom_feature_extractor_name: Optional[str] = None,
):
"""
Parameters:
-----------
image_list : list
A list of images to be indexed and searched.
model_name : str, optional (default='vgg19')
The name of the pre-trained model to use for feature extraction.
pretrained : bool, optional (default=True)
Whether to use the pre-trained weights for the chosen model.
image_count : int, optional (default=None)
The number of images to be indexed and searched. If None, all images in the image_list will be used.
A list of images to be indexed and searched.
model_name : str, optional
The name of the pre-trained model to use for feature extraction (default='vgg19').
pretrained : bool, optional
Whether to use the pre-trained weights for the chosen model (default=True).
image_count : int, optional
The number of images to be indexed and searched. If None, all images in the image_list will be used (default=None).
model : torch.nn.Module, optional
Custom model for feature extraction (default=None).
custom_model_name : str, optional
Name of the custom model (default=None).
preprocess_fn : Callable, optional
Custom preprocess function (default=None).
"""
self.model_name = model_name
self.pretrained = pretrained
self.image_data = pd.DataFrame()
self.d = None
if image_count==None:
self.image_list = image_list
else:
self.image_list = image_list[:image_count]
self.image_list = (
image_list[:image_count] if image_count is not None else image_list
)

if f'metadata-files/{self.model_name}' not in os.listdir():
try:
os.makedirs(f'metadata-files/{self.model_name}')
except Exception as e:
pass
#print(f'\033[91m file already exists: metadata-files/{self.model_name}')
# Load relevant model
if custom_feature_extractor is None:
# Load the pre-trained model and remove the last layer
print("\033[91m Please wait, model is loading or downloading from server!")
base_model = timm.create_model(self.model_name, pretrained=self.pretrained)
self.model = torch.nn.Sequential(*list(base_model.children())[:-1])
self.model.eval()
print(f"\033[92m Model loaded successfully: {model_name}")
self.using_custom_feature_extractor = False

elif custom_feature_extractor is not None:
self.model = custom_feature_extractor
self.model_name = (
custom_feature_extractor_name or "custom_feature_extractor"
)
self.using_custom_feature_extractor = True

# Define preprocess function
self.default_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)

# Load the pre-trained model and remove the last layer
print("\033[91m Please Wait Model Is Loading or Downloading From Server!")
base_model = timm.create_model(self.model_name, pretrained=self.pretrained)
self.model = torch.nn.Sequential(*list(base_model.children())[:-1])
self.model.eval()
print(f"\033[92m Model Loaded Successfully: {model_name}")
# Create metadata directory
self.metadata_dir = os.path.join(os.getcwd(), "metadata_dir")
os.makedirs(self.metadata_dir, exist_ok=True)

def _default_preprocess_fn(self, img, transforms=None):
"""Default preprocess function to preprocess the image."""
transforms = transforms or self.default_transforms
x = transforms(img)
x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False)
return x

def _extract(self, img):
"""Extract features from the image."""
# Resize and convert the image
img = img.resize((224, 224))
img = img.convert('RGB')

# Preprocess the image
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224, 0.225]),
])
x = preprocess(img)
x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False)
img = img.convert("RGB")

# Extract features from the image using the custom feature extractor
if self.using_custom_feature_extractor:
feature = self.model(img)

# Extract features from the image using the pre-trained model
elif not self.using_custom_feature_extractor:
x = self._default_preprocess_fn(img, transforms=self.default_transforms)
feature = self.model(x)
feature = feature.data.numpy()

# Extract features
feature = self.model(x)
feature = feature.data.numpy().flatten()
# Normalize the feature vector
feature = feature.flatten()
return feature / np.linalg.norm(feature)

def _get_feature(self, image_data: list):
Expand All @@ -120,49 +163,65 @@ def _get_feature(self, image_data: list):
feature = self._extract(img=Image.open(img_path))
features.append(feature)
except:
# If there is an error, append None to the feature list
features.append(None)
continue
# If there is an error, append None to the feature list
features.append(None)
continue
return features

def _start_feature_extraction(self):
image_data = pd.DataFrame()
image_data['images_paths'] = self.image_list
image_data["images_paths"] = self.image_list
f_data = self._get_feature(self.image_list)
image_data['features'] = f_data
image_data["features"] = f_data
image_data = image_data.dropna().reset_index(drop=True)
image_data.to_pickle(config.image_data_with_features_pkl(self.model_name))
print(f"\033[94m Image Meta Information Saved: [metadata-files/{self.model_name}/image_data_features.pkl]")

image_data.to_pickle(
config.image_data_with_features_pkl(self.metadata_dir, self.model_name)
)

print(
"\033[94m Image Meta Information Saved: [os.path.join(self.metadata_dir, self.model_name, 'image_data_features.pkl')]"
)
return image_data

def _start_indexing(self, image_data):
self.image_data = image_data
d = len(image_data['features'][0]) # Length of item vector that will be indexed
d = len(image_data["features"][0]) # Length of item vector that will be indexed
self.d = d
index = faiss.IndexFlatL2(d)
features_matrix = np.vstack(image_data['features'].values).astype(np.float32)
features_matrix = np.vstack(image_data["features"].values).astype(np.float32)
index.add(features_matrix) # Add the features matrix to the index
faiss.write_index(index, config.image_features_vectors_idx(self.model_name))
print("\033[94m Saved The Indexed File:" + f"[metadata-files/{self.model_name}/image_features_vectors.idx]")
faiss.write_index(
index, config.image_features_vectors_idx(self.metadata_dir, self.model_name)
)

print(
"\033[94m Saved The Indexed File:"
+ f"[os.path.join(self.metadata_dir, self.model_name, 'image_features_vectors.idx')]"
)

def run_index(self):
"""
Indexes the images in the image_list and creates an index file for fast similarity search.
"""
if len(os.listdir(f'metadata-files/{self.model_name}')) == 0:
if len(os.listdir(self.metadata_dir)) == 0:
data = self._start_feature_extraction()
self._start_indexing(data)
else:
print("\033[91m Metadata and Features are already present, Do you want Extract Again? Enter yes or no")
print(
"\033[91m Metadata and Features are already present, Do you want Extract Again? Enter yes or no"
)
flag = str(input())
if flag.lower() == 'yes':
if flag.lower() == "yes":
data = self._start_feature_extraction()
self._start_indexing(data)
else:
print("\033[93m Meta data already Present, Please Apply Search!")
print(os.listdir(f'metadata-files/{self.model_name}'))
self.image_data = pd.read_pickle(config.image_data_with_features_pkl(self.model_name))
self.f = len(self.image_data['features'][0])
print(os.listdir(self.metadata_dir))
self.image_data = pd.read_pickle(
config.image_data_with_features_pkl(self.metadata_dir, self.model_name)
)
self.f = len(self.image_data["features"][0])

def add_images_to_index(self, new_image_paths: list):
"""
Expand All @@ -174,8 +233,12 @@ def add_images_to_index(self, new_image_paths: list):
A list of paths to the new images to be added to the index.
"""
# Load existing metadata and index
self.image_data = pd.read_pickle(config.image_data_with_features_pkl(self.model_name))
index = faiss.read_index(config.image_features_vectors_idx(self.model_name))
self.image_data = pd.read_pickle(
config.image_data_with_features_pkl(self.metadata_dir, self.model_name)
)
index = faiss.read_index(
config.image_features_vectors_idx(self.metadta_dir, self.model_name)
)

for new_image_path in tqdm(new_image_paths):
# Extract features from the new image
Expand All @@ -187,25 +250,35 @@ def add_images_to_index(self, new_image_paths: list):
continue

# Add the new image to the metadata
new_metadata = pd.DataFrame({"images_paths": [new_image_path], "features": [feature]})
#self.image_data = self.image_data.append(new_metadata, ignore_index=True)
self.image_data =pd.concat([self.image_data, new_metadata], axis=0, ignore_index=True)
new_metadata = pd.DataFrame(
{"images_paths": [new_image_path], "features": [feature]}
)
# self.image_data = self.image_data.append(new_metadata, ignore_index=True)
self.image_data = pd.concat(
[self.image_data, new_metadata], axis=0, ignore_index=True
)

# Add the new image to the index
index.add(np.array([feature], dtype=np.float32))

# Save the updated metadata and index
self.image_data.to_pickle(config.image_data_with_features_pkl(self.model_name))
faiss.write_index(index, config.image_features_vectors_idx(self.model_name))
self.image_data.to_pickle(
config.image_data_with_features_pkl(self.metadata_dir, self.model_name)
)
faiss.write_index(
index, config.image_features_vectors_idx(self.metadta_dir, self.model_name)
)

print(f"\033[92m New images added to the index: {len(new_image_paths)}")

def _search_by_vector(self, v, n: int):
self.v = v
self.n = n
index = faiss.read_index(config.image_features_vectors_idx(self.model_name))
index = faiss.read_index(
config.image_features_vectors_idx(self.metadata_dir, self.model_name)
)
D, I = index.search(np.array([self.v], dtype=np.float32), self.n)
return dict(zip(I[0], self.image_data.iloc[I[0]]['images_paths'].to_list()))
return dict(zip(I[0], self.image_data.iloc[I[0]]["images_paths"].to_list()))

def _get_query_vector(self, image_path: str):
self.image_path = image_path
Expand All @@ -227,10 +300,10 @@ def plot_similar_images(self, image_path: str, number_of_images: int = 6):
input_img = Image.open(image_path)
input_img_resized = ImageOps.fit(input_img, (224, 224), Image.LANCZOS)
plt.figure(figsize=(5, 5))
plt.axis('off')
plt.title('Input Image', fontsize=18)
plt.axis("off")
plt.title("Input Image", fontsize=18)
plt.imshow(input_img_resized)
plt.show()
# plt.show()

query_vector = self._get_query_vector(image_path)
img_list = list(self._search_by_vector(query_vector, number_of_images).values())
Expand All @@ -240,14 +313,15 @@ def plot_similar_images(self, image_path: str, number_of_images: int = 6):
fig = plt.figure(figsize=(20, 15))
for a in range(number_of_images):
axes.append(fig.add_subplot(grid_size, grid_size, a + 1))
plt.axis('off')
plt.axis("off")
img = Image.open(img_list[a])
img_resized = ImageOps.fit(img, (224, 224), Image.LANCZOS)
plt.imshow(img_resized)
fig.tight_layout()
fig.subplots_adjust(top=0.93)
fig.suptitle('Similar Result Found', fontsize=22)
plt.show(fig)
fig.suptitle("Similar Result Found", fontsize=22)
# plt.show(fig)
plt.show()

def get_similar_images(self, image_path: str, number_of_images: int = 10):
"""
Expand All @@ -265,6 +339,7 @@ def get_similar_images(self, image_path: str, number_of_images: int = 10):
query_vector = self._get_query_vector(self.image_path)
img_dict = self._search_by_vector(query_vector, self.number_of_images)
return img_dict

def get_image_metadata_file(self):
"""
Returns the metadata file containing information about the indexed images.
Expand All @@ -274,5 +349,7 @@ def get_image_metadata_file(self):
DataFrame
The Panda DataFrame of the metadata file.
"""
self.image_data = pd.read_pickle(config.image_data_with_features_pkl(self.model_name))
self.image_data = pd.read_pickle(
config.image_data_with_features_pkl(self.metadata_dir, self.model_name)
)
return self.image_data
2 changes: 1 addition & 1 deletion DeepImageSearch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from DeepImageSearch.DeepImageSearch import Load_Data,Search_Setup
from DeepImageSearch.DeepImageSearch import Load_Data, Search_Setup
12 changes: 9 additions & 3 deletions DeepImageSearch/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import os
import os


def image_data_with_features_pkl(model_name):
image_data_with_features_pkl = os.path.join('metadata-files/',f'{model_name}/','image_data_features.pkl')
image_data_with_features_pkl = os.path.join(
"metadata-files/", f"{model_name}/", "image_data_features.pkl"
)
return image_data_with_features_pkl


def image_features_vectors_idx(model_name):
image_features_vectors_idx = os.path.join('metadata-files/',f'{model_name}/','image_features_vectors.idx')
image_features_vectors_idx = os.path.join(
"metadata-files/", f"{model_name}/", "image_features_vectors.idx"
)
return image_features_vectors_idx
Loading