Skip to content

Commit 36d33c8

Browse files
Fix examples in docstrings under nn
1 parent d3634ae commit 36d33c8

File tree

11 files changed

+115
-62
lines changed

11 files changed

+115
-62
lines changed

torch_concepts/distributions/delta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class Delta(Distribution):
3333
mean: Returns the deterministic value.
3434
3535
Examples:
36+
>>> import torch
37+
>>> from torch_concepts.distributions import Delta
3638
>>> dist = Delta(torch.tensor([1.0, 2.0, 3.0]))
3739
>>> sample = dist.sample()
3840
>>> print(sample) # tensor([1., 2., 3.])

torch_concepts/nn/modules/high/base/learner.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,6 @@ def configure_optimizers(self):
257257
Union[Optimizer, dict, None]: Returns optimizer directly, or dict with
258258
'optimizer' and optionally 'lr_scheduler' and 'monitor' keys,
259259
or None if no optimizer is configured.
260-
261-
Example:
262-
>>> # With scheduler monitoring validation loss
263-
>>> model = ConceptBottleneckModel(
264-
... ...,
265-
... optim_class=torch.optim.Adam,
266-
... optim_kwargs={'lr': 0.001},
267-
... scheduler_class=torch.optim.lr_scheduler.ReduceLROnPlateau,
268-
... scheduler_kwargs={'mode': 'min', 'patience': 5, 'monitor': 'val_loss'}
269-
... )
270260
"""
271261
# No optimizer configured
272262
if self.optim_class is None:

torch_concepts/nn/modules/low/dense_layers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def get_layer_activation(activation):
4545
ValueError: If activation name is not recognized.
4646
4747
Example:
48+
>>> from torch_concepts.nn.modules.low.dense_layers import get_layer_activation
4849
>>> act_class = get_layer_activation('relu')
4950
>>> activation = act_class() # ReLU()
5051
>>> act_class = get_layer_activation(None)

torch_concepts/nn/modules/low/lazy.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _filter_kwargs_for_ctor(cls, **kwargs):
2727
2828
Example:
2929
>>> import torch.nn as nn
30-
>>> from torch_concepts.nn.modules.propagator import _filter_kwargs_for_ctor
30+
>>> from torch_concepts.nn.modules.low.lazy import _filter_kwargs_for_ctor
3131
>>>
3232
>>> # Filter kwargs for Linear layer
3333
>>> kwargs = {'in_features': 10, 'out_features': 5, 'unknown_param': 42}
@@ -69,7 +69,7 @@ def instantiate_adaptive(module_cls, *args, drop_none=True, **kwargs):
6969
7070
Example:
7171
>>> import torch.nn as nn
72-
>>> from torch_concepts.nn.modules.propagator import instantiate_adaptive
72+
>>> from torch_concepts.nn.modules.low.lazy import instantiate_adaptive
7373
>>>
7474
>>> # Instantiate a Linear layer with extra kwargs
7575
>>> kwargs = {'in_features': 10, 'out_features': 5, 'extra': None}
@@ -106,13 +106,13 @@ class LazyConstructor(torch.nn.Module):
106106
>>> from torch_concepts.nn import LinearCC
107107
>>>
108108
>>> # Create a propagator for a predictor
109-
>>> lazy_constructorLazyConstructor(
109+
>>> lazy_constructor = LazyConstructor(
110110
... LinearCC,
111111
... activation=torch.sigmoid
112112
... )
113113
>>>
114114
>>> # Build the module when dimensions are known
115-
>>> module = propagator.build(
115+
>>> module = lazy_constructor.build(
116116
... out_features=3,
117117
... in_features_endogenous=5,
118118
... in_features=None,
@@ -121,7 +121,7 @@ class LazyConstructor(torch.nn.Module):
121121
>>>
122122
>>> # Use the module
123123
>>> x = torch.randn(2, 5)
124-
>>> output = propagator(x)
124+
>>> output = lazy_constructor(x)
125125
>>> print(output.shape)
126126
torch.Size([2, 3])
127127
"""
@@ -227,12 +227,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
227227
228228
Example:
229229
>>> import torch
230-
>>> from torch_concepts.nn.modules.propagator import LazyConstructor
231-
>>> from torch_concepts.nn.modules.predictors.linear import LinearCC
230+
>>> from torch_concepts.nn import LazyConstructor
231+
>>> from torch_concepts.nn import LinearCC
232232
>>>
233233
>>> # Create and build propagator
234-
>>> lazy_constructorLazyConstructor(LinearCC)
235-
>>> propagator.build(
234+
>>> lazy_constructor = LazyConstructor(LinearCC)
235+
>>> lazy_constructor.build(
236236
... out_features=3,
237237
... in_features_endogenous=5,
238238
... in_features=None,
@@ -241,7 +241,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
241241
>>>
242242
>>> # Forward pass
243243
>>> x = torch.randn(2, 5)
244-
>>> output = propagator(x)
244+
>>> output = lazy_constructor(x)
245245
>>> print(output.shape)
246246
torch.Size([2, 3])
247247
"""

torch_concepts/nn/modules/metrics.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,27 @@ class ConceptMetrics(nn.Module):
3636
specified concept names. Default: False.
3737
3838
Example:
39-
>>> from torch_concepts.nn.modules import ConceptMetrics, GroupConfig
39+
>>> from torch_concepts.nn.modules.metrics import ConceptMetrics, GroupConfig
4040
>>> import torchmetrics
41+
>>> import torch
42+
>>> from torch_concepts import Annotations, AxisAnnotation
4143
>>>
4244
>>> # Three ways to specify metrics:
45+
>>> concept_annotations = Annotations({1: AxisAnnotation(
46+
... labels=['concept1', 'concept2'],
47+
... metadata={
48+
... 'concept1': {'type': 'discrete'},
49+
... 'concept2': {'type': 'discrete'}
50+
... },
51+
... )})
4352
>>> metrics = ConceptMetrics(
4453
... annotations=concept_annotations,
4554
... fn_collection=GroupConfig(
4655
... binary={
4756
... # 1. Pre-instantiated
4857
... 'accuracy': torchmetrics.classification.BinaryAccuracy(),
4958
... # 2. Class + user kwargs (average='macro')
50-
... 'f1': (torchmetrics.classification.BinaryF1Score, {'average': 'macro'})
59+
... 'f1': (torchmetrics.classification.BinaryF1Score, {'multidim_average': 'global'})
5160
... },
5261
... categorical={
5362
... # 3. Class only (num_classes will be added automatically)
@@ -57,7 +66,11 @@ class ConceptMetrics(nn.Module):
5766
... summary_metrics=True,
5867
... perconcept_metrics=['concept1', 'concept2']
5968
... )
60-
>>>
69+
>>>
70+
>>> # Simulated predictions and targets
71+
>>> predictions = torch.tensor([[0.8, 0.2], [0.4, 0.6]])
72+
>>> targets = torch.tensor([[1, 0], [0, 1]])
73+
>>>
6174
>>> # Update metrics during training
6275
>>> metrics.update(predictions, targets, split='train')
6376
>>>

torch_concepts/nn/modules/mid/base/model.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
This module provides the abstract base class for all concept-based models,
55
defining the structure for models that use concept representations.
66
"""
7+
from typing import Union
8+
79
import torch
10+
from torch.nn import Module
811

912
from .....annotations import Annotations
1013
from ...low.lazy import LazyConstructor
@@ -33,16 +36,26 @@ class BaseConstructor(torch.nn.Module):
3336
Example:
3437
>>> import torch
3538
>>> from torch_concepts import Annotations, AxisAnnotation
36-
>>> from torch_concepts.nn import BaseModel, LazyConstructor
39+
>>> from torch_concepts.nn import LazyConstructor
40+
>>> from torch_concepts.nn.modules.mid.base.model import BaseConstructor
41+
>>> from torch.distributions import RelaxedBernoulli
3742
>>>
3843
>>> # Create annotations for concepts
3944
>>> concept_labels = ('color', 'shape', 'size')
40-
>>> annotations = Annotations({
41-
... 1: AxisAnnotation(labels=concept_labels)
42-
... })
45+
>>> cardinalities = [1, 1, 1]
46+
>>> metadata = {
47+
... 'color': {'distribution': RelaxedBernoulli},
48+
... 'shape': {'distribution': RelaxedBernoulli},
49+
... 'size': {'distribution': RelaxedBernoulli}
50+
... }
51+
>>> annotations = Annotations({1: AxisAnnotation(
52+
... labels=concept_labels,
53+
... cardinalities=cardinalities,
54+
... metadata=metadata
55+
... )})
4356
>>>
4457
>>> # Create a concrete model class
45-
>>> class MyConceptModel(BaseModel):
58+
>>> class MyConceptModel(BaseConstructor):
4659
... def __init__(self, input_size, annotations, encoder, predictor):
4760
... super().__init__(input_size, annotations, encoder, predictor)
4861
... # Build encoder and predictor
@@ -84,8 +97,8 @@ class BaseConstructor(torch.nn.Module):
8497
def __init__(self,
8598
input_size: int,
8699
annotations: Annotations,
87-
encoder: LazyConstructor, # layer for root concepts
88-
predictor: LazyConstructor,
100+
encoder: Union[LazyConstructor, Module], # layer for root concepts
101+
predictor: Union[LazyConstructor, Module],
89102
*args,
90103
**kwargs
91104
):

torch_concepts/nn/modules/mid/constructors/concept_graph.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def from_sparse(cls, edge_index: Tensor, edge_weight: Tensor, n_nodes: int, node
158158
ConceptGraph instance
159159
160160
Example:
161+
>>> import torch
162+
>>> from torch_concepts import ConceptGraph
161163
>>> edge_index = torch.tensor([[0, 0, 1], [1, 2, 2]])
162164
>>> edge_weight = torch.tensor([1.0, 1.0, 1.0])
163165
>>> graph = ConceptGraph.from_sparse(edge_index, edge_weight, n_nodes=3)
@@ -310,11 +312,6 @@ def to_networkx(self, threshold: float = 0.0) -> nx.DiGraph:
310312
311313
Returns:
312314
nx.DiGraph: NetworkX directed graph
313-
314-
Example:
315-
>>> G = graph.to_networkx()
316-
>>> list(G.nodes())
317-
['A', 'B', 'C']
318315
"""
319316
# If threshold is 0.0 and we have a cache, return it
320317
if threshold == 0.0 and self._nx_graph_cache is not None:
@@ -357,11 +354,6 @@ def dense_to_sparse(self, threshold: float = 0.0) -> Tuple[Tensor, Tensor]:
357354
Returns:
358355
edge_index: Tensor of shape (2, num_edges) with source and target indices
359356
edge_weight: Tensor of shape (num_edges,) with edge weights
360-
361-
Example:
362-
>>> edge_index, edge_weight = graph.dense_to_sparse()
363-
>>> edge_index.shape
364-
torch.Size([2, num_edges])
365357
"""
366358
if threshold > 0.0:
367359
# Filter edges by threshold
@@ -500,6 +492,8 @@ def dense_to_sparse(
500492
edge_weight: Tensor of shape (num_edges,) with edge weights
501493
502494
Example:
495+
>>> import torch
496+
>>> from torch_concepts.nn.modules.mid.constructors.concept_graph import dense_to_sparse
503497
>>> adj = torch.tensor([[0., 1., 0.],
504498
... [0., 0., 1.],
505499
... [0., 0., 0.]])
@@ -539,6 +533,8 @@ def to_networkx_graph(
539533
nx.DiGraph: NetworkX directed graph
540534
541535
Example:
536+
>>> import torch
537+
>>> from torch_concepts.nn.modules.mid.constructors.concept_graph import to_networkx_graph
542538
>>> adj = torch.tensor([[0., 1., 1.],
543539
... [0., 0., 1.],
544540
... [0., 0., 0.]])
@@ -590,6 +586,8 @@ def get_root_nodes(
590586
List of root node names
591587
592588
Example:
589+
>>> import torch
590+
>>> from torch_concepts.nn.modules.mid.constructors.concept_graph import get_root_nodes
593591
>>> adj = torch.tensor([[0., 1., 1.],
594592
... [0., 0., 1.],
595593
... [0., 0., 0.]])
@@ -621,6 +619,8 @@ def get_leaf_nodes(
621619
List of leaf node names
622620
623621
Example:
622+
>>> import torch
623+
>>> from torch_concepts.nn.modules.mid.constructors.concept_graph import get_leaf_nodes
624624
>>> adj = torch.tensor([[0., 1., 1.],
625625
... [0., 0., 1.],
626626
... [0., 0., 0.]])
@@ -657,6 +657,8 @@ def topological_sort(
657657
nx.NetworkXError: If graph contains cycles
658658
659659
Example:
660+
>>> import torch
661+
>>> from torch_concepts.nn.modules.mid.constructors.concept_graph import topological_sort
660662
>>> adj = torch.tensor([[0., 1., 1.],
661663
... [0., 0., 1.],
662664
... [0., 0., 0.]])
@@ -692,6 +694,8 @@ def get_predecessors(
692694
List of predecessor node names
693695
694696
Example:
697+
>>> import torch
698+
>>> from torch_concepts.nn.modules.mid.constructors.concept_graph import get_predecessors
695699
>>> adj = torch.tensor([[0., 1., 1.],
696700
... [0., 0., 1.],
697701
... [0., 0., 0.]])
@@ -731,6 +735,8 @@ def get_successors(
731735
List of successor node names
732736
733737
Example:
738+
>>> import torch
739+
>>> from torch_concepts.nn.modules.mid.constructors.concept_graph import get_successors
734740
>>> adj = torch.tensor([[0., 1., 1.],
735741
... [0., 0., 1.],
736742
... [0., 0., 0.]])

torch_concepts/nn/modules/mid/inference/forward.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ class DeterministicInference(ForwardInference):
712712
>>> from torch.distributions import Bernoulli
713713
>>> from torch_concepts import InputVariable, EndogenousVariable
714714
>>> from torch_concepts.distributions import Delta
715-
>>> from torch_concepts.nn import DeterministicInference, ParametricCPD, ProbabilisticModel
715+
>>> from torch_concepts.nn import DeterministicInference, ParametricCPD, ProbabilisticModel, LinearCC
716716
>>>
717717
>>> # Create a simple PGM: latent -> A -> B
718718
>>> input_var = InputVariable('input', parents=[], distribution=Delta, size=10)
@@ -723,7 +723,7 @@ class DeterministicInference(ForwardInference):
723723
>>> from torch.nn import Identity, Linear
724724
>>> cpd_emb = ParametricCPD('input', parametrization=Identity())
725725
>>> cpd_A = ParametricCPD('A', parametrization=Linear(10, 1))
726-
>>> cpd_B = ParametricCPD('B', parametrization=Linear(1, 1))
726+
>>> cpd_B = ParametricCPD('B', parametrization=LinearCC(1, 1))
727727
>>>
728728
>>> # Create probabilistic model
729729
>>> pgm = ProbabilisticModel(
@@ -743,7 +743,7 @@ class DeterministicInference(ForwardInference):
743743
>>> print(results['B'].shape) # torch.Size([4, 1]) - endogenous, not {0,1}
744744
>>>
745745
>>> # Query specific concepts - returns concatenated endogenous
746-
>>> output = inference.query(['B', 'A'], evidence={'embedding': x})
746+
>>> output = inference.query(['B', 'A'], evidence={'input': x})
747747
>>> print(output.shape) # torch.Size([4, 2])
748748
>>> # output contains [logit_B, logit_A] for each sample
749749
>>>
@@ -794,6 +794,8 @@ class AncestralSamplingInference(ForwardInference):
794794
>>> from torch_concepts import InputVariable
795795
>>> from torch_concepts.distributions import Delta
796796
>>> from torch_concepts.nn import AncestralSamplingInference, ParametricCPD, ProbabilisticModel
797+
>>> from torch_concepts import EndogenousVariable
798+
>>> from torch_concepts.nn import LinearCC
797799
>>>
798800
>>> # Create a simple PGM: embedding -> A -> B
799801
>>> embedding_var = InputVariable('embedding', parents=[], distribution=Delta, size=10)
@@ -804,7 +806,7 @@ class AncestralSamplingInference(ForwardInference):
804806
>>> from torch.nn import Identity, Linear
805807
>>> cpd_emb = ParametricCPD('embedding', parametrization=Identity())
806808
>>> cpd_A = ParametricCPD('A', parametrization=Linear(10, 1))
807-
>>> cpd_B = ParametricCPD('B', parametrization=Linear(1, 1))
809+
>>> cpd_B = ParametricCPD('B', parametrization=LinearCC(1, 1))
808810
>>>
809811
>>> # Create probabilistic model
810812
>>> pgm = ProbabilisticModel(
@@ -838,8 +840,8 @@ class AncestralSamplingInference(ForwardInference):
838840
>>>
839841
>>> # With relaxed distributions (requires temperature)
840842
>>> from torch.distributions import RelaxedBernoulli
841-
>>> var_A_relaxed = Variable('A', parents=['embedding'],
842-
... distribution=RelaxedBernoulli, size=1)
843+
>>> var_A_relaxed = InputVariable('A', parents=['embedding'],
844+
... distribution=RelaxedBernoulli, size=1)
843845
>>> pgm = ProbabilisticModel(
844846
... variables=[embedding_var, var_A_relaxed, var_B],
845847
... parametric_cpds=[cpd_emb, cpd_A, cpd_B]

torch_concepts/nn/modules/mid/models/cpd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ class ParametricCPD(nn.Module):
6060
... parametrization=[module_a, module_b]
6161
... )
6262
>>>
63-
>>> print(cpd[0].module)
63+
>>> print(cpd[0].parametrization)
6464
Linear(in_features=10, out_features=1, bias=True)
65-
>>> print(cpd[1].module)
65+
>>> print(cpd[1].parametrization)
6666
Sequential(...)
6767
6868
Notes

0 commit comments

Comments
 (0)