Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/tvln.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: tvln pytest

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:

test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install system deps
run: |
sudo apt-get update


- name: Run Python Tests
uses: actions/setup-python@v5
with:
python-version: 3.13
- name: Install dependencies and run tests
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
uv python install 3.13
uv python pin 3.13
uv sync --group dev
source .venv/bin/activate
pytest -rPvv

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ uv sync --group dev
```
tvln
```

[![tvln pytest](https://github.com/darkshapes/tvln/actions/workflows/tvln.yml/badge.svg)](https://github.com/darkshapes/tvln/actions/workflows/tvln.yml)
Empty file added tests/__init__.py
Empty file.
74 changes: 74 additions & 0 deletions tests/test_extract_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
# <!-- // /* d a r k s h a p e s */ -->

import pytest
from unittest.mock import patch, MagicMock
from tvln.clip_features import CLIPFeatures
from tvln.extract import FeatureExtractor, AutoencoderKL, snapshot_download


@pytest.fixture
def dummy_image_file():
"""Return a minimal ImageFile‑like object."""
return MagicMock()


@pytest.fixture
def dummy_download():
"""Patch `snapshot_download` so that no network traffic occurs."""
with patch("tvln.extract.snapshot_download", autospec=True) as mock_dl:
mock_dl.return_value = "/tmp/vae" # dummy path
yield mock_dl


@pytest.fixture
def dummy_encoder():
"""
Patch `AutoencoderKL` so that the model is never actually loaded.
The patched class returns a mock instance whose `tiled_encode` method
yields a dummy tensor.
"""
with patch("tvln.extract.AutoencoderKL", autospec=True) as MockKL:
# The classmethod `from_pretrained` should return a mock model
mock_model = MagicMock(name="vae_model")
MockKL.from_pretrained.return_value = mock_model
# The instance method `tiled_encode` returns a dummy tensor
mock_model.tiled_encode.return_value = MagicMock(name="vae_tensor")
yield MockKL


@patch.object(FeatureExtractor, "cleanup")
@patch.object(CLIPFeatures, "extract", return_value=MagicMock(name="tensor"))
@patch.object(CLIPFeatures, "set_model_link")
@patch.object(CLIPFeatures, "set_model_type")
@patch.object(CLIPFeatures, "set_precision")
@patch.object(CLIPFeatures, "set_device")
def test_clip_features_flow(
mock_set_device,
mock_set_precision,
mock_set_model_type,
mock_set_model_link,
mock_extract,
mock_cleanup,
dummy_encoder,
dummy_download,
dummy_image_file,
):
"""
Verify that the feature‑extraction pipeline calls the expected
"""
# run the three blocks (copy‑paste the original snippet here)
# block 1
from tvln.main import main

tensor_stack = main()
# assertions
assert mock_set_device.call_count == 3
assert mock_set_precision.call_count == 3
assert mock_set_model_link.call_count == 2
assert mock_set_model_type.call_count == 1
assert mock_extract.call_count == 3
assert mock_cleanup.call_count == 3
for name, tensors in tensor_stack.items():
if name != "F1 VAE":
assert isinstance(tensors[1], MagicMock) # extraction was triggered
20 changes: 16 additions & 4 deletions tvln/batch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
# <!-- // /* d a r k s h a p e s */ -->

import torch
from pathlib import Path

import torch


class ImageFile:
_image_path: str

def __init__(self) -> None:
"""Initializes an ImageFile instance with a default"""
self._default_path = Path(__file__).resolve().parent / "assets" / "DSC_0047.png"
self._default_path.resolve()
self._default_path.as_posix()

def single_image(self) -> None:
image_path = input("Enter the path to an image file (e.g. /home/user/image.png, C:/Users/user/Pictures/...): ")
"""Set absolute path to an image file, ensuring the file exists, falling back to a default image if none is provided."""
from sys import modules as sys_modules

if "pytest" not in sys_modules:
image_path = input("Enter the path to an image file (e.g. /home/user/image.png, C:/Users/user/Pictures/...): ")
else:
image_path = None
if not image_path:
image_path = self._default_path
if not Path(image_path).resolve().is_file():
Expand All @@ -25,15 +33,15 @@ def single_image(self) -> None:
if not isinstance(self._image_path, str):
raise TypeError(f"Expected a string or list of strings for `image_paths` {self._image_path}, got {type(self._image_path)} ")

def as_tensor(self, dtype: torch.dtype, device: torch.device, normalize: bool = False) -> None:
def as_tensor(self, dtype: torch.dtype, device: str, normalize: bool = False) -> None:
"""Convert a Pillow `Image` to a batched `torch.Tensor`\n
:param image: Pillow image (RGB) to encode.
:param device: Target device for the tensor (default: ``gpu.device``).
:param normalize: Normalize tensor to [-1, 1]:
:return: Tensor of shape ``[1, 3, H, W]`` on ``device``."""

from numpy._typing import NDArray
from numpy import array as np_array
from numpy._typing import NDArray
from PIL.Image import open as open_img

with open_img(str(self._image_path)).convert("RGB") as pil_image:
Expand All @@ -43,3 +51,7 @@ def as_tensor(self, dtype: torch.dtype, device: torch.device, normalize: bool =
if normalize:
tensor = tensor * 2.0 - 1.0
self.tensor = tensor

@property
def image_path(self) -> str:
return self._image_path
48 changes: 4 additions & 44 deletions tvln/clip_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,10 @@

from enum import Enum

from open_clip import list_pretrained
from open_clip.pretrained import _PRETRAINED
from torch import Tensor, nn

from tvln.batch import ImageFile


class DeviceName(str, Enum):
"""Graphics processors usable by the CLIP pipeline."""

CPU = "cpu"
CUDA = "cuda"
MPS = "mps"


class PrecisionType(str, Enum):
"""Supported numeric float precision."""

FP64 = "fp64"
FP32 = "fp32"
BF16 = "bf16"
FP16 = "fp16"


ModelType = Enum(
"ModelData",
{
# member name → (model_type, pretrained) value
f"{model.replace('-', '_').upper()}_{pretrained.replace('-', '_').upper()}": (
model,
pretrained,
)
for model, pretrained in list_pretrained()
},
)


ModelLink = Enum(
"ModelData",
{
f"{family.replace('-', '_').upper()}_{id.replace('-', '_').upper()}": (data.get("hf_hub", "").strip("/"), data.get("url"))
for family, name in _PRETRAINED.items()
for id, data in name.items()
if data.get("hf_hub") or data.get("url")
},
)
from tvln.options import PrecisionType, ModelType, DeviceName


def get_model_and_pretrained(member: Enum) -> tuple[str, str]:
Expand All @@ -59,7 +17,9 @@ def get_model_and_pretrained(member: Enum) -> tuple[str, str]:


class CLIPEncoder(nn.Module):
"""CLIP wrapper courtesy ncclab-sustech/BrainFLORA"""
"""CLIP wrapper\n\n
MIT licensed by ncclab-sustech/BrainFLORA
"""

def __init__(self, device: str = "cpu", model: str = "openai/clip-vit-large-patch14") -> None:
"""Instantiate the encoder with a specific device and model\n
Expand Down
67 changes: 55 additions & 12 deletions tvln/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,67 @@
# <!-- // /* d a r k s h a p e s */ -->

from enum import Enum

import torch
from tvln.batch import ImageFile
from tvln.clip_features import CLIPFeatures, ModelType, ModelLink, PrecisionType, DeviceName
from tvln.clip_features import CLIPFeatures
from tvln.options import DeviceName, ModelLink, ModelType, PrecisionType
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from huggingface_hub import snapshot_download


class FeatureExtractor:
def __init__(self, text_device: DeviceName, precision: PrecisionType):
self.text_device = text_device
def __init__(self, device: DeviceName, precision: PrecisionType, image_file: ImageFile):
self.device = device
self.precision = precision
self.image_file = image_file

def extract_features(self, model_info: Enum | str, last_layer: bool = False):
"""Extract features from the image using the specified model.
:param model_info: The kind of model to use
:param last_layer: Whether the features are extracted from the last layer of the model or from an intermediate layer."""

if isinstance(model_info, str):
import os

def extract_features(self, model_info: Enum, image_file: ImageFile, last_layer: bool = False):
feature_extractor = CLIPFeatures()
feature_extractor.set_device(self.text_device)
feature_extractor.set_precision(self.precision)
vae_path = snapshot_download(model_info, allow_patterns=["vae/*"])
vae_path = os.path.join(vae_path, "vae")
vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=self.precision).to(self.device.value)
vae_tensor = vae_model.tiled_encode(self.image_file.tensor, return_dict=False)
return vae_tensor, model_info
clip_extractor = CLIPFeatures()
clip_extractor.set_device(self.device)
clip_extractor.set_precision(self.precision)
if isinstance(model_info, ModelLink):
feature_extractor.set_model_link(model_info)
clip_extractor.set_model_link(model_info)
elif isinstance(model_info, ModelType):
feature_extractor.set_model_type(model_info)
tensor = feature_extractor.extract(image_file, last_layer)
data = vars(feature_extractor)
self.cleanup(model=feature_extractor, device=self.text_device)
clip_extractor.set_model_type(model_info)
print(clip_extractor._precision)
tensor = clip_extractor.extract(self.image_file, last_layer)
data = vars(clip_extractor)
self.cleanup(model=clip_extractor)
return tensor, data

def cleanup(self, model: CLIPFeatures) -> None: # type:ignore
"""Cleans up the model and frees GPU memory
:param model: The model instance used for feature extraction"""

import gc

import torch

if self.device != "cpu":
gpu = getattr(torch, self.device)
gpu.empty_cache()
model: None = None
del model
gc.collect()

def set_device(self, device: DeviceName):
self.device = device

def set_precision(self, precision: PrecisionType | torch.dtype):
if isinstance(precision, PrecisionType):
self.precision = precision.value
else:
self.precision = precision
Loading