Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conceptarium/conf/dataset/_commons.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
batch_size: 512

seed: ${seed}

val_size: 0.1
test_size: 0.2

Expand Down
18 changes: 18 additions & 0 deletions conceptarium/conf/dataset/celeba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
defaults:
- _commons
- _self_

_target_: torch_concepts.data.datamodules.celeba.CelebADataModule

name: celeba

backbone: resnet18
precompute_embs: true
force_recompute: false

# Task label - which CelebA attribute to predict
task_label: [Attractive]

# all CelebA attributes are binary facial features
# label_descriptions can be added here if needed
label_descriptions: null
34 changes: 12 additions & 22 deletions torch_concepts/data/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import os
import torch
import logging
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import get_model, get_model_weights
from tqdm import tqdm


logger = logging.getLogger(__name__)

def compute_backbone_embs(
dataset,
backbone: nn.Module,
backbone: str,
batch_size: int = 512,
workers: int = 0,
device: str = None,
Expand All @@ -28,7 +29,7 @@ def compute_backbone_embs(

Args:
dataset: Dataset with __getitem__ returning dict with 'x' key or 'inputs'.'x' nested key.
backbone (nn.Module): Feature extraction model (e.g., ResNet encoder).
backbone (str): Backbone model name for feature extraction (e.g., 'resnet18').
batch_size (int, optional): Batch size for processing. Defaults to 512.
workers (int, optional): Number of DataLoader workers. Defaults to 0.
device (str, optional): Device to use ('cpu', 'cuda', 'cuda:0', etc.).
Expand All @@ -51,12 +52,9 @@ def compute_backbone_embs(
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

# Store original training state to restore later
was_training = backbone.training

# Move backbone to device and set to eval mode
backbone = backbone.to(device)
backbone.eval()
backbone_model = get_model(backbone, weights="DEFAULT").to(device).eval() # "DEFAULT" points to best available weights
weights = get_model_weights(backbone, weights="DEFAULT")
preprocess = weights.transforms()

# Create dataloader
dataloader = DataLoader(
Expand All @@ -73,25 +71,17 @@ def compute_backbone_embs(
with torch.no_grad():
iterator = tqdm(dataloader, desc="Extracting embeddings") if verbose else dataloader
for batch in iterator:
# Handle both {'x': tensor} and {'inputs': {'x': tensor}} structures
if 'inputs' in batch:
x = batch['inputs']['x'].to(device)
else:
x = batch['x'].to(device)
embeddings = backbone(x) # Forward pass through backbone
x = batch['inputs']['x'].to(device)
embeddings = backbone_model(preprocess(x)) # Forward pass through backbone
embeddings_list.append(embeddings.cpu()) # Move back to CPU and store

all_embeddings = torch.cat(embeddings_list, dim=0) # Concatenate all embeddings

# Restore original training state
if was_training:
backbone.train()

return all_embeddings

def get_backbone_embs(path: str,
dataset,
backbone,
backbone: str,
batch_size,
force_recompute=False,
workers=0,
Expand All @@ -105,7 +95,7 @@ def get_backbone_embs(path: str,
Args:
path (str): File path for saving/loading embeddings (.pt file).
dataset: Dataset to extract embeddings from.
backbone: Backbone model for feature extraction.
backbone: Backbone model name for feature extraction.
batch_size: Batch size for computation.
force_recompute (bool, optional): Recompute even if cached. Defaults to False.
workers (int, optional): Number of DataLoader workers. Defaults to 0.
Expand All @@ -130,7 +120,7 @@ def get_backbone_embs(path: str,
if not os.path.exists(path) or force_recompute:
# compute
embs = compute_backbone_embs(dataset,
backbone,
backbone=backbone,
batch_size=batch_size,
workers=workers,
device=device,
Expand Down
83 changes: 83 additions & 0 deletions torch_concepts/data/datamodules/celeba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from ..datasets import CelebADataset

from ..base.datamodule import ConceptDataModule
from ...typing import BackboneType
from ..splitters import StandardSplitter, RandomSplitter


class CelebADataModule(ConceptDataModule):
"""DataModule for CelebA dataset.

Handles data loading, splitting, and batching for CelebA dataset
with support for concept-based learning.

Args:
seed: Random seed for reproducibility.
name: Dataset identifier (default: 'celeba').
split: Dataset split to use ('train', 'valid', or 'test').
val_size: Validation set size (fraction or absolute count).
test_size: Test set size (fraction or absolute count).
ftune_size: Fine-tuning set size (fraction or absolute count).
ftune_val_size: Fine-tuning validation set size (fraction or absolute count).
batch_size: Batch size for dataloaders.
download: Whether to download the dataset if not present.
task_label: List of attributes to use as task labels.
concept_subset: Subset of concepts to use. If None, uses all concepts.
label_descriptions: Dictionary mapping concept names to descriptions.
backbone: Model backbone to use (if applicable).
workers: Number of workers for dataloaders.
DATA_ROOT: Root directory for data storage.
"""

def __init__(
self,
seed: int, # seed for reproducibility
name: str, # dataset identifier
root: str, # root directory for dataset
val_size: int | float = 0.1,
test_size: int | float = 0.2,
ftune_size: int | float = 0.0,
ftune_val_size: int | float = 0.0,
batch_size: int = 512,
backbone: BackboneType = None,
precompute_embs: bool = True,
force_recompute: bool = False,
task_label: list | None = None,
concept_subset: list | None = None,
label_descriptions: dict | None = None,
splitter: str = "standard",
workers: int = 0,
DATA_ROOT = None,
**kwargs
):

dataset = CelebADataset(
name=name,
root=root,
transform=None,
task_label=task_label,
class_attributes=task_label,
concept_subset=concept_subset,
label_descriptions=label_descriptions
)

# check configura
if splitter== "standard":
splitter = StandardSplitter()
else:
splitter = RandomSplitter(
val_size=val_size,
test_size=test_size
)

super().__init__(
dataset=dataset,
val_size=val_size,
test_size=test_size,
batch_size=batch_size,
backbone=backbone,
precompute_embs=precompute_embs,
force_recompute=force_recompute,
workers=workers,
splitter=splitter
)
2 changes: 2 additions & 0 deletions torch_concepts/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .bnlearn import BnLearnDataset
from .toy import ToyDataset, CompletenessDataset
from .celeba import CelebADataset

__all__: list[str] = [
"BnLearnDataset",
"ToyDataset",
"CompletenessDataset",
"CelebADataset",
]

Loading