Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ wheels/
.venv
_version.py
.DS_Store
.continue/
.continue/
*ph/
.datasets/
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ license = {file = "LICENSE"}
requires-python = ">=3.11"
dependencies = [
"accelerate>=1.12.0",
"datasets>=4.5.0",
"diffusers>=0.36.0",
"numpy==2.4.1",
"open-clip-torch>=3.2.0",
"transformers>=4.57.3",
]
Expand Down
34 changes: 33 additions & 1 deletion tvln/batch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
# <!-- // /* d a r k s h a p e s */ -->

from enum import Enum
from pathlib import Path
from typing import Callable, Iterable

import torch

Expand All @@ -14,12 +16,13 @@ def __init__(self) -> None:
self._default_path: Path = Path(__file__).resolve().parent / "assets" / "DSC_0047.png"
self._default_path.resolve()
self._default_path.as_posix()
self._image_path = ""

def single_image(self) -> None:
"""Set absolute path to an image file, ensuring the file exists, falling back to a default image if none is provided."""
from sys import modules as sys_modules

if "pytest" not in sys_modules:
if not self.image_path and "pytest" not in sys_modules:
image_path = input("Enter the path to an image file (e.g. /home/user/image.png, C:/Users/user/Pictures/...): ")
else:
image_path = None
Expand Down Expand Up @@ -54,7 +57,36 @@ def as_tensor(self, dtype: torch.dtype, device: str, normalize: bool = False) ->

@property
def image_path(self) -> str:
"""Reveal the current image path"""
return self._image_path

def set_image_path(self, image) -> None:
"""Change the current image path"""
self._image_path = image


# ... existing imports ...


def batch_process_images(image_paths: Iterable[str], extractor: Callable, device: str) -> dict[str, torch.Tensor]:
"""Process many images with a single FeatureExtractor instance.\n
:param image_paths: Paths to images.\n
:param model: The model to use for extraction.\n
:returns: Mapping of image paths to feature tensors.\n
:raises FileNotFoundError: If an image does not exist.\n
:raises ValueError: If no image paths are supplied."""
from tqdm import tqdm

if not image_paths:
raise ValueError("No image paths supplied")

image_file = ImageFile()

features: dict[str, torch.Tensor] = {}
for path in tqdm(image_paths, desc="processing_images..."):
image_file.set_image_path(path)
image_file.as_tensor(device=device, dtype=torch.float32, normalize=False)
extractor.set_image_file(image_file)
tensor = extractor._extract_vae()
features[path] = tensor
return features
53 changes: 53 additions & 0 deletions tvln/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
# <!-- // /* d a r k s h a p e s */ -->

from pathlib import Path

import torch
from datasets import Dataset, DatasetDict, Image, IterableDataset, interleave_datasets, load_dataset

from tvln.batch import ImageFile
from tvln.extract import FeatureExtractor
from tvln.options import DeviceName


def build_datasets() -> dict[str, Dataset | dict[str, IterableDataset] | IterableDataset | DatasetDict]:
"""Builds synthetic and original datasets.\n
:returns: A dictionary containing synthetic and original datasets."""

synthetic_input_folder = ".datasets"
original_input_folder = Path(__file__).parent / "assets" / "ph"

slice_dataset = load_dataset("darkshapes/a_slice", cache_dir=str(synthetic_input_folder), split="train").cast_column("image", Image(decode=False))
rnd_synthetic_dataset = load_dataset("exdysa/rnd_synthetic_img", cache_dir=str(synthetic_input_folder), split="train").cast_column("image", Image(decode=False))

synthetic_dataset = interleave_datasets([slice_dataset, rnd_synthetic_dataset])
original_folder_contents = [{"image": str(image)} for image in original_input_folder.iterdir() if image.is_file()]
original_dataset = Dataset.from_list(original_folder_contents).cast_column("image", Image(decode=False))
return {"synthetic": synthetic_dataset, "original": original_dataset}


@torch.no_grad
async def process_dataset(dataset) -> dict[str, torch.Tensor]:
"""Processes a dataset to extract features.\n
:param dataset: The dataset to process.
:returns: A dictionary mapping image paths to their feature tensors.
:raises ValueError: If the dataset is empty."""

device = DeviceName.CPU
if torch.cuda.is_available():
device = DeviceName.CUDA
elif torch.mps.is_available():
device = DeviceName.MPS

features = {}
image_file = ImageFile()
for image_data in dataset:
image_path = image_data["image"]["path"]
image_file.set_image_path(image_path)
image_file.single_image()
image_file.as_tensor(device=device, dtype=torch.float32)
feature_extractor = FeatureExtractor(image=image_file)
vae_tensor, _ = feature_extractor.extract(model="black-forest-labs/FLUX.1-dev")
features.setdefault(image_path, vae_tensor)
return features
58 changes: 46 additions & 12 deletions tvln/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ class FeatureExtractor:
def __init__(self, image: ImageFile):
self.image: ImageFile = image

def extract(self, model: Enum | str) -> tuple[torch.Tensor, str | dict]:
def extract(self, model: Enum | str | None = None) -> tuple[torch.Tensor, str | dict]:
"""Extract features from the image using the specified model.
:param model_info: The kind of model to use
:param image: One or more image file paths.
:returns: Extracted image features"""

self.dtype = self.image.tensor.dtype
self.device = self.image.tensor.device
self.model = model
self.model = model or self.model

if isinstance(model, FloraModel): # type: ignore
tensor = self._extract_flora()
Expand All @@ -39,35 +39,69 @@ def extract(self, model: Enum | str) -> tuple[torch.Tensor, str | dict]:
self.cleanup()
return tensor, data

@torch.no_grad
def _extract_flora(self) -> torch.Tensor:
"""Extract features using a Flora model."""
flora_encoder = FloraEncoder(device=self.device.type)
flora_encoder.flora_model, _ = self.model.value # type: ignore
if isinstance(self.encoder, FloraEncoder) and self.encoder.flora_model == self.model.value[0]:
flora_encoder = self.encoder
else:
flora_encoder = FloraEncoder(device=self.device.type)
flora_encoder.flora_model, _ = self.model.value # type: ignore
tensor: torch.Tensor = flora_encoder.encode_image(self.image.tensor)
self.encoder = flora_encoder
return tensor

@torch.no_grad
def _extract_openclip(self) -> torch.Tensor:
"""Extract features using an OpenClip model."""
open_clip_encoder = OpenClipEncoder(device=self.device.type, precision=self.dtype)
open_clip_encoder.open_clip_model, open_clip_encoder.pretraining = self.model.value # type: ignore
if isinstance(self.encoder, OpenClipEncoder) and self.encoder.open_clip_model == self.model.value[0]:
open_clip_encoder = self.encoder
else:
open_clip_encoder = OpenClipEncoder(device=self.device.type, precision=self.dtype)
open_clip_encoder.open_clip_model, open_clip_encoder.pretraining = self.model.value # type: ignore
open_clip_encoder.precision = self.image.tensor.dtype
tensor: torch.Tensor = open_clip_encoder.encode_image(self.image)
self.encoder = open_clip_encoder
return tensor

@torch.no_grad
def _extract_vae(self) -> torch.Tensor:
"""Extract features using a VAE model."""
import os
"""Extract features using a VAE model, re‑using the model when possible."""
if isinstance(self.encoder, AutoencoderKL):
vae_model = self.encoder
else:
import os

vae_path = snapshot_download(self.model, allow_patterns=["vae/*"]) # type: ignore
vae_path = os.path.join(vae_path, "vae")
vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=self.dtype).to(self.device.type) # type: ignore DeviceLike
self.encoder = vae_model

vae_path = snapshot_download(self.model, allow_patterns=["vae/*"]) # type:ignore
vae_path = os.path.join(vae_path, "vae")
vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=self.dtype).to(self.device.type) # type:ignore DeviceLike
vae_tensor = vae_model.tiled_encode(self.image.tensor, return_dict=False)
tensor = vae_tensor[0].sample()
self.encoder = vae_model
return tensor

@property
def model_name(self) -> Enum | str | None:
"""Reveal the current model"""

return self.model

def set_model(self, model) -> None:
"""Change the current model"""
self.model = model

@property
def image_file(self) -> ImageFile:
"""Reveal the current image file"""
return self.image

def set_image_file(self, image_file: ImageFile) -> None:
"""Change the current image file and align dtypes"""
self.image = image_file
self.dtype = self.image.tensor.dtype
self.device = self.image.tensor.device

def cleanup(self) -> None: # type:ignore
"""Cleans up the model and frees GPU memory
:param model: The model instance used for feature extraction"""
Expand Down
33 changes: 33 additions & 0 deletions tvln/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from tvln.datasets import build_datasets, process_dataset
from tvln.extract import FeatureExtractor
from tvln.batch import batch_process_images, ImageFile
from tvln.options import DeviceName


def main():
device = DeviceName.CPU
if torch.cuda.is_available():
device = DeviceName.CUDA
elif torch.mps.is_available():
device = DeviceName.MPS
datasets = build_datasets()

synthetic_data = datasets["synthetic"]
original_data = datasets["original"]

model = "black-forest-labs/FLUX.1-dev"
image_file = ImageFile()
extractor = FeatureExtractor(image_file)
extractor.set_model(model)

synthetic_paths = [image_data["image"]["path"] for image_data in synthetic_data] # type: ignore

original_paths = [image_data["image"]["path"] for image_data in original_data] # type: ignore

synthetic_features = batch_process_images(image_paths=synthetic_paths, extractor=extractor, device=device) # type: ignore
original_features = batch_process_images(image_paths=original_paths, extractor=extractor, device=device) # type: ignore


if __name__ == "__main__":
main()
Loading