diff --git a/DeepImageSearch/DeepImageSearch.py b/DeepImageSearch/DeepImageSearch.py index 8377dc0..ac58194 100644 --- a/DeepImageSearch/DeepImageSearch.py +++ b/DeepImageSearch/DeepImageSearch.py @@ -12,6 +12,8 @@ from PIL import ImageOps import math import faiss +from typing import Callable, Optional + class Load_Data: """A class for loading data from single/multiple folders or a CSV file""" @@ -21,7 +23,7 @@ def __init__(self): Initializes an instance of LoadData class """ pass - + def from_folder(self, folder_list: list): """ Adds images from the specified folders to the image_list. @@ -36,7 +38,7 @@ def from_folder(self, folder_list: list): for folder in self.folder_list: for root, dirs, files in os.walk(folder): for file in files: - if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): + if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")): image_path.append(os.path.join(root, file)) return image_path @@ -55,60 +57,101 @@ def from_csv(self, csv_file_path: str, images_column_name: str): self.images_column_name = images_column_name return pd.read_csv(self.csv_file_path)[self.images_column_name].to_list() + class Search_Setup: - """ A class for setting up and running image similarity search.""" - def __init__(self, image_list: list, model_name='vgg19', pretrained=True, image_count: int = None): + """A class for setting up and running image similarity search.""" + + def __init__( + self, + image_list: list, + model_name: str = "vgg19", + pretrained: bool = True, + image_count: Optional[int] = None, + custom_feature_extractor: Optional[Callable] = None, + custom_feature_extractor_name: Optional[str] = None, + ): """ Parameters: ----------- image_list : list - A list of images to be indexed and searched. - model_name : str, optional (default='vgg19') - The name of the pre-trained model to use for feature extraction. - pretrained : bool, optional (default=True) - Whether to use the pre-trained weights for the chosen model. - image_count : int, optional (default=None) - The number of images to be indexed and searched. If None, all images in the image_list will be used. + A list of images to be indexed and searched. + model_name : str, optional + The name of the pre-trained model to use for feature extraction (default='vgg19'). + pretrained : bool, optional + Whether to use the pre-trained weights for the chosen model (default=True). + image_count : int, optional + The number of images to be indexed and searched. If None, all images in the image_list will be used (default=None). + model : torch.nn.Module, optional + Custom model for feature extraction (default=None). + custom_model_name : str, optional + Name of the custom model (default=None). + preprocess_fn : Callable, optional + Custom preprocess function (default=None). """ self.model_name = model_name self.pretrained = pretrained self.image_data = pd.DataFrame() self.d = None - if image_count==None: - self.image_list = image_list - else: - self.image_list = image_list[:image_count] + self.image_list = ( + image_list[:image_count] if image_count is not None else image_list + ) - if f'metadata-files/{self.model_name}' not in os.listdir(): - try: - os.makedirs(f'metadata-files/{self.model_name}') - except Exception as e: - pass - #print(f'\033[91m file already exists: metadata-files/{self.model_name}') + # Load relevant model + if custom_feature_extractor is None: + # Load the pre-trained model and remove the last layer + print("\033[91m Please wait, model is loading or downloading from server!") + base_model = timm.create_model(self.model_name, pretrained=self.pretrained) + self.model = torch.nn.Sequential(*list(base_model.children())[:-1]) + self.model.eval() + print(f"\033[92m Model loaded successfully: {model_name}") + self.using_custom_feature_extractor = False + + elif custom_feature_extractor is not None: + self.model = custom_feature_extractor + self.model_name = ( + custom_feature_extractor_name or "custom_feature_extractor" + ) + self.using_custom_feature_extractor = True + + # Define preprocess function + self.default_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) - # Load the pre-trained model and remove the last layer - print("\033[91m Please Wait Model Is Loading or Downloading From Server!") - base_model = timm.create_model(self.model_name, pretrained=self.pretrained) - self.model = torch.nn.Sequential(*list(base_model.children())[:-1]) - self.model.eval() - print(f"\033[92m Model Loaded Successfully: {model_name}") + # Create metadata directory + self.metadata_dir = os.path.join(os.getcwd(), "metadata_dir") + os.makedirs(self.metadata_dir, exist_ok=True) + + def _default_preprocess_fn(self, img, transforms=None): + """Default preprocess function to preprocess the image.""" + transforms = transforms or self.default_transforms + x = transforms(img) + x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False) + return x def _extract(self, img): + """Extract features from the image.""" # Resize and convert the image img = img.resize((224, 224)) - img = img.convert('RGB') - - # Preprocess the image - preprocess = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224, 0.225]), - ]) - x = preprocess(img) - x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False) + img = img.convert("RGB") + + # Extract features from the image using the custom feature extractor + if self.using_custom_feature_extractor: + feature = self.model(img) + + # Extract features from the image using the pre-trained model + elif not self.using_custom_feature_extractor: + x = self._default_preprocess_fn(img, transforms=self.default_transforms) + feature = self.model(x) + feature = feature.data.numpy() - # Extract features - feature = self.model(x) - feature = feature.data.numpy().flatten() + # Normalize the feature vector + feature = feature.flatten() return feature / np.linalg.norm(feature) def _get_feature(self, image_data: list): @@ -120,49 +163,65 @@ def _get_feature(self, image_data: list): feature = self._extract(img=Image.open(img_path)) features.append(feature) except: - # If there is an error, append None to the feature list - features.append(None) - continue + # If there is an error, append None to the feature list + features.append(None) + continue return features def _start_feature_extraction(self): image_data = pd.DataFrame() - image_data['images_paths'] = self.image_list + image_data["images_paths"] = self.image_list f_data = self._get_feature(self.image_list) - image_data['features'] = f_data + image_data["features"] = f_data image_data = image_data.dropna().reset_index(drop=True) - image_data.to_pickle(config.image_data_with_features_pkl(self.model_name)) - print(f"\033[94m Image Meta Information Saved: [metadata-files/{self.model_name}/image_data_features.pkl]") + + image_data.to_pickle( + config.image_data_with_features_pkl(self.metadata_dir, self.model_name) + ) + + print( + "\033[94m Image Meta Information Saved: [os.path.join(self.metadata_dir, self.model_name, 'image_data_features.pkl')]" + ) return image_data def _start_indexing(self, image_data): self.image_data = image_data - d = len(image_data['features'][0]) # Length of item vector that will be indexed + d = len(image_data["features"][0]) # Length of item vector that will be indexed self.d = d index = faiss.IndexFlatL2(d) - features_matrix = np.vstack(image_data['features'].values).astype(np.float32) + features_matrix = np.vstack(image_data["features"].values).astype(np.float32) index.add(features_matrix) # Add the features matrix to the index - faiss.write_index(index, config.image_features_vectors_idx(self.model_name)) - print("\033[94m Saved The Indexed File:" + f"[metadata-files/{self.model_name}/image_features_vectors.idx]") + faiss.write_index( + index, config.image_features_vectors_idx(self.metadata_dir, self.model_name) + ) + + print( + "\033[94m Saved The Indexed File:" + + f"[os.path.join(self.metadata_dir, self.model_name, 'image_features_vectors.idx')]" + ) def run_index(self): """ Indexes the images in the image_list and creates an index file for fast similarity search. """ - if len(os.listdir(f'metadata-files/{self.model_name}')) == 0: + if len(os.listdir(self.metadata_dir)) == 0: data = self._start_feature_extraction() self._start_indexing(data) else: - print("\033[91m Metadata and Features are already present, Do you want Extract Again? Enter yes or no") + print( + "\033[91m Metadata and Features are already present, Do you want Extract Again? Enter yes or no" + ) flag = str(input()) - if flag.lower() == 'yes': + if flag.lower() == "yes": data = self._start_feature_extraction() self._start_indexing(data) else: print("\033[93m Meta data already Present, Please Apply Search!") - print(os.listdir(f'metadata-files/{self.model_name}')) - self.image_data = pd.read_pickle(config.image_data_with_features_pkl(self.model_name)) - self.f = len(self.image_data['features'][0]) + print(os.listdir(self.metadata_dir)) + self.image_data = pd.read_pickle( + config.image_data_with_features_pkl(self.metadata_dir, self.model_name) + ) + self.f = len(self.image_data["features"][0]) def add_images_to_index(self, new_image_paths: list): """ @@ -174,8 +233,12 @@ def add_images_to_index(self, new_image_paths: list): A list of paths to the new images to be added to the index. """ # Load existing metadata and index - self.image_data = pd.read_pickle(config.image_data_with_features_pkl(self.model_name)) - index = faiss.read_index(config.image_features_vectors_idx(self.model_name)) + self.image_data = pd.read_pickle( + config.image_data_with_features_pkl(self.metadata_dir, self.model_name) + ) + index = faiss.read_index( + config.image_features_vectors_idx(self.metadta_dir, self.model_name) + ) for new_image_path in tqdm(new_image_paths): # Extract features from the new image @@ -187,25 +250,35 @@ def add_images_to_index(self, new_image_paths: list): continue # Add the new image to the metadata - new_metadata = pd.DataFrame({"images_paths": [new_image_path], "features": [feature]}) - #self.image_data = self.image_data.append(new_metadata, ignore_index=True) - self.image_data =pd.concat([self.image_data, new_metadata], axis=0, ignore_index=True) + new_metadata = pd.DataFrame( + {"images_paths": [new_image_path], "features": [feature]} + ) + # self.image_data = self.image_data.append(new_metadata, ignore_index=True) + self.image_data = pd.concat( + [self.image_data, new_metadata], axis=0, ignore_index=True + ) # Add the new image to the index index.add(np.array([feature], dtype=np.float32)) # Save the updated metadata and index - self.image_data.to_pickle(config.image_data_with_features_pkl(self.model_name)) - faiss.write_index(index, config.image_features_vectors_idx(self.model_name)) + self.image_data.to_pickle( + config.image_data_with_features_pkl(self.metadata_dir, self.model_name) + ) + faiss.write_index( + index, config.image_features_vectors_idx(self.metadta_dir, self.model_name) + ) print(f"\033[92m New images added to the index: {len(new_image_paths)}") def _search_by_vector(self, v, n: int): self.v = v self.n = n - index = faiss.read_index(config.image_features_vectors_idx(self.model_name)) + index = faiss.read_index( + config.image_features_vectors_idx(self.metadata_dir, self.model_name) + ) D, I = index.search(np.array([self.v], dtype=np.float32), self.n) - return dict(zip(I[0], self.image_data.iloc[I[0]]['images_paths'].to_list())) + return dict(zip(I[0], self.image_data.iloc[I[0]]["images_paths"].to_list())) def _get_query_vector(self, image_path: str): self.image_path = image_path @@ -227,10 +300,10 @@ def plot_similar_images(self, image_path: str, number_of_images: int = 6): input_img = Image.open(image_path) input_img_resized = ImageOps.fit(input_img, (224, 224), Image.LANCZOS) plt.figure(figsize=(5, 5)) - plt.axis('off') - plt.title('Input Image', fontsize=18) + plt.axis("off") + plt.title("Input Image", fontsize=18) plt.imshow(input_img_resized) - plt.show() + # plt.show() query_vector = self._get_query_vector(image_path) img_list = list(self._search_by_vector(query_vector, number_of_images).values()) @@ -240,14 +313,15 @@ def plot_similar_images(self, image_path: str, number_of_images: int = 6): fig = plt.figure(figsize=(20, 15)) for a in range(number_of_images): axes.append(fig.add_subplot(grid_size, grid_size, a + 1)) - plt.axis('off') + plt.axis("off") img = Image.open(img_list[a]) img_resized = ImageOps.fit(img, (224, 224), Image.LANCZOS) plt.imshow(img_resized) fig.tight_layout() fig.subplots_adjust(top=0.93) - fig.suptitle('Similar Result Found', fontsize=22) - plt.show(fig) + fig.suptitle("Similar Result Found", fontsize=22) + # plt.show(fig) + plt.show() def get_similar_images(self, image_path: str, number_of_images: int = 10): """ @@ -265,6 +339,7 @@ def get_similar_images(self, image_path: str, number_of_images: int = 10): query_vector = self._get_query_vector(self.image_path) img_dict = self._search_by_vector(query_vector, self.number_of_images) return img_dict + def get_image_metadata_file(self): """ Returns the metadata file containing information about the indexed images. @@ -274,5 +349,7 @@ def get_image_metadata_file(self): DataFrame The Panda DataFrame of the metadata file. """ - self.image_data = pd.read_pickle(config.image_data_with_features_pkl(self.model_name)) + self.image_data = pd.read_pickle( + config.image_data_with_features_pkl(self.metadata_dir, self.model_name) + ) return self.image_data diff --git a/DeepImageSearch/__init__.py b/DeepImageSearch/__init__.py index bf34717..4b39dd6 100644 --- a/DeepImageSearch/__init__.py +++ b/DeepImageSearch/__init__.py @@ -1 +1 @@ -from DeepImageSearch.DeepImageSearch import Load_Data,Search_Setup +from DeepImageSearch.DeepImageSearch import Load_Data, Search_Setup diff --git a/DeepImageSearch/config.py b/DeepImageSearch/config.py index 187a067..94f8d53 100644 --- a/DeepImageSearch/config.py +++ b/DeepImageSearch/config.py @@ -1,9 +1,15 @@ -import os +import os + def image_data_with_features_pkl(model_name): - image_data_with_features_pkl = os.path.join('metadata-files/',f'{model_name}/','image_data_features.pkl') + image_data_with_features_pkl = os.path.join( + "metadata-files/", f"{model_name}/", "image_data_features.pkl" + ) return image_data_with_features_pkl + def image_features_vectors_idx(model_name): - image_features_vectors_idx = os.path.join('metadata-files/',f'{model_name}/','image_features_vectors.idx') + image_features_vectors_idx = os.path.join( + "metadata-files/", f"{model_name}/", "image_features_vectors.idx" + ) return image_features_vectors_idx diff --git a/Demo/DeepImageSearchDemo.py b/Demo/DeepImageSearchDemo.py index a68e9df..a9cf3fa 100644 --- a/Demo/DeepImageSearchDemo.py +++ b/Demo/DeepImageSearchDemo.py @@ -4,10 +4,12 @@ from DeepImageSearch import Load_Data, Search_Setup # Load images from a folder -image_list = Load_Data().from_folder(['folder_path']) +image_list = Load_Data().from_folder(["folder_path"]) # Set up the search engine -st = Search_Setup(image_list=image_list, model_name='vgg19', pretrained=True, image_count=100) +st = Search_Setup( + image_list=image_list, model_name="vgg19", pretrained=True, image_count=100 +) # Index the images st.run_index() @@ -16,13 +18,13 @@ metadata = st.get_image_metadata_file() # Add New images to the index -st.add_images_to_index(['image_path_1', 'image_path_2']) +st.add_images_to_index(["image_path_1", "image_path_2"]) # Get similar images -st.get_similar_images(image_path='image_path', number_of_images=10) +st.get_similar_images(image_path="image_path", number_of_images=10) # Plot similar images -st.plot_similar_images(image_path='image_path', number_of_images=9) +st.plot_similar_images(image_path="image_path", number_of_images=9) # Update metadata -metadata = st.get_image_metadata_file() \ No newline at end of file +metadata = st.get_image_metadata_file() diff --git a/setup.py b/setup.py index 00d9012..8f8763b 100644 --- a/setup.py +++ b/setup.py @@ -9,38 +9,38 @@ README = (HERE / "README.md").read_text() setup( - long_description_content_type="text/markdown", - name = 'DeepImageSearch', - packages = ['DeepImageSearch'], - version = '2.5', - license='MIT', - description = 'DeepImageSearch is a Python library for fast and accurate image search. It offers seamless integration with Python, GPU support, and advanced capabilities for identifying complex image patterns using the Vision Transformer models.', - long_description=README, - author = 'Nilesh Verma', - author_email = 'me@nileshverma.com', - url = 'https://github.com/TechyNilesh/DeepImageSearch', - download_url = 'https://github.com/TechyNilesh/DeepImageSearch/archive/refs/tags/v_25.tar.gz', - keywords = ['Deep Image Search Engine', 'AI Image search', 'Image Search Python'], - install_requires=[ - 'faiss_cpu', - 'torch', - 'torchvision', - 'matplotlib', - 'pandas', - 'numpy', - 'tqdm', - 'Pillow', - 'timm' - ], - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Build Tools', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10' - ], + long_description_content_type="text/markdown", + name="DeepImageSearch", + packages=["DeepImageSearch"], + version="2.6", + license="MIT", + description="DeepImageSearch is a Python library for fast and accurate image search. It offers seamless integration with Python, GPU support, and advanced capabilities for identifying complex image patterns using the Vision Transformer models. Custom feature extractors can also be used to extract features from images.", + long_description=README, + author="Nilesh Verma", + author_email="me@nileshverma.com", + url="https://github.com/TechyNilesh/DeepImageSearch", + download_url="https://github.com/TechyNilesh/DeepImageSearch/archive/refs/tags/v_25.tar.gz", + keywords=["Deep Image Search Engine", "AI Image search", "Image Search Python"], + install_requires=[ + "faiss_cpu", + "torch", + "torchvision", + "matplotlib", + "pandas", + "numpy", + "tqdm", + "Pillow", + "timm", + ], + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Topic :: Software Development :: Build Tools", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], )