-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
284 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from typing import List, Optional, Tuple, Union | ||
import argparse | ||
import os | ||
from glob import glob | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
from PIL import Image | ||
from transformers import ( | ||
CLIPProcessor, | ||
CLIPTextModel, | ||
CLIPVisionModel, | ||
CLIPModel, | ||
) | ||
|
||
from tqdm.auto import tqdm | ||
|
||
PRETRAINED_MODEL_NAME_OR_PATH = "openai/clip-vit-large-patch14" | ||
|
||
|
||
class CLIPDataset(Dataset): | ||
""" | ||
Dataset for CLIP Score | ||
""" | ||
def __init__(self, img_files: List[str], captions: List[str], processor: CLIPProcessor=None, prefix: str='A photo depicts '): | ||
assert len(img_files) == len(captions), f"Number of images {len(img_files)} and captions {len(captions)} must be the same length" | ||
self.img_files = img_files | ||
self.captions = captions | ||
self.processor = processor | ||
self.prefix = prefix | ||
|
||
def __len__(self): | ||
return len(self.img_files) | ||
|
||
def __getitem__(self, idx: int): | ||
image = Image.open(self.img_files[idx]) | ||
caption = self.prefix + self.captions[idx] | ||
if self.processor is not None: | ||
image = self.processor(images=image, return_tensors="pt") | ||
image["pixel_values"] = image["pixel_values"].squeeze(0) | ||
caption = self.processor(text=caption, return_tensors="pt", padding="max_length", truncation=True, max_length=77) | ||
caption["input_ids"] = caption["input_ids"].squeeze(0) | ||
caption["attention_mask"] = caption["attention_mask"].squeeze(0) | ||
return image, caption | ||
|
||
|
||
class CLIPFeatureExtractor(nn.Module): | ||
"""Reuse CLIP Model to reduce memory footprint.""" | ||
def __init__(self, base_model: Union[CLIPTextModel, CLIPVisionModel], projector: nn.Module): | ||
super(CLIPFeatureExtractor, self).__init__() | ||
self.base_model = base_model | ||
self.projector = projector | ||
# visual_projection for vision_model | ||
# text_projection for text_model | ||
|
||
def forward(self, *args, **kwargs): | ||
outputs = self.base_model(*args, **kwargs) | ||
pooled_output = outputs[1] | ||
return self.projector(pooled_output) | ||
|
||
|
||
def get_model_and_processor(pretrained_model_name_or_path: str=PRETRAINED_MODEL_NAME_OR_PATH) -> Tuple[CLIPModel, CLIPProcessor]: | ||
""" | ||
Get CLIP model and processor | ||
""" | ||
clip_processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path) | ||
clip_model = CLIPModel.from_pretrained(pretrained_model_name_or_path) | ||
return clip_model, clip_processor | ||
|
||
|
||
@torch.no_grad() | ||
def get_clip_score(model: CLIPModel, images, captions, w: float=2.5, device: Optional[torch.device]=None) -> float: | ||
""" | ||
Calculate CLIPScore from images and captions | ||
""" | ||
|
||
model.eval() | ||
|
||
pixel_values = images["pixel_values"].to(device) if device is not None else images["pixel_values"] | ||
input_ids = captions["input_ids"].to(device) if device is not None else captions["input_ids"] | ||
attention_mask = captions["attention_mask"].to(device) if device is not None else captions["attention_mask"] | ||
|
||
image_features = model.get_image_features( | ||
pixel_values=pixel_values, | ||
) # (B, D) | ||
text_features = model.get_text_features( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
) # (B, D) | ||
similarity = w * F.cosine_similarity(image_features, text_features, dim=1) # (B, ) | ||
# cosine similarities range from -1 to 1, but normally, we get values from 0 to 0.4. | ||
# so we multiply by 2.5 to get values from 0 to 1. | ||
score = similarity.mean().item() | ||
|
||
return score | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser( | ||
prog = 'CLIPScore', | ||
description = 'Takes the path to images and prompts and gives CLIPScore') | ||
parser.add_argument('--img_path', help='path to generated images to be evaluated', type=str, required=True) | ||
parser.add_argument('--prompts_path', help='path to txt prompts (separated by newline), If not provided, assume img_path contains prompts.txt', type=str, required=False, default=None) | ||
parser.add_argument('--save_path', help='path to save results', type=str, required=False, default=None) | ||
parser.add_argument('--batch_size', help='batch size', type=int, default=32) | ||
parser.add_argument('--device', help='device to use', type=str, default='cuda:0') | ||
parser.add_argument('--pretrained_model_name_or_path', help='pretrained model name or path', type=str, default=PRETRAINED_MODEL_NAME_OR_PATH) | ||
parser.add_argument('--w', help='weight for cosine similarity', type=float, default=1.0) | ||
parser.add_argument('--ext', help='extention', type=str, default='png') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
device = torch.device(args.device) | ||
clip_model, clip_processor = get_model_and_processor(args.pretrained_model_name_or_path) | ||
clip_model.to(device) | ||
|
||
if args.prompts_path is None: | ||
prompts_path = os.path.join(args.img_path, 'prompts.txt') | ||
else: | ||
prompts_path = args.prompts_path | ||
|
||
with open(prompts_path, 'r') as f: | ||
captions = f.readlines() | ||
captions = [caption.strip() for caption in captions] | ||
|
||
img_files = sorted(glob(os.path.join(args.img_path, f"*.{args.ext.replace('.', '').strip()}")), key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) | ||
dataset = CLIPDataset(img_files, captions, clip_processor, prefix="A photo depicts ") | ||
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) | ||
|
||
score, n = 0.0, 0 | ||
tbar = tqdm(dataloader) | ||
for images, captions in tbar: | ||
score += get_clip_score(clip_model, images, captions, args.w, device) | ||
n += len(images) | ||
tbar.set_description(f"CLIPScore: {score/n:.4f}") | ||
score /= n | ||
|
||
print(score) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from typing import List, Union, Callable, Tuple, Optional | ||
from collections import OrderedDict | ||
import os | ||
from glob import glob | ||
from PIL import Image | ||
import torch | ||
from torch.utils.data import Dataset | ||
import torchvision.transforms as T | ||
|
||
|
||
ALL_EXTS = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.gif'] | ||
|
||
|
||
class ImagePathDataset(Dataset): | ||
def __init__(self, img_files: List[str], transform: Union[Callable, T.Compose]=None): | ||
self.img_files = img_files | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return len(self.img_files) | ||
|
||
def __getitem__(self, idx): | ||
img = Image.open(self.img_files[idx]).convert('RGB') | ||
if self.transform: | ||
img = self.transform(img) | ||
return img | ||
|
||
|
||
def get_transform( | ||
size: int=256, | ||
normalize: Optional[Union[bool, Tuple[Tuple[float], Tuple[float]]]]=None, | ||
center_crop: bool=True, | ||
) -> T.Compose: | ||
transforms = [] | ||
|
||
if size is not None: | ||
transforms.append(T.Resize(size, interpolation=T.InterpolationMode.BICUBIC)) | ||
if center_crop: | ||
transforms.append(T.CenterCrop(size)) | ||
transforms.append(T.ToTensor()) | ||
|
||
if isinstance(normalize, bool) and normalize: | ||
# same as T.Lambda(lambda x: (x - 0.5) * 2) for [-1, 1] normalization | ||
transforms.append(T.Normalize(0.5, 0.5)) | ||
elif isinstance(normalize, tuple) and len(normalize) == 2: | ||
# mean, std | ||
transforms.append(T.Normalize(normalize[0], normalize[1])) | ||
|
||
return T.Compose(transforms) | ||
|
||
|
||
def get_img_files(path: str, exts: Union[str, List[str]]=ALL_EXTS, sort: bool=True) -> List[str]: | ||
""" | ||
Gets all files in a directory with given extensions. | ||
Returns a sorted list of files by index if sort is True. | ||
""" | ||
if isinstance(exts, str): | ||
exts = [exts] | ||
files = [] | ||
for ext in exts: | ||
files.extend(glob(os.path.join(path, f'*{ext}'))) | ||
files.extend(glob(os.path.join(path, f'*{ext.upper()}'))) | ||
if sort: | ||
files = sorted(files, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) | ||
return files | ||
|
||
|
||
def match_files(files1: List[str], files2: List[str]) -> Tuple[List[str], List[str]]: | ||
""" | ||
Matches files in two lists by number indices. If ignore_ext is True, ignores extension. | ||
""" | ||
files1_ids = [int(os.path.splitext(os.path.basename(f))[0]) for f in files1] | ||
files2_ids = [int(os.path.splitext(os.path.basename(f))[0]) for f in files2] | ||
|
||
files1_ids = set(files1_ids) | ||
files2_ids = set(files2_ids) | ||
|
||
common_files = files1_ids.intersection(files2_ids) | ||
|
||
files1 = [f for f in files1 if int(os.path.splitext(os.path.basename(f))[0]) in common_files] | ||
files2 = [f for f in files2 if int(os.path.splitext(os.path.basename(f))[0]) in common_files] | ||
|
||
return files1, files2 | ||
|
||
|
||
def gather_img_tensors(tensors: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: | ||
if isinstance(tensors, torch.Tensor) and tensors.ndim == 3: | ||
tensors = tensors.unsqueeze(0) | ||
elif isinstance(tensors, list) and isinstance(tensors[0], torch.Tensor): | ||
if tensors[0].ndim == 3: | ||
tensors = torch.stack(tensors, dim=0) | ||
elif tensors[0].ndim == 4: | ||
tensors = torch.cat(tensors, dim=0) | ||
return tensors | ||
|
||
|
||
def read_prompt_to_ids(path: Optional[str]=None, prompts: Optional[List[str]]=None) -> OrderedDict: | ||
"""Read the prompts txt to get correspoding case_number and prompts. | ||
prompt.txt should be in the format (each corresponding to a single image): | ||
``` | ||
japan body | ||
japan body | ||
... | ||
japan body | ||
america body | ||
... | ||
``` | ||
Returns an OrderedDict mapping each prompt to a list of case numbers as follows: | ||
``` | ||
{ | ||
"japan body": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | ||
"america body": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], | ||
... | ||
} | ||
``` | ||
""" | ||
if prompts is None: | ||
if path is None: | ||
raise ValueError('Either prompts or path must be provided.') | ||
with open(path, 'r') as f: | ||
prompts = [line.strip() for line in f.readlines()] | ||
else: | ||
prompts = [prompt.strip() for prompt in prompts] | ||
|
||
prompt_to_ids = OrderedDict() | ||
for idx, prompt in enumerate(prompts): | ||
if prompt not in prompt_to_ids: | ||
prompt_to_ids[prompt] = [idx] | ||
else: | ||
prompt_to_ids[prompt].append(idx) | ||
return prompt_to_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters