Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/test_cleanup_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from unittest.mock import patch, MagicMock
from tvln.extract import FeatureExtractor, ImageFile, FloraModels, OpenClipModels
from tvln.clip_features import FloraEncoder
import torch


@pytest.fixture
def dupe_image():
image = ImageFile()
image.tensor = torch.randn((1, 3, 256, 256))
return image


class TestFeatureExtractor:
@patch("tvln.extract.snapshot_download")
def test_cleanup(self, mock_download, dupe_image):
extractor = FeatureExtractor(dupe_image)
extractor.encoder = FloraEncoder()
with patch("torch.cuda.empty_cache", return_value=None) as mock_empty_cache, patch("gc.collect") as mock_collect:
extractor.cleanup()
if dupe_image.tensor.device.type != "cpu":
mock_empty_cache.assert_called_once()
mock_collect.assert_called_once()
74 changes: 0 additions & 74 deletions tests/test_extract_calls.py

This file was deleted.

5 changes: 4 additions & 1 deletion tvln/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ImageFile:

def __init__(self) -> None:
"""Initializes an ImageFile instance with a default"""
self._default_path = Path(__file__).resolve().parent / "assets" / "DSC_0047.png"
self._default_path: Path = Path(__file__).resolve().parent / "assets" / "DSC_0047.png"
self._default_path.resolve()
self._default_path.as_posix()

Expand Down Expand Up @@ -55,3 +55,6 @@ def as_tensor(self, dtype: torch.dtype, device: str, normalize: bool = False) ->
@property
def image_path(self) -> str:
return self._image_path

def set_image_path(self, image) -> None:
self._image_path = image
140 changes: 54 additions & 86 deletions tvln/clip_features.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,24 @@
# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
# <!-- // /* d a r k s h a p e s */ -->

from enum import Enum

from torch import Tensor, nn
from torch import Tensor, nn, dtype, float32, device

from tvln.batch import ImageFile
from tvln.options import PrecisionType, ModelType, DeviceName
from tvln.options import DeviceName, FloraModels, OpenClipModels


def get_model_and_pretrained(member: Enum) -> tuple[str, str]:
"""Return the raw strings for a member.\n
:param member: Enum member representing a model and its pretrained variant.
:returns: The model type and pretrained string."""
return member.value


class CLIPEncoder(nn.Module):
class FloraEncoder(nn.Module):
"""CLIP wrapper\n\n
MIT licensed by ncclab-sustech/BrainFLORA
"""

def __init__(self, device: str = "cpu", model: str = "openai/clip-vit-large-patch14") -> None:
def __init__(self, device: str = DeviceName.CPU) -> None:
"""Instantiate the encoder with a specific device and model\n
:param device: The graphics device to allocate, Default is cpu"""
from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Normalize, Resize
from transformers import CLIPVisionModel

super().__init__()
self.clip = CLIPVisionModel.from_pretrained(model).to(device)
self.clip_size = (224, 224)
self.preprocess = Compose(
[
Resize(size=self.clip_size[0], interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=self.clip_size),
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
]
)
self._device = device
self.flora_model, _ = FloraModels.VIT_L_14_LAION2B_S32B_B82K.value # type: ignore dynamic
self.device = device

def clip_encode_image(self, x: Tensor) -> Tensor:
"""Encode image patches using CLIP vision model\n
Expand All @@ -49,45 +30,67 @@ def clip_encode_image(self, x: Tensor) -> Tensor:
x = x.reshape(x.shape[0], x.shape[1], -1) # [batchsize, 1024, 256]
x = x.permute(0, 2, 1)

class_embedding = self.clip.vision_model.embeddings.class_embedding.to(x.dtype)
class_embedding = self.model.vision_model.embeddings.class_embedding.to(x.dtype)
class_embedding = class_embedding.repeat(x.shape[0], 1, 1) # [batchsize, 1, 1024]
x = torch.cat([class_embedding, x], dim=1)

pos_embedding = self.clip.vision_model.embeddings.position_embedding
position_ids = torch.arange(0, 257).unsqueeze(0).to(self._device)
pos_embedding = self.model.vision_model.embeddings.position_embedding
position_ids = torch.arange(0, 257).unsqueeze(0).to(self.device)
x = x + pos_embedding(position_ids)

x = self.clip.vision_model.pre_layrnorm(x)
x = self.clip.vision_model.encoder(x, output_hidden_states=True)
x = self.model.vision_model.pre_layrnorm(x)
x = self.model.vision_model.encoder(x, output_hidden_states=True)

select_hidden_state_layer = -2
select_hidden_state = x.hidden_states[select_hidden_state_layer] # [1, 256, 1024]
select_hidden_state = x.hidden_states[select_hidden_state_layer] # [1, 256, 1024] #type:ignore came with code
image_features = select_hidden_state[:, 1:] # Remove class token

return image_features

def encode_image(self, x: Tensor) -> Tensor:
"""Full image encoding pipeline
:param x: the input image tensor in shape [B, C, H, W] and device-compatible dtype."""
x = x.to(self._device)
from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Normalize, Resize
from transformers import CLIPVisionModel

self.model = CLIPVisionModel.from_pretrained(pretrained_model_name_or_path=self.flora_model).to(self.device) # type: ignore DeviceLikeType

self.clip_size = (224, 224)
self.preprocess = Compose(
[
Resize(size=self.clip_size[0], interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=self.clip_size),
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
]
)
x = x.to(self.device)
x = self.preprocess(x) # [3, 224, 224]
x = self.clip.vision_model.embeddings.patch_embedding(x) # [1024, 16, 16]
x = self.model.vision_model.embeddings.patch_embedding(x) # [1024, 16, 16]
image_feats = self.clip_encode_image(x)
return image_feats


class CLIPFeatures:
"""Convenience wrapper around the Open-CLIP model for image feature extraction."""
class OpenClipEncoder:
FLOAT64 = "fp64"
FLOAT32 = "fp32"
BFLOAT16 = "bf16"
FLOAT16 = "fp16"

def __init__(self) -> None:
"""Create a CLIPFeatures instance with the default model configuration (VIT_L_14_LAION2B_S32B_B82K @ FP32)."""
self._images = []
model_name, dataset_name = get_model_and_pretrained(ModelType.VIT_L_14_LAION2B_S32B_B82K) # type:ignore
self._model_type: str = model_name
self._pretrained: str = dataset_name
self._precision: str = "fp32"
def __init__(self, device: str | device = DeviceName.CPU, precision: dtype = float32) -> None:
super().__init__()
self.open_clip_model, self.pretraining = OpenClipModels.VIT_L_14_LAION2B_S32B_B82K.value # type:ignore
self.precision: dtype = precision
self.device = device

def ImageEncoder(self) -> Tensor:
def convert_dtype(self, precision: dtype) -> str:
import torch

if isinstance(precision, torch.dtype):
_, torch_dtype = precision.__repr__().rsplit(".")
torch_dtype = getattr(self, torch_dtype.upper())
return torch_dtype

def encode_image(self, x: ImageFile) -> Tensor:
"""Encode a batch of images into CLIP features.\n
:param images: Paths to the image files.
:returns Concatenated image feature vectors."""
Expand All @@ -98,59 +101,24 @@ def ImageEncoder(self) -> Tensor:
from torch import no_grad as torch_no_grad
from torch import stack as torch_stack

self.images = [x.image_path]
vlmodel, preprocess_train, feature_extractor = create_model_and_transforms(
self._model_type,
pretrained=self._pretrained,
precision=self._precision,
device=self._device,
self.open_clip_model,
pretrained=self.pretraining,
precision=self.convert_dtype(self.precision),
device=self.device,
)

batch_size = 512
image_features_list = []

for i in range(0, len(self._images), batch_size):
batch_images = self._images[i : i + batch_size]
for i in range(0, len(self.images), batch_size):
batch_images = self.images[i : i + batch_size]
image_inputs = torch_stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]) # type:ignore

with torch_no_grad():
batch_image_features = vlmodel.encode_image(image_inputs)
batch_image_features = vlmodel.encode_image(image_inputs) # type: ignore came with code
image_features_list.append(batch_image_features)

image_features = torch_cat(image_features_list, dim=0)
return image_features

def set_device(self, device_name: DeviceName) -> None:
"""Set the computation device.\n
:param device_name : Target graphics processing device."""
self._device: str = device_name.value

def set_model_type(self, model_type: Enum) -> None:
"""Switch the underlying Open-CLIP model.
:param model_type: Desired pretrained model and dataset variant."""
model_name, pretrained = get_model_and_pretrained(model_type)
self._model_type = model_name
self._pretrained = pretrained

def set_model_link(self, model_link: Enum) -> None:
"""Switch the path to an Open-CLIP model
:param model_link: Desired pretrained model and dataset variant."""
model_link, model_hub = get_model_and_pretrained(model_link)
self._model_link: str = model_link
self._model_hub: str = model_hub

def set_precision(self, precision: PrecisionType) -> None:
"""Change the numeric precision used by the model.
:param precision: Desired float calculation precision."""
self._precision = precision.value

def extract(self, image: ImageFile, last_layer=False) -> Tensor:
"""Convenience entry-point that sets images and returns CLIP features.\n
:param image_paths: One or more image file paths.
:returns: Extracted image features"""
if not last_layer:
clip_encoder = CLIPEncoder(self._device, model=self._model_link)

return clip_encoder.encode_image(image.tensor)
else:
self._images = [image._image_path]
return self.ImageEncoder()
Loading