From 9729d304fd8df8f27b09a5d6ec0967a35cb02f67 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Tue, 25 Nov 2025 18:26:42 +0100 Subject: [PATCH 1/8] Update .gitignore --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ec98efa..3bc2a3a 100644 --- a/.gitignore +++ b/.gitignore @@ -85,4 +85,8 @@ data/ !tests/data/ # conceptarium logs -outputs/ \ No newline at end of file +outputs/ + +CUB200/ + +.DS_Store \ No newline at end of file From 1f0a162bff984c2aec66cfe81fb8008656d618f0 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Tue, 25 Nov 2025 18:30:09 +0100 Subject: [PATCH 2/8] Fix cub -- no embeddings --- torch_concepts/data/datasets/cub.py | 704 ++++++---------------------- 1 file changed, 149 insertions(+), 555 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 1b6636d..3574a0b 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -1,381 +1,15 @@ -""" -CUB-200 Dataset Loader -** THIS DATASET NEEDS TO BE DOWNLOADED BEFORE BEING ABLE TO USE THE LOADER ** - -################################################################################ -## DOWNLOAD INSTRUCTIONS -################################################################################ - -**** OPTION #1 ***** -The simplest way to get the CUB dataset, is to download the pre-processed CUB -dataset by Koh et al. [CBM Paper]. This can be downloaded from their -public colab notebook at: https://worksheets.codalab.org/worksheets/0x362911581fcd4e048ddfd84f47203fd2. -You will need to download the original CUB dataset from that notebook (found -here: https://worksheets.codalab.org/bundles/0xd013a7ba2e88481bbc07e787f73109f5) -and the preprocessed "CUB_preprocessed" dataset (which can be directly accessed -here: https://worksheets.codalab.org/bundles/0x5b9d528d2101418b87212db92fea6683) - -**** OPTION #2 ***** -Follow the download the preprocess instructions found in Koh et al.'s original -repository here: https://github.com/yewsiang/ConceptBottleneck. -Specifically, here: https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/ - -################################################################################ - -[IMPORTANT] After downloading the files, they need to follow the following -structure: - - -This loader has been adapted/inspired by that of found in Koh et al.'s -repository (https://github.com/yewsiang/ConceptBottleneck/blob/master/CUB/cub_loader.py) -as well as in Espinosa Zarlenga and Barbiero's et al.'s repository -(https://github.com/mateoespinosa/cem). -""" - - -import numpy as np import os -import pickle +import tarfile import torch -import torchvision.transforms as transforms - -from collections import defaultdict +import pandas as pd +import numpy as np +from typing import List, Optional from PIL import Image -from torch.utils.data import Dataset - -######################################################## -## GENERAL DATASET GLOBAL VARIABLES -######################################################## - -N_CLASSES = 200 - -# CAN BE OVERWRITTEN WITH AN ENV VARIABLE CUB_DIR -CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') - - -######################################################### -## CONCEPT INFORMATION REGARDING CUB -######################################################### - -# CUB Class names - -CLASS_NAMES = [ - "Black_footed_Albatross", - "Laysan_Albatross", - "Sooty_Albatross", - "Groove_billed_Ani", - "Crested_Auklet", - "Least_Auklet", - "Parakeet_Auklet", - "Rhinoceros_Auklet", - "Brewer_Blackbird", - "Red_winged_Blackbird", - "Rusty_Blackbird", - "Yellow_headed_Blackbird", - "Bobolink", - "Indigo_Bunting", - "Lazuli_Bunting", - "Painted_Bunting", - "Cardinal", - "Spotted_Catbird", - "Gray_Catbird", - "Yellow_breasted_Chat", - "Eastern_Towhee", - "Chuck_will_Widow", - "Brandt_Cormorant", - "Red_faced_Cormorant", - "Pelagic_Cormorant", - "Bronzed_Cowbird", - "Shiny_Cowbird", - "Brown_Creeper", - "American_Crow", - "Fish_Crow", - "Black_billed_Cuckoo", - "Mangrove_Cuckoo", - "Yellow_billed_Cuckoo", - "Gray_crowned_Rosy_Finch", - "Purple_Finch", - "Northern_Flicker", - "Acadian_Flycatcher", - "Great_Crested_Flycatcher", - "Least_Flycatcher", - "Olive_sided_Flycatcher", - "Scissor_tailed_Flycatcher", - "Vermilion_Flycatcher", - "Yellow_bellied_Flycatcher", - "Frigatebird", - "Northern_Fulmar", - "Gadwall", - "American_Goldfinch", - "European_Goldfinch", - "Boat_tailed_Grackle", - "Eared_Grebe", - "Horned_Grebe", - "Pied_billed_Grebe", - "Western_Grebe", - "Blue_Grosbeak", - "Evening_Grosbeak", - "Pine_Grosbeak", - "Rose_breasted_Grosbeak", - "Pigeon_Guillemot", - "California_Gull", - "Glaucous_winged_Gull", - "Heermann_Gull", - "Herring_Gull", - "Ivory_Gull", - "Ring_billed_Gull", - "Slaty_backed_Gull", - "Western_Gull", - "Anna_Hummingbird", - "Ruby_throated_Hummingbird", - "Rufous_Hummingbird", - "Green_Violetear", - "Long_tailed_Jaeger", - "Pomarine_Jaeger", - "Blue_Jay", - "Florida_Jay", - "Green_Jay", - "Dark_eyed_Junco", - "Tropical_Kingbird", - "Gray_Kingbird", - "Belted_Kingfisher", - "Green_Kingfisher", - "Pied_Kingfisher", - "Ringed_Kingfisher", - "White_breasted_Kingfisher", - "Red_legged_Kittiwake", - "Horned_Lark", - "Pacific_Loon", - "Mallard", - "Western_Meadowlark", - "Hooded_Merganser", - "Red_breasted_Merganser", - "Mockingbird", - "Nighthawk", - "Clark_Nutcracker", - "White_breasted_Nuthatch", - "Baltimore_Oriole", - "Hooded_Oriole", - "Orchard_Oriole", - "Scott_Oriole", - "Ovenbird", - "Brown_Pelican", - "White_Pelican", - "Western_Wood_Pewee", - "Sayornis", - "American_Pipit", - "Whip_poor_Will", - "Horned_Puffin", - "Common_Raven", - "White_necked_Raven", - "American_Redstart", - "Geococcyx", - "Loggerhead_Shrike", - "Great_Grey_Shrike", - "Baird_Sparrow", - "Black_throated_Sparrow", - "Brewer_Sparrow", - "Chipping_Sparrow", - "Clay_colored_Sparrow", - "House_Sparrow", - "Field_Sparrow", - "Fox_Sparrow", - "Grasshopper_Sparrow", - "Harris_Sparrow", - "Henslow_Sparrow", - "Le_Conte_Sparrow", - "Lincoln_Sparrow", - "Nelson_Sharp_tailed_Sparrow", - "Savannah_Sparrow", - "Seaside_Sparrow", - "Song_Sparrow", - "Tree_Sparrow", - "Vesper_Sparrow", - "White_crowned_Sparrow", - "White_throated_Sparrow", - "Cape_Glossy_Starling", - "Bank_Swallow", - "Barn_Swallow", - "Cliff_Swallow", - "Tree_Swallow", - "Scarlet_Tanager", - "Summer_Tanager", - "Artic_Tern", - "Black_Tern", - "Caspian_Tern", - "Common_Tern", - "Elegant_Tern", - "Forsters_Tern", - "Least_Tern", - "Green_tailed_Towhee", - "Brown_Thrasher", - "Sage_Thrasher", - "Black_capped_Vireo", - "Blue_headed_Vireo", - "Philadelphia_Vireo", - "Red_eyed_Vireo", - "Warbling_Vireo", - "White_eyed_Vireo", - "Yellow_throated_Vireo", - "Bay_breasted_Warbler", - "Black_and_white_Warbler", - "Black_throated_Blue_Warbler", - "Blue_winged_Warbler", - "Canada_Warbler", - "Cape_May_Warbler", - "Cerulean_Warbler", - "Chestnut_sided_Warbler", - "Golden_winged_Warbler", - "Hooded_Warbler", - "Kentucky_Warbler", - "Magnolia_Warbler", - "Mourning_Warbler", - "Myrtle_Warbler", - "Nashville_Warbler", - "Orange_crowned_Warbler", - "Palm_Warbler", - "Pine_Warbler", - "Prairie_Warbler", - "Prothonotary_Warbler", - "Swainson_Warbler", - "Tennessee_Warbler", - "Wilson_Warbler", - "Worm_eating_Warbler", - "Yellow_Warbler", - "Northern_Waterthrush", - "Louisiana_Waterthrush", - "Bohemian_Waxwing", - "Cedar_Waxwing", - "American_Three_toed_Woodpecker", - "Pileated_Woodpecker", - "Red_bellied_Woodpecker", - "Red_cockaded_Woodpecker", - "Red_headed_Woodpecker", - "Downy_Woodpecker", - "Bewick_Wren", - "Cactus_Wren", - "Carolina_Wren", - "House_Wren", - "Marsh_Wren", - "Rock_Wren", - "Winter_Wren", - "Common_Yellowthroat", -] -# Set of CUB attributes selected by Koh et al. [CBM Paper] -SELECTED_CONCEPTS = [ - 1, - 4, - 6, - 7, - 10, - 14, - 15, - 20, - 21, - 23, - 25, - 29, - 30, - 35, - 36, - 38, - 40, - 44, - 45, - 50, - 51, - 53, - 54, - 56, - 57, - 59, - 63, - 64, - 69, - 70, - 72, - 75, - 80, - 84, - 90, - 91, - 93, - 99, - 101, - 106, - 110, - 111, - 116, - 117, - 119, - 125, - 126, - 131, - 132, - 134, - 145, - 149, - 151, - 152, - 153, - 157, - 158, - 163, - 164, - 168, - 172, - 178, - 179, - 181, - 183, - 187, - 188, - 193, - 194, - 196, - 198, - 202, - 203, - 208, - 209, - 211, - 212, - 213, - 218, - 220, - 221, - 225, - 235, - 236, - 238, - 239, - 240, - 242, - 243, - 244, - 249, - 253, - 254, - 259, - 260, - 262, - 268, - 274, - 277, - 283, - 289, - 292, - 293, - 294, - 298, - 299, - 304, - 305, - 308, - 309, - 310, - 311, -] +import torchvision.transforms as T +from torch_concepts import Annotations +from torch_concepts.annotations import AxisAnnotation +from torch_concepts.data.base import ConceptDataset +from torch_concepts.data.io import download_url # Names of all CUB attributes CONCEPT_SEMANTICS = [ @@ -693,196 +327,156 @@ "has_wing_pattern::multi-colored", ] -# Generate a mapping containing all concept groups in CUB generated -# using a simple prefix tree -CONCEPT_GROUP_MAP = defaultdict(list) -for i, concept_name in enumerate(list( - np.array(CONCEPT_SEMANTICS)[SELECTED_CONCEPTS] -)): - group = concept_name[:concept_name.find("::")] - CONCEPT_GROUP_MAP[group].append(i) - - - -# Definitions from CUB (certainties.txt) -# 1 not visible -# 2 guessing -# 3 probably -# 4 definitely -# Unc map represents a mapping from the discrete score to a "mental probability" -DEFAULT_UNC_MAP = [ - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, - {0: 0.5, 1: 0.5, 2: 0.5, 3:0.75, 4:1.0}, -] - -########################################################## -## Helper Functions -########################################################## - - -def discrete_to_continuous_unc(unc_val, attr_label, unc_map): - ''' - Yield a continuous prob representing discrete conf val - Inspired by CBM data processing - - The selected probability should account for whether the concept is on or off - E.g., if a human is "probably" sure the concept is off - flip the prob in unc_map - ''' - unc_val = unc_val.item() - attr_label = attr_label.item() - return float(unc_map[int(attr_label)][unc_val]) - - -########################################################## -## Data Loaders -########################################################## +CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') -class CUBDataset(Dataset): +class CUB(ConceptDataset): """ - TODO + The CUB dataset is a dataset of bird images with annotated attributes. + Each image is associated with a set of concept labels (attributes) and + task labels (bird species). + + Attributes: + concept_attr_names: The names of the concept labels (attributes). + task_attr_names: The names of the task labels (bird species). + root: The root directory where the dataset is stored. + split: The dataset split to use ('train' or 'test'). + uncertain_concept_labels: Whether to treat uncertain concept labels as + positive. + path_transform: A function to transform the image paths. """ + name = "cub" + n_concepts = 312 + n_tasks = 200 + + concept_attr_names: List[str] = [] + task_attr_names: List[str] = [] def __init__( self, - split='train', - uncertain_concept_labels=False, - root=CUB_DIR, - path_transform=None, - sample_transform=None, - concept_transform=None, - label_transform=None, - uncertainty_based_random_labels=False, - unc_map=DEFAULT_UNC_MAP, - selected_concepts=None, - training_augment=True, - ): - """ - TODO: Define different arguments - """ - if not (os.path.exists(root) and os.path.isdir(root)): - raise ValueError( - f'Provided CUB data directory "{root}" is not a valid or ' - f'an existing directory.' - ) - assert split in ['train', 'val', 'test'], ( - f"CUB split must be in ['train', 'val', 'test'] but got '{split}'" - ) - self.split = split - base_dir = os.path.join(root, 'class_attr_data_10') - self.pkl_file_path = os.path.join(base_dir, f'{split}.pkl') - self.name = 'CUB' - - self.data = [] - with open(self.pkl_file_path, 'rb') as f: - self.data.extend(pickle.load(f)) - image_size = 299 - if (split == 'train') and training_augment: - self.sample_transform = transforms.Compose([ - transforms.ColorJitter(brightness=32/255, saturation=(0.5, 1.5)), - transforms.RandomResizedCrop(image_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - else: - self.sample_transform = transforms.Compose([ - transforms.CenterCrop(image_size), - transforms.ToTensor(), #implicitly divides by 255 - transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2]), - sample_transform or (lambda x: x), - ]) - self.concept_transform = concept_transform or (lambda x: x) - self.label_transform = label_transform or (lambda x: x) - self.uncertain_concept_labels = uncertain_concept_labels + name : str = "cub", + precision : int = 32, + input_data : np.ndarray | pd.DataFrame | torch.Tensor = None, + concepts : np.ndarray | pd.DataFrame | torch.Tensor = None, + annotations : Annotations | None = None, + graph : pd.DataFrame | None = None, + concept_names_subset : List[str] | None = None, + root : str = CUB_DIR, + image_transform: Optional[object] = None, + ) -> None: self.root = root - self.path_transform = path_transform - self.uncertainty_based_random_labels = uncertainty_based_random_labels - self.unc_map = unc_map - if selected_concepts is None: - selected_concepts = list(range(len(SELECTED_CONCEPTS))) - self.selected_concepts = selected_concepts - self.concept_names = self.concept_attr_names = list( - np.array( - CONCEPT_SEMANTICS - )[CONCEPT_SEMANTICS][selected_concepts] + self.image_transform = image_transform or T.ToTensor() + + input_data, concepts, annotations, graph, image_paths = self.load() + + super().__init__( + name=name, + precision=precision, + input_data=input_data, + concepts=concepts, + annotations=annotations, + graph=graph, + concept_names_subset=concept_names_subset, ) - self.task_names = self.task_attr_names = CLASS_NAMES - - def __len__(self): - return len(self.data) + self.image_paths = image_paths + + @property + def raw_filenames(self) -> List[str]: + """List of raw filenames that need to be present in the raw directory + for the dataset to be considered present.""" + return [ + "CUB_200_2011/images.txt", + "CUB_200_2011/image_class_labels.txt", + "CUB_200_2011/train_test_split.txt", + "CUB_200_2011/bounding_boxes.txt", + "CUB_200_2011/classes.txt", + "CUB_200_2011/attributes/image_attribute_labels.txt", + "CUB_200_2011/attributes/class_attribute_labels_continuous.txt", + "CUB_200_2011/attributes/certainties.txt", + ] + + @property + def processed_filenames(self) -> List[str]: + """List of processed filenames that will be created during build step.""" + return [ + "cub_inputs.pt", + "cub_concepts.pt", + "cub_annotations.pt", + "cub_graph.h5", + ] + + def download(self) -> None: + """Downloads the CUB dataset if it is not already present.""" + if not os.path.exists(self.root): + os.makedirs(self.root) + + url = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1" + tgz_path = download_url(url, self.root) + + with tarfile.open(tgz_path, "r:gz") as tar: + tar.extractall(path=self.root) + os.unlink(tgz_path) + + def build(self): + self.maybe_download() + + images = pd.read_csv(self.raw_paths[0], sep=r"\s+", header=None, names=['image_id', 'path']) + image_paths = images.set_index('image_id')['path'] + image_paths = image_paths.apply(lambda p: os.path.join(self.root, "CUB_200_2011", "images", p)) + + # attribute names: use canonical order from CONCEPT_SEMANTICS (matches attr_id 1..312) + concept_names = CONCEPT_SEMANTICS + + # image_attribute_labels.txt has 6 columns; we only need is_present (col 3) + attr_labels = pd.read_csv( + self.raw_paths[5], + header=None, + names=['image_id', 'attr_id', 'is_present', 'certainty', 'time_ms', 'extra'], + usecols=[0, 1, 2], + delim_whitespace=True, + engine="python", + ) + concepts_df = attr_labels.pivot(index='image_id', columns='attr_id', values='is_present').fillna(0) + concepts_df = concepts_df.loc[image_paths.index] + concepts_tensor = torch.tensor(concepts_df.values, dtype=torch.float32) + + concept_metadata = {name: {'type': 'discrete'} for name in concept_names} + cardinalities = tuple(1 for _ in concept_names) # binary concepts + annotations = Annotations({ + 1: AxisAnnotation(labels=concept_names, + cardinalities=cardinalities, + metadata=concept_metadata) + }) + + torch.save(list(image_paths.values), self.processed_paths[0]) + torch.save(concepts_tensor, self.processed_paths[1]) + torch.save(annotations, self.processed_paths[2]) + + def load_raw(self): + self.maybe_build() + # PyTorch 2.6 switches torch.load default to weights_only=True; set False to load metadata objects + image_paths = torch.load(self.processed_paths[0], weights_only=False) + concepts = torch.load(self.processed_paths[1], weights_only=False) + annotations = torch.load(self.processed_paths[2], weights_only=False) + return image_paths, concepts, annotations, None + + def load(self): + image_paths, concepts, annotations, graph = self.load_raw() + input_indices = torch.arange(len(image_paths), dtype=torch.long) + return input_indices, concepts, annotations, graph, image_paths def __getitem__(self, idx): - img_data = self.data[idx] - img_path = img_data['img_path'] - if self.path_transform is None: - # This is needed if the dataset is downloaded from the original - # CBM paper's repository/experiment code - img_path = img_path.replace( - '/juice/scr/scr102/scr/thaonguyen/CUB_supervision/datasets/', - self.root - ) - try: - img = Image.open(img_path).convert('RGB') - except: - img_path_split = img_path.split('/') - img_path = '/'.join( - img_path_split[:2] + [self.split] + img_path_split[2:] - ) - img = Image.open(img_path).convert('RGB') - else: - img = Image.open(self.path_transform(img_path)).convert('RGB') - - class_label = self.label_transform(img_data['class_label']) - img = self.sample_transform(img) - - if self.uncertain_concept_labels: - attr_label = img_data['uncertain_attribute_label'] - else: - attr_label = img_data['attribute_label'] - attr_label = self.concept_transform( - np.array(attr_label)[self.selected_concepts] - ) + img_path = self.image_paths[idx] + image = Image.open(img_path).convert("RGB") + if self.image_transform is not None: + image = self.image_transform(image) - # We may want to randomly sample concept labels based on their provided - # annotator uncertainty - if self.uncertainty_based_random_labels: - discrete_unc_label = np.array( - img_data['attribute_certainty'] - )[self.selected_concepts] - instance_attr_label = np.array(img_data['attribute_label']) - competencies = [] - for (discrete_unc_val, hard_concept_val) in zip( - discrete_unc_label, - instance_attr_label, - ): - competencies.append( - discrete_to_continuous_unc( - discrete_unc_val, - hard_concept_val, - self.unc_map, - ) - ) - attr_label = np.random.binomial(1, competencies) + concepts = self.concepts[idx] + sample = { + 'inputs': {'x': image}, + 'concepts': {'c': concepts}, + } + return sample - return img, torch.FloatTensor(attr_label), class_label - def concept_weights(self): - """ - Calculate class imbalance ratio for binary attribute labels - """ - imbalance_ratio = [] - with open(self.pkl_file_path, 'rb') as f: - data = pickle.load(f) - n = len(data) - n_attr = len(data[0]['attribute_label']) - n_ones = [0] * n_attr - total = [n] * n_attr - for d in data: - labels = d['attribute_label'] - for i in range(n_attr): - n_ones[i] += labels[i] - for j in range(len(n_ones)): - imbalance_ratio.append(total[j]/n_ones[j] - 1) - return np.array(imbalance_ratio)[self.selected_concepts] \ No newline at end of file +# test +cub = CUB() From 364356a53a2e7ab4f12c93949cd5bfacf6f08079 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 12:51:11 +0100 Subject: [PATCH 3/8] Fix backbone + add embeddings computation in cub --- torch_concepts/data/backbone.py | 19 +++-- torch_concepts/data/base/dataset.py | 2 +- torch_concepts/data/datasets/cub.py | 114 +++++++++++++++++----------- 3 files changed, 84 insertions(+), 51 deletions(-) diff --git a/torch_concepts/data/backbone.py b/torch_concepts/data/backbone.py index 86ec3c5..d73bb52 100644 --- a/torch_concepts/data/backbone.py +++ b/torch_concepts/data/backbone.py @@ -12,6 +12,18 @@ logger = logging.getLogger(__name__) +def _collate_inputs(batch): + """Collate only the input images, ignoring other fields.""" + first = batch[0] + if isinstance(first, dict): + if 'inputs' in first and isinstance(first['inputs'], dict) and 'x' in first['inputs']: + xs = [b['inputs']['x'] for b in batch] + else: + raise KeyError("Batch items must contain 'inputs'['x'].") + else: + xs = batch + return torch.stack(xs, dim=0) + def compute_backbone_embs( dataset, backbone: nn.Module, @@ -64,6 +76,7 @@ def compute_backbone_embs( batch_size=batch_size, shuffle=False, # Important: maintain order num_workers=workers, + collate_fn=_collate_inputs, ) embeddings_list = [] @@ -73,11 +86,7 @@ def compute_backbone_embs( with torch.no_grad(): iterator = tqdm(dataloader, desc="Extracting embeddings") if verbose else dataloader for batch in iterator: - # Handle both {'x': tensor} and {'inputs': {'x': tensor}} structures - if 'inputs' in batch: - x = batch['inputs']['x'].to(device) - else: - x = batch['x'].to(device) + x = batch.to(device) # batch already collated to only inputs embeddings = backbone(x) # Forward pass through backbone embeddings_list.append(embeddings.cpu()) # Move back to CPU and store diff --git a/torch_concepts/data/base/dataset.py b/torch_concepts/data/base/dataset.py index d67b6f8..43f9ed8 100644 --- a/torch_concepts/data/base/dataset.py +++ b/torch_concepts/data/base/dataset.py @@ -42,7 +42,7 @@ class ConceptDataset(Dataset): Args: input_data: Input features as numpy array, pandas DataFrame, or Tensor. concepts: Concept annotations as numpy array, pandas DataFrame, or Tensor. - annotations: Optional Annotations object with concept metadata. + annotations: Optional Annotations object with concept metadata. (TODO: this can't be optional, since we need concept names in set_concepts(.)) graph: Optional concept graph as pandas DataFrame or tensor. concept_names_subset: Optional list to select subset of concepts. precision: Numerical precision (16, 32, or 64, default: 32). diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 3574a0b..f673981 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -3,13 +3,15 @@ import torch import pandas as pd import numpy as np -from typing import List, Optional -from PIL import Image +from typing import List, Dict +from PIL import Image, ImageFile import torchvision.transforms as T from torch_concepts import Annotations from torch_concepts.annotations import AxisAnnotation from torch_concepts.data.base import ConceptDataset from torch_concepts.data.io import download_url +from torch_concepts.data.backbone import compute_backbone_embs +from torchvision.models import resnet18 # Names of all CUB attributes CONCEPT_SEMANTICS = [ @@ -328,16 +330,15 @@ ] CUB_DIR = os.environ.get("CUB_DIR", './CUB200/') +ImageFile.LOAD_TRUNCATED_IMAGES = True -class CUB(ConceptDataset): +class CUBDataset(ConceptDataset): """ The CUB dataset is a dataset of bird images with annotated attributes. Each image is associated with a set of concept labels (attributes) and task labels (bird species). Attributes: - concept_attr_names: The names of the concept labels (attributes). - task_attr_names: The names of the task labels (bird species). root: The root directory where the dataset is stored. split: The dataset split to use ('train' or 'test'). uncertain_concept_labels: Whether to treat uncertain concept labels as @@ -348,36 +349,29 @@ class CUB(ConceptDataset): n_concepts = 312 n_tasks = 200 - concept_attr_names: List[str] = [] - task_attr_names: List[str] = [] - def __init__( self, - name : str = "cub", precision : int = 32, - input_data : np.ndarray | pd.DataFrame | torch.Tensor = None, concepts : np.ndarray | pd.DataFrame | torch.Tensor = None, annotations : Annotations | None = None, - graph : pd.DataFrame | None = None, concept_names_subset : List[str] | None = None, root : str = CUB_DIR, - image_transform: Optional[object] = None, + image_transform: object | None = None, ) -> None: self.root = root - self.image_transform = image_transform or T.ToTensor() - - input_data, concepts, annotations, graph, image_paths = self.load() - + # ensure images have consistent size for batching + self.image_transform = image_transform or T.Compose([T.Resize((256, 256)), T.ToTensor()]) + + embeddings, concepts, annotations, graph = self.load() + super().__init__( - name=name, precision=precision, - input_data=input_data, + input_data=embeddings, concepts=concepts, annotations=annotations, graph=graph, concept_names_subset=concept_names_subset, ) - self.image_paths = image_paths @property def raw_filenames(self) -> List[str]: @@ -398,10 +392,9 @@ def raw_filenames(self) -> List[str]: def processed_filenames(self) -> List[str]: """List of processed filenames that will be created during build step.""" return [ - "cub_inputs.pt", "cub_concepts.pt", "cub_annotations.pt", - "cub_graph.h5", + "cub_embeddings.pt", ] def download(self) -> None: @@ -415,18 +408,23 @@ def download(self) -> None: with tarfile.open(tgz_path, "r:gz") as tar: tar.extractall(path=self.root) os.unlink(tgz_path) - + def build(self): self.maybe_download() + + # workaround to get self.n_samples() work in ConceptDataset. We will overwrite later in super().__init__() + # create a torch tensor with shape (n_samples, whatever) and set self.input_data to it temporarily + temp_input_data = torch.zeros((11788, 10)) # CUB has 11788 samples + self.input_data = temp_input_data - images = pd.read_csv(self.raw_paths[0], sep=r"\s+", header=None, names=['image_id', 'path']) - image_paths = images.set_index('image_id')['path'] - image_paths = image_paths.apply(lambda p: os.path.join(self.root, "CUB_200_2011", "images", p)) - - # attribute names: use canonical order from CONCEPT_SEMANTICS (matches attr_id 1..312) + images = pd.read_csv( + self.raw_paths[0], + sep=r"\s+", + header=None, + names=["image_id", "path"], + ) concept_names = CONCEPT_SEMANTICS - # image_attribute_labels.txt has 6 columns; we only need is_present (col 3) attr_labels = pd.read_csv( self.raw_paths[5], header=None, @@ -436,7 +434,7 @@ def build(self): engine="python", ) concepts_df = attr_labels.pivot(index='image_id', columns='attr_id', values='is_present').fillna(0) - concepts_df = concepts_df.loc[image_paths.index] + concepts_df = concepts_df.loc[images["image_id"]] concepts_tensor = torch.tensor(concepts_df.values, dtype=torch.float32) concept_metadata = {name: {'type': 'discrete'} for name in concept_names} @@ -447,30 +445,55 @@ def build(self): metadata=concept_metadata) }) - torch.save(list(image_paths.values), self.processed_paths[0]) - torch.save(concepts_tensor, self.processed_paths[1]) - torch.save(annotations, self.processed_paths[2]) + torch.save(concepts_tensor, self.processed_paths[0]) + torch.save(annotations, self.processed_paths[1]) + + annotations = torch.load(self.processed_paths[1], weights_only=False) + self._annotations = annotations + self.maybe_reduce_annotations(annotations, None) + concepts = torch.load(self.processed_paths[0], weights_only=False) + # temporary placeholder so set_concepts has a length reference + self.input_data = torch.zeros((concepts.shape[0], 1)) + self.precision = 32 # set precision before calling set_concepts + self.set_concepts(concepts) + + # Compute embeddings using a pretrained model (e.g., ResNet) as backbone from torch_concepts.data.backbone + backbone = torch.nn.Sequential(*list(resnet18(pretrained=True).children())[:-1]) + embeddings = compute_backbone_embs( + self, + backbone, + batch_size=64, + workers=4, + verbose=True + ) + + torch.save(embeddings, self.processed_paths[2]) def load_raw(self): self.maybe_build() - # PyTorch 2.6 switches torch.load default to weights_only=True; set False to load metadata objects - image_paths = torch.load(self.processed_paths[0], weights_only=False) - concepts = torch.load(self.processed_paths[1], weights_only=False) - annotations = torch.load(self.processed_paths[2], weights_only=False) - return image_paths, concepts, annotations, None + concepts = torch.load(self.processed_paths[0], weights_only=False) + annotations = torch.load(self.processed_paths[1], weights_only=False) + embeddings = torch.load(self.processed_paths[2], weights_only=False) + return embeddings, concepts, annotations, None def load(self): - image_paths, concepts, annotations, graph = self.load_raw() - input_indices = torch.arange(len(image_paths), dtype=torch.long) - return input_indices, concepts, annotations, graph, image_paths + embeddings, concepts, annotations, graph = self.load_raw() + return embeddings, concepts, annotations, graph - def __getitem__(self, idx): - img_path = self.image_paths[idx] + def __getitem__(self, idx) -> Dict[str, Dict[str, torch.Tensor]]: + img_rel_path = pd.read_csv( # TODO: optimize by reading this once in __init__ + self.raw_paths[0], + header=None, + names=['image_id', 'img_path'], + delim_whitespace=True, + engine="python", + ).set_index('image_id').loc[idx + 1, 'img_path'] # idx +1 because image_id starts from 1 + img_path = os.path.join(self.root, "CUB_200_2011/images", img_rel_path) image = Image.open(img_path).convert("RGB") if self.image_transform is not None: image = self.image_transform(image) - concepts = self.concepts[idx] + concepts = self.concepts[idx].clone() sample = { 'inputs': {'x': image}, 'concepts': {'c': concepts}, @@ -478,5 +501,6 @@ def __getitem__(self, idx): return sample -# test -cub = CUB() +if __name__ == "__main__": + dataset = CUBDataset() + print(f"Dataset loaded with {dataset.n_samples} samples.") From 1b6ed3522752ece0c4b7ecced1b174bd71d0b133 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 12:51:41 +0100 Subject: [PATCH 4/8] Remove test in cub --- torch_concepts/data/datasets/cub.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index f673981..465a981 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -499,8 +499,3 @@ def __getitem__(self, idx) -> Dict[str, Dict[str, torch.Tensor]]: 'concepts': {'c': concepts}, } return sample - - -if __name__ == "__main__": - dataset = CUBDataset() - print(f"Dataset loaded with {dataset.n_samples} samples.") From 6db134554bd1b2b118648510442af3a0add65ca3 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 18:28:19 +0100 Subject: [PATCH 5/8] Add typing annotations --- torch_concepts/data/datasets/cub.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 465a981..eeeb522 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -3,7 +3,7 @@ import torch import pandas as pd import numpy as np -from typing import List, Dict +from typing import List, Dict, Tuple from PIL import Image, ImageFile import torchvision.transforms as T from torch_concepts import Annotations @@ -409,7 +409,7 @@ def download(self) -> None: tar.extractall(path=self.root) os.unlink(tgz_path) - def build(self): + def build(self) -> None: self.maybe_download() # workaround to get self.n_samples() work in ConceptDataset. We will overwrite later in super().__init__() @@ -469,18 +469,18 @@ def build(self): torch.save(embeddings, self.processed_paths[2]) - def load_raw(self): + def load_raw(self) -> Tuple[torch.Tensor, pd.DataFrame, Annotations, None]: self.maybe_build() concepts = torch.load(self.processed_paths[0], weights_only=False) annotations = torch.load(self.processed_paths[1], weights_only=False) embeddings = torch.load(self.processed_paths[2], weights_only=False) return embeddings, concepts, annotations, None - def load(self): + def load(self) -> Tuple[torch.Tensor, pd.DataFrame, Annotations, None]: embeddings, concepts, annotations, graph = self.load_raw() return embeddings, concepts, annotations, graph - def __getitem__(self, idx) -> Dict[str, Dict[str, torch.Tensor]]: + def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]: img_rel_path = pd.read_csv( # TODO: optimize by reading this once in __init__ self.raw_paths[0], header=None, From 26091a1c081f2640d526855fb02446c018047fb1 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Thu, 27 Nov 2025 18:32:44 +0100 Subject: [PATCH 6/8] Fix style --- torch_concepts/data/datasets/cub.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index eeeb522..7810518 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -351,11 +351,11 @@ class CUBDataset(ConceptDataset): def __init__( self, - precision : int = 32, - concepts : np.ndarray | pd.DataFrame | torch.Tensor = None, - annotations : Annotations | None = None, - concept_names_subset : List[str] | None = None, - root : str = CUB_DIR, + precision: int = 32, + concepts: np.ndarray | pd.DataFrame | torch.Tensor = None, + annotations: Annotations | None = None, + concept_names_subset: List[str] | None = None, + root: str = CUB_DIR, image_transform: object | None = None, ) -> None: self.root = root From e53a66244e4cdaa7d6421d45e995bd0d2ed03081 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Fri, 28 Nov 2025 15:43:34 +0100 Subject: [PATCH 7/8] Remove input_data from cub and superclass --- torch_concepts/data/base/dataset.py | 7 +------ torch_concepts/data/datasets/cub.py | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/torch_concepts/data/base/dataset.py b/torch_concepts/data/base/dataset.py index 43f9ed8..5ed23da 100644 --- a/torch_concepts/data/base/dataset.py +++ b/torch_concepts/data/base/dataset.py @@ -63,7 +63,7 @@ class ConceptDataset(Dataset): """ def __init__( self, - input_data: Union[np.ndarray, pd.DataFrame, Tensor], + input_data: Union[np.ndarray, pd.DataFrame, Tensor, None], concepts: Union[np.ndarray, pd.DataFrame, Tensor], annotations: Optional[Annotations] = None, graph: Optional[pd.DataFrame] = None, @@ -127,11 +127,6 @@ def __init__( self.maybe_reduce_annotations(annotations, concept_names_subset) - # Set dataset's input data X - # TODO: input is assumed to be a one of "np.ndarray, pd.DataFrame, Tensor" for now - # allow more complex data structures in the future with a custom parser - self.input_data: Tensor = parse_tensor(input_data, 'input', self.precision) - # Store concept data C self.concepts = None if concepts is not None: diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 7810518..7cb37f7 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -366,7 +366,6 @@ def __init__( super().__init__( precision=precision, - input_data=embeddings, concepts=concepts, annotations=annotations, graph=graph, From 4e3fda89c76a1a8304fca52f5fbcc0bb9c762377 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Fri, 28 Nov 2025 19:10:09 +0100 Subject: [PATCH 8/8] Remove embedding computation in CUB's build(.) --- torch_concepts/data/datasets/cub.py | 47 +++++++---------------------- 1 file changed, 11 insertions(+), 36 deletions(-) diff --git a/torch_concepts/data/datasets/cub.py b/torch_concepts/data/datasets/cub.py index 7cb37f7..2e60902 100644 --- a/torch_concepts/data/datasets/cub.py +++ b/torch_concepts/data/datasets/cub.py @@ -408,22 +408,17 @@ def download(self) -> None: tar.extractall(path=self.root) os.unlink(tgz_path) - def build(self) -> None: + def build(self): self.maybe_download() - - # workaround to get self.n_samples() work in ConceptDataset. We will overwrite later in super().__init__() - # create a torch tensor with shape (n_samples, whatever) and set self.input_data to it temporarily - temp_input_data = torch.zeros((11788, 10)) # CUB has 11788 samples - self.input_data = temp_input_data - images = pd.read_csv( - self.raw_paths[0], - sep=r"\s+", - header=None, - names=["image_id", "path"], - ) + images = pd.read_csv(self.raw_paths[0], sep=r"\s+", header=None, names=['image_id', 'path']) + image_paths = images.set_index('image_id')['path'] + image_paths = image_paths.apply(lambda p: os.path.join(self.root, "CUB_200_2011", "images", p)) + + # attribute names: use canonical order from CONCEPT_SEMANTICS (matches attr_id 1..312) concept_names = CONCEPT_SEMANTICS + # image_attribute_labels.txt has 6 columns; we only need is_present (col 3) attr_labels = pd.read_csv( self.raw_paths[5], header=None, @@ -433,7 +428,7 @@ def build(self) -> None: engine="python", ) concepts_df = attr_labels.pivot(index='image_id', columns='attr_id', values='is_present').fillna(0) - concepts_df = concepts_df.loc[images["image_id"]] + concepts_df = concepts_df.loc[image_paths.index] concepts_tensor = torch.tensor(concepts_df.values, dtype=torch.float32) concept_metadata = {name: {'type': 'discrete'} for name in concept_names} @@ -444,30 +439,10 @@ def build(self) -> None: metadata=concept_metadata) }) - torch.save(concepts_tensor, self.processed_paths[0]) - torch.save(annotations, self.processed_paths[1]) - - annotations = torch.load(self.processed_paths[1], weights_only=False) - self._annotations = annotations - self.maybe_reduce_annotations(annotations, None) - concepts = torch.load(self.processed_paths[0], weights_only=False) - # temporary placeholder so set_concepts has a length reference - self.input_data = torch.zeros((concepts.shape[0], 1)) - self.precision = 32 # set precision before calling set_concepts - self.set_concepts(concepts) - - # Compute embeddings using a pretrained model (e.g., ResNet) as backbone from torch_concepts.data.backbone - backbone = torch.nn.Sequential(*list(resnet18(pretrained=True).children())[:-1]) - embeddings = compute_backbone_embs( - self, - backbone, - batch_size=64, - workers=4, - verbose=True - ) + torch.save(list(image_paths.values), self.processed_paths[0]) + torch.save(concepts_tensor, self.processed_paths[1]) + torch.save(annotations, self.processed_paths[2]) - torch.save(embeddings, self.processed_paths[2]) - def load_raw(self) -> Tuple[torch.Tensor, pd.DataFrame, Annotations, None]: self.maybe_build() concepts = torch.load(self.processed_paths[0], weights_only=False)