Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,8 @@ data/
!tests/data/

# conceptarium logs
outputs/
outputs/

CUB200/

.DS_Store
19 changes: 14 additions & 5 deletions torch_concepts/data/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@

logger = logging.getLogger(__name__)

def _collate_inputs(batch):
"""Collate only the input images, ignoring other fields."""
first = batch[0]
if isinstance(first, dict):
if 'inputs' in first and isinstance(first['inputs'], dict) and 'x' in first['inputs']:
xs = [b['inputs']['x'] for b in batch]
else:
raise KeyError("Batch items must contain 'inputs'['x'].")
else:
xs = batch
return torch.stack(xs, dim=0)

def compute_backbone_embs(
dataset,
backbone: nn.Module,
Expand Down Expand Up @@ -64,6 +76,7 @@ def compute_backbone_embs(
batch_size=batch_size,
shuffle=False, # Important: maintain order
num_workers=workers,
collate_fn=_collate_inputs,
)

embeddings_list = []
Expand All @@ -73,11 +86,7 @@ def compute_backbone_embs(
with torch.no_grad():
iterator = tqdm(dataloader, desc="Extracting embeddings") if verbose else dataloader
for batch in iterator:
# Handle both {'x': tensor} and {'inputs': {'x': tensor}} structures
if 'inputs' in batch:
x = batch['inputs']['x'].to(device)
else:
x = batch['x'].to(device)
x = batch.to(device) # batch already collated to only inputs
embeddings = backbone(x) # Forward pass through backbone
embeddings_list.append(embeddings.cpu()) # Move back to CPU and store

Expand Down
9 changes: 2 additions & 7 deletions torch_concepts/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ConceptDataset(Dataset):
Args:
input_data: Input features as numpy array, pandas DataFrame, or Tensor.
concepts: Concept annotations as numpy array, pandas DataFrame, or Tensor.
annotations: Optional Annotations object with concept metadata.
annotations: Optional Annotations object with concept metadata. (TODO: this can't be optional, since we need concept names in set_concepts(.))
graph: Optional concept graph as pandas DataFrame or tensor.
concept_names_subset: Optional list to select subset of concepts.
precision: Numerical precision (16, 32, or 64, default: 32).
Expand All @@ -63,7 +63,7 @@ class ConceptDataset(Dataset):
"""
def __init__(
self,
input_data: Union[np.ndarray, pd.DataFrame, Tensor],
input_data: Union[np.ndarray, pd.DataFrame, Tensor, None],
concepts: Union[np.ndarray, pd.DataFrame, Tensor],
annotations: Optional[Annotations] = None,
graph: Optional[pd.DataFrame] = None,
Expand Down Expand Up @@ -127,11 +127,6 @@ def __init__(
self.maybe_reduce_annotations(annotations,
concept_names_subset)

# Set dataset's input data X
# TODO: input is assumed to be a one of "np.ndarray, pd.DataFrame, Tensor" for now
# allow more complex data structures in the future with a custom parser
self.input_data: Tensor = parse_tensor(input_data, 'input', self.precision)

# Store concept data C
self.concepts = None
if concepts is not None:
Expand Down
Loading