From c20f18f61f02566ecc9e749c5e3569488798ac04 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Wed, 19 Mar 2025 13:13:43 +0100 Subject: [PATCH 1/7] include simple utils --- mambular/arch_utils/simple_utils.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 mambular/arch_utils/simple_utils.py diff --git a/mambular/arch_utils/simple_utils.py b/mambular/arch_utils/simple_utils.py new file mode 100644 index 00000000..ba1067a8 --- /dev/null +++ b/mambular/arch_utils/simple_utils.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +class MLP_Block(nn.Module): + def __init__(self, d_in: int, d: int, dropout: float): + super().__init__() + self.block = nn.Sequential( + nn.BatchNorm1d(d_in), + nn.Linear(d_in, d), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(d, d_in) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +import torch + +def make_random_batches( + train_size: int, batch_size: int, device = None +) : + permutation = torch.randperm(train_size, device=device) + batches = permutation.split(batch_size) + + assert torch.equal( + torch.arange(train_size, device=device), permutation.sort().values + ) + return batches \ No newline at end of file From 99729b9408a7c97143114ae063e3a1032c61deac Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Wed, 19 Mar 2025 13:13:57 +0100 Subject: [PATCH 2/7] adapt lightning module to use training candidates during prediction/validation --- .../base_models/utils/lightning_wrapper.py | 71 +++++++++++++++++-- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/mambular/base_models/utils/lightning_wrapper.py b/mambular/base_models/utils/lightning_wrapper.py index 62a451c2..1e1bdbc6 100644 --- a/mambular/base_models/utils/lightning_wrapper.py +++ b/mambular/base_models/utils/lightning_wrapper.py @@ -96,6 +96,39 @@ def __init__( **kwargs, ) + def setup(self, stage=None): + if stage == "fit" and hasattr(self.estimator, "uses_nca_candidates"): + all_train_num = [] + all_train_cat = [] + all_train_embeddings = [] + all_train_targets = [] + + device = self.device if hasattr(self, "device") else self.trainer.device + + for batch in self.trainer.datamodule.train_dataloader(): + (num_features, cat_features, embeddings), labels = batch + + all_train_num.append([f.to(device) for f in num_features]) # Keep lists + all_train_cat.append([f.to(device) for f in cat_features]) # Keep lists + if embeddings is not None: + all_train_embeddings.append([f.to(device) for f in embeddings]) + all_train_targets.append(labels.to(device)) + + # Maintain structure: each feature type remains a list of tensors + self.train_features = ( + [torch.cat(features, dim=0) for features in zip(*all_train_num)], + [torch.cat(features, dim=0) for features in zip(*all_train_cat)], + ( + [ + torch.cat(features, dim=0) + for features in zip(*all_train_embeddings) + ] + if all_train_embeddings + else None + ), + ) + self.train_targets = torch.cat(all_train_targets, dim=0) + def forward(self, num_features, cat_features, embeddings): """Forward pass through the model. @@ -184,7 +217,7 @@ def training_step(self, batch, batch_idx): # type: ignore Index of the batch. Returns - ------- + ------ Tensor Training loss. """ @@ -194,6 +227,14 @@ def training_step(self, batch, batch_idx): # type: ignore if hasattr(self.estimator, "penalty_forward"): preds, penalty = self.estimator.penalty_forward(*data) loss = self.compute_loss(preds, labels) + penalty + elif hasattr(self.estimator, "uses_nca_candidates"): + preds = self.estimator.nca_train( + *data, + targets=labels, + candidate_x=self.train_features, + candidate_y=self.train_targets, + ) + loss = self.compute_loss(preds, labels) else: preds = self(*data) loss = self.compute_loss(preds, labels) @@ -234,7 +275,12 @@ def validation_step(self, batch, batch_idx): # type: ignore """ data, labels = batch - preds = self(*data) + if hasattr(self.estimator, "nca_validate") and self.train_features is not None: + preds = self.estimator.nca_validate( + *data, candidate_x=self.train_features, candidate_y=self.train_targets + ) + else: + preds = self(*data) val_loss = self.compute_loss(preds, labels) self.log( @@ -276,7 +322,12 @@ def test_step(self, batch, batch_idx): # type: ignore Test loss. """ data, labels = batch - preds = self(*data) + if hasattr(self.estimator, "nca_predict") and self.train_features is not None: + preds = self.estimator.nca_predict( + *data, candidates_x=self.train_features, candidates_y=self.train_targets + ) + else: + preds = self(*data) test_loss = self.compute_loss(preds, labels) self.log( @@ -305,8 +356,14 @@ def predict_step(self, batch, batch_idx): Tensor Predictions. """ - - preds = self(*batch) + if hasattr(self.estimator, "nca_predict") and self.train_features is not None: + preds = self.estimator.nca_predict( + *batch, + candidate_x=self.train_features, + candidate_y=self.train_targets, + ) + else: + preds = self(*batch) return preds @@ -425,7 +482,7 @@ def pretrain_embeddings( temperature=0.1, save_path="pretrained_embeddings.pth", regression=True, - lr=1e-04 + lr=1e-04, ): """Pretrain embeddings before full model training. @@ -594,7 +651,7 @@ def contrastive_loss(self, embeddings, knn_indices, temperature=0.1): ) # Shape: (N * k_neighbors) # Compute cosine embedding loss - loss += -1.0*loss_fn(embeddings_s, positive_pairs, labels) + loss += -1.0 * loss_fn(embeddings_s, positive_pairs, labels) # Average loss across all sequence steps loss /= S From 6ffb6b2b6852e98dd4c860ae280b05f8bc1f0272 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Wed, 19 Mar 2025 13:14:06 +0100 Subject: [PATCH 3/7] add modernnca to basemodels --- mambular/base_models/__init__.py | 2 + mambular/base_models/modern_nca.py | 216 +++++++++++++++++++++++++++++ 2 files changed, 218 insertions(+) create mode 100644 mambular/base_models/modern_nca.py diff --git a/mambular/base_models/__init__.py b/mambular/base_models/__init__.py index 06809938..3411d9be 100644 --- a/mambular/base_models/__init__.py +++ b/mambular/base_models/__init__.py @@ -14,8 +14,10 @@ from .trompt import Trompt from .enode import ENODE from .tangos import Tangos +from .modern_nca import ModernNCA __all__ = [ + "ModernNCA", "Tangos", "ENODE", "Trompt", diff --git a/mambular/base_models/modern_nca.py b/mambular/base_models/modern_nca.py new file mode 100644 index 00000000..a2aea324 --- /dev/null +++ b/mambular/base_models/modern_nca.py @@ -0,0 +1,216 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..utils.get_feature_dimensions import get_feature_dimensions +from ..arch_utils.get_norm_fn import get_normalization_layer +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.mlp_utils import MLPhead +from ..configs.modernnca_config import DefaultModernNCAConfig +from .utils.basemodel import BaseModel + + +class ModernNCA(BaseModel): + def __init__( + self, + feature_information: tuple, + num_classes=1, + config: DefaultModernNCAConfig = DefaultModernNCAConfig(), # noqa: B008 + **kwargs, + ): + super().__init__(config=config, **kwargs) + self.save_hyperparameters(ignore=["feature_information"]) + + self.returns_ensemble = False + self.uses_nca_candidates = True + + self.T = config.temperature + self.sample_rate = config.sample_rate + if self.hparams.use_embeddings: + self.embedding_layer = EmbeddingLayer( + *feature_information, + config=config, + ) + input_dim = np.sum( + [len(info) * self.hparams.d_model for info in feature_information] + ) + else: + input_dim = get_feature_dimensions(*feature_information) + + self.encoder = nn.Linear(input_dim, config.dim) + + if config.n_blocks > 0: + self.post_encoder = nn.Sequential( + *[self.make_layer(config) for _ in range(config.n_blocks)], + nn.BatchNorm1d(config.dim), + ) + + self.tabular_head = MLPhead( + input_dim=config.dim, + config=config, + output_dim=num_classes, + ) + + self.hparams.num_classes = num_classes + + def make_layer(self, config): + return nn.Sequential( + nn.BatchNorm1d(config.dim), + nn.Linear(config.dim, config.d_block), + nn.ReLU(inplace=True), + nn.Dropout(config.dropout), + nn.Linear(config.d_block, config.dim), + ) + + def forward(self, *data): + """Standard forward pass without candidate selection (for baseline compatibility).""" + if self.hparams.use_embeddings: + x = self.embedding_layer(*data) + B, S, D = x.shape + x = x.reshape(B, S * D) + else: + x = torch.cat([t for tensors in data for t in tensors], dim=1) + x = self.encoder(x) + if hasattr(self, "post_encoder"): + x = self.post_encoder(x) + return self.tabular_head(x) + + def nca_train(self, *data, targets, candidate_x, candidate_y): + """NCA-style training forward pass selecting candidates.""" + if self.hparams.use_embeddings: + x = self.embedding_layer(*data) + B, S, D = x.shape + x = x.reshape(B, S * D) + candidate_x = self.embedding_layer(*candidate_x) + B, S, D = candidate_x.shape + candidate_x = candidate_x.reshape(B, S * D) + else: + x = torch.cat([t for tensors in data for t in tensors], dim=1) + candidate_x = torch.cat( + [t for tensors in candidate_x for t in tensors], dim=1 + ) + + # Encode input + x = self.encoder(x) + candidate_x = self.encoder(candidate_x) + + if hasattr(self, "post_encoder"): + x = self.post_encoder(x) + candidate_x = self.post_encoder(candidate_x) + + # Select a subset of candidates + data_size = candidate_x.shape[0] + retrieval_size = int(data_size * self.sample_rate) + sample_idx = torch.randperm(data_size)[:retrieval_size] + candidate_x = candidate_x[sample_idx] + candidate_y = candidate_y[sample_idx] + + # Concatenate with training batch + candidate_x = torch.cat([x, candidate_x], dim=0) + candidate_y = torch.cat([targets, candidate_y], dim=0) + + # One-hot encode if classification + if self.hparams.num_classes > 1: + candidate_y = F.one_hot( + candidate_y, num_classes=self.hparams.num_classes + ).to(x.dtype) + elif len(candidate_y.shape) == 1: + candidate_y = candidate_y.unsqueeze(-1) + + # Compute distances + distances = torch.cdist(x, candidate_x, p=2) / self.T + # remove the label of training index + distances = distances.fill_diagonal_(torch.inf) + distances = F.softmax(-distances, dim=-1) + logits = torch.mm(distances, candidate_y) + eps = 1e-7 + if self.hparams.num_classes > 1: + logits = torch.log(logits + eps) + + return logits + + def nca_validate(self, *data, candidate_x, candidate_y): + """Validation forward pass with NCA-style candidate selection.""" + if self.hparams.use_embeddings: + x = self.embedding_layer(*data) + B, S, D = x.shape + x = x.reshape(B, S * D) + candidate_x = self.embedding_layer(*candidate_x) + B, S, D = candidate_x.shape + candidate_x = candidate_x.reshape(B, S * D) + else: + x = torch.cat([t for tensors in data for t in tensors], dim=1) + candidate_x = torch.cat( + [t for tensors in candidate_x for t in tensors], dim=1 + ) + + # Encode input + x = self.encoder(x) + candidate_x = self.encoder(candidate_x) + + if hasattr(self, "post_encoder"): + x = self.post_encoder(x) + candidate_x = self.post_encoder(candidate_x) + + # One-hot encode if classification + if self.hparams.num_classes > 1: + candidate_y = F.one_hot( + candidate_y, num_classes=self.hparams.num_classes + ).to(x.dtype) + elif len(candidate_y.shape) == 1: + candidate_y = candidate_y.unsqueeze(-1) + + # Compute distances + distances = torch.cdist(x, candidate_x, p=2) / self.T + distances = F.softmax(-distances, dim=-1) + + # Compute logits + logits = torch.mm(distances, candidate_y) + eps = 1e-7 + if self.hparams.num_classes > 1: + logits = torch.log(logits + eps) + + return logits + + def nca_predict(self, *data, candidate_x, candidate_y): + """Prediction forward pass with candidate selection.""" + if self.hparams.use_embeddings: + x = self.embedding_layer(*data) + B, S, D = x.shape + x = x.reshape(B, S * D) + candidate_x = self.embedding_layer(*candidate_x) + B, S, D = candidate_x.shape + candidate_x = candidate_x.reshape(B, S * D) + else: + x = torch.cat([t for tensors in data for t in tensors], dim=1) + candidate_x = torch.cat( + [t for tensors in candidate_x for t in tensors], dim=1 + ) + + # Encode input + x = self.encoder(x) + candidate_x = self.encoder(candidate_x) + + if hasattr(self, "post_encoder"): + x = self.post_encoder(x) + candidate_x = self.post_encoder(candidate_x) + + # One-hot encode if classification + if self.hparams.num_classes > 1: + candidate_y = F.one_hot( + candidate_y, num_classes=self.hparams.num_classes + ).to(x.dtype) + elif len(candidate_y.shape) == 1: + candidate_y = candidate_y.unsqueeze(-1) + + # Compute distances + distances = torch.cdist(x, candidate_x, p=2) / self.T + distances = F.softmax(-distances, dim=-1) + + # Compute logits + logits = torch.mm(distances, candidate_y) + eps = 1e-7 + if self.hparams.num_classes > 1: + logits = torch.log(logits + eps) + + return logits From 2a009839b365077d8abdd1d9bf99ee3a0c601dc0 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Wed, 19 Mar 2025 13:14:13 +0100 Subject: [PATCH 4/7] add modernnca config --- mambular/configs/__init__.py | 2 ++ mambular/configs/modernnca_config.py | 34 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 mambular/configs/modernnca_config.py diff --git a/mambular/configs/__init__.py b/mambular/configs/__init__.py index 7a6ae0b4..9e803ac2 100644 --- a/mambular/configs/__init__.py +++ b/mambular/configs/__init__.py @@ -15,8 +15,10 @@ from .base_config import BaseConfig from .enode_config import DefaultENODEConfig from .tangos_config import DefaultTangosConfig +from .modernnca_config import DefaultModernNCAConfig __all__ = [ + "DefaultModernNCAConfig", "DefaultTangosConfig", "DefaultENODEConfig", "DefaultTromptConfig", diff --git a/mambular/configs/modernnca_config.py b/mambular/configs/modernnca_config.py new file mode 100644 index 00000000..bb9f758e --- /dev/null +++ b/mambular/configs/modernnca_config.py @@ -0,0 +1,34 @@ +from collections.abc import Callable +from dataclasses import dataclass, field +import torch.nn as nn +from .base_config import BaseConfig + + +@dataclass +class DefaultModernNCAConfig(BaseConfig): + """ + Default configuration for the ModernNCA model. + """ + + # Architecture Parameters + dim: int = 128 # Hidden dimension for encoding + d_block: int = 512 # Block size for MLP layers + n_blocks: int = 4 # Number of MLP blocks + dropout: float = 0.1 # Dropout rate + temperature: float = 0.75 # Temperature scaling for distance weighting + sample_rate: float = 0.5 # Fraction of candidate samples used + num_embeddings: dict | None = None # Dictionary for categorical embeddings + + # Training Parameters + optimizer_type: str = "AdamW" # Optimizer type + weight_decay: float = 1e-5 # Weight decay for optimizer + learning_rate: float = 1e-02 # Learning rate + lr_patience: int = 10 # Patience for LR scheduler + lr_factor: float = 0.1 # Factor for LR scheduler + + # Head Parameters + head_layer_sizes: list = field(default_factory=list) + head_dropout: float = 0.5 + head_skip_layers: bool = False + head_activation: Callable = nn.SELU() # noqa: RUF009 + head_use_batch_norm: bool = False From 8c29f45ef482c3cc25a8e632df2eb79014765fa1 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Wed, 19 Mar 2025 13:14:22 +0100 Subject: [PATCH 5/7] add modernnca to models --- mambular/models/__init__.py | 4 +++ mambular/models/modern_nca.py | 66 +++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 mambular/models/modern_nca.py diff --git a/mambular/models/__init__.py b/mambular/models/__init__.py index 97f74bd5..4db9342f 100644 --- a/mambular/models/__init__.py +++ b/mambular/models/__init__.py @@ -29,8 +29,12 @@ from .trompt import TromptClassifier, TromptLSS, TromptRegressor from .enode import ENODEClassifier, ENODELSS, ENODERegressor from .tangos import TangosClassifier, TangosLSS, TangosRegressor +from .modern_nca import ModernNCARegressor, ModernNCAClassifier, ModernNCALSS __all__ = [ + "ModernNCARegressor", + "ModernNCAClassifier", + "ModernNCALSS", "TangosClassifier", "TangosLSS", "TangosRegressor", diff --git a/mambular/models/modern_nca.py b/mambular/models/modern_nca.py new file mode 100644 index 00000000..1318ead2 --- /dev/null +++ b/mambular/models/modern_nca.py @@ -0,0 +1,66 @@ +from ..base_models.modern_nca import ModernNCA +from ..configs.modernnca_config import DefaultModernNCAConfig +from ..utils.docstring_generator import generate_docstring +from .utils.sklearn_base_classifier import SklearnBaseClassifier +from .utils.sklearn_base_lss import SklearnBaseLSS +from .utils.sklearn_base_regressor import SklearnBaseRegressor + + +class ModernNCARegressor(SklearnBaseRegressor): + __doc__ = generate_docstring( + DefaultModernNCAConfig, + model_description=""" + Multi-Layer Perceptron regressor. This class extends the SklearnBaseRegressor class and uses the ModernNCA model + with the default ModernNCA configuration. + """, + examples=""" + >>> from mambular.models import ModernNCARegressor + >>> model = ModernNCARegressor(d_model=64, n_layers=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, + ) + + def __init__(self, **kwargs): + super().__init__(model=ModernNCA, config=DefaultModernNCAConfig, **kwargs) + + +class ModernNCAClassifier(SklearnBaseClassifier): + __doc__ = generate_docstring( + DefaultModernNCAConfig, + model_description=""" + Multi-Layer Perceptron classifier This class extends the SklearnBaseClassifier class and uses the ModernNCA model + with the default ModernNCA configuration. + """, + examples=""" + >>> from mambular.models import ModernNCAClassifier + >>> model = ModernNCAClassifier(d_model=64, n_layers=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, + ) + + def __init__(self, **kwargs): + super().__init__(model=ModernNCA, config=DefaultModernNCAConfig, **kwargs) + + +class ModernNCALSS(SklearnBaseLSS): + __doc__ = generate_docstring( + DefaultModernNCAConfig, + model_description=""" + Multi-Layer Perceptron for distributional regression. This class extends the SklearnBaseLSS class and uses the ModernNCA model + with the default ModernNCA configuration. + """, + examples=""" + >>> from mambular.models import ModernNCALSS + >>> model = ModernNCALSS(d_model=64, n_layers=8) + >>> model.fit(X_train, y_train, family='normal') + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """, + ) + + def __init__(self, **kwargs): + super().__init__(model=ModernNCA, config=DefaultModernNCAConfig, **kwargs) From faee827ee2e31ab34431518e34f8b85aaab06bd4 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Wed, 19 Mar 2025 13:14:57 +0100 Subject: [PATCH 6/7] increase version (minor) --- mambular/__version__.py | 3 +-- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mambular/__version__.py b/mambular/__version__.py index fcfe670f..43b979e7 100644 --- a/mambular/__version__.py +++ b/mambular/__version__.py @@ -17,5 +17,4 @@ # The following line *must* be the last in the module, exactly as formatted: -__version__ = "1.3.1" - +__version__ = "1.3.2" diff --git a/pyproject.toml b/pyproject.toml index 99186db9..d1fdb8a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "mambular" -version = "1.3.1" +version = "1.3.2" description = "A python package for tabular deep learning with mamba blocks." authors = ["Anton Thielmann", "Manish Kumar", "Christoph Weisser"] From 4e6090dac90e906d0c6c7df2689587fcaba80ee6 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 24 Mar 2025 09:57:16 +0100 Subject: [PATCH 7/7] changed config for modernNCA --- mambular/configs/modernnca_config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mambular/configs/modernnca_config.py b/mambular/configs/modernnca_config.py index bb9f758e..d079d542 100644 --- a/mambular/configs/modernnca_config.py +++ b/mambular/configs/modernnca_config.py @@ -32,3 +32,9 @@ class DefaultModernNCAConfig(BaseConfig): head_skip_layers: bool = False head_activation: Callable = nn.SELU() # noqa: RUF009 head_use_batch_norm: bool = False + + # Embedding Parameters + emebedding_type: str = "plr" + plr_lite: bool = True + n_frequencies: int = 75 + frequencies_init_scale: float = 0.045