Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mambular/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@

# The following line *must* be the last in the module, exactly as formatted:

__version__ = "1.3.1"

__version__ = "1.3.2"
29 changes: 29 additions & 0 deletions mambular/arch_utils/simple_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.nn as nn

class MLP_Block(nn.Module):
def __init__(self, d_in: int, d: int, dropout: float):
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm1d(d_in),
nn.Linear(d_in, d),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d, d_in)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.block(x)


import torch

def make_random_batches(
train_size: int, batch_size: int, device = None
) :
permutation = torch.randperm(train_size, device=device)
batches = permutation.split(batch_size)

assert torch.equal(
torch.arange(train_size, device=device), permutation.sort().values
)
return batches
2 changes: 2 additions & 0 deletions mambular/base_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from .trompt import Trompt
from .enode import ENODE
from .tangos import Tangos
from .modern_nca import ModernNCA

__all__ = [
"ModernNCA",
"Tangos",
"ENODE",
"Trompt",
Expand Down
216 changes: 216 additions & 0 deletions mambular/base_models/modern_nca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..utils.get_feature_dimensions import get_feature_dimensions
from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mlp_utils import MLPhead
from ..configs.modernnca_config import DefaultModernNCAConfig
from .utils.basemodel import BaseModel


class ModernNCA(BaseModel):
def __init__(
self,
feature_information: tuple,
num_classes=1,
config: DefaultModernNCAConfig = DefaultModernNCAConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
self.save_hyperparameters(ignore=["feature_information"])

self.returns_ensemble = False
self.uses_nca_candidates = True

self.T = config.temperature
self.sample_rate = config.sample_rate
if self.hparams.use_embeddings:
self.embedding_layer = EmbeddingLayer(
*feature_information,
config=config,
)
input_dim = np.sum(
[len(info) * self.hparams.d_model for info in feature_information]
)
else:
input_dim = get_feature_dimensions(*feature_information)

self.encoder = nn.Linear(input_dim, config.dim)

if config.n_blocks > 0:
self.post_encoder = nn.Sequential(
*[self.make_layer(config) for _ in range(config.n_blocks)],
nn.BatchNorm1d(config.dim),
)

self.tabular_head = MLPhead(
input_dim=config.dim,
config=config,
output_dim=num_classes,
)

self.hparams.num_classes = num_classes

def make_layer(self, config):
return nn.Sequential(
nn.BatchNorm1d(config.dim),
nn.Linear(config.dim, config.d_block),
nn.ReLU(inplace=True),
nn.Dropout(config.dropout),
nn.Linear(config.d_block, config.dim),
)

def forward(self, *data):
"""Standard forward pass without candidate selection (for baseline compatibility)."""
if self.hparams.use_embeddings:
x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
else:
x = torch.cat([t for tensors in data for t in tensors], dim=1)
x = self.encoder(x)
if hasattr(self, "post_encoder"):
x = self.post_encoder(x)
return self.tabular_head(x)

def nca_train(self, *data, targets, candidate_x, candidate_y):
"""NCA-style training forward pass selecting candidates."""
if self.hparams.use_embeddings:
x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
candidate_x = self.embedding_layer(*candidate_x)
B, S, D = candidate_x.shape
candidate_x = candidate_x.reshape(B, S * D)
else:
x = torch.cat([t for tensors in data for t in tensors], dim=1)
candidate_x = torch.cat(
[t for tensors in candidate_x for t in tensors], dim=1
)

# Encode input
x = self.encoder(x)
candidate_x = self.encoder(candidate_x)

if hasattr(self, "post_encoder"):
x = self.post_encoder(x)
candidate_x = self.post_encoder(candidate_x)

# Select a subset of candidates
data_size = candidate_x.shape[0]
retrieval_size = int(data_size * self.sample_rate)
sample_idx = torch.randperm(data_size)[:retrieval_size]
candidate_x = candidate_x[sample_idx]
candidate_y = candidate_y[sample_idx]

# Concatenate with training batch
candidate_x = torch.cat([x, candidate_x], dim=0)
candidate_y = torch.cat([targets, candidate_y], dim=0)

# One-hot encode if classification
if self.hparams.num_classes > 1:
candidate_y = F.one_hot(
candidate_y, num_classes=self.hparams.num_classes
).to(x.dtype)
elif len(candidate_y.shape) == 1:
candidate_y = candidate_y.unsqueeze(-1)

# Compute distances
distances = torch.cdist(x, candidate_x, p=2) / self.T
# remove the label of training index
distances = distances.fill_diagonal_(torch.inf)
distances = F.softmax(-distances, dim=-1)
logits = torch.mm(distances, candidate_y)
eps = 1e-7
if self.hparams.num_classes > 1:
logits = torch.log(logits + eps)

return logits

def nca_validate(self, *data, candidate_x, candidate_y):
"""Validation forward pass with NCA-style candidate selection."""
if self.hparams.use_embeddings:
x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
candidate_x = self.embedding_layer(*candidate_x)
B, S, D = candidate_x.shape
candidate_x = candidate_x.reshape(B, S * D)
else:
x = torch.cat([t for tensors in data for t in tensors], dim=1)
candidate_x = torch.cat(
[t for tensors in candidate_x for t in tensors], dim=1
)

# Encode input
x = self.encoder(x)
candidate_x = self.encoder(candidate_x)

if hasattr(self, "post_encoder"):
x = self.post_encoder(x)
candidate_x = self.post_encoder(candidate_x)

# One-hot encode if classification
if self.hparams.num_classes > 1:
candidate_y = F.one_hot(
candidate_y, num_classes=self.hparams.num_classes
).to(x.dtype)
elif len(candidate_y.shape) == 1:
candidate_y = candidate_y.unsqueeze(-1)

# Compute distances
distances = torch.cdist(x, candidate_x, p=2) / self.T
distances = F.softmax(-distances, dim=-1)

# Compute logits
logits = torch.mm(distances, candidate_y)
eps = 1e-7
if self.hparams.num_classes > 1:
logits = torch.log(logits + eps)

return logits

def nca_predict(self, *data, candidate_x, candidate_y):
"""Prediction forward pass with candidate selection."""
if self.hparams.use_embeddings:
x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
candidate_x = self.embedding_layer(*candidate_x)
B, S, D = candidate_x.shape
candidate_x = candidate_x.reshape(B, S * D)
else:
x = torch.cat([t for tensors in data for t in tensors], dim=1)
candidate_x = torch.cat(
[t for tensors in candidate_x for t in tensors], dim=1
)

# Encode input
x = self.encoder(x)
candidate_x = self.encoder(candidate_x)

if hasattr(self, "post_encoder"):
x = self.post_encoder(x)
candidate_x = self.post_encoder(candidate_x)

# One-hot encode if classification
if self.hparams.num_classes > 1:
candidate_y = F.one_hot(
candidate_y, num_classes=self.hparams.num_classes
).to(x.dtype)
elif len(candidate_y.shape) == 1:
candidate_y = candidate_y.unsqueeze(-1)

# Compute distances
distances = torch.cdist(x, candidate_x, p=2) / self.T
distances = F.softmax(-distances, dim=-1)

# Compute logits
logits = torch.mm(distances, candidate_y)
eps = 1e-7
if self.hparams.num_classes > 1:
logits = torch.log(logits + eps)

return logits
71 changes: 64 additions & 7 deletions mambular/base_models/utils/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,39 @@ def __init__(
**kwargs,
)

def setup(self, stage=None):
if stage == "fit" and hasattr(self.estimator, "uses_nca_candidates"):
all_train_num = []
all_train_cat = []
all_train_embeddings = []
all_train_targets = []

device = self.device if hasattr(self, "device") else self.trainer.device

for batch in self.trainer.datamodule.train_dataloader():
(num_features, cat_features, embeddings), labels = batch

all_train_num.append([f.to(device) for f in num_features]) # Keep lists
all_train_cat.append([f.to(device) for f in cat_features]) # Keep lists
if embeddings is not None:
all_train_embeddings.append([f.to(device) for f in embeddings])
all_train_targets.append(labels.to(device))

# Maintain structure: each feature type remains a list of tensors
self.train_features = (
[torch.cat(features, dim=0) for features in zip(*all_train_num)],
[torch.cat(features, dim=0) for features in zip(*all_train_cat)],
(
[
torch.cat(features, dim=0)
for features in zip(*all_train_embeddings)
]
if all_train_embeddings
else None
),
)
self.train_targets = torch.cat(all_train_targets, dim=0)

def forward(self, num_features, cat_features, embeddings):
"""Forward pass through the model.

Expand Down Expand Up @@ -184,7 +217,7 @@ def training_step(self, batch, batch_idx): # type: ignore
Index of the batch.

Returns
-------
------
Tensor
Training loss.
"""
Expand All @@ -194,6 +227,14 @@ def training_step(self, batch, batch_idx): # type: ignore
if hasattr(self.estimator, "penalty_forward"):
preds, penalty = self.estimator.penalty_forward(*data)
loss = self.compute_loss(preds, labels) + penalty
elif hasattr(self.estimator, "uses_nca_candidates"):
preds = self.estimator.nca_train(
*data,
targets=labels,
candidate_x=self.train_features,
candidate_y=self.train_targets,
)
loss = self.compute_loss(preds, labels)
else:
preds = self(*data)
loss = self.compute_loss(preds, labels)
Expand Down Expand Up @@ -234,7 +275,12 @@ def validation_step(self, batch, batch_idx): # type: ignore
"""

data, labels = batch
preds = self(*data)
if hasattr(self.estimator, "nca_validate") and self.train_features is not None:
preds = self.estimator.nca_validate(
*data, candidate_x=self.train_features, candidate_y=self.train_targets
)
else:
preds = self(*data)
val_loss = self.compute_loss(preds, labels)

self.log(
Expand Down Expand Up @@ -276,7 +322,12 @@ def test_step(self, batch, batch_idx): # type: ignore
Test loss.
"""
data, labels = batch
preds = self(*data)
if hasattr(self.estimator, "nca_predict") and self.train_features is not None:
preds = self.estimator.nca_predict(
*data, candidates_x=self.train_features, candidates_y=self.train_targets
)
else:
preds = self(*data)
test_loss = self.compute_loss(preds, labels)

self.log(
Expand Down Expand Up @@ -305,8 +356,14 @@ def predict_step(self, batch, batch_idx):
Tensor
Predictions.
"""

preds = self(*batch)
if hasattr(self.estimator, "nca_predict") and self.train_features is not None:
preds = self.estimator.nca_predict(
*batch,
candidate_x=self.train_features,
candidate_y=self.train_targets,
)
else:
preds = self(*batch)

return preds

Expand Down Expand Up @@ -425,7 +482,7 @@ def pretrain_embeddings(
temperature=0.1,
save_path="pretrained_embeddings.pth",
regression=True,
lr=1e-04
lr=1e-04,
):
"""Pretrain embeddings before full model training.

Expand Down Expand Up @@ -594,7 +651,7 @@ def contrastive_loss(self, embeddings, knn_indices, temperature=0.1):
) # Shape: (N * k_neighbors)

# Compute cosine embedding loss
loss += -1.0*loss_fn(embeddings_s, positive_pairs, labels)
loss += -1.0 * loss_fn(embeddings_s, positive_pairs, labels)

# Average loss across all sequence steps
loss /= S
Expand Down
Loading