|
| 1 | +from typing import Union |
| 2 | +import os |
| 3 | +import re |
| 4 | +import logging |
| 5 | +from tqdm.rich import tqdm |
| 6 | +import torch |
| 7 | +import PIL |
| 8 | +import faiss |
| 9 | +import numpy as np |
| 10 | +import pandas as pd |
| 11 | +import transformers |
| 12 | + |
| 13 | + |
| 14 | +class ImageDB: |
| 15 | + # TODO index: quantize and train faiss index |
| 16 | + # TODO index: clip batch processing |
| 17 | + def __init__(self, |
| 18 | + name:str='db', |
| 19 | + fmt:str='json', |
| 20 | + cache_dir:str=None, |
| 21 | + dtype:torch.dtype=torch.float16, |
| 22 | + device:torch.device=torch.device('cpu'), |
| 23 | + model:str='openai/clip-vit-large-patch14', # 'facebook/dinov2-small' |
| 24 | + debug:bool=False, |
| 25 | + pbar:bool=True, |
| 26 | + ): |
| 27 | + self.format = fmt |
| 28 | + self.name = name |
| 29 | + self.cache_dir = cache_dir |
| 30 | + self.processor: transformers.AutoImageProcessor = None |
| 31 | + self.model: transformers.AutoModel = None |
| 32 | + self.tokenizer = transformers.AutoTokenizer = None |
| 33 | + self.device: torch.device = device |
| 34 | + self.dtype: torch.dtype = dtype |
| 35 | + self.dimension = 768 if 'clip' in model else 384 |
| 36 | + self.debug = debug |
| 37 | + self.pbar = pbar |
| 38 | + self.repo = model |
| 39 | + self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata']) # image/metadata database |
| 40 | + self.index = faiss.IndexFlatL2(self.dimension) # embed database |
| 41 | + self.log = logging.getLogger(__name__) |
| 42 | + self.err = logging.getLogger(__name__).error |
| 43 | + self.log = logging.getLogger(__name__).info if self.debug else logging.getLogger(__name__).debug |
| 44 | + # self.init() |
| 45 | + # self.load() |
| 46 | + |
| 47 | + def __str__(self): |
| 48 | + return f'db: name="{self.name}" format={self.format} device={self.device} dtype={self.dtype} dimension={self.dimension} model="{self.repo}" records={len(self.df)} index={self.index.ntotal}' |
| 49 | + |
| 50 | + def init(self): # initialize models |
| 51 | + if self.processor is None or self.model is None: |
| 52 | + if 'clip' in self.repo: |
| 53 | + self.processor = transformers.CLIPImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir) |
| 54 | + self.tokenizer = transformers.CLIPTokenizer.from_pretrained(self.repo, cache_dir=self.cache_dir) |
| 55 | + self.model = transformers.CLIPModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype) |
| 56 | + elif 'dino' in self.repo: |
| 57 | + self.processor = transformers.AutoImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir) |
| 58 | + self.model = transformers.AutoModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype) |
| 59 | + else: |
| 60 | + self.err(f'db: model="{self.repo}" unknown') |
| 61 | + self.log(f'db: load model="{self.repo}" cache="{self.cache_dir}" device={self.device} dtype={self.dtype}') |
| 62 | + |
| 63 | + def load(self): # load db to disk |
| 64 | + if self.format == 'json' and os.path.exists(f'{self.name}.json'): |
| 65 | + self.df = pd.read_json(f'{self.name}.json') |
| 66 | + elif self.format == 'csv' and os.path.exists(f'{self.name}.csv'): |
| 67 | + self.df = pd.read_csv(f'{self.name}.csv') |
| 68 | + elif self.format == 'pickle' and os.path.exists(f'{self.name}.pkl'): |
| 69 | + self.df = pd.read_pickle(f'{self.name}.parquet') |
| 70 | + if os.path.exists(f'{self.name}.index'): |
| 71 | + self.index = faiss.read_index(f'{self.name}.index') |
| 72 | + if self.index.ntotal != len(self.df): |
| 73 | + self.err(f'db: index={self.index.ntotal} data={len(self.df)} mismatch') |
| 74 | + self.index = faiss.IndexFlatL2(self.dimension) |
| 75 | + self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata']) |
| 76 | + self.log(f'db: load data={len(self.df)} name={self.name} format={self.format} name={self.name}') |
| 77 | + |
| 78 | + def save(self): # save db to disk |
| 79 | + if self.format == 'json': |
| 80 | + self.df.to_json(f'{self.name}.json') |
| 81 | + elif self.format == 'csv': |
| 82 | + self.df.to_csv(f'{self.name}.csv') |
| 83 | + elif self.format == 'pickle': |
| 84 | + self.df.to_pickle(f'{self.name}.pkl') |
| 85 | + faiss.write_index(self.index, f'{self.name}.index') |
| 86 | + self.log(f'db: save data={len(self.df)} name={self.name} format={self.format} name={self.name}') |
| 87 | + |
| 88 | + def normalize(self, embed) -> np.ndarray: # normalize embed before using it |
| 89 | + embed = embed.detach().float().cpu().numpy() |
| 90 | + faiss.normalize_L2(embed) |
| 91 | + return embed |
| 92 | + |
| 93 | + def embedding(self, query: Union[PIL.Image.Image | str]) -> np.ndarray: # calculate embed for prompt or image |
| 94 | + if self.processor is None or self.model is None: |
| 95 | + self.err('db: model not loaded') |
| 96 | + if isinstance(query, str) and os.path.exists(query): |
| 97 | + query = PIL.Image.open(query).convert('RGB') |
| 98 | + self.model = self.model.to(self.device) |
| 99 | + with torch.no_grad(): |
| 100 | + if 'clip' in self.repo: |
| 101 | + if isinstance(query, str): |
| 102 | + processed = self.tokenizer(text=query, padding=True, return_tensors="pt").to(device=self.device) |
| 103 | + results = self.model.get_text_features(**processed) |
| 104 | + else: |
| 105 | + processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype) |
| 106 | + results = self.model.get_image_features(**processed) |
| 107 | + elif 'dino' in self.repo: |
| 108 | + processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype) |
| 109 | + results = self.model(**processed) |
| 110 | + results = results.last_hidden_state.mean(dim=1) |
| 111 | + else: |
| 112 | + self.err(f'db: model="{self.repo}" unknown') |
| 113 | + return None |
| 114 | + return self.normalize(results) |
| 115 | + |
| 116 | + def add(self, embed, filename=None, metadata=None): # add embed to db |
| 117 | + rec = pd.DataFrame([{'filename': filename, 'timestamp': pd.Timestamp.now(), 'metadata': metadata}]) |
| 118 | + if len(self.df) > 0: |
| 119 | + self.df = pd.concat([self.df, rec], ignore_index=True) |
| 120 | + else: |
| 121 | + self.df = rec |
| 122 | + self.index.add(embed) |
| 123 | + |
| 124 | + def search(self, filename: str = None, metadata: str = None, embed: np.ndarray = None, k=10, d=1.0): # search by filename/metadata/prompt-embed/image-embed |
| 125 | + def dct(record: pd.DataFrame, mode: str, distance: float = None): |
| 126 | + if distance is not None: |
| 127 | + return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata'], 'distance': round(distance, 2)} |
| 128 | + else: |
| 129 | + return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata']} |
| 130 | + |
| 131 | + if self.index.ntotal == 0: |
| 132 | + return |
| 133 | + self.log(f'db: search k={k} d={d}') |
| 134 | + if embed is not None: |
| 135 | + distances, indexes = self.index.search(embed, k) |
| 136 | + records = self.df.iloc[indexes[0]] |
| 137 | + for record, distance in zip(records.iterrows(), distances[0]): |
| 138 | + if d <= 0 or distance <= d: |
| 139 | + yield dct(record, distance=distance, mode='embed') |
| 140 | + if filename is not None: |
| 141 | + records = self.df[self.df['filename'].str.contains(filename, na=False, case=False)] |
| 142 | + for record in records.iterrows(): |
| 143 | + yield dct(record, mode='filename') |
| 144 | + if metadata is not None: |
| 145 | + records = self.df[self.df['metadata'].str.contains(filename, na=False, case=False)] |
| 146 | + for record in records.iterrows(): |
| 147 | + yield dct(record, mode='metadata') |
| 148 | + |
| 149 | + def decode(self, s: bytes): # decode byte-encoded exif metadata |
| 150 | + remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment |
| 151 | + for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings |
| 152 | + try: |
| 153 | + s = remove_prefix(s, b'UNICODE') |
| 154 | + s = remove_prefix(s, b'ASCII') |
| 155 | + s = remove_prefix(s, b'\x00') |
| 156 | + val = s.decode(encoding, errors="strict") |
| 157 | + val = re.sub(r'[\x00-\x09\n\s\s+]', '', val).strip() # remove remaining special characters, new line breaks, and double empty spaces |
| 158 | + if len(val) == 0: # remove empty strings |
| 159 | + val = None |
| 160 | + return val |
| 161 | + except Exception: |
| 162 | + pass |
| 163 | + return None |
| 164 | + |
| 165 | + def metadata(self, image: PIL.Image.Image): # get exif metadata from image |
| 166 | + exif = image._getexif() # pylint: disable=protected-access |
| 167 | + if exif is None: |
| 168 | + return '' |
| 169 | + for k, v in exif.items(): |
| 170 | + if k == 37510: # comment |
| 171 | + return self.decode(v) |
| 172 | + return '' |
| 173 | + |
| 174 | + def image(self, filename: str, image=None): # add file/image to db |
| 175 | + try: |
| 176 | + if image is None: |
| 177 | + image = PIL.Image.open(filename) |
| 178 | + image.load() |
| 179 | + embed = self.embedding(image.convert('RGB')) |
| 180 | + metadata = self.metadata(image) |
| 181 | + image.close() |
| 182 | + self.add(embed, filename=filename, metadata=metadata) |
| 183 | + except Exception as _e: |
| 184 | + # self.err(f'db: {str(_e)}') |
| 185 | + pass |
| 186 | + |
| 187 | + def folder(self, folder: str): # add all files from folder to db |
| 188 | + files = [] |
| 189 | + for root, _subdir, _files in os.walk(folder): |
| 190 | + for f in _files: |
| 191 | + files.append(os.path.join(root, f)) |
| 192 | + if self.pbar: |
| 193 | + for f in tqdm(files): |
| 194 | + self.image(filename=f) |
| 195 | + else: |
| 196 | + for f in files: |
| 197 | + self.image(filename=f) |
| 198 | + |
| 199 | + def offload(self): # offload model to cpu |
| 200 | + if self.model is not None: |
| 201 | + self.model = self.model.to('cpu') |
| 202 | + |
| 203 | + |
| 204 | +if __name__ == '__main__': |
| 205 | + import time |
| 206 | + import argparse |
| 207 | + logging.basicConfig(level=logging.INFO) |
| 208 | + parser = argparse.ArgumentParser(description = 'image-search') |
| 209 | + group = parser.add_mutually_exclusive_group(required=True) |
| 210 | + group.add_argument('--search', action='store_true', help='run search') |
| 211 | + group.add_argument('--index', action='store_true', help='run indexing') |
| 212 | + parser.add_argument('--db', default='db', help='database name') |
| 213 | + parser.add_argument('--model', default='openai/clip-vit-large-patch14', help='huggingface model') |
| 214 | + parser.add_argument('--cache', default='/mnt/models/huggingface', help='cache folder') |
| 215 | + parser.add_argument('input', nargs='*', default=os.getcwd()) |
| 216 | + args = parser.parse_args() |
| 217 | + |
| 218 | + db = ImageDB( |
| 219 | + name=args.db, |
| 220 | + model=args.model, # 'facebook/dinov2-small' |
| 221 | + cache_dir=args.cache, |
| 222 | + dtype=torch.bfloat16, |
| 223 | + device=torch.device('cuda'), |
| 224 | + debug=True, |
| 225 | + pbar=True, |
| 226 | + ) |
| 227 | + db.init() |
| 228 | + db.load() |
| 229 | + print(db) |
| 230 | + |
| 231 | + if args.index: |
| 232 | + t0 = time.time() |
| 233 | + if len(args.input) > 0: |
| 234 | + for fn in args.input: |
| 235 | + if os.path.isfile(fn): |
| 236 | + db.image(filename=fn) |
| 237 | + elif os.path.isdir(fn): |
| 238 | + db.folder(folder=fn) |
| 239 | + t1 = time.time() |
| 240 | + print('index', t1-t0) |
| 241 | + db.save() |
| 242 | + db.offload() |
| 243 | + |
| 244 | + if args.search: |
| 245 | + for ref in args.input: |
| 246 | + emb = db.embedding(ref) |
| 247 | + res = db.search(filename=ref, metadata=ref, embed=emb, k=10, d=0) |
| 248 | + for r in res: |
| 249 | + print(ref, r) |
0 commit comments