diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index acc60370..01d6fe0f 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -22,7 +22,7 @@ If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. Ubuntu]
- Python version [e.g. 3.8]
- - DeepTabular Version [e.g. 1.6.0]
+ - deeptab Version [e.g. 1.6.0]
**Additional context**
Add any other context about the problem here.
diff --git a/README.md b/README.md
index 70651f57..91c51465 100644
--- a/README.md
+++ b/README.md
@@ -2,24 +2,24 @@
-[](https://pypi.org/project/deeptabular)
-
-[](https://deeptabular.readthedocs.io/en/latest/?badge=latest)
-[](https://deeptabular.readthedocs.io/en/latest/)
-[](https://github.com/OpenTabular/DeepTabular/issues)
+[](https://pypi.org/project/deeptab)
+
+[](https://deeptab.readthedocs.io/en/latest/?badge=latest)
+[](https://deeptab.readthedocs.io/en/latest/)
+[](https://github.com/OpenTabular/deeptab/issues)
-[📘Documentation](https://deeptabular.readthedocs.io/en/latest/index.html) |
-[🛠️Installation](https://deeptabular.readthedocs.io/en/latest/installation.html) |
-[Models](https://deeptabular.readthedocs.io/en/latest/api/models/index.html) |
-[🤔Report Issues](https://github.com/OpenTabular/DeepTabular/issues)
+[📘Documentation](https://deeptab.readthedocs.io/en/latest/index.html) |
+[🛠️Installation](https://deeptab.readthedocs.io/en/latest/installation.html) |
+[Models](https://deeptab.readthedocs.io/en/latest/api/models/index.html) |
+[🤔Report Issues](https://github.com/OpenTabular/deeptab/issues)
-
DeepTabular: Tabular Deep Learning Made Simple
+ deeptab: Tabular Deep Learning Made Simple
-DeepTabular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
+deeptab is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
⚡ What's New ⚡
@@ -48,10 +48,10 @@ DeepTabular is a Python library for tabular deep learning. It includes models th
# 🏃 Quickstart
-Similar to any sklearn model, DeepTabular models can be fit as easy as this:
+Similar to any sklearn model, deeptab models can be fit as easy as this:
```python
-from deeptabular.models import MambularClassifier
+from deeptab.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier()
@@ -60,7 +60,7 @@ model.fit(X, y, max_epochs=150, lr=1e-04)
```
# 📖 Introduction
-DeepTabular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, DeepTabular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using DeepTabular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
+deeptab is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, deeptab models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using deeptab models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
# 🤖 Models
@@ -94,13 +94,13 @@ Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `
# 📚 Documentation
-You can find the DeepTabular API documentation [here](https://deeptabular.readthedocs.io/en/latest/).
+You can find the deeptab API documentation [here](https://deeptab.readthedocs.io/en/latest/).
# 🛠️ Installation
-Install DeepTabular using pip:
+Install deeptab using pip:
```sh
-pip install deeptabular
+pip install deeptab
```
If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via:
@@ -120,7 +120,7 @@ pip install mamba-ssm
Preprocessing
-DeepTabular uses pretab preprocessing: https://github.com/OpenTabular/PreTab
+deeptab uses pretab preprocessing: https://github.com/OpenTabular/PreTab
Hence, datatypes etc. are detected automatically and all preprocessing methods from pretab as well as from Sklearn.preprocessing are available.
Additionally, you can specify that each feature is preprocessed differently, according to your requirements, by setting the `feature_preprocessing={}`argument during model initialization.
@@ -144,10 +144,10 @@ For an overview over all available methods: [pretab](https://github.com/OpenTabu
Fit a Model
-Fitting a model in deeptabular is as simple as it gets. All models in deeptabular are sklearn BaseEstimators. Thus the `.fit` method is implemented for all of them. Additionally, this allows for using all other sklearn inherent methods such as their built in hyperparameter optimization tools.
+Fitting a model in deeptab is as simple as it gets. All models in deeptab are sklearn BaseEstimators. Thus the `.fit` method is implemented for all of them. Additionally, this allows for using all other sklearn inherent methods such as their built in hyperparameter optimization tools.
```python
-from deeptabular.models import MambularClassifier
+from deeptab.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier(
d_model=64,
@@ -243,12 +243,12 @@ Or use the built-in bayesian hpo simply by running:
best_params = model.optimize_hparams(X, y)
```
-This automatically sets the search space based on the default config from ``deeptabular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
+This automatically sets the search space based on the default config from ``deeptab.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
⚖️ Distributional Regression with MambularLSS
-MambularLSS allows you to model the full distribution of a response variable, not just its mean. This is crucial when understanding variability, skewness, or kurtosis is important. All DeepTabular models are available as distributional models.
+MambularLSS allows you to model the full distribution of a response variable, not just its mean. This is crucial when understanding variability, skewness, or kurtosis is important. All deeptab models are available as distributional models.
Key Features of MambularLSS:
@@ -277,10 +277,10 @@ These distribution classes make MambularLSS versatile in modeling various data t
Getting Started with MambularLSS:
-To integrate distributional regression into your workflow with `MambularLSS`, start by initializing the model with your desired configuration, similar to other DeepTabular models:
+To integrate distributional regression into your workflow with `MambularLSS`, start by initializing the model with your desired configuration, similar to other deeptab models:
```python
-from deeptabular.models import MambularLSS
+from deeptab.models import MambularLSS
# Initialize the MambularLSS model
model = MambularLSS(
@@ -305,18 +305,18 @@ model.fit(
# 💻 Implement Your Own Model
-DeepTabular allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from DeepTabular's `BaseModel`. Each DeepTabular model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
+deeptab allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from deeptab's `BaseModel`. Each deeptab model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
-One of the key advantages of using DeepTabular is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
+One of the key advantages of using deeptab is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
-Here's how you can implement a custom model with DeepTabular:
+Here's how you can implement a custom model with deeptab:
1. **First, define your config:**
The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass.
```python
from dataclasses import dataclass
- from deeptabular.configs import BaseConfig
+ from deeptab.configs import BaseConfig
@dataclass
class MyConfig(BaseConfig):
@@ -332,8 +332,8 @@ Here's how you can implement a custom model with DeepTabular:
Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.
```python
- from deeptabular.base_models.utils import BaseModel
- from deeptabular.utils.get_feature_dimensions import get_feature_dimensions
+ from deeptab.base_models.utils import BaseModel
+ from deeptab.utils.get_feature_dimensions import get_feature_dimensions
import torch
import torch.nn
@@ -372,11 +372,11 @@ Here's how you can implement a custom model with DeepTabular:
return output
```
-3. **Leverage the DeepTabular API:**
- You can build a regression, classification, or distributional regression model that can leverage all of DeepTabular's built-in methods by using the following:
+3. **Leverage the deeptab API:**
+ You can build a regression, classification, or distributional regression model that can leverage all of deeptab's built-in methods by using the following:
```python
- from deeptabular.models.utils import SklearnBaseRegressor
+ from deeptab.models.utils import SklearnBaseRegressor
class MyRegressor(SklearnBaseRegressor):
def __init__(self, **kwargs):
@@ -384,7 +384,7 @@ Here's how you can implement a custom model with DeepTabular:
```
4. **Train and evaluate your model:**
- You can now fit, evaluate, and predict with your custom model just like with any other DeepTabular model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively.
+ You can now fit, evaluate, and predict with your custom model just like with any other deeptab model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively.
```python
regressor = MyRegressor(numerical_preprocessing="ple")
diff --git a/deeptabular/__init__.py b/deeptab/__init__.py
similarity index 100%
rename from deeptabular/__init__.py
rename to deeptab/__init__.py
diff --git a/deeptabular/__version__.py b/deeptab/__version__.py
similarity index 100%
rename from deeptabular/__version__.py
rename to deeptab/__version__.py
diff --git a/deeptabular/arch_utils/__init__.py b/deeptab/arch_utils/__init__.py
similarity index 100%
rename from deeptabular/arch_utils/__init__.py
rename to deeptab/arch_utils/__init__.py
diff --git a/deeptabular/arch_utils/cnn_utils.py b/deeptab/arch_utils/cnn_utils.py
similarity index 100%
rename from deeptabular/arch_utils/cnn_utils.py
rename to deeptab/arch_utils/cnn_utils.py
diff --git a/deeptabular/arch_utils/data_aware_initialization.py b/deeptab/arch_utils/data_aware_initialization.py
similarity index 100%
rename from deeptabular/arch_utils/data_aware_initialization.py
rename to deeptab/arch_utils/data_aware_initialization.py
diff --git a/deeptabular/arch_utils/enode_utils.py b/deeptab/arch_utils/enode_utils.py
similarity index 99%
rename from deeptabular/arch_utils/enode_utils.py
rename to deeptab/arch_utils/enode_utils.py
index 4af760d1..d8116f17 100644
--- a/deeptabular/arch_utils/enode_utils.py
+++ b/deeptab/arch_utils/enode_utils.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from deeptabular.arch_utils.layer_utils.sparsemax import sparsemax, sparsemoid
+from deeptab.arch_utils.layer_utils.sparsemax import sparsemax, sparsemoid
from .data_aware_initialization import ModuleWithInit
from .numpy_utils import check_numpy
import numpy as np
diff --git a/deeptabular/arch_utils/get_norm_fn.py b/deeptab/arch_utils/get_norm_fn.py
similarity index 96%
rename from deeptabular/arch_utils/get_norm_fn.py
rename to deeptab/arch_utils/get_norm_fn.py
index e05fa06a..dfcbfcd1 100644
--- a/deeptabular/arch_utils/get_norm_fn.py
+++ b/deeptab/arch_utils/get_norm_fn.py
@@ -1,49 +1,49 @@
-from .layer_utils.normalization_layers import (
- BatchNorm,
- GroupNorm,
- InstanceNorm,
- LayerNorm,
- LearnableLayerScaling,
- RMSNorm,
-)
-
-
-def get_normalization_layer(config):
- """Function to return the appropriate normalization layer based on the configuration.
-
- Parameters:
- -----------
- config : DefaultMambularConfig
- Configuration object containing the parameters for the model including normalization.
-
- Returns:
- --------
- nn.Module:
- The normalization layer as per the config.
-
- Raises:
- -------
- ValueError:
- If an unsupported normalization layer is specified in the config.
- """
-
- norm_layer = getattr(config, "norm", None)
- d_model = getattr(config, "d_model", 128)
- layer_norm_eps = getattr(config, "layer_norm_eps", 1e-05)
-
- if norm_layer == "RMSNorm":
- return RMSNorm(d_model, eps=layer_norm_eps)
- elif norm_layer == "LayerNorm":
- return LayerNorm(d_model, eps=layer_norm_eps)
- elif norm_layer == "BatchNorm":
- return BatchNorm(d_model, eps=layer_norm_eps)
- elif norm_layer == "InstanceNorm":
- return InstanceNorm(d_model, eps=layer_norm_eps)
- elif norm_layer == "GroupNorm":
- return GroupNorm(1, d_model, eps=layer_norm_eps)
- elif norm_layer == "LearnableLayerScaling":
- return LearnableLayerScaling(d_model)
- elif norm_layer is None:
- return None
- else:
- raise ValueError(f"Unsupported normalization layer: {norm_layer}")
+from .layer_utils.normalization_layers import (
+ BatchNorm,
+ GroupNorm,
+ InstanceNorm,
+ LayerNorm,
+ LearnableLayerScaling,
+ RMSNorm,
+)
+
+
+def get_normalization_layer(config):
+ """Function to return the appropriate normalization layer based on the configuration.
+
+ Parameters:
+ -----------
+ config : DefaultMambularConfig
+ Configuration object containing the parameters for the model including normalization.
+
+ Returns:
+ --------
+ nn.Module:
+ The normalization layer as per the config.
+
+ Raises:
+ -------
+ ValueError:
+ If an unsupported normalization layer is specified in the config.
+ """
+
+ norm_layer = getattr(config, "norm", None)
+ d_model = getattr(config, "d_model", 128)
+ layer_norm_eps = getattr(config, "layer_norm_eps", 1e-05)
+
+ if norm_layer == "RMSNorm":
+ return RMSNorm(d_model, eps=layer_norm_eps)
+ elif norm_layer == "LayerNorm":
+ return LayerNorm(d_model, eps=layer_norm_eps)
+ elif norm_layer == "BatchNorm":
+ return BatchNorm(d_model, eps=layer_norm_eps)
+ elif norm_layer == "InstanceNorm":
+ return InstanceNorm(d_model, eps=layer_norm_eps)
+ elif norm_layer == "GroupNorm":
+ return GroupNorm(1, d_model, eps=layer_norm_eps)
+ elif norm_layer == "LearnableLayerScaling":
+ return LearnableLayerScaling(d_model)
+ elif norm_layer is None:
+ return None
+ else:
+ raise ValueError(f"Unsupported normalization layer: {norm_layer}")
diff --git a/deeptabular/arch_utils/layer_utils/__init__.py b/deeptab/arch_utils/layer_utils/__init__.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/__init__.py
rename to deeptab/arch_utils/layer_utils/__init__.py
diff --git a/deeptabular/arch_utils/layer_utils/attention_net_arch_utils.py b/deeptab/arch_utils/layer_utils/attention_net_arch_utils.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/attention_net_arch_utils.py
rename to deeptab/arch_utils/layer_utils/attention_net_arch_utils.py
diff --git a/deeptabular/arch_utils/layer_utils/attention_utils.py b/deeptab/arch_utils/layer_utils/attention_utils.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/attention_utils.py
rename to deeptab/arch_utils/layer_utils/attention_utils.py
diff --git a/deeptabular/arch_utils/layer_utils/batch_ensemble_layer.py b/deeptab/arch_utils/layer_utils/batch_ensemble_layer.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/batch_ensemble_layer.py
rename to deeptab/arch_utils/layer_utils/batch_ensemble_layer.py
diff --git a/deeptabular/arch_utils/layer_utils/block_diagonal.py b/deeptab/arch_utils/layer_utils/block_diagonal.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/block_diagonal.py
rename to deeptab/arch_utils/layer_utils/block_diagonal.py
diff --git a/deeptabular/arch_utils/layer_utils/embedding_layer.py b/deeptab/arch_utils/layer_utils/embedding_layer.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/embedding_layer.py
rename to deeptab/arch_utils/layer_utils/embedding_layer.py
diff --git a/deeptabular/arch_utils/layer_utils/embedding_tree.py b/deeptab/arch_utils/layer_utils/embedding_tree.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/embedding_tree.py
rename to deeptab/arch_utils/layer_utils/embedding_tree.py
diff --git a/deeptabular/arch_utils/layer_utils/importance.py b/deeptab/arch_utils/layer_utils/importance.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/importance.py
rename to deeptab/arch_utils/layer_utils/importance.py
diff --git a/deeptabular/arch_utils/layer_utils/invariance_layer.py b/deeptab/arch_utils/layer_utils/invariance_layer.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/invariance_layer.py
rename to deeptab/arch_utils/layer_utils/invariance_layer.py
diff --git a/deeptabular/arch_utils/layer_utils/normalization_layers.py b/deeptab/arch_utils/layer_utils/normalization_layers.py
similarity index 97%
rename from deeptabular/arch_utils/layer_utils/normalization_layers.py
rename to deeptab/arch_utils/layer_utils/normalization_layers.py
index 0877fbb8..9f09aba0 100644
--- a/deeptabular/arch_utils/layer_utils/normalization_layers.py
+++ b/deeptab/arch_utils/layer_utils/normalization_layers.py
@@ -1,149 +1,149 @@
-import torch
-import torch.nn as nn
-
-
-class RMSNorm(nn.Module):
- """Root Mean Square normalization layer.
-
- Attributes:
- d_model (int): The dimensionality of the input and output tensors.
- eps (float): Small value to avoid division by zero.
- weight (nn.Parameter): Learnable parameter for scaling.
- """
-
- def __init__(self, d_model: int, eps: float = 1e-5):
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(d_model))
-
- def forward(self, x):
- output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
-
- return output
-
-
-class LayerNorm(nn.Module):
- """Layer normalization layer.
-
- Attributes:
- d_model (int): The dimensionality of the input and output tensors.
- eps (float): Small value to avoid division by zero.
- weight (nn.Parameter): Learnable parameter for scaling.
- bias (nn.Parameter): Learnable parameter for shifting.
- """
-
- def __init__(self, d_model: int, eps: float = 1e-5):
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(d_model))
- self.bias = nn.Parameter(torch.zeros(d_model))
-
- def forward(self, x):
- mean = x.mean(dim=-1, keepdim=True)
- std = x.std(dim=-1, keepdim=True)
- output = (x - mean) / (std + self.eps)
- output = output * self.weight + self.bias
- return output
-
-
-class BatchNorm(nn.Module):
- """Batch normalization layer.
-
- Attributes:
- d_model (int): The dimensionality of the input and output tensors.
- eps (float): Small value to avoid division by zero.
- momentum (float): The value used for the running mean and variance computation.
- """
-
- def __init__(self, d_model: int, eps: float = 1e-5, momentum: float = 0.1):
- super().__init__()
- self.d_model = d_model
- self.eps = eps
- self.momentum = momentum
- self.register_buffer("running_mean", torch.zeros(d_model))
- self.register_buffer("running_var", torch.ones(d_model))
- self.weight = nn.Parameter(torch.ones(d_model))
- self.bias = nn.Parameter(torch.zeros(d_model))
-
- def forward(self, x):
- if self.training:
- mean = x.mean(dim=0)
- # Use unbiased=False for consistency with BatchNorm
- var = x.var(dim=0, unbiased=False)
- # Update running stats in-place
- self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
- self.running_var.mul_(1 - self.momentum).add_(self.momentum * var)
- else:
- mean = self.running_mean
- var = self.running_var
- output = (x - mean) / torch.sqrt(var + self.eps)
- output = output * self.weight + self.bias
- return output
-
-
-class InstanceNorm(nn.Module):
- """Instance normalization layer.
-
- Attributes:
- d_model (int): The dimensionality of the input and output tensors.
- eps (float): Small value to avoid division by zero.
- """
-
- def __init__(self, d_model: int, eps: float = 1e-5):
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(d_model))
- self.bias = nn.Parameter(torch.zeros(d_model))
-
- def forward(self, x):
- mean = x.mean(dim=(2, 3), keepdim=True)
- var = x.var(dim=(2, 3), keepdim=True)
- output = (x - mean) / torch.sqrt(var + self.eps)
- output = output * self.weight.unsqueeze(0).unsqueeze(2) + self.bias.unsqueeze(0).unsqueeze(2)
- return output
-
-
-class GroupNorm(nn.Module):
- """Group normalization layer.
-
- Attributes:
- num_groups (int): Number of groups to separate the channels into.
- d_model (int): The dimensionality of the input and output tensors.
- eps (float): Small value to avoid division by zero.
- """
-
- def __init__(self, num_groups: int, d_model: int, eps: float = 1e-5):
- super().__init__()
- self.num_groups = num_groups
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(d_model))
- self.bias = nn.Parameter(torch.zeros(d_model))
-
- def forward(self, x):
- b, c, h, w = x.size()
- x = x.view(b, self.num_groups, -1)
- mean = x.mean(dim=-1, keepdim=True)
- var = x.var(dim=-1, keepdim=True)
- output = (x - mean) / torch.sqrt(var + self.eps)
- output = output.view(b, c, h, w)
- output = output * self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) + self.bias.unsqueeze(0).unsqueeze(
- 2
- ).unsqueeze(3)
- return output
-
-
-class LearnableLayerScaling(nn.Module):
- """Learnable Layer Scaling (LLS) normalization layer.
-
- Attributes:
- d_model (int): The dimensionality of the input and output tensors.
- """
-
- def __init__(self, d_model: int):
- """Initialize LLS normalization layer."""
- super().__init__()
- self.weight = nn.Parameter(torch.ones(d_model))
-
- def forward(self, x):
- output = x * self.weight.unsqueeze(0)
- return output
+import torch
+import torch.nn as nn
+
+
+class RMSNorm(nn.Module):
+ """Root Mean Square normalization layer.
+
+ Attributes:
+ d_model (int): The dimensionality of the input and output tensors.
+ eps (float): Small value to avoid division by zero.
+ weight (nn.Parameter): Learnable parameter for scaling.
+ """
+
+ def __init__(self, d_model: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(d_model))
+
+ def forward(self, x):
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
+
+ return output
+
+
+class LayerNorm(nn.Module):
+ """Layer normalization layer.
+
+ Attributes:
+ d_model (int): The dimensionality of the input and output tensors.
+ eps (float): Small value to avoid division by zero.
+ weight (nn.Parameter): Learnable parameter for scaling.
+ bias (nn.Parameter): Learnable parameter for shifting.
+ """
+
+ def __init__(self, d_model: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(d_model))
+ self.bias = nn.Parameter(torch.zeros(d_model))
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdim=True)
+ std = x.std(dim=-1, keepdim=True)
+ output = (x - mean) / (std + self.eps)
+ output = output * self.weight + self.bias
+ return output
+
+
+class BatchNorm(nn.Module):
+ """Batch normalization layer.
+
+ Attributes:
+ d_model (int): The dimensionality of the input and output tensors.
+ eps (float): Small value to avoid division by zero.
+ momentum (float): The value used for the running mean and variance computation.
+ """
+
+ def __init__(self, d_model: int, eps: float = 1e-5, momentum: float = 0.1):
+ super().__init__()
+ self.d_model = d_model
+ self.eps = eps
+ self.momentum = momentum
+ self.register_buffer("running_mean", torch.zeros(d_model))
+ self.register_buffer("running_var", torch.ones(d_model))
+ self.weight = nn.Parameter(torch.ones(d_model))
+ self.bias = nn.Parameter(torch.zeros(d_model))
+
+ def forward(self, x):
+ if self.training:
+ mean = x.mean(dim=0)
+ # Use unbiased=False for consistency with BatchNorm
+ var = x.var(dim=0, unbiased=False)
+ # Update running stats in-place
+ self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
+ self.running_var.mul_(1 - self.momentum).add_(self.momentum * var)
+ else:
+ mean = self.running_mean
+ var = self.running_var
+ output = (x - mean) / torch.sqrt(var + self.eps)
+ output = output * self.weight + self.bias
+ return output
+
+
+class InstanceNorm(nn.Module):
+ """Instance normalization layer.
+
+ Attributes:
+ d_model (int): The dimensionality of the input and output tensors.
+ eps (float): Small value to avoid division by zero.
+ """
+
+ def __init__(self, d_model: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(d_model))
+ self.bias = nn.Parameter(torch.zeros(d_model))
+
+ def forward(self, x):
+ mean = x.mean(dim=(2, 3), keepdim=True)
+ var = x.var(dim=(2, 3), keepdim=True)
+ output = (x - mean) / torch.sqrt(var + self.eps)
+ output = output * self.weight.unsqueeze(0).unsqueeze(2) + self.bias.unsqueeze(0).unsqueeze(2)
+ return output
+
+
+class GroupNorm(nn.Module):
+ """Group normalization layer.
+
+ Attributes:
+ num_groups (int): Number of groups to separate the channels into.
+ d_model (int): The dimensionality of the input and output tensors.
+ eps (float): Small value to avoid division by zero.
+ """
+
+ def __init__(self, num_groups: int, d_model: int, eps: float = 1e-5):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(d_model))
+ self.bias = nn.Parameter(torch.zeros(d_model))
+
+ def forward(self, x):
+ b, c, h, w = x.size()
+ x = x.view(b, self.num_groups, -1)
+ mean = x.mean(dim=-1, keepdim=True)
+ var = x.var(dim=-1, keepdim=True)
+ output = (x - mean) / torch.sqrt(var + self.eps)
+ output = output.view(b, c, h, w)
+ output = output * self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) + self.bias.unsqueeze(0).unsqueeze(
+ 2
+ ).unsqueeze(3)
+ return output
+
+
+class LearnableLayerScaling(nn.Module):
+ """Learnable Layer Scaling (LLS) normalization layer.
+
+ Attributes:
+ d_model (int): The dimensionality of the input and output tensors.
+ """
+
+ def __init__(self, d_model: int):
+ """Initialize LLS normalization layer."""
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(d_model))
+
+ def forward(self, x):
+ output = x * self.weight.unsqueeze(0)
+ return output
diff --git a/deeptabular/arch_utils/layer_utils/plr_layer.py b/deeptab/arch_utils/layer_utils/plr_layer.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/plr_layer.py
rename to deeptab/arch_utils/layer_utils/plr_layer.py
diff --git a/deeptabular/arch_utils/layer_utils/poly_layer.py b/deeptab/arch_utils/layer_utils/poly_layer.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/poly_layer.py
rename to deeptab/arch_utils/layer_utils/poly_layer.py
diff --git a/deeptabular/arch_utils/layer_utils/rotary_utils.py b/deeptab/arch_utils/layer_utils/rotary_utils.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/rotary_utils.py
rename to deeptab/arch_utils/layer_utils/rotary_utils.py
diff --git a/deeptabular/arch_utils/layer_utils/sn_linear.py b/deeptab/arch_utils/layer_utils/sn_linear.py
similarity index 97%
rename from deeptabular/arch_utils/layer_utils/sn_linear.py
rename to deeptab/arch_utils/layer_utils/sn_linear.py
index 429f621e..b775ccd2 100644
--- a/deeptabular/arch_utils/layer_utils/sn_linear.py
+++ b/deeptab/arch_utils/layer_utils/sn_linear.py
@@ -1,27 +1,27 @@
-import torch
-import torch.nn as nn
-from torch.nn.parameter import Parameter
-
-
-class SNLinear(nn.Module):
- """Separate linear layers for each feature embedding."""
-
- def __init__(self, n: int, in_features: int, out_features: int) -> None:
- super().__init__()
- self.weight = Parameter(torch.empty(n, in_features, out_features))
- self.bias = Parameter(torch.empty(n, out_features))
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- d_in_rsqrt = self.weight.shape[-2] ** -0.5
- nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt)
- nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)
-
- def forward(self, x):
- if x.ndim != 3:
- raise ValueError("SNLinear requires a 3D input (batch, features, embedding).")
- if x.shape[-(self.weight.ndim - 1) :] != self.weight.shape[:-1]:
- raise ValueError("Input shape mismatch with weight dimensions.")
-
- x = x.transpose(0, 1) @ self.weight
- return x.transpose(0, 1) + self.bias
+import torch
+import torch.nn as nn
+from torch.nn.parameter import Parameter
+
+
+class SNLinear(nn.Module):
+ """Separate linear layers for each feature embedding."""
+
+ def __init__(self, n: int, in_features: int, out_features: int) -> None:
+ super().__init__()
+ self.weight = Parameter(torch.empty(n, in_features, out_features))
+ self.bias = Parameter(torch.empty(n, out_features))
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ d_in_rsqrt = self.weight.shape[-2] ** -0.5
+ nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt)
+ nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)
+
+ def forward(self, x):
+ if x.ndim != 3:
+ raise ValueError("SNLinear requires a 3D input (batch, features, embedding).")
+ if x.shape[-(self.weight.ndim - 1) :] != self.weight.shape[:-1]:
+ raise ValueError("Input shape mismatch with weight dimensions.")
+
+ x = x.transpose(0, 1) @ self.weight
+ return x.transpose(0, 1) + self.bias
diff --git a/deeptabular/arch_utils/layer_utils/sparsemax.py b/deeptab/arch_utils/layer_utils/sparsemax.py
similarity index 100%
rename from deeptabular/arch_utils/layer_utils/sparsemax.py
rename to deeptab/arch_utils/layer_utils/sparsemax.py
diff --git a/deeptabular/arch_utils/learnable_ple.py b/deeptab/arch_utils/learnable_ple.py
similarity index 100%
rename from deeptabular/arch_utils/learnable_ple.py
rename to deeptab/arch_utils/learnable_ple.py
diff --git a/deeptabular/arch_utils/lstm_utils.py b/deeptab/arch_utils/lstm_utils.py
similarity index 100%
rename from deeptabular/arch_utils/lstm_utils.py
rename to deeptab/arch_utils/lstm_utils.py
diff --git a/deeptabular/arch_utils/mamba_utils/__init__.py b/deeptab/arch_utils/mamba_utils/__init__.py
similarity index 100%
rename from deeptabular/arch_utils/mamba_utils/__init__.py
rename to deeptab/arch_utils/mamba_utils/__init__.py
diff --git a/deeptabular/arch_utils/mamba_utils/init_weights.py b/deeptab/arch_utils/mamba_utils/init_weights.py
similarity index 96%
rename from deeptabular/arch_utils/mamba_utils/init_weights.py
rename to deeptab/arch_utils/mamba_utils/init_weights.py
index 31dc5f65..767d4214 100644
--- a/deeptabular/arch_utils/mamba_utils/init_weights.py
+++ b/deeptab/arch_utils/mamba_utils/init_weights.py
@@ -1,28 +1,28 @@
-import math
-
-import torch
-import torch.nn as nn
-
-# taken from https://github.com/state-spaces/mamba
-
-
-def _init_weights(
- module,
- n_layer,
- initializer_range=0.02, # Now only used for embedding layer.
- rescale_prenorm_residual=True,
- n_residuals_per_layer=1, # Change to 2 if we have MLP
-):
- if isinstance(module, nn.Linear):
- if module.bias is not None:
- if not getattr(module.bias, "_no_reinit", False):
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.Embedding):
- nn.init.normal_(module.weight, std=initializer_range)
-
- if rescale_prenorm_residual:
- for name, p in module.named_parameters():
- if name in ["out_proj.weight", "fc2.weight"]:
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
- with torch.no_grad():
- p /= math.sqrt(n_residuals_per_layer * n_layer)
+import math
+
+import torch
+import torch.nn as nn
+
+# taken from https://github.com/state-spaces/mamba
+
+
+def _init_weights(
+ module,
+ n_layer,
+ initializer_range=0.02, # Now only used for embedding layer.
+ rescale_prenorm_residual=True,
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
+):
+ if isinstance(module, nn.Linear):
+ if module.bias is not None:
+ if not getattr(module.bias, "_no_reinit", False):
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ nn.init.normal_(module.weight, std=initializer_range)
+
+ if rescale_prenorm_residual:
+ for name, p in module.named_parameters():
+ if name in ["out_proj.weight", "fc2.weight"]:
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
+ with torch.no_grad():
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
diff --git a/deeptabular/arch_utils/mamba_utils/mamba_arch.py b/deeptab/arch_utils/mamba_utils/mamba_arch.py
similarity index 97%
rename from deeptabular/arch_utils/mamba_utils/mamba_arch.py
rename to deeptab/arch_utils/mamba_utils/mamba_arch.py
index 0293c15d..afe78662 100644
--- a/deeptabular/arch_utils/mamba_utils/mamba_arch.py
+++ b/deeptab/arch_utils/mamba_utils/mamba_arch.py
@@ -1,551 +1,551 @@
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from ..get_norm_fn import get_normalization_layer
-from ..layer_utils.normalization_layers import LayerNorm, LearnableLayerScaling, RMSNorm
-
-# Heavily inspired and mostly taken from https://github.com/alxndrTL/mamba.py
-
-
-class Mamba(nn.Module):
- """Mamba model composed of multiple MambaBlocks.
-
- Attributes:
- config (MambaConfig): Configuration object for the Mamba model.
- layers (nn.ModuleList): List of MambaBlocks constituting the model.
- """
-
- def __init__(
- self,
- config,
- ):
- super().__init__()
-
- self.layers = nn.ModuleList(
- [
- ResidualBlock(
- d_model=getattr(config, "d_model", 128),
- expand_factor=getattr(config, "expand_factor", 4),
- bias=getattr(config, "bias", True),
- d_conv=getattr(config, "d_conv", 4),
- conv_bias=getattr(config, "conv_bias", False),
- dropout=getattr(config, "dropout", 0.0),
- dt_rank=getattr(config, "dt_rank", "auto"),
- d_state=getattr(config, "d_state", 256),
- dt_scale=getattr(config, "dt_scale", 1.0),
- dt_init=getattr(config, "dt_init", "random"),
- dt_max=getattr(config, "dt_max", 0.1),
- dt_min=getattr(config, "dt_min", 1e-04),
- dt_init_floor=getattr(config, "dt_init_floor", 1e-04),
- norm=get_normalization_layer(config), # type: ignore
- activation=getattr(config, "activation", nn.SiLU()),
- bidirectional=getattr(config, "bidirectional", False),
- use_learnable_interaction=getattr(
- config, "use_learnable_interaction", False
- ),
- layer_norm_eps=getattr(config, "layer_norm_eps", 1e-5),
- AD_weight_decay=getattr(config, "AD_weight_decay", True),
- BC_layer_norm=getattr(config, "BC_layer_norm", False),
- use_pscan=getattr(config, "use_pscan", False),
- dilation=getattr(config, "dilation", 1),
- )
- for _ in range(getattr(config, "n_layers", 6))
- ]
- )
-
- def forward(self, x):
- for layer in self.layers:
- x = layer(x)
-
- return x
-
-
-class ResidualBlock(nn.Module):
- """Residual block composed of a MambaBlock and a normalization layer.
-
- Parameters
- ----------
- d_model : int, optional
- Dimension of the model input, by default 32.
- expand_factor : int, optional
- Expansion factor for the model, by default 2.
- bias : bool, optional
- Whether to use bias in the MambaBlock, by default False.
- d_conv : int, optional
- Dimension of the convolution layer in the MambaBlock, by default 16.
- conv_bias : bool, optional
- Whether to use bias in the convolution layer, by default True.
- dropout : float, optional
- Dropout rate for the layers, by default 0.01.
- dt_rank : Union[str, int], optional
- Rank for dynamic time components, 'auto' or an integer, by default 'auto'.
- d_state : int, optional
- Dimension of the state vector, by default 32.
- dt_scale : float, optional
- Scale factor for dynamic time components, by default 1.0.
- dt_init : str, optional
- Initialization strategy for dynamic time components, by default 'random'.
- dt_max : float, optional
- Maximum value for dynamic time components, by default 0.1.
- dt_min : float, optional
- Minimum value for dynamic time components, by default 1e-03.
- dt_init_floor : float, optional
- Floor value for initialization of dynamic time components, by default 1e-04.
- norm : callable, optional
- Normalization layer, by default RMSNorm.
- activation : callable, optional
- Activation function used in the MambaBlock, by default `F.silu`.
- bidirectional : bool, optional
- Whether the block is bidirectional, by default False.
- use_learnable_interaction : bool, optional
- Whether to use learnable interactions, by default False.
- layer_norm_eps : float, optional
- Epsilon for layer normalization, by default 1e-05.
- AD_weight_decay : bool, optional
- Whether to apply weight decay in adaptive dynamics, by default False.
- BC_layer_norm : bool, optional
- Whether to use layer normalization for batch compatibility, by default False.
- use_pscan : bool, optional
- Whether to use PSCAN, by default False.
-
- Attributes
- ----------
- layers : MambaBlock
- The main MambaBlock layers for processing input.
- norm : callable
- Normalization layer applied before the MambaBlock.
-
- Methods
- -------
- forward(x)
- Performs a forward pass through the block and returns the output.
-
- Raises
- ------
- ValueError
- If the provided normalization layer is not valid.
- """
-
- def __init__(
- self,
- d_model=32,
- expand_factor=2,
- bias=False,
- d_conv=16,
- conv_bias=True,
- dropout=0.01,
- dt_rank="auto",
- d_state=32,
- dt_scale=1.0,
- dt_init="random",
- dt_max=0.1,
- dt_min=1e-03,
- dt_init_floor=1e-04,
- norm=RMSNorm,
- activation=F.silu,
- bidirectional=False,
- use_learnable_interaction=False,
- layer_norm_eps=1e-05,
- AD_weight_decay=False,
- BC_layer_norm=False,
- use_pscan=False,
- dilation=1,
- ):
- super().__init__()
-
- VALID_NORMALIZATION_LAYERS = {
- "RMSNorm": RMSNorm,
- "LayerNorm": LayerNorm,
- "LearnableLayerScaling": LearnableLayerScaling,
- }
-
- # Check if the provided normalization layer is valid
- if isinstance(norm, type) and norm.__name__ not in VALID_NORMALIZATION_LAYERS:
- raise ValueError(
- f"Invalid normalization layer: {norm.__name__}. "
- f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
- )
- elif isinstance(norm, str) and norm not in VALID_NORMALIZATION_LAYERS:
- raise ValueError(
- f"Invalid normalization layer: {norm}. "
- f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
- )
-
- if dt_rank == "auto":
- dt_rank = math.ceil(d_model / 16)
-
- self.layers = MambaBlock(
- d_model=d_model,
- expand_factor=expand_factor,
- bias=bias,
- d_conv=d_conv,
- conv_bias=conv_bias,
- dropout=dropout,
- dt_rank=dt_rank, # type: ignore
- d_state=d_state,
- dt_scale=dt_scale,
- dt_init=dt_init,
- dt_max=dt_max,
- dt_min=dt_min,
- dt_init_floor=dt_init_floor,
- activation=activation,
- bidirectional=bidirectional,
- use_learnable_interaction=use_learnable_interaction,
- layer_norm_eps=layer_norm_eps,
- AD_weight_decay=AD_weight_decay,
- BC_layer_norm=BC_layer_norm,
- use_pscan=use_pscan,
- dilation=dilation,
- )
- self.norm = norm
-
- def forward(self, x):
- """Forward pass through the residual block.
-
- Parameters
- ----------
- x : torch.Tensor
- Input tensor to the block.
-
- Returns
- -------
- torch.Tensor
- Output tensor after applying the residual connection and MambaBlock.
- """
- output = self.layers(self.norm(x)) + x
- return output
-
-
-class MambaBlock(nn.Module):
- """MambaBlock module containing the main computational components for processing input.
-
- Parameters
- ----------
- d_model : int, optional
- Dimension of the model input, by default 32.
- expand_factor : int, optional
- Factor by which the input is expanded in the block, by default 2.
- bias : bool, optional
- Whether to use bias in the linear projections, by default False.
- d_conv : int, optional
- Dimension of the convolution layer, by default 16.
- conv_bias : bool, optional
- Whether to use bias in the convolution layer, by default True.
- dropout : float, optional
- Dropout rate applied to the layers, by default 0.01.
- dt_rank : Union[str, int], optional
- Rank for dynamic time components, either 'auto' or an integer, by default 'auto'.
- d_state : int, optional
- Dimensionality of the state vector, by default 32.
- dt_scale : float, optional
- Scale factor applied to the dynamic time component, by default 1.0.
- dt_init : str, optional
- Initialization strategy for the dynamic time component, by default 'random'.
- dt_max : float, optional
- Maximum value for dynamic time component initialization, by default 0.1.
- dt_min : float, optional
- Minimum value for dynamic time component initialization, by default 1e-03.
- dt_init_floor : float, optional
- Floor value for dynamic time component initialization, by default 1e-04.
- activation : callable, optional
- Activation function applied in the block, by default `F.silu`.
- bidirectional : bool, optional
- Whether the block is bidirectional, by default False.
- use_learnable_interaction : bool, optional
- Whether to use learnable feature interaction, by default False.
- layer_norm_eps : float, optional
- Epsilon for layer normalization, by default 1e-05.
- AD_weight_decay : bool, optional
- Whether to apply weight decay in adaptive dynamics, by default False.
- BC_layer_norm : bool, optional
- Whether to use layer normalization for batch compatibility, by default False.
- use_pscan : bool, optional
- Whether to use the PSCAN mechanism, by default False.
-
- Attributes
- ----------
- in_proj : nn.Linear
- Linear projection applied to the input tensor.
- conv1d : nn.Conv1d
- 1D convolutional layer for processing input.
- x_proj : nn.Linear
- Linear projection applied to input-dependent tensors.
- dt_proj : nn.Linear
- Linear projection for the dynamical time component.
- A_log : nn.Parameter
- Logarithmically stored tensor A for internal dynamics.
- D : nn.Parameter
- Tensor for the D component of the model's dynamics.
- out_proj : nn.Linear
- Linear projection applied to the output.
- learnable_interaction : LearnableFeatureInteraction
- Layer for learnable feature interactions, if `use_learnable_interaction` is True.
-
- Methods
- -------
- forward(x)
- Performs a forward pass through the MambaBlock.
- """
-
- def __init__(
- self,
- d_model=32,
- expand_factor=2,
- bias=False,
- d_conv=16,
- conv_bias=True,
- dropout=0.01,
- dt_rank="auto",
- d_state=32,
- dt_scale=1.0,
- dt_init="random",
- dt_max=0.1,
- dt_min=1e-03,
- dt_init_floor=1e-04,
- activation=F.silu,
- bidirectional=False,
- use_learnable_interaction=False,
- layer_norm_eps=1e-05,
- AD_weight_decay=False,
- BC_layer_norm=False,
- use_pscan=False,
- dilation=1,
- ):
- super().__init__()
-
- self.use_pscan = use_pscan
-
- if self.use_pscan:
- try:
- from mambapy.pscan import pscan # type: ignore
-
- self.pscan = pscan # Store the imported pscan function
- except ImportError:
- self.pscan = None # Set to None if pscan is not available
- print(
- "The 'mambapy' package is not installed. Please install it by running:\n"
- "pip install mambapy"
- )
- else:
- self.pscan = None
-
- self.d_inner = d_model * expand_factor
- self.bidirectional = bidirectional
- self.use_learnable_interaction = use_learnable_interaction
-
- self.in_proj_fwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
- if self.bidirectional:
- self.in_proj_bwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
-
- self.conv1d_fwd = nn.Conv1d(
- in_channels=self.d_inner,
- out_channels=self.d_inner,
- kernel_size=d_conv,
- bias=conv_bias,
- groups=self.d_inner,
- padding=d_conv - 1,
- )
- if self.bidirectional:
- self.conv1d_bwd = nn.Conv1d(
- in_channels=self.d_inner,
- out_channels=self.d_inner,
- kernel_size=d_conv,
- bias=conv_bias,
- groups=self.d_inner,
- padding=d_conv - 1,
- dilation=dilation,
- )
-
- self.dropout = nn.Dropout(dropout)
- self.activation = activation
-
- if self.use_learnable_interaction:
- self.learnable_interaction = LearnableFeatureInteraction(self.d_inner)
-
- self.x_proj_fwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False) # type: ignore
- if self.bidirectional:
- self.x_proj_bwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False) # type: ignore
-
- self.dt_proj_fwd = nn.Linear(dt_rank, self.d_inner, bias=True) # type: ignore
- if self.bidirectional:
- self.dt_proj_bwd = nn.Linear(dt_rank, self.d_inner, bias=True) # type: ignore
-
- dt_init_std = dt_rank**-0.5 * dt_scale # type: ignore
- if dt_init == "constant":
- nn.init.constant_(self.dt_proj_fwd.weight, dt_init_std)
- if self.bidirectional:
- nn.init.constant_(self.dt_proj_bwd.weight, dt_init_std)
- elif dt_init == "random":
- nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
- if self.bidirectional:
- nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
- else:
- raise NotImplementedError
-
- dt_fwd = torch.exp(
- torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
- + math.log(dt_min)
- ).clamp(min=dt_init_floor)
- inv_dt_fwd = dt_fwd + torch.log(-torch.expm1(-dt_fwd))
- with torch.no_grad():
- self.dt_proj_fwd.bias.copy_(inv_dt_fwd)
-
- if self.bidirectional:
- dt_bwd = torch.exp(
- torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
- + math.log(dt_min)
- ).clamp(min=dt_init_floor)
- inv_dt_bwd = dt_bwd + torch.log(-torch.expm1(-dt_bwd))
- with torch.no_grad():
- self.dt_proj_bwd.bias.copy_(inv_dt_bwd)
-
- A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
- self.A_log_fwd = nn.Parameter(torch.log(A))
- self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
-
- if self.bidirectional:
- self.A_log_bwd = nn.Parameter(torch.log(A))
- self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
-
- if not AD_weight_decay:
- self.A_log_fwd._no_weight_decay = True # type: ignore
- self.D_fwd._no_weight_decay = True # type: ignore
-
- if self.bidirectional:
- if not AD_weight_decay:
- self.A_log_bwd._no_weight_decay = True # type: ignore
- self.D_bwd._no_weight_decay = True # type: ignore
-
- self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
- self.dt_rank = dt_rank
- self.d_state = d_state
-
- if BC_layer_norm:
- self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps) # type: ignore
- self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
- self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
- else:
- self.dt_layernorm = None
- self.B_layernorm = None
- self.C_layernorm = None
-
- def forward(self, x):
- _, L, _ = x.shape
-
- xz_fwd = self.in_proj_fwd(x)
- x_fwd, z_fwd = xz_fwd.chunk(2, dim=-1)
-
- x_fwd = x_fwd.transpose(1, 2)
- x_fwd = self.conv1d_fwd(x_fwd)[:, :, :L]
- x_fwd = x_fwd.transpose(1, 2)
-
- if self.bidirectional:
- xz_bwd = self.in_proj_bwd(x)
- x_bwd, z_bwd = xz_bwd.chunk(2, dim=-1)
-
- x_bwd = x_bwd.transpose(1, 2)
- x_bwd = self.conv1d_bwd(x_bwd)[:, :, :L]
- x_bwd = x_bwd.transpose(1, 2)
-
- if self.use_learnable_interaction:
- x_fwd = self.learnable_interaction(x_fwd)
- if self.bidirectional:
- x_bwd = self.learnable_interaction(x_bwd) # type: ignore
-
- x_fwd = self.activation(x_fwd)
- x_fwd = self.dropout(x_fwd)
- y_fwd = self.ssm(x_fwd, forward=True)
-
- if self.bidirectional:
- x_bwd = self.activation(x_bwd) # type: ignore
- x_bwd = self.dropout(x_bwd)
- y_bwd = self.ssm(torch.flip(x_bwd, [1]), forward=False)
- y = y_fwd + torch.flip(y_bwd, [1])
- y = y / 2
- else:
- y = y_fwd
-
- z_fwd = self.activation(z_fwd)
- z_fwd = self.dropout(z_fwd)
-
- output = y * z_fwd
- output = self.out_proj(output)
-
- return output
-
- def _apply_layernorms(self, dt, B, C):
- if self.dt_layernorm is not None:
- dt = self.dt_layernorm(dt)
- if self.B_layernorm is not None:
- B = self.B_layernorm(B)
- if self.C_layernorm is not None:
- C = self.C_layernorm(C)
- return dt, B, C
-
- def ssm(self, x, forward=True):
- if forward:
- A = -torch.exp(self.A_log_fwd.float())
- D = self.D_fwd.float()
- deltaBC = self.x_proj_fwd(x)
- delta, B, C = torch.split(
- deltaBC,
- [self.dt_rank, self.d_state, self.d_state], # type: ignore
- dim=-1,
- )
- delta, B, C = self._apply_layernorms(delta, B, C)
- delta = F.softplus(self.dt_proj_fwd(delta))
- else:
- A = -torch.exp(self.A_log_bwd.float())
- D = self.D_bwd.float()
- deltaBC = self.x_proj_bwd(x)
- delta, B, C = torch.split(
- deltaBC,
- [self.dt_rank, self.d_state, self.d_state], # type: ignore
- dim=-1,
- )
- delta, B, C = self._apply_layernorms(delta, B, C)
- delta = F.softplus(self.dt_proj_bwd(delta))
-
- y = self.selective_scan_seq(x, delta, A, B, C, D)
- return y
-
- def selective_scan_seq(self, x, delta, A, B, C, D):
- _, L, _ = x.shape
-
- deltaA = torch.exp(delta.unsqueeze(-1) * A)
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
-
- BX = deltaB * (x.unsqueeze(-1))
-
- if self.use_pscan:
- hs = self.pscan(deltaA, BX) # type: ignore
- else:
- h = torch.zeros(x.size(0), self.d_inner, self.d_state, device=deltaA.device)
- hs = []
-
- for t in range(0, L):
- h = deltaA[:, t] * h + BX[:, t]
- hs.append(h)
-
- hs = torch.stack(hs, dim=1)
-
- y = (hs @ C.unsqueeze(-1)).squeeze(3)
-
- y = y + D * x
-
- return y
-
-
-class LearnableFeatureInteraction(nn.Module):
- def __init__(self, n_vars):
- super().__init__()
- self.interaction_weights = nn.Parameter(torch.Tensor(n_vars, n_vars))
- nn.init.xavier_uniform_(self.interaction_weights)
-
- def forward(self, x):
- batch_size, n_vars, d_model = x.size()
- interactions = torch.matmul(x, self.interaction_weights)
- return interactions.view(batch_size, n_vars, d_model)
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..get_norm_fn import get_normalization_layer
+from ..layer_utils.normalization_layers import LayerNorm, LearnableLayerScaling, RMSNorm
+
+# Heavily inspired and mostly taken from https://github.com/alxndrTL/mamba.py
+
+
+class Mamba(nn.Module):
+ """Mamba model composed of multiple MambaBlocks.
+
+ Attributes:
+ config (MambaConfig): Configuration object for the Mamba model.
+ layers (nn.ModuleList): List of MambaBlocks constituting the model.
+ """
+
+ def __init__(
+ self,
+ config,
+ ):
+ super().__init__()
+
+ self.layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ d_model=getattr(config, "d_model", 128),
+ expand_factor=getattr(config, "expand_factor", 4),
+ bias=getattr(config, "bias", True),
+ d_conv=getattr(config, "d_conv", 4),
+ conv_bias=getattr(config, "conv_bias", False),
+ dropout=getattr(config, "dropout", 0.0),
+ dt_rank=getattr(config, "dt_rank", "auto"),
+ d_state=getattr(config, "d_state", 256),
+ dt_scale=getattr(config, "dt_scale", 1.0),
+ dt_init=getattr(config, "dt_init", "random"),
+ dt_max=getattr(config, "dt_max", 0.1),
+ dt_min=getattr(config, "dt_min", 1e-04),
+ dt_init_floor=getattr(config, "dt_init_floor", 1e-04),
+ norm=get_normalization_layer(config), # type: ignore
+ activation=getattr(config, "activation", nn.SiLU()),
+ bidirectional=getattr(config, "bidirectional", False),
+ use_learnable_interaction=getattr(
+ config, "use_learnable_interaction", False
+ ),
+ layer_norm_eps=getattr(config, "layer_norm_eps", 1e-5),
+ AD_weight_decay=getattr(config, "AD_weight_decay", True),
+ BC_layer_norm=getattr(config, "BC_layer_norm", False),
+ use_pscan=getattr(config, "use_pscan", False),
+ dilation=getattr(config, "dilation", 1),
+ )
+ for _ in range(getattr(config, "n_layers", 6))
+ ]
+ )
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+
+ return x
+
+
+class ResidualBlock(nn.Module):
+ """Residual block composed of a MambaBlock and a normalization layer.
+
+ Parameters
+ ----------
+ d_model : int, optional
+ Dimension of the model input, by default 32.
+ expand_factor : int, optional
+ Expansion factor for the model, by default 2.
+ bias : bool, optional
+ Whether to use bias in the MambaBlock, by default False.
+ d_conv : int, optional
+ Dimension of the convolution layer in the MambaBlock, by default 16.
+ conv_bias : bool, optional
+ Whether to use bias in the convolution layer, by default True.
+ dropout : float, optional
+ Dropout rate for the layers, by default 0.01.
+ dt_rank : Union[str, int], optional
+ Rank for dynamic time components, 'auto' or an integer, by default 'auto'.
+ d_state : int, optional
+ Dimension of the state vector, by default 32.
+ dt_scale : float, optional
+ Scale factor for dynamic time components, by default 1.0.
+ dt_init : str, optional
+ Initialization strategy for dynamic time components, by default 'random'.
+ dt_max : float, optional
+ Maximum value for dynamic time components, by default 0.1.
+ dt_min : float, optional
+ Minimum value for dynamic time components, by default 1e-03.
+ dt_init_floor : float, optional
+ Floor value for initialization of dynamic time components, by default 1e-04.
+ norm : callable, optional
+ Normalization layer, by default RMSNorm.
+ activation : callable, optional
+ Activation function used in the MambaBlock, by default `F.silu`.
+ bidirectional : bool, optional
+ Whether the block is bidirectional, by default False.
+ use_learnable_interaction : bool, optional
+ Whether to use learnable interactions, by default False.
+ layer_norm_eps : float, optional
+ Epsilon for layer normalization, by default 1e-05.
+ AD_weight_decay : bool, optional
+ Whether to apply weight decay in adaptive dynamics, by default False.
+ BC_layer_norm : bool, optional
+ Whether to use layer normalization for batch compatibility, by default False.
+ use_pscan : bool, optional
+ Whether to use PSCAN, by default False.
+
+ Attributes
+ ----------
+ layers : MambaBlock
+ The main MambaBlock layers for processing input.
+ norm : callable
+ Normalization layer applied before the MambaBlock.
+
+ Methods
+ -------
+ forward(x)
+ Performs a forward pass through the block and returns the output.
+
+ Raises
+ ------
+ ValueError
+ If the provided normalization layer is not valid.
+ """
+
+ def __init__(
+ self,
+ d_model=32,
+ expand_factor=2,
+ bias=False,
+ d_conv=16,
+ conv_bias=True,
+ dropout=0.01,
+ dt_rank="auto",
+ d_state=32,
+ dt_scale=1.0,
+ dt_init="random",
+ dt_max=0.1,
+ dt_min=1e-03,
+ dt_init_floor=1e-04,
+ norm=RMSNorm,
+ activation=F.silu,
+ bidirectional=False,
+ use_learnable_interaction=False,
+ layer_norm_eps=1e-05,
+ AD_weight_decay=False,
+ BC_layer_norm=False,
+ use_pscan=False,
+ dilation=1,
+ ):
+ super().__init__()
+
+ VALID_NORMALIZATION_LAYERS = {
+ "RMSNorm": RMSNorm,
+ "LayerNorm": LayerNorm,
+ "LearnableLayerScaling": LearnableLayerScaling,
+ }
+
+ # Check if the provided normalization layer is valid
+ if isinstance(norm, type) and norm.__name__ not in VALID_NORMALIZATION_LAYERS:
+ raise ValueError(
+ f"Invalid normalization layer: {norm.__name__}. "
+ f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
+ )
+ elif isinstance(norm, str) and norm not in VALID_NORMALIZATION_LAYERS:
+ raise ValueError(
+ f"Invalid normalization layer: {norm}. "
+ f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
+ )
+
+ if dt_rank == "auto":
+ dt_rank = math.ceil(d_model / 16)
+
+ self.layers = MambaBlock(
+ d_model=d_model,
+ expand_factor=expand_factor,
+ bias=bias,
+ d_conv=d_conv,
+ conv_bias=conv_bias,
+ dropout=dropout,
+ dt_rank=dt_rank, # type: ignore
+ d_state=d_state,
+ dt_scale=dt_scale,
+ dt_init=dt_init,
+ dt_max=dt_max,
+ dt_min=dt_min,
+ dt_init_floor=dt_init_floor,
+ activation=activation,
+ bidirectional=bidirectional,
+ use_learnable_interaction=use_learnable_interaction,
+ layer_norm_eps=layer_norm_eps,
+ AD_weight_decay=AD_weight_decay,
+ BC_layer_norm=BC_layer_norm,
+ use_pscan=use_pscan,
+ dilation=dilation,
+ )
+ self.norm = norm
+
+ def forward(self, x):
+ """Forward pass through the residual block.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor to the block.
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor after applying the residual connection and MambaBlock.
+ """
+ output = self.layers(self.norm(x)) + x
+ return output
+
+
+class MambaBlock(nn.Module):
+ """MambaBlock module containing the main computational components for processing input.
+
+ Parameters
+ ----------
+ d_model : int, optional
+ Dimension of the model input, by default 32.
+ expand_factor : int, optional
+ Factor by which the input is expanded in the block, by default 2.
+ bias : bool, optional
+ Whether to use bias in the linear projections, by default False.
+ d_conv : int, optional
+ Dimension of the convolution layer, by default 16.
+ conv_bias : bool, optional
+ Whether to use bias in the convolution layer, by default True.
+ dropout : float, optional
+ Dropout rate applied to the layers, by default 0.01.
+ dt_rank : Union[str, int], optional
+ Rank for dynamic time components, either 'auto' or an integer, by default 'auto'.
+ d_state : int, optional
+ Dimensionality of the state vector, by default 32.
+ dt_scale : float, optional
+ Scale factor applied to the dynamic time component, by default 1.0.
+ dt_init : str, optional
+ Initialization strategy for the dynamic time component, by default 'random'.
+ dt_max : float, optional
+ Maximum value for dynamic time component initialization, by default 0.1.
+ dt_min : float, optional
+ Minimum value for dynamic time component initialization, by default 1e-03.
+ dt_init_floor : float, optional
+ Floor value for dynamic time component initialization, by default 1e-04.
+ activation : callable, optional
+ Activation function applied in the block, by default `F.silu`.
+ bidirectional : bool, optional
+ Whether the block is bidirectional, by default False.
+ use_learnable_interaction : bool, optional
+ Whether to use learnable feature interaction, by default False.
+ layer_norm_eps : float, optional
+ Epsilon for layer normalization, by default 1e-05.
+ AD_weight_decay : bool, optional
+ Whether to apply weight decay in adaptive dynamics, by default False.
+ BC_layer_norm : bool, optional
+ Whether to use layer normalization for batch compatibility, by default False.
+ use_pscan : bool, optional
+ Whether to use the PSCAN mechanism, by default False.
+
+ Attributes
+ ----------
+ in_proj : nn.Linear
+ Linear projection applied to the input tensor.
+ conv1d : nn.Conv1d
+ 1D convolutional layer for processing input.
+ x_proj : nn.Linear
+ Linear projection applied to input-dependent tensors.
+ dt_proj : nn.Linear
+ Linear projection for the dynamical time component.
+ A_log : nn.Parameter
+ Logarithmically stored tensor A for internal dynamics.
+ D : nn.Parameter
+ Tensor for the D component of the model's dynamics.
+ out_proj : nn.Linear
+ Linear projection applied to the output.
+ learnable_interaction : LearnableFeatureInteraction
+ Layer for learnable feature interactions, if `use_learnable_interaction` is True.
+
+ Methods
+ -------
+ forward(x)
+ Performs a forward pass through the MambaBlock.
+ """
+
+ def __init__(
+ self,
+ d_model=32,
+ expand_factor=2,
+ bias=False,
+ d_conv=16,
+ conv_bias=True,
+ dropout=0.01,
+ dt_rank="auto",
+ d_state=32,
+ dt_scale=1.0,
+ dt_init="random",
+ dt_max=0.1,
+ dt_min=1e-03,
+ dt_init_floor=1e-04,
+ activation=F.silu,
+ bidirectional=False,
+ use_learnable_interaction=False,
+ layer_norm_eps=1e-05,
+ AD_weight_decay=False,
+ BC_layer_norm=False,
+ use_pscan=False,
+ dilation=1,
+ ):
+ super().__init__()
+
+ self.use_pscan = use_pscan
+
+ if self.use_pscan:
+ try:
+ from mambapy.pscan import pscan # type: ignore
+
+ self.pscan = pscan # Store the imported pscan function
+ except ImportError:
+ self.pscan = None # Set to None if pscan is not available
+ print(
+ "The 'mambapy' package is not installed. Please install it by running:\n"
+ "pip install mambapy"
+ )
+ else:
+ self.pscan = None
+
+ self.d_inner = d_model * expand_factor
+ self.bidirectional = bidirectional
+ self.use_learnable_interaction = use_learnable_interaction
+
+ self.in_proj_fwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
+ if self.bidirectional:
+ self.in_proj_bwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
+
+ self.conv1d_fwd = nn.Conv1d(
+ in_channels=self.d_inner,
+ out_channels=self.d_inner,
+ kernel_size=d_conv,
+ bias=conv_bias,
+ groups=self.d_inner,
+ padding=d_conv - 1,
+ )
+ if self.bidirectional:
+ self.conv1d_bwd = nn.Conv1d(
+ in_channels=self.d_inner,
+ out_channels=self.d_inner,
+ kernel_size=d_conv,
+ bias=conv_bias,
+ groups=self.d_inner,
+ padding=d_conv - 1,
+ dilation=dilation,
+ )
+
+ self.dropout = nn.Dropout(dropout)
+ self.activation = activation
+
+ if self.use_learnable_interaction:
+ self.learnable_interaction = LearnableFeatureInteraction(self.d_inner)
+
+ self.x_proj_fwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False) # type: ignore
+ if self.bidirectional:
+ self.x_proj_bwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False) # type: ignore
+
+ self.dt_proj_fwd = nn.Linear(dt_rank, self.d_inner, bias=True) # type: ignore
+ if self.bidirectional:
+ self.dt_proj_bwd = nn.Linear(dt_rank, self.d_inner, bias=True) # type: ignore
+
+ dt_init_std = dt_rank**-0.5 * dt_scale # type: ignore
+ if dt_init == "constant":
+ nn.init.constant_(self.dt_proj_fwd.weight, dt_init_std)
+ if self.bidirectional:
+ nn.init.constant_(self.dt_proj_bwd.weight, dt_init_std)
+ elif dt_init == "random":
+ nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
+ if self.bidirectional:
+ nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
+ else:
+ raise NotImplementedError
+
+ dt_fwd = torch.exp(
+ torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ + math.log(dt_min)
+ ).clamp(min=dt_init_floor)
+ inv_dt_fwd = dt_fwd + torch.log(-torch.expm1(-dt_fwd))
+ with torch.no_grad():
+ self.dt_proj_fwd.bias.copy_(inv_dt_fwd)
+
+ if self.bidirectional:
+ dt_bwd = torch.exp(
+ torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ + math.log(dt_min)
+ ).clamp(min=dt_init_floor)
+ inv_dt_bwd = dt_bwd + torch.log(-torch.expm1(-dt_bwd))
+ with torch.no_grad():
+ self.dt_proj_bwd.bias.copy_(inv_dt_bwd)
+
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
+ self.A_log_fwd = nn.Parameter(torch.log(A))
+ self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
+
+ if self.bidirectional:
+ self.A_log_bwd = nn.Parameter(torch.log(A))
+ self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
+
+ if not AD_weight_decay:
+ self.A_log_fwd._no_weight_decay = True # type: ignore
+ self.D_fwd._no_weight_decay = True # type: ignore
+
+ if self.bidirectional:
+ if not AD_weight_decay:
+ self.A_log_bwd._no_weight_decay = True # type: ignore
+ self.D_bwd._no_weight_decay = True # type: ignore
+
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
+ self.dt_rank = dt_rank
+ self.d_state = d_state
+
+ if BC_layer_norm:
+ self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps) # type: ignore
+ self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
+ self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
+ else:
+ self.dt_layernorm = None
+ self.B_layernorm = None
+ self.C_layernorm = None
+
+ def forward(self, x):
+ _, L, _ = x.shape
+
+ xz_fwd = self.in_proj_fwd(x)
+ x_fwd, z_fwd = xz_fwd.chunk(2, dim=-1)
+
+ x_fwd = x_fwd.transpose(1, 2)
+ x_fwd = self.conv1d_fwd(x_fwd)[:, :, :L]
+ x_fwd = x_fwd.transpose(1, 2)
+
+ if self.bidirectional:
+ xz_bwd = self.in_proj_bwd(x)
+ x_bwd, z_bwd = xz_bwd.chunk(2, dim=-1)
+
+ x_bwd = x_bwd.transpose(1, 2)
+ x_bwd = self.conv1d_bwd(x_bwd)[:, :, :L]
+ x_bwd = x_bwd.transpose(1, 2)
+
+ if self.use_learnable_interaction:
+ x_fwd = self.learnable_interaction(x_fwd)
+ if self.bidirectional:
+ x_bwd = self.learnable_interaction(x_bwd) # type: ignore
+
+ x_fwd = self.activation(x_fwd)
+ x_fwd = self.dropout(x_fwd)
+ y_fwd = self.ssm(x_fwd, forward=True)
+
+ if self.bidirectional:
+ x_bwd = self.activation(x_bwd) # type: ignore
+ x_bwd = self.dropout(x_bwd)
+ y_bwd = self.ssm(torch.flip(x_bwd, [1]), forward=False)
+ y = y_fwd + torch.flip(y_bwd, [1])
+ y = y / 2
+ else:
+ y = y_fwd
+
+ z_fwd = self.activation(z_fwd)
+ z_fwd = self.dropout(z_fwd)
+
+ output = y * z_fwd
+ output = self.out_proj(output)
+
+ return output
+
+ def _apply_layernorms(self, dt, B, C):
+ if self.dt_layernorm is not None:
+ dt = self.dt_layernorm(dt)
+ if self.B_layernorm is not None:
+ B = self.B_layernorm(B)
+ if self.C_layernorm is not None:
+ C = self.C_layernorm(C)
+ return dt, B, C
+
+ def ssm(self, x, forward=True):
+ if forward:
+ A = -torch.exp(self.A_log_fwd.float())
+ D = self.D_fwd.float()
+ deltaBC = self.x_proj_fwd(x)
+ delta, B, C = torch.split(
+ deltaBC,
+ [self.dt_rank, self.d_state, self.d_state], # type: ignore
+ dim=-1,
+ )
+ delta, B, C = self._apply_layernorms(delta, B, C)
+ delta = F.softplus(self.dt_proj_fwd(delta))
+ else:
+ A = -torch.exp(self.A_log_bwd.float())
+ D = self.D_bwd.float()
+ deltaBC = self.x_proj_bwd(x)
+ delta, B, C = torch.split(
+ deltaBC,
+ [self.dt_rank, self.d_state, self.d_state], # type: ignore
+ dim=-1,
+ )
+ delta, B, C = self._apply_layernorms(delta, B, C)
+ delta = F.softplus(self.dt_proj_bwd(delta))
+
+ y = self.selective_scan_seq(x, delta, A, B, C, D)
+ return y
+
+ def selective_scan_seq(self, x, delta, A, B, C, D):
+ _, L, _ = x.shape
+
+ deltaA = torch.exp(delta.unsqueeze(-1) * A)
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
+
+ BX = deltaB * (x.unsqueeze(-1))
+
+ if self.use_pscan:
+ hs = self.pscan(deltaA, BX) # type: ignore
+ else:
+ h = torch.zeros(x.size(0), self.d_inner, self.d_state, device=deltaA.device)
+ hs = []
+
+ for t in range(0, L):
+ h = deltaA[:, t] * h + BX[:, t]
+ hs.append(h)
+
+ hs = torch.stack(hs, dim=1)
+
+ y = (hs @ C.unsqueeze(-1)).squeeze(3)
+
+ y = y + D * x
+
+ return y
+
+
+class LearnableFeatureInteraction(nn.Module):
+ def __init__(self, n_vars):
+ super().__init__()
+ self.interaction_weights = nn.Parameter(torch.Tensor(n_vars, n_vars))
+ nn.init.xavier_uniform_(self.interaction_weights)
+
+ def forward(self, x):
+ batch_size, n_vars, d_model = x.size()
+ interactions = torch.matmul(x, self.interaction_weights)
+ return interactions.view(batch_size, n_vars, d_model)
diff --git a/deeptabular/arch_utils/mamba_utils/mamba_original.py b/deeptab/arch_utils/mamba_utils/mamba_original.py
similarity index 97%
rename from deeptabular/arch_utils/mamba_utils/mamba_original.py
rename to deeptab/arch_utils/mamba_utils/mamba_original.py
index a5a95727..0746b9f1 100644
--- a/deeptabular/arch_utils/mamba_utils/mamba_original.py
+++ b/deeptab/arch_utils/mamba_utils/mamba_original.py
@@ -1,213 +1,213 @@
-# black: noqa
-
-import torch
-import torch.nn as nn
-
-from ..get_norm_fn import get_normalization_layer
-from ..layer_utils.normalization_layers import (
- BatchNorm,
- GroupNorm,
- InstanceNorm,
- LayerNorm,
- LearnableLayerScaling,
- RMSNorm,
-)
-from .init_weights import _init_weights
-
-
-class ResidualBlock(nn.Module):
- """Residual block composed of a MambaBlock and a normalization layer.
-
- Attributes:
- layers (MambaBlock): MambaBlock layers.
- norm (RMSNorm): Normalization layer.
- """
-
- MambaBlock = None # Declare MambaBlock at the class level
-
- def __init__(
- self,
- d_model=32,
- expand_factor=2,
- bias=False,
- d_conv=16,
- conv_bias=True,
- d_state=32,
- dt_max=0.1,
- dt_min=1e-03,
- dt_init_floor=1e-04,
- norm=RMSNorm,
- layer_idx=0,
- mamba_version="mamba1",
- ):
- super().__init__()
-
- # Lazy import for Mamba and only import if it's None
- if ResidualBlock.MambaBlock is None:
- self._lazy_import_mamba(mamba_version)
-
- VALID_NORMALIZATION_LAYERS = {
- "RMSNorm": RMSNorm,
- "LayerNorm": LayerNorm,
- "LearnableLayerScaling": LearnableLayerScaling,
- "BatchNorm": BatchNorm,
- "InstanceNorm": InstanceNorm,
- "GroupNorm": GroupNorm,
- }
-
- # Check if the provided normalization layer is valid
- if isinstance(norm, type) and norm.__name__ not in VALID_NORMALIZATION_LAYERS:
- raise ValueError(
- f"Invalid normalization layer: {norm.__name__}. "
- f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
- )
- elif isinstance(norm, str) and norm not in VALID_NORMALIZATION_LAYERS:
- raise ValueError(
- f"Invalid normalization layer: {norm}. "
- f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
- )
-
- # Use the imported MambaBlock to create layers
- self.layers = ResidualBlock.MambaBlock(
- d_model=d_model,
- d_state=d_state,
- d_conv=d_conv,
- expand=expand_factor,
- dt_min=dt_min,
- dt_max=dt_max,
- dt_init_floor=dt_init_floor,
- conv_bias=conv_bias,
- bias=bias,
- layer_idx=layer_idx,
- ) # type: ignore
- self.norm = norm
-
- def _lazy_import_mamba(self, mamba_version):
- """Lazily import Mamba or Mamba2 based on the provided version and alias it."""
- if ResidualBlock.MambaBlock is None:
- try:
- if mamba_version == "mamba1":
- from mamba_ssm import Mamba as MambaBlock # type: ignore
-
- ResidualBlock.MambaBlock = MambaBlock
- print("Successfully imported Mamba (version 1)")
- elif mamba_version == "mamba2":
- from mamba_ssm import Mamba2 as MambaBlock # type: ignore
-
- ResidualBlock.MambaBlock = MambaBlock
- print("Successfully imported Mamba2")
- else:
- raise ValueError(f"Invalid mamba_version: {mamba_version}. Choose 'mamba1' or 'mamba2'.")
- except ImportError:
- raise ImportError(
- f"Failed to import {mamba_version}. Please ensure the correct version is installed."
- ) from None
-
- def forward(self, x):
- output = self.layers(self.norm(x)) + x
- return output
-
-
-class MambaOriginal(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- VALID_NORMALIZATION_LAYERS = {
- "RMSNorm": RMSNorm,
- "LayerNorm": LayerNorm,
- "LearnableLayerScaling": LearnableLayerScaling,
- "BatchNorm": BatchNorm,
- "InstanceNorm": InstanceNorm,
- "GroupNorm": GroupNorm,
- }
-
- # Get normalization layer from config
- norm = config.norm
- self.bidirectional = config.bidirectional
- if isinstance(norm, str) and norm in VALID_NORMALIZATION_LAYERS:
- self.norm_f = VALID_NORMALIZATION_LAYERS[norm](config.d_model, eps=config.layer_norm_eps)
- else:
- raise ValueError(
- f"Invalid normalization layer: {norm}. "
- f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
- )
-
- # Initialize Mamba layers based on the configuration
-
- self.fwd_layers = nn.ModuleList(
- [
- ResidualBlock(
- mamba_version=getattr(config, "mamba_version", "mamba2"),
- d_model=getattr(config, "d_model", 128),
- d_state=getattr(config, "d_state", 256),
- d_conv=getattr(config, "d_conv", 4),
- norm=get_normalization_layer(config), # type: ignore
- expand_factor=getattr(config, "expand_factor", 2),
- dt_min=getattr(config, "dt_min", 1e-04),
- dt_max=getattr(config, "dt_max", 0.1),
- dt_init_floor=getattr(config, "dt_init_floor", 1e-04),
- conv_bias=getattr(config, "conv_bias", False),
- bias=getattr(config, "bias", True),
- layer_idx=i,
- )
- for i in range(getattr(config, "n_layers", 6))
- ]
- )
-
- if self.bidirectional:
- self.bckwd_layers = nn.ModuleList(
- [
- ResidualBlock(
- mamba_version=config.mamba_version,
- d_model=config.d_model,
- d_state=config.d_state,
- d_conv=config.d_conv,
- norm=get_normalization_layer(config), # type: ignore
- expand_factor=config.expand_factor,
- dt_min=config.dt_min,
- dt_max=config.dt_max,
- dt_init_floor=config.dt_init_floor,
- conv_bias=config.conv_bias,
- bias=config.bias,
- layer_idx=i + config.n_layers,
- )
- for i in range(config.n_layers)
- ]
- )
-
- # Apply weight initialization
- self.apply(
- lambda m: _init_weights(
- m,
- n_layer=config.n_layers,
- n_residuals_per_layer=1 if config.d_state == 0 else 2,
- )
- )
-
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
- return {
- i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
- for i, layer in enumerate(self.layers)
- }
-
- def forward(self, x):
- if self.bidirectional:
- # Reverse input and pass through backward layers
- x_reversed = torch.flip(x, [1])
- # Forward pass through forward layers
- for layer in self.fwd_layers:
- # Update x in-place as each forward layer processes it
- x = layer(x)
-
- if self.bidirectional:
- for layer in self.bckwd_layers:
- x_reversed = layer(x_reversed) # type: ignore
-
- # Reverse the output of the backward pass to original order
- x_reversed = torch.flip(x_reversed, [1]) # type: ignore
-
- # Combine forward and backward outputs by averaging
- return (x + x_reversed) / 2
-
- # Return forward output only if not bidirectional
- return x
+# black: noqa
+
+import torch
+import torch.nn as nn
+
+from ..get_norm_fn import get_normalization_layer
+from ..layer_utils.normalization_layers import (
+ BatchNorm,
+ GroupNorm,
+ InstanceNorm,
+ LayerNorm,
+ LearnableLayerScaling,
+ RMSNorm,
+)
+from .init_weights import _init_weights
+
+
+class ResidualBlock(nn.Module):
+ """Residual block composed of a MambaBlock and a normalization layer.
+
+ Attributes:
+ layers (MambaBlock): MambaBlock layers.
+ norm (RMSNorm): Normalization layer.
+ """
+
+ MambaBlock = None # Declare MambaBlock at the class level
+
+ def __init__(
+ self,
+ d_model=32,
+ expand_factor=2,
+ bias=False,
+ d_conv=16,
+ conv_bias=True,
+ d_state=32,
+ dt_max=0.1,
+ dt_min=1e-03,
+ dt_init_floor=1e-04,
+ norm=RMSNorm,
+ layer_idx=0,
+ mamba_version="mamba1",
+ ):
+ super().__init__()
+
+ # Lazy import for Mamba and only import if it's None
+ if ResidualBlock.MambaBlock is None:
+ self._lazy_import_mamba(mamba_version)
+
+ VALID_NORMALIZATION_LAYERS = {
+ "RMSNorm": RMSNorm,
+ "LayerNorm": LayerNorm,
+ "LearnableLayerScaling": LearnableLayerScaling,
+ "BatchNorm": BatchNorm,
+ "InstanceNorm": InstanceNorm,
+ "GroupNorm": GroupNorm,
+ }
+
+ # Check if the provided normalization layer is valid
+ if isinstance(norm, type) and norm.__name__ not in VALID_NORMALIZATION_LAYERS:
+ raise ValueError(
+ f"Invalid normalization layer: {norm.__name__}. "
+ f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
+ )
+ elif isinstance(norm, str) and norm not in VALID_NORMALIZATION_LAYERS:
+ raise ValueError(
+ f"Invalid normalization layer: {norm}. "
+ f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
+ )
+
+ # Use the imported MambaBlock to create layers
+ self.layers = ResidualBlock.MambaBlock(
+ d_model=d_model,
+ d_state=d_state,
+ d_conv=d_conv,
+ expand=expand_factor,
+ dt_min=dt_min,
+ dt_max=dt_max,
+ dt_init_floor=dt_init_floor,
+ conv_bias=conv_bias,
+ bias=bias,
+ layer_idx=layer_idx,
+ ) # type: ignore
+ self.norm = norm
+
+ def _lazy_import_mamba(self, mamba_version):
+ """Lazily import Mamba or Mamba2 based on the provided version and alias it."""
+ if ResidualBlock.MambaBlock is None:
+ try:
+ if mamba_version == "mamba1":
+ from mamba_ssm import Mamba as MambaBlock # type: ignore
+
+ ResidualBlock.MambaBlock = MambaBlock
+ print("Successfully imported Mamba (version 1)")
+ elif mamba_version == "mamba2":
+ from mamba_ssm import Mamba2 as MambaBlock # type: ignore
+
+ ResidualBlock.MambaBlock = MambaBlock
+ print("Successfully imported Mamba2")
+ else:
+ raise ValueError(f"Invalid mamba_version: {mamba_version}. Choose 'mamba1' or 'mamba2'.")
+ except ImportError:
+ raise ImportError(
+ f"Failed to import {mamba_version}. Please ensure the correct version is installed."
+ ) from None
+
+ def forward(self, x):
+ output = self.layers(self.norm(x)) + x
+ return output
+
+
+class MambaOriginal(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ VALID_NORMALIZATION_LAYERS = {
+ "RMSNorm": RMSNorm,
+ "LayerNorm": LayerNorm,
+ "LearnableLayerScaling": LearnableLayerScaling,
+ "BatchNorm": BatchNorm,
+ "InstanceNorm": InstanceNorm,
+ "GroupNorm": GroupNorm,
+ }
+
+ # Get normalization layer from config
+ norm = config.norm
+ self.bidirectional = config.bidirectional
+ if isinstance(norm, str) and norm in VALID_NORMALIZATION_LAYERS:
+ self.norm_f = VALID_NORMALIZATION_LAYERS[norm](config.d_model, eps=config.layer_norm_eps)
+ else:
+ raise ValueError(
+ f"Invalid normalization layer: {norm}. "
+ f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}"
+ )
+
+ # Initialize Mamba layers based on the configuration
+
+ self.fwd_layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ mamba_version=getattr(config, "mamba_version", "mamba2"),
+ d_model=getattr(config, "d_model", 128),
+ d_state=getattr(config, "d_state", 256),
+ d_conv=getattr(config, "d_conv", 4),
+ norm=get_normalization_layer(config), # type: ignore
+ expand_factor=getattr(config, "expand_factor", 2),
+ dt_min=getattr(config, "dt_min", 1e-04),
+ dt_max=getattr(config, "dt_max", 0.1),
+ dt_init_floor=getattr(config, "dt_init_floor", 1e-04),
+ conv_bias=getattr(config, "conv_bias", False),
+ bias=getattr(config, "bias", True),
+ layer_idx=i,
+ )
+ for i in range(getattr(config, "n_layers", 6))
+ ]
+ )
+
+ if self.bidirectional:
+ self.bckwd_layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ mamba_version=config.mamba_version,
+ d_model=config.d_model,
+ d_state=config.d_state,
+ d_conv=config.d_conv,
+ norm=get_normalization_layer(config), # type: ignore
+ expand_factor=config.expand_factor,
+ dt_min=config.dt_min,
+ dt_max=config.dt_max,
+ dt_init_floor=config.dt_init_floor,
+ conv_bias=config.conv_bias,
+ bias=config.bias,
+ layer_idx=i + config.n_layers,
+ )
+ for i in range(config.n_layers)
+ ]
+ )
+
+ # Apply weight initialization
+ self.apply(
+ lambda m: _init_weights(
+ m,
+ n_layer=config.n_layers,
+ n_residuals_per_layer=1 if config.d_state == 0 else 2,
+ )
+ )
+
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
+ return {
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
+ for i, layer in enumerate(self.layers)
+ }
+
+ def forward(self, x):
+ if self.bidirectional:
+ # Reverse input and pass through backward layers
+ x_reversed = torch.flip(x, [1])
+ # Forward pass through forward layers
+ for layer in self.fwd_layers:
+ # Update x in-place as each forward layer processes it
+ x = layer(x)
+
+ if self.bidirectional:
+ for layer in self.bckwd_layers:
+ x_reversed = layer(x_reversed) # type: ignore
+
+ # Reverse the output of the backward pass to original order
+ x_reversed = torch.flip(x_reversed, [1]) # type: ignore
+
+ # Combine forward and backward outputs by averaging
+ return (x + x_reversed) / 2
+
+ # Return forward output only if not bidirectional
+ return x
diff --git a/deeptabular/arch_utils/mamba_utils/mambattn_arch.py b/deeptab/arch_utils/mamba_utils/mambattn_arch.py
similarity index 97%
rename from deeptabular/arch_utils/mamba_utils/mambattn_arch.py
rename to deeptab/arch_utils/mamba_utils/mambattn_arch.py
index df9e32bc..bbea31e7 100644
--- a/deeptabular/arch_utils/mamba_utils/mambattn_arch.py
+++ b/deeptab/arch_utils/mamba_utils/mambattn_arch.py
@@ -1,117 +1,117 @@
-import torch.nn as nn
-
-from ..get_norm_fn import get_normalization_layer
-from .mamba_arch import ResidualBlock
-
-
-class MambAttn(nn.Module):
- """Mamba model composed of alternating MambaBlocks and Attention layers.
-
- Attributes:
- config (MambaConfig): Configuration object for the Mamba model.
- layers (nn.ModuleList): List of alternating ResidualBlock (Mamba layers) and
- attention layers constituting the model.
- """
-
- def __init__(
- self,
- config,
- ):
- super().__init__()
-
- # Define Mamba and Attention layers alternation
- self.layers = nn.ModuleList()
-
- total_blocks = config.n_layers + config.n_attention_layers # Total blocks to be created
- attention_count = 0
-
- for i in range(total_blocks):
- # Insert attention layer after N Mamba layers
- if (i + 1) % (config.n_mamba_per_attention + 1) == 0:
- self.layers.append(
- nn.MultiheadAttention(
- embed_dim=config.d_model,
- num_heads=config.n_heads,
- dropout=config.attn_dropout,
- )
- )
- attention_count += 1
- else:
- self.layers.append(
- ResidualBlock(
- d_model=config.d_model,
- expand_factor=config.expand_factor,
- bias=config.bias,
- d_conv=config.d_conv,
- conv_bias=config.conv_bias,
- dropout=config.dropout,
- dt_rank=config.dt_rank,
- d_state=config.d_state,
- dt_scale=config.dt_scale,
- dt_init=config.dt_init,
- dt_max=config.dt_max,
- dt_min=config.dt_min,
- dt_init_floor=config.dt_init_floor,
- norm=get_normalization_layer(config), # type: ignore
- activation=config.activation,
- bidirectional=config.bidirectional,
- use_learnable_interaction=config.use_learnable_interaction,
- layer_norm_eps=config.layer_norm_eps,
- AD_weight_decay=config.AD_weight_decay,
- BC_layer_norm=config.BC_layer_norm,
- use_pscan=config.use_pscan,
- )
- )
-
- # Check the type of the last layer and append the desired one if necessary
- if config.last_layer == "attn":
- if not isinstance(self.layers[-1], nn.MultiheadAttention):
- self.layers.append(
- nn.MultiheadAttention(
- embed_dim=config.d_model,
- num_heads=config.n_heads,
- dropout=config.dropout,
- )
- )
- else:
- if not isinstance(self.layers[-1], ResidualBlock):
- self.layers.append(
- ResidualBlock(
- d_model=config.d_model,
- expand_factor=config.expand_factor,
- bias=config.bias,
- d_conv=config.d_conv,
- conv_bias=config.conv_bias,
- dropout=config.dropout,
- dt_rank=config.dt_rank,
- d_state=config.d_state,
- dt_scale=config.dt_scale,
- dt_init=config.dt_init,
- dt_max=config.dt_max,
- dt_min=config.dt_min,
- dt_init_floor=config.dt_init_floor,
- norm=get_normalization_layer(config), # type: ignore
- activation=config.activation,
- bidirectional=config.bidirectional,
- use_learnable_interaction=config.use_learnable_interaction,
- layer_norm_eps=config.layer_norm_eps,
- AD_weight_decay=config.AD_weight_decay,
- BC_layer_norm=config.BC_layer_norm,
- use_pscan=config.use_pscan,
- )
- )
-
- def forward(self, x):
- for layer in self.layers:
- if isinstance(layer, nn.MultiheadAttention):
- # If it's an attention layer, handle input shape (seq_len, batch, embed_dim)
- # Switch to (seq_len, batch, embed_dim) for attention
- x = x.transpose(0, 1)
- x, _ = layer(x, x, x)
- # Switch back to (batch, seq_len, embed_dim)
- x = x.transpose(0, 1)
- else:
- # Otherwise, pass through Mamba block
- x = layer(x)
-
- return x
+import torch.nn as nn
+
+from ..get_norm_fn import get_normalization_layer
+from .mamba_arch import ResidualBlock
+
+
+class MambAttn(nn.Module):
+ """Mamba model composed of alternating MambaBlocks and Attention layers.
+
+ Attributes:
+ config (MambaConfig): Configuration object for the Mamba model.
+ layers (nn.ModuleList): List of alternating ResidualBlock (Mamba layers) and
+ attention layers constituting the model.
+ """
+
+ def __init__(
+ self,
+ config,
+ ):
+ super().__init__()
+
+ # Define Mamba and Attention layers alternation
+ self.layers = nn.ModuleList()
+
+ total_blocks = config.n_layers + config.n_attention_layers # Total blocks to be created
+ attention_count = 0
+
+ for i in range(total_blocks):
+ # Insert attention layer after N Mamba layers
+ if (i + 1) % (config.n_mamba_per_attention + 1) == 0:
+ self.layers.append(
+ nn.MultiheadAttention(
+ embed_dim=config.d_model,
+ num_heads=config.n_heads,
+ dropout=config.attn_dropout,
+ )
+ )
+ attention_count += 1
+ else:
+ self.layers.append(
+ ResidualBlock(
+ d_model=config.d_model,
+ expand_factor=config.expand_factor,
+ bias=config.bias,
+ d_conv=config.d_conv,
+ conv_bias=config.conv_bias,
+ dropout=config.dropout,
+ dt_rank=config.dt_rank,
+ d_state=config.d_state,
+ dt_scale=config.dt_scale,
+ dt_init=config.dt_init,
+ dt_max=config.dt_max,
+ dt_min=config.dt_min,
+ dt_init_floor=config.dt_init_floor,
+ norm=get_normalization_layer(config), # type: ignore
+ activation=config.activation,
+ bidirectional=config.bidirectional,
+ use_learnable_interaction=config.use_learnable_interaction,
+ layer_norm_eps=config.layer_norm_eps,
+ AD_weight_decay=config.AD_weight_decay,
+ BC_layer_norm=config.BC_layer_norm,
+ use_pscan=config.use_pscan,
+ )
+ )
+
+ # Check the type of the last layer and append the desired one if necessary
+ if config.last_layer == "attn":
+ if not isinstance(self.layers[-1], nn.MultiheadAttention):
+ self.layers.append(
+ nn.MultiheadAttention(
+ embed_dim=config.d_model,
+ num_heads=config.n_heads,
+ dropout=config.dropout,
+ )
+ )
+ else:
+ if not isinstance(self.layers[-1], ResidualBlock):
+ self.layers.append(
+ ResidualBlock(
+ d_model=config.d_model,
+ expand_factor=config.expand_factor,
+ bias=config.bias,
+ d_conv=config.d_conv,
+ conv_bias=config.conv_bias,
+ dropout=config.dropout,
+ dt_rank=config.dt_rank,
+ d_state=config.d_state,
+ dt_scale=config.dt_scale,
+ dt_init=config.dt_init,
+ dt_max=config.dt_max,
+ dt_min=config.dt_min,
+ dt_init_floor=config.dt_init_floor,
+ norm=get_normalization_layer(config), # type: ignore
+ activation=config.activation,
+ bidirectional=config.bidirectional,
+ use_learnable_interaction=config.use_learnable_interaction,
+ layer_norm_eps=config.layer_norm_eps,
+ AD_weight_decay=config.AD_weight_decay,
+ BC_layer_norm=config.BC_layer_norm,
+ use_pscan=config.use_pscan,
+ )
+ )
+
+ def forward(self, x):
+ for layer in self.layers:
+ if isinstance(layer, nn.MultiheadAttention):
+ # If it's an attention layer, handle input shape (seq_len, batch, embed_dim)
+ # Switch to (seq_len, batch, embed_dim) for attention
+ x = x.transpose(0, 1)
+ x, _ = layer(x, x, x)
+ # Switch back to (batch, seq_len, embed_dim)
+ x = x.transpose(0, 1)
+ else:
+ # Otherwise, pass through Mamba block
+ x = layer(x)
+
+ return x
diff --git a/deeptabular/arch_utils/mlp_utils.py b/deeptab/arch_utils/mlp_utils.py
similarity index 100%
rename from deeptabular/arch_utils/mlp_utils.py
rename to deeptab/arch_utils/mlp_utils.py
diff --git a/deeptabular/arch_utils/neural_decision_tree.py b/deeptab/arch_utils/neural_decision_tree.py
similarity index 100%
rename from deeptabular/arch_utils/neural_decision_tree.py
rename to deeptab/arch_utils/neural_decision_tree.py
diff --git a/deeptabular/arch_utils/node_utils.py b/deeptab/arch_utils/node_utils.py
similarity index 100%
rename from deeptabular/arch_utils/node_utils.py
rename to deeptab/arch_utils/node_utils.py
diff --git a/deeptabular/arch_utils/numpy_utils.py b/deeptab/arch_utils/numpy_utils.py
similarity index 100%
rename from deeptabular/arch_utils/numpy_utils.py
rename to deeptab/arch_utils/numpy_utils.py
diff --git a/deeptabular/arch_utils/resnet_utils.py b/deeptab/arch_utils/resnet_utils.py
similarity index 100%
rename from deeptabular/arch_utils/resnet_utils.py
rename to deeptab/arch_utils/resnet_utils.py
diff --git a/deeptabular/arch_utils/rnn_utils.py b/deeptab/arch_utils/rnn_utils.py
similarity index 97%
rename from deeptabular/arch_utils/rnn_utils.py
rename to deeptab/arch_utils/rnn_utils.py
index e550e44a..6b2ba931 100644
--- a/deeptabular/arch_utils/rnn_utils.py
+++ b/deeptab/arch_utils/rnn_utils.py
@@ -1,274 +1,274 @@
-import torch
-import torch.nn as nn
-
-from .layer_utils.batch_ensemble_layer import RNNBatchEnsembleLayer
-from .lstm_utils import mLSTMblock, sLSTMblock
-
-
-class ConvRNN(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- # Configuration parameters with defaults where needed
- # 'RNN', 'LSTM', or 'GRU'
- self.model_type = getattr(config, "model_type", "RNN")
- self.input_size = getattr(config, "d_model", 128)
- self.hidden_size = getattr(config, "dim_feedforward", 128)
- self.num_layers = getattr(config, "n_layers", 4)
- self.rnn_dropout = getattr(config, "rnn_dropout", 0.0)
- self.bias = getattr(config, "bias", True)
- self.conv_bias = getattr(config, "conv_bias", True)
- self.rnn_activation = getattr(config, "rnn_activation", "relu")
- self.d_conv = getattr(config, "d_conv", 4)
- self.residuals = getattr(config, "residuals", False)
- self.dilation = getattr(config, "dilation", 1)
-
- # Choose RNN layer based on model_type
- rnn_layer = {
- "RNN": nn.RNN,
- "LSTM": nn.LSTM,
- "GRU": nn.GRU,
- "mLSTM": mLSTMblock,
- "sLSTM": sLSTMblock,
- }[self.model_type]
-
- # Convolutional layers
- self.convs = nn.ModuleList()
- self.layernorms_conv = nn.ModuleList() # LayerNorms for Conv layers
-
- if self.residuals:
- self.residual_matrix = nn.ParameterList(
- [
- nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
- for _ in range(self.num_layers)
- ]
- )
-
- # First Conv1d layer uses input_size
- self.convs.append(
- nn.Conv1d(
- in_channels=self.input_size,
- out_channels=self.input_size,
- kernel_size=self.d_conv,
- padding=self.d_conv - 1,
- bias=self.conv_bias,
- groups=self.input_size,
- dilation=self.dilation,
- )
- )
- self.layernorms_conv.append(nn.LayerNorm(self.input_size))
-
- # Subsequent Conv1d layers use hidden_size as input
- for i in range(self.num_layers - 1):
- self.convs.append(
- nn.Conv1d(
- in_channels=self.hidden_size,
- out_channels=self.hidden_size,
- kernel_size=self.d_conv,
- padding=self.d_conv - 1,
- bias=self.conv_bias,
- groups=self.hidden_size,
- dilation=self.dilation,
- )
- )
- self.layernorms_conv.append(nn.LayerNorm(self.hidden_size))
-
- # Initialize the RNN layers
- self.rnns = nn.ModuleList()
- self.layernorms_rnn = nn.ModuleList() # LayerNorms for RNN layers
-
- for i in range(self.num_layers):
- rnn_args = {
- "input_size": self.input_size if i == 0 else self.hidden_size,
- "hidden_size": self.hidden_size,
- "num_layers": 1,
- "batch_first": True,
- "dropout": self.rnn_dropout if i < self.num_layers - 1 else 0,
- "bias": self.bias,
- }
- if self.model_type == "RNN":
- rnn_args["nonlinearity"] = self.rnn_activation
- self.rnns.append(rnn_layer(**rnn_args))
- self.layernorms_rnn.append(nn.LayerNorm(self.hidden_size))
-
- def forward(self, x):
- """Forward pass through Conv-RNN layers.
-
- Parameters
- -----------
- x : torch.Tensor
- Input tensor of shape (batch_size, seq_length, input_size).
-
- Returns
- --------
- output : torch.Tensor
- Output tensor after passing through Conv-RNN layers.
- """
- _, L, _ = x.shape
- if self.residuals:
- residual = x
-
- # Loop through the RNN layers and apply 1D convolution before each
- for i in range(self.num_layers):
- # Transpose to (batch_size, input_size, seq_length) for Conv1d
-
- x = self.layernorms_conv[i](x)
- x = x.transpose(1, 2)
-
- # Apply the 1D convolution
- x = self.convs[i](x)[:, :, :L]
-
- # Transpose back to (batch_size, seq_length, input_size)
- x = x.transpose(1, 2)
-
- # Pass through the RNN layer
- x, _ = self.rnns[i](x)
-
- # Residual connection with learnable matrix
- if self.residuals:
- if i < self.num_layers and i > 0:
- residual_proj = torch.matmul(residual, self.residual_matrix[i]) # type: ignore
- x = x + residual_proj
-
- # Update residual for next layer
- residual = x
-
- return x, _
-
-
-class EnsembleConvRNN(nn.Module):
- def __init__(
- self,
- config,
- ):
- super().__init__()
-
- self.input_size = getattr(config, "d_model", 128)
- self.hidden_size = getattr(config, "dim_feedforward", 128)
- self.ensemble_size = getattr(config, "ensemble_size", 16)
- self.num_layers = getattr(config, "n_layers", 4)
- self.rnn_dropout = getattr(config, "rnn_dropout", 0.5)
- self.bias = getattr(config, "bias", True)
- self.conv_bias = getattr(config, "conv_bias", True)
- self.rnn_activation = getattr(config, "rnn_activation", torch.tanh)
- self.d_conv = getattr(config, "d_conv", 4)
- self.residuals = getattr(config, "residuals", False)
- self.ensemble_scaling_in = getattr(config, "ensemble_scaling_in", True)
- self.ensemble_scaling_out = getattr(config, "ensemble_scaling_out", True)
- self.ensemble_bias = getattr(config, "ensemble_bias", False)
- self.scaling_init = getattr(config, "scaling_init", "ones")
- self.model_type = getattr(config, "model_type", "full")
-
- # Convolutional layers
- self.convs = nn.ModuleList()
- self.layernorms_conv = nn.ModuleList() # LayerNorms for Conv layers
-
- if self.residuals:
- self.residual_matrix = nn.ParameterList(
- [
- nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
- for _ in range(self.num_layers)
- ]
- )
-
- # First Conv1d layer uses input_size
- self.conv = nn.Conv1d(
- in_channels=self.input_size,
- out_channels=self.input_size,
- kernel_size=self.d_conv,
- padding=self.d_conv - 1,
- bias=self.conv_bias,
- groups=self.input_size,
- )
-
- self.layernorms_conv = nn.LayerNorm(self.input_size)
-
- # Initialize the RNN layers
- self.rnns = nn.ModuleList()
- self.layernorms_rnn = nn.ModuleList() # LayerNorms for RNN layers
-
- self.rnns.append(
- RNNBatchEnsembleLayer(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- ensemble_size=self.ensemble_size,
- ensemble_scaling_in=self.ensemble_scaling_in,
- ensemble_scaling_out=self.ensemble_scaling_out,
- ensemble_bias=self.ensemble_bias,
- dropout=self.rnn_dropout,
- nonlinearity=self.rnn_activation,
- scaling_init="normal",
- )
- )
-
- for i in range(1, self.num_layers):
- if self.model_type == "mini":
- rnn = RNNBatchEnsembleLayer(
- input_size=self.hidden_size,
- hidden_size=self.hidden_size,
- ensemble_size=self.ensemble_size,
- ensemble_scaling_in=False,
- ensemble_scaling_out=False,
- ensemble_bias=self.ensemble_bias,
- dropout=self.rnn_dropout if i < self.num_layers - 1 else 0,
- nonlinearity=self.rnn_activation,
- scaling_init=self.scaling_init, # type: ignore
- )
- else:
- rnn = RNNBatchEnsembleLayer(
- input_size=self.hidden_size,
- hidden_size=self.hidden_size,
- ensemble_size=self.ensemble_size,
- ensemble_scaling_in=self.ensemble_scaling_in,
- ensemble_scaling_out=self.ensemble_scaling_out,
- ensemble_bias=self.ensemble_bias,
- dropout=self.rnn_dropout if i < self.num_layers - 1 else 0,
- nonlinearity=self.rnn_activation,
- scaling_init=self.scaling_init, # type: ignore
- )
-
- self.rnns.append(rnn)
-
- def forward(self, x):
- """Forward pass through Conv-RNN layers.
-
- Parameters
- -----------
- x : torch.Tensor
- Input tensor of shape (batch_size, seq_length, input_size).
-
- Returns
- --------
- output : torch.Tensor
- Output tensor after passing through Conv-RNN layers.
- """
- _, L, _ = x.shape
- if self.residuals:
- residual = x
-
- x = self.layernorms_conv(x)
- x = x.transpose(1, 2)
-
- # Apply the 1D convolution
- x = self.conv(x)[:, :, :L]
-
- # Transpose back to (batch_size, seq_length, input_size)
- x = x.transpose(1, 2)
-
- # Loop through the RNN layers and apply 1D convolution before each
- for i, layer in enumerate(self.rnns):
- # Transpose to (batch_size, input_size, seq_length) for Conv1d
-
- # Pass through the RNN layer
- x, _ = layer(x)
-
- # Residual connection with learnable matrix
- if self.residuals:
- if i < self.num_layers and i > 0:
- residual_proj = torch.matmul(residual, self.residual_matrix[i]) # type: ignore
- x = x + residual_proj
-
- # Update residual for next layer
- residual = x
-
- return x, _
+import torch
+import torch.nn as nn
+
+from .layer_utils.batch_ensemble_layer import RNNBatchEnsembleLayer
+from .lstm_utils import mLSTMblock, sLSTMblock
+
+
+class ConvRNN(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # Configuration parameters with defaults where needed
+ # 'RNN', 'LSTM', or 'GRU'
+ self.model_type = getattr(config, "model_type", "RNN")
+ self.input_size = getattr(config, "d_model", 128)
+ self.hidden_size = getattr(config, "dim_feedforward", 128)
+ self.num_layers = getattr(config, "n_layers", 4)
+ self.rnn_dropout = getattr(config, "rnn_dropout", 0.0)
+ self.bias = getattr(config, "bias", True)
+ self.conv_bias = getattr(config, "conv_bias", True)
+ self.rnn_activation = getattr(config, "rnn_activation", "relu")
+ self.d_conv = getattr(config, "d_conv", 4)
+ self.residuals = getattr(config, "residuals", False)
+ self.dilation = getattr(config, "dilation", 1)
+
+ # Choose RNN layer based on model_type
+ rnn_layer = {
+ "RNN": nn.RNN,
+ "LSTM": nn.LSTM,
+ "GRU": nn.GRU,
+ "mLSTM": mLSTMblock,
+ "sLSTM": sLSTMblock,
+ }[self.model_type]
+
+ # Convolutional layers
+ self.convs = nn.ModuleList()
+ self.layernorms_conv = nn.ModuleList() # LayerNorms for Conv layers
+
+ if self.residuals:
+ self.residual_matrix = nn.ParameterList(
+ [
+ nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
+ for _ in range(self.num_layers)
+ ]
+ )
+
+ # First Conv1d layer uses input_size
+ self.convs.append(
+ nn.Conv1d(
+ in_channels=self.input_size,
+ out_channels=self.input_size,
+ kernel_size=self.d_conv,
+ padding=self.d_conv - 1,
+ bias=self.conv_bias,
+ groups=self.input_size,
+ dilation=self.dilation,
+ )
+ )
+ self.layernorms_conv.append(nn.LayerNorm(self.input_size))
+
+ # Subsequent Conv1d layers use hidden_size as input
+ for i in range(self.num_layers - 1):
+ self.convs.append(
+ nn.Conv1d(
+ in_channels=self.hidden_size,
+ out_channels=self.hidden_size,
+ kernel_size=self.d_conv,
+ padding=self.d_conv - 1,
+ bias=self.conv_bias,
+ groups=self.hidden_size,
+ dilation=self.dilation,
+ )
+ )
+ self.layernorms_conv.append(nn.LayerNorm(self.hidden_size))
+
+ # Initialize the RNN layers
+ self.rnns = nn.ModuleList()
+ self.layernorms_rnn = nn.ModuleList() # LayerNorms for RNN layers
+
+ for i in range(self.num_layers):
+ rnn_args = {
+ "input_size": self.input_size if i == 0 else self.hidden_size,
+ "hidden_size": self.hidden_size,
+ "num_layers": 1,
+ "batch_first": True,
+ "dropout": self.rnn_dropout if i < self.num_layers - 1 else 0,
+ "bias": self.bias,
+ }
+ if self.model_type == "RNN":
+ rnn_args["nonlinearity"] = self.rnn_activation
+ self.rnns.append(rnn_layer(**rnn_args))
+ self.layernorms_rnn.append(nn.LayerNorm(self.hidden_size))
+
+ def forward(self, x):
+ """Forward pass through Conv-RNN layers.
+
+ Parameters
+ -----------
+ x : torch.Tensor
+ Input tensor of shape (batch_size, seq_length, input_size).
+
+ Returns
+ --------
+ output : torch.Tensor
+ Output tensor after passing through Conv-RNN layers.
+ """
+ _, L, _ = x.shape
+ if self.residuals:
+ residual = x
+
+ # Loop through the RNN layers and apply 1D convolution before each
+ for i in range(self.num_layers):
+ # Transpose to (batch_size, input_size, seq_length) for Conv1d
+
+ x = self.layernorms_conv[i](x)
+ x = x.transpose(1, 2)
+
+ # Apply the 1D convolution
+ x = self.convs[i](x)[:, :, :L]
+
+ # Transpose back to (batch_size, seq_length, input_size)
+ x = x.transpose(1, 2)
+
+ # Pass through the RNN layer
+ x, _ = self.rnns[i](x)
+
+ # Residual connection with learnable matrix
+ if self.residuals:
+ if i < self.num_layers and i > 0:
+ residual_proj = torch.matmul(residual, self.residual_matrix[i]) # type: ignore
+ x = x + residual_proj
+
+ # Update residual for next layer
+ residual = x
+
+ return x, _
+
+
+class EnsembleConvRNN(nn.Module):
+ def __init__(
+ self,
+ config,
+ ):
+ super().__init__()
+
+ self.input_size = getattr(config, "d_model", 128)
+ self.hidden_size = getattr(config, "dim_feedforward", 128)
+ self.ensemble_size = getattr(config, "ensemble_size", 16)
+ self.num_layers = getattr(config, "n_layers", 4)
+ self.rnn_dropout = getattr(config, "rnn_dropout", 0.5)
+ self.bias = getattr(config, "bias", True)
+ self.conv_bias = getattr(config, "conv_bias", True)
+ self.rnn_activation = getattr(config, "rnn_activation", torch.tanh)
+ self.d_conv = getattr(config, "d_conv", 4)
+ self.residuals = getattr(config, "residuals", False)
+ self.ensemble_scaling_in = getattr(config, "ensemble_scaling_in", True)
+ self.ensemble_scaling_out = getattr(config, "ensemble_scaling_out", True)
+ self.ensemble_bias = getattr(config, "ensemble_bias", False)
+ self.scaling_init = getattr(config, "scaling_init", "ones")
+ self.model_type = getattr(config, "model_type", "full")
+
+ # Convolutional layers
+ self.convs = nn.ModuleList()
+ self.layernorms_conv = nn.ModuleList() # LayerNorms for Conv layers
+
+ if self.residuals:
+ self.residual_matrix = nn.ParameterList(
+ [
+ nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
+ for _ in range(self.num_layers)
+ ]
+ )
+
+ # First Conv1d layer uses input_size
+ self.conv = nn.Conv1d(
+ in_channels=self.input_size,
+ out_channels=self.input_size,
+ kernel_size=self.d_conv,
+ padding=self.d_conv - 1,
+ bias=self.conv_bias,
+ groups=self.input_size,
+ )
+
+ self.layernorms_conv = nn.LayerNorm(self.input_size)
+
+ # Initialize the RNN layers
+ self.rnns = nn.ModuleList()
+ self.layernorms_rnn = nn.ModuleList() # LayerNorms for RNN layers
+
+ self.rnns.append(
+ RNNBatchEnsembleLayer(
+ input_size=self.input_size,
+ hidden_size=self.hidden_size,
+ ensemble_size=self.ensemble_size,
+ ensemble_scaling_in=self.ensemble_scaling_in,
+ ensemble_scaling_out=self.ensemble_scaling_out,
+ ensemble_bias=self.ensemble_bias,
+ dropout=self.rnn_dropout,
+ nonlinearity=self.rnn_activation,
+ scaling_init="normal",
+ )
+ )
+
+ for i in range(1, self.num_layers):
+ if self.model_type == "mini":
+ rnn = RNNBatchEnsembleLayer(
+ input_size=self.hidden_size,
+ hidden_size=self.hidden_size,
+ ensemble_size=self.ensemble_size,
+ ensemble_scaling_in=False,
+ ensemble_scaling_out=False,
+ ensemble_bias=self.ensemble_bias,
+ dropout=self.rnn_dropout if i < self.num_layers - 1 else 0,
+ nonlinearity=self.rnn_activation,
+ scaling_init=self.scaling_init, # type: ignore
+ )
+ else:
+ rnn = RNNBatchEnsembleLayer(
+ input_size=self.hidden_size,
+ hidden_size=self.hidden_size,
+ ensemble_size=self.ensemble_size,
+ ensemble_scaling_in=self.ensemble_scaling_in,
+ ensemble_scaling_out=self.ensemble_scaling_out,
+ ensemble_bias=self.ensemble_bias,
+ dropout=self.rnn_dropout if i < self.num_layers - 1 else 0,
+ nonlinearity=self.rnn_activation,
+ scaling_init=self.scaling_init, # type: ignore
+ )
+
+ self.rnns.append(rnn)
+
+ def forward(self, x):
+ """Forward pass through Conv-RNN layers.
+
+ Parameters
+ -----------
+ x : torch.Tensor
+ Input tensor of shape (batch_size, seq_length, input_size).
+
+ Returns
+ --------
+ output : torch.Tensor
+ Output tensor after passing through Conv-RNN layers.
+ """
+ _, L, _ = x.shape
+ if self.residuals:
+ residual = x
+
+ x = self.layernorms_conv(x)
+ x = x.transpose(1, 2)
+
+ # Apply the 1D convolution
+ x = self.conv(x)[:, :, :L]
+
+ # Transpose back to (batch_size, seq_length, input_size)
+ x = x.transpose(1, 2)
+
+ # Loop through the RNN layers and apply 1D convolution before each
+ for i, layer in enumerate(self.rnns):
+ # Transpose to (batch_size, input_size, seq_length) for Conv1d
+
+ # Pass through the RNN layer
+ x, _ = layer(x)
+
+ # Residual connection with learnable matrix
+ if self.residuals:
+ if i < self.num_layers and i > 0:
+ residual_proj = torch.matmul(residual, self.residual_matrix[i]) # type: ignore
+ x = x + residual_proj
+
+ # Update residual for next layer
+ residual = x
+
+ return x, _
diff --git a/deeptabular/arch_utils/simple_utils.py b/deeptab/arch_utils/simple_utils.py
similarity index 100%
rename from deeptabular/arch_utils/simple_utils.py
rename to deeptab/arch_utils/simple_utils.py
diff --git a/deeptabular/arch_utils/transformer_utils.py b/deeptab/arch_utils/transformer_utils.py
similarity index 100%
rename from deeptabular/arch_utils/transformer_utils.py
rename to deeptab/arch_utils/transformer_utils.py
diff --git a/deeptabular/arch_utils/trompt_utils.py b/deeptab/arch_utils/trompt_utils.py
similarity index 100%
rename from deeptabular/arch_utils/trompt_utils.py
rename to deeptab/arch_utils/trompt_utils.py
diff --git a/deeptabular/base_models/__init__.py b/deeptab/base_models/__init__.py
similarity index 100%
rename from deeptabular/base_models/__init__.py
rename to deeptab/base_models/__init__.py
diff --git a/deeptabular/base_models/autoint.py b/deeptab/base_models/autoint.py
similarity index 100%
rename from deeptabular/base_models/autoint.py
rename to deeptab/base_models/autoint.py
diff --git a/deeptabular/base_models/enode.py b/deeptab/base_models/enode.py
similarity index 100%
rename from deeptabular/base_models/enode.py
rename to deeptab/base_models/enode.py
diff --git a/deeptabular/base_models/ft_transformer.py b/deeptab/base_models/ft_transformer.py
similarity index 100%
rename from deeptabular/base_models/ft_transformer.py
rename to deeptab/base_models/ft_transformer.py
diff --git a/deeptabular/base_models/mambatab.py b/deeptab/base_models/mambatab.py
similarity index 100%
rename from deeptabular/base_models/mambatab.py
rename to deeptab/base_models/mambatab.py
diff --git a/deeptabular/base_models/mambattn.py b/deeptab/base_models/mambattn.py
similarity index 97%
rename from deeptabular/base_models/mambattn.py
rename to deeptab/base_models/mambattn.py
index c5baefcc..e0f36d85 100644
--- a/deeptabular/base_models/mambattn.py
+++ b/deeptab/base_models/mambattn.py
@@ -1,132 +1,132 @@
-import torch
-import numpy as np
-from ..arch_utils.get_norm_fn import get_normalization_layer
-from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
-from ..arch_utils.mamba_utils.mambattn_arch import MambAttn
-from ..arch_utils.mlp_utils import MLPhead
-from ..configs.mambattention_config import DefaultMambAttentionConfig
-from .utils.basemodel import BaseModel
-
-
-class MambAttention(BaseModel):
- """A MambAttention model for tabular data, integrating feature embeddings, attention-based Mamba transformations,
- and a customizable architecture for handling categorical and numerical features.
-
- Parameters
- ----------
- cat_feature_info : dict
- Dictionary containing information about categorical features, including their names and dimensions.
- num_feature_info : dict
- Dictionary containing information about numerical features, including their names and dimensions.
- num_classes : int, optional
- The number of output classes or target dimensions for regression, by default 1.
- config : DefaultMambAttentionConfig, optional
- Configuration object with model hyperparameters such as dropout rates, head layer sizes, attention settings,
- and other architectural configurations, by default DefaultMambAttentionConfig().
- **kwargs : dict
- Additional keyword arguments for the BaseModel class.
-
- Attributes
- ----------
- pooling_method : str
- Pooling method to aggregate features after the Mamba attention layer.
- shuffle_embeddings : bool
- Flag indicating if embeddings should be shuffled, as specified in the configuration.
- mamba : MambAttn
- Mamba attention layer to process embedded features.
- norm_f : nn.Module
- Normalization layer for the processed features.
- embedding_layer : EmbeddingLayer
- Layer for embedding categorical and numerical features.
- tabular_head : MLPhead
- MLPhead layer to produce the final prediction based on the output of the Mamba attention layer.
- perm : torch.Tensor, optional
- Permutation tensor used for shuffling embeddings, if enabled.
-
- Methods
- -------
- forward(num_features, cat_features)
- Perform a forward pass through the model, including embedding, Mamba attention transformation, pooling,
- and prediction steps.
- """
-
- def __init__(
- self,
- feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
- num_classes=1,
- config: DefaultMambAttentionConfig = DefaultMambAttentionConfig(), # noqa: B008
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["feature_information"])
-
- self.returns_ensemble = False
-
- try:
- self.pooling_method = self.hparams.pooling_method
- except AttributeError:
- self.pooling_method = config.pooling_method
-
- try:
- self.shuffle_embeddings = self.hparams.shuffle_embeddings
- except AttributeError:
- self.shuffle_embeddings = config.shuffle_embeddings
-
- self.mamba = MambAttn(config)
- self.norm_f = get_normalization_layer(config)
-
- # embedding layer
- self.embedding_layer = EmbeddingLayer(
- *feature_information,
- config=config,
- )
-
- try:
- head_activation = self.hparams.head_activation
- except AttributeError:
- head_activation = config.head_activation
-
- try:
- input_dim = self.hparams.d_model
- except AttributeError:
- input_dim = config.d_model
-
- self.tabular_head = MLPhead(
- input_dim=input_dim,
- config=config,
- output_dim=num_classes,
- )
-
- if self.shuffle_embeddings:
- self.perm = torch.randperm(self.embedding_layer.seq_len)
-
- # pooling
- n_inputs = np.sum([len(info) for info in feature_information])
- self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
-
- def forward(self, *data):
- """Defines the forward pass of the model.
-
- Parameters
- ----------
- data : tuple
- Input tuple of tensors of num_features, cat_features, embeddings.
-
- Returns
- -------
- torch.Tensor
- Output tensor.
- """
- x = self.embedding_layer(*data)
-
- if self.shuffle_embeddings:
- x = x[:, self.perm, :]
-
- x = self.mamba(x)
-
- x = self.pool_sequence(x)
-
- x = self.norm_f(x) # type: ignore
- preds = self.tabular_head(x)
-
- return preds
+import torch
+import numpy as np
+from ..arch_utils.get_norm_fn import get_normalization_layer
+from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
+from ..arch_utils.mamba_utils.mambattn_arch import MambAttn
+from ..arch_utils.mlp_utils import MLPhead
+from ..configs.mambattention_config import DefaultMambAttentionConfig
+from .utils.basemodel import BaseModel
+
+
+class MambAttention(BaseModel):
+ """A MambAttention model for tabular data, integrating feature embeddings, attention-based Mamba transformations,
+ and a customizable architecture for handling categorical and numerical features.
+
+ Parameters
+ ----------
+ cat_feature_info : dict
+ Dictionary containing information about categorical features, including their names and dimensions.
+ num_feature_info : dict
+ Dictionary containing information about numerical features, including their names and dimensions.
+ num_classes : int, optional
+ The number of output classes or target dimensions for regression, by default 1.
+ config : DefaultMambAttentionConfig, optional
+ Configuration object with model hyperparameters such as dropout rates, head layer sizes, attention settings,
+ and other architectural configurations, by default DefaultMambAttentionConfig().
+ **kwargs : dict
+ Additional keyword arguments for the BaseModel class.
+
+ Attributes
+ ----------
+ pooling_method : str
+ Pooling method to aggregate features after the Mamba attention layer.
+ shuffle_embeddings : bool
+ Flag indicating if embeddings should be shuffled, as specified in the configuration.
+ mamba : MambAttn
+ Mamba attention layer to process embedded features.
+ norm_f : nn.Module
+ Normalization layer for the processed features.
+ embedding_layer : EmbeddingLayer
+ Layer for embedding categorical and numerical features.
+ tabular_head : MLPhead
+ MLPhead layer to produce the final prediction based on the output of the Mamba attention layer.
+ perm : torch.Tensor, optional
+ Permutation tensor used for shuffling embeddings, if enabled.
+
+ Methods
+ -------
+ forward(num_features, cat_features)
+ Perform a forward pass through the model, including embedding, Mamba attention transformation, pooling,
+ and prediction steps.
+ """
+
+ def __init__(
+ self,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
+ num_classes=1,
+ config: DefaultMambAttentionConfig = DefaultMambAttentionConfig(), # noqa: B008
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
+ self.save_hyperparameters(ignore=["feature_information"])
+
+ self.returns_ensemble = False
+
+ try:
+ self.pooling_method = self.hparams.pooling_method
+ except AttributeError:
+ self.pooling_method = config.pooling_method
+
+ try:
+ self.shuffle_embeddings = self.hparams.shuffle_embeddings
+ except AttributeError:
+ self.shuffle_embeddings = config.shuffle_embeddings
+
+ self.mamba = MambAttn(config)
+ self.norm_f = get_normalization_layer(config)
+
+ # embedding layer
+ self.embedding_layer = EmbeddingLayer(
+ *feature_information,
+ config=config,
+ )
+
+ try:
+ head_activation = self.hparams.head_activation
+ except AttributeError:
+ head_activation = config.head_activation
+
+ try:
+ input_dim = self.hparams.d_model
+ except AttributeError:
+ input_dim = config.d_model
+
+ self.tabular_head = MLPhead(
+ input_dim=input_dim,
+ config=config,
+ output_dim=num_classes,
+ )
+
+ if self.shuffle_embeddings:
+ self.perm = torch.randperm(self.embedding_layer.seq_len)
+
+ # pooling
+ n_inputs = np.sum([len(info) for info in feature_information])
+ self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
+
+ def forward(self, *data):
+ """Defines the forward pass of the model.
+
+ Parameters
+ ----------
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor.
+ """
+ x = self.embedding_layer(*data)
+
+ if self.shuffle_embeddings:
+ x = x[:, self.perm, :]
+
+ x = self.mamba(x)
+
+ x = self.pool_sequence(x)
+
+ x = self.norm_f(x) # type: ignore
+ preds = self.tabular_head(x)
+
+ return preds
diff --git a/deeptabular/base_models/mambular.py b/deeptab/base_models/mambular.py
similarity index 100%
rename from deeptabular/base_models/mambular.py
rename to deeptab/base_models/mambular.py
diff --git a/deeptabular/base_models/mlp.py b/deeptab/base_models/mlp.py
similarity index 100%
rename from deeptabular/base_models/mlp.py
rename to deeptab/base_models/mlp.py
diff --git a/deeptabular/base_models/modern_nca.py b/deeptab/base_models/modern_nca.py
similarity index 100%
rename from deeptabular/base_models/modern_nca.py
rename to deeptab/base_models/modern_nca.py
diff --git a/deeptabular/base_models/ndtf.py b/deeptab/base_models/ndtf.py
similarity index 100%
rename from deeptabular/base_models/ndtf.py
rename to deeptab/base_models/ndtf.py
diff --git a/deeptabular/base_models/node.py b/deeptab/base_models/node.py
similarity index 100%
rename from deeptabular/base_models/node.py
rename to deeptab/base_models/node.py
diff --git a/deeptabular/base_models/resnet.py b/deeptab/base_models/resnet.py
similarity index 100%
rename from deeptabular/base_models/resnet.py
rename to deeptab/base_models/resnet.py
diff --git a/deeptabular/base_models/saint.py b/deeptab/base_models/saint.py
similarity index 100%
rename from deeptabular/base_models/saint.py
rename to deeptab/base_models/saint.py
diff --git a/deeptabular/base_models/tabm.py b/deeptab/base_models/tabm.py
similarity index 100%
rename from deeptabular/base_models/tabm.py
rename to deeptab/base_models/tabm.py
diff --git a/deeptabular/base_models/tabr.py b/deeptab/base_models/tabr.py
similarity index 97%
rename from deeptabular/base_models/tabr.py
rename to deeptab/base_models/tabr.py
index e2f03ca6..b603cab5 100644
--- a/deeptabular/base_models/tabr.py
+++ b/deeptab/base_models/tabr.py
@@ -1,490 +1,490 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-from ..utils.get_feature_dimensions import get_feature_dimensions
-from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
-from ..configs.tabr_config import DefaultTabRConfig
-from .utils.basemodel import BaseModel
-from torch import Tensor
-import math
-
-class TabR(BaseModel):
- delu = None
- faiss = None
- faiss_torch_utils = None
-
- def __init__(
- self,
- feature_information: tuple,
- num_classes=1,
- config: DefaultTabRConfig = DefaultTabRConfig(), # noqa: B008
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["feature_information"])
-
- # lazy import
- if TabR.delu or TabR.faiss or TabR.faiss_torch_utils is None:
- self._lazy_import_dependencies()
-
- self.returns_ensemble = False
- self.uses_candidates = True
-
- if self.hparams.use_embeddings:
- self.embedding_layer = EmbeddingLayer(
- *feature_information,
- config=config,
- )
- print(self.embedding_layer)
- input_dim = np.sum(
- [len(info) * self.hparams.d_model for info in feature_information]
- )
- else:
- input_dim = get_feature_dimensions(*feature_information)
-
- self.hparams.num_classes = num_classes
- memory_efficient = self.hparams.memory_efficient
- mixer_normalization = self.hparams.mixer_normalization
- encoder_n_blocks = self.hparams.encoder_n_blocks
- predictor_n_blocks = self.hparams.predictor_n_blocks
- dropout0 = self.hparams.dropout1
- self.candidate_encoding_batch_size = self.hparams.candidate_encoding_batch_size
- d_main = self.hparams.d_main
- d_multiplier = self.hparams.d_multiplier
- normalization = self.hparams.normalization
- activation = self.hparams.activation
- dropout0 = self.hparams.dropout0
- dropout1 = self.hparams.dropout1
- context_dropout = self.hparams.context_dropout
-
- if memory_efficient:
- assert self.candidate_encoding_batch_size !=0
-
- if mixer_normalization == 'auto':
- mixer_normalization = encoder_n_blocks > 0
- if encoder_n_blocks == 0:
- assert not mixer_normalization
-
- # Encoder Module: E
- d_in = input_dim
- d_block = int(d_main * d_multiplier)
- Normalization = getattr(nn, normalization)
- self.linear = nn.Linear(d_in, d_main)
- self.context_size = self.hparams.context_size
-
-
- def make_block(prenorm: bool) -> nn.Sequential:
- return nn.Sequential(
- *([Normalization(d_main)] if prenorm else []),
- nn.Linear(d_main, d_block),
- activation,
- nn.Dropout(dropout0),
- nn.Linear(d_block, d_main),
- nn.Dropout(dropout1),
- )
-
- # here in the TabR paper, for first block of Encoder(E),
- # LayerNorm is omitted. In code, we omitted Normalization.
- self.blocks0 = nn.ModuleList(
- [make_block(i > 0) for i in range(encoder_n_blocks)]
- )
-
- # Retrieval Module: R
- self.normalization = Normalization(d_main) if mixer_normalization else None
-
- delu = TabR.delu
- self.label_encoder = (
- nn.Linear(1, d_main)
- if num_classes == 1
- else nn.Sequential(
- nn.Embedding(num_classes, d_main),
- # gives depreciation warning
- delu.nn.Lambda(lambda x: x.squeeze(-2)) # Removes the unnecessary extra dimension added by the embedding layer
- )
- )
- self.K = nn.Linear(d_main, d_main) # W_k in paper
- self.T = nn.Sequential(
- nn.Linear(d_main, d_block),
- activation,
- nn.Dropout(dropout0),
- nn.Linear(d_block, d_main, bias=False),
- ) # T for T(k-k_i) form the TabR paper.
- self.dropout = nn.Dropout(context_dropout)
-
- # Predictor Module : P
- self.blocks1 = nn.ModuleList(
- [make_block(True) for _ in range(predictor_n_blocks)]
- )
- self.head = nn.Sequential(
- Normalization(d_main),
- activation,
- nn.Linear(d_main, num_classes),
- )
-
- # >>>
- self.search_index = None
- self.memory_efficient = memory_efficient
- self.reset_parameters()
-
- def reset_parameters(self):
- if isinstance(self.label_encoder, nn.Linear): # if num_classes==1
- bound = 1 / math.sqrt(2.0) # He initialization (common for layers with ReLU activation)
- nn.init.uniform_(self.label_encoder.weight, -bound, bound) # type: ignore[code] # noqa: E501
- nn.init.uniform_(self.label_encoder.bias, -bound, bound) # type: ignore[code] # noqa: E501
- else:
- assert isinstance(self.label_encoder[0], nn.Embedding)
- nn.init.uniform_(self.label_encoder[0].weight, -1.0, 1.0) # type: ignore[code] # noqa: E501
-
- def _lazy_import_dependencies(self):
- """Lazily import external dependencies and store them as class attributes."""
- if TabR.delu is None:
- try:
- import delu
- TabR.delu = delu
- print("Successfully lazy imported delu dependency.")
-
- except ImportError:
- raise ImportError("Failed to import delu module for TabR. Ensure all dependencies are installed\n"
- "You can install delu running 'pip install delu'.") from None
-
- if TabR.faiss is None:
- try:
- import faiss
- import faiss.contrib.torch_utils
-
- TabR.faiss = faiss
- TabR.faiss_torch_utils = faiss.contrib.torch_utils
- print("Successfully lazy imported faiss dependency")
-
- except ImportError as e:
- raise ImportError("Failed to import faiss module for TabR. Ensure all dependencies are installed\n"
- "You can install faiss running 'pip install faiss-cpu' for CPU and 'pip install faiss-gpu' for GPU.") from None
-
- def _encode(
- self,
- a
- ):
- # x = x.double() # issue
- x = a.float()
- # x=a.clone().detach().requires_grad_(True)
- x = x.float()
- x = self.linear(x)
- for block in self.blocks0:
- x = x + block(x)
- k = self.K(x if self.normalization is None else self.normalization(x))
-
- return x, k
-
- def forward(
- self,
- *data
- ):
- """
- Standard forward pass without candidate selection (for baseline compatibility).
- """
- if self.hparams.use_embeddings:
- x = self.embedding_layer(*data)
- B, S, D = x.shape
- x = x.reshape(B, S * D)
- else:
- x = torch.cat([t for tensors in data for t in tensors], dim=1)
- x,k = self._encode(x)
- context_k = k.unsqueeze(1).expand(-1, self.context_size, -1) # using the batch itself as context
- similarities = (
- -k.square().sum(-1, keepdim=True)
- + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
- - context_k.square().sum(-1)
- )
- probs = F.softmax(similarities, dim=-1)
- context_x = torch.sum(probs.unsqueeze(-1) * context_k, dim=1)
- t = self.T(self.dropout(context_x))
- for block in self.blocks1:
- x = x + block(x + t)
- return self.head(x)
-
- def train_with_candidates(
- self,
- *data,
- targets,
- candidate_x,
- candidate_y
- ):
- """TabR-style training forward pass selecting candidates."""
- assert targets is not None
-
- if self.hparams.use_embeddings:
- x = self.embedding_layer(*data)
- B, S, D = x.shape
- x = x.reshape(B, S * D)
- candidate_x = self.embedding_layer(*candidate_x)
- B, S, D = candidate_x.shape
- candidate_x = candidate_x.reshape(B, S * D)
- else:
-
- x = torch.cat([t for tensors in data for t in tensors], dim=1)
- candidate_x = torch.cat(
- [t for tensors in candidate_x for t in tensors], dim=1
- )
-
- with torch.set_grad_enabled(
- torch.is_grad_enabled() and not self.memory_efficient
- ):
-
- candidate_k = (
- self._encode(candidate_x)[1] # normalized candidate_x
- if self.candidate_encoding_batch_size ==0
- else torch.cat(
- [
- self._encode(x)[1] # normalized x
- # for x in delu.iter_batches(
- for x in TabR.delu.iter_batches(
- candidate_x,
- self.candidate_encoding_batch_size
- )
- ]
- )
- )
-
- # Encode input
- x, k = self._encode(x)
-
- batch_size, d_main = k.shape
- device = k.device
- context_size = self.context_size
-
- with torch.no_grad():
- # initializing the search index
- if self.search_index is None:
- self.search_index = (
- TabR.faiss.GpuIndexFlatL2(
- TabR.faiss.StandardGpuResources(),
- d_main
- )
- if device.type == 'cuda'
- else TabR.faiss.IndexFlatL2(d_main)
- )
- # Updating the index is much faster than creating a new one.
- self.search_index.reset()
- self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
- distances: Tensor
- context_idx: Tensor
- distances, context_idx = self.search_index.search( # type: ignore[code]
- k.to(torch.float32), context_size + 1
- )
- # NOTE: to avoid leakage, the index i must be removed from the i-th row,
- # (because of how candidate_k is constructed).
- distances[
- context_idx == torch.arange(batch_size, device=device)[:, None]
- ] = torch.inf
- # Not the most elegant solution to remove the argmax, but anyway.
- context_idx = context_idx.gather(-1, distances.argsort()[:, :-1])
-
- if self.memory_efficient and torch.is_grad_enabled():
- # Repeating the same computation,
- # but now only for the context objects and with autograd on.
- context_k = self._encode(
- torch.cat([x,candidate_x])[context_idx].flatten(0,1)
- )[1].reshape(batch_size, context_size, -1)
- else:
- context_k = candidate_k[context_idx]
-
- # In theory, when autograd is off, the distances obtained during the search
- # can be reused. However, this is not a bottleneck, so let's keep it simple
- # and use the same code to compute `similarities` during both
- # training and evaluation.
- similarities = (
- -k.square().sum(-1, keepdim=True)
- + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
- - context_k.square().sum(-1)
- )
- probs = F.softmax(similarities, dim=-1)
- probs = self.dropout(probs)
-
- if self.hparams.num_classes > 1: # for classification
- context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
- else: # for regression
- context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
- if len(context_y_emb.shape) == 4:
- context_y_emb = context_y_emb[:,:,0,:]
-
- # Combine keys and labels with a transformation T.
- values = context_y_emb + self.T(k[:, None] - context_k)
- context_x = (probs[:, None] @ values).squeeze(1)
- x = x + context_x
-
- # Predictor has LayerNorm, ReLU and Linear after the N_P number of blocks.
- for block in self.blocks1:
- x = x + block(x)
- x = self.head(x)
- return x
-
- def validate_with_candidates(
- self,
- *data,
- candidate_x,
- candidate_y
- ):
- """Validation forward pass with TabR-style candidate selection."""
- if self.hparams.use_embeddings:
- x = self.embedding_layer(*data)
- B, S, D = x.shape
- x = x.reshape(B, S * D)
- candidate_x = self.embedding_layer(*candidate_x)
- B, S, D = candidate_x.shape
- candidate_x = candidate_x.reshape(B, S * D)
- else:
- x = torch.cat([t for tensors in data for t in tensors], dim=1)
- candidate_x = torch.cat(
- [t for tensors in candidate_x for t in tensors], dim=1
- )
-
- if not self.memory_efficient:
- candidate_k = (
- self._encode(candidate_x)[1] # normalized candidate_x
- if self.candidate_encoding_batch_size == 0
- else torch.cat(
- [
- self._encode(x)[1] # normalized x
- for x in TabR.delu.iter_batches(
- candidate_x,
- self.candidate_encoding_batch_size
- )
- ]
- )
- )
- else:
- candidate_x, candidate_k = self._encode(candidate_x)
-
- x, k = self._encode(x) # encoded x and k
- batch_size, d_main = k.shape
- device = k.device
- context_size = self.context_size
-
- if self.search_index is None:
- self.search_index = (
- TabR.faiss.GpuIndexFlatL2(TabR.faiss.StandardGpuResources(), d_main)
- if device.type == 'cuda'
- else TabR.faiss.IndexFlatL2(d_main)
- )
-
- # Updating the index is much faster than creating a new one.
- self.search_index.reset()
- self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
- distances: Tensor
- context_idx: Tensor
- distances, context_idx = self.search_index.search( # type: ignore[code]
- k.to(torch.float32), context_size
- )
-
- context_k = candidate_k[context_idx]
- similarities = (
- -k.square().sum(-1, keepdim=True)
- + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
- - context_k.square().sum(-1)
- )
- probs = F.softmax(similarities, dim=-1)
- probs = self.dropout(probs)
-
- if self.hparams.num_classes > 1: # for classification
- context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
- else: # for regression
- context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
- if len(context_y_emb.shape) == 4:
- context_y_emb = context_y_emb[:,:,0,:]
-
- values = context_y_emb + self.T(k[:, None] - context_k)
- context_x = (probs[:, None] @ values).squeeze(1)
- x = x + context_x
-
- # Predictor has LayerNorm, ReLU and Linear after the N_P number of blocks.
- for block in self.blocks1:
- x = x + block(x)
- x = self.head(x)
- return x
-
-
- def predict_with_candidates(
- self,
- *data,
- candidate_x,
- candidate_y
- ):
- """Prediction forward pass with TabR-style candidate selection."""
- if self.hparams.use_embeddings:
- x = self.embedding_layer(*data)
- B, S, D = x.shape
- x = x.reshape(B, S * D)
- candidate_x = self.embedding_layer(*candidate_x)
- B, S, D = candidate_x.shape
- candidate_x = candidate_x.reshape(B, S * D)
- else:
- x = torch.cat([t for tensors in data for t in tensors], dim=1)
- candidate_x = torch.cat(
- [t for tensors in candidate_x for t in tensors], dim=1
- )
-
- if not self.memory_efficient:
- candidate_k = (
- self._encode(candidate_x)[1] # normalized candidate_x
- if self.candidate_encoding_batch_size == 0
- else torch.cat(
- [
- self._encode(x)[1] # normalized x
- for x in TabR.delu.iter_batches(
- candidate_x,
- self.candidate_encoding_batch_size
- )
- ]
- )
- )
- else:
- candidate_x, candidate_k = self._encode(candidate_x)
-
- x, k = self._encode(x) # encoded x and k
- batch_size, d_main = k.shape
- device = k.device
- context_size = self.context_size
-
- if self.search_index is None:
- self.search_index = (
- TabR.faiss.GpuIndexFlatL2(TabR.faiss.StandardGpuResources(), d_main)
- if device.type == 'cuda'
- else TabR.faiss.IndexFlatL2(d_main)
- )
-
-
- # Updating the index is much faster than creating a new one.
- self.search_index.reset()
- self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
- distances: Tensor
- context_idx: Tensor
- distances, context_idx = self.search_index.search( # type: ignore[code]
- k.to(torch.float32), context_size
- )
-
- context_k = candidate_k[context_idx]
- similarities = (
- -k.square().sum(-1, keepdim=True)
- + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
- - context_k.square().sum(-1)
- )
- probs = F.softmax(similarities, dim=-1)
- probs = self.dropout(probs)
-
- if self.hparams.num_classes > 1: # for classification
- context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
- else: # for regression
- context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
- if len(context_y_emb.shape) == 4:
- context_y_emb = context_y_emb[:,:,0,:]
-
- values = context_y_emb + self.T(k[:, None] - context_k)
- context_x = (probs[:, None] @ values).squeeze(1)
- x = x + context_x
-
- # Predictor has LayerNorm, ReLU and Linear after the N_P number of blocks.
- for block in self.blocks1:
- x = x + block(x)
- x = self.head(x)
- return x
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ..utils.get_feature_dimensions import get_feature_dimensions
+from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
+from ..configs.tabr_config import DefaultTabRConfig
+from .utils.basemodel import BaseModel
+from torch import Tensor
+import math
+
+class TabR(BaseModel):
+ delu = None
+ faiss = None
+ faiss_torch_utils = None
+
+ def __init__(
+ self,
+ feature_information: tuple,
+ num_classes=1,
+ config: DefaultTabRConfig = DefaultTabRConfig(), # noqa: B008
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
+ self.save_hyperparameters(ignore=["feature_information"])
+
+ # lazy import
+ if TabR.delu or TabR.faiss or TabR.faiss_torch_utils is None:
+ self._lazy_import_dependencies()
+
+ self.returns_ensemble = False
+ self.uses_candidates = True
+
+ if self.hparams.use_embeddings:
+ self.embedding_layer = EmbeddingLayer(
+ *feature_information,
+ config=config,
+ )
+ print(self.embedding_layer)
+ input_dim = np.sum(
+ [len(info) * self.hparams.d_model for info in feature_information]
+ )
+ else:
+ input_dim = get_feature_dimensions(*feature_information)
+
+ self.hparams.num_classes = num_classes
+ memory_efficient = self.hparams.memory_efficient
+ mixer_normalization = self.hparams.mixer_normalization
+ encoder_n_blocks = self.hparams.encoder_n_blocks
+ predictor_n_blocks = self.hparams.predictor_n_blocks
+ dropout0 = self.hparams.dropout1
+ self.candidate_encoding_batch_size = self.hparams.candidate_encoding_batch_size
+ d_main = self.hparams.d_main
+ d_multiplier = self.hparams.d_multiplier
+ normalization = self.hparams.normalization
+ activation = self.hparams.activation
+ dropout0 = self.hparams.dropout0
+ dropout1 = self.hparams.dropout1
+ context_dropout = self.hparams.context_dropout
+
+ if memory_efficient:
+ assert self.candidate_encoding_batch_size !=0
+
+ if mixer_normalization == 'auto':
+ mixer_normalization = encoder_n_blocks > 0
+ if encoder_n_blocks == 0:
+ assert not mixer_normalization
+
+ # Encoder Module: E
+ d_in = input_dim
+ d_block = int(d_main * d_multiplier)
+ Normalization = getattr(nn, normalization)
+ self.linear = nn.Linear(d_in, d_main)
+ self.context_size = self.hparams.context_size
+
+
+ def make_block(prenorm: bool) -> nn.Sequential:
+ return nn.Sequential(
+ *([Normalization(d_main)] if prenorm else []),
+ nn.Linear(d_main, d_block),
+ activation,
+ nn.Dropout(dropout0),
+ nn.Linear(d_block, d_main),
+ nn.Dropout(dropout1),
+ )
+
+ # here in the TabR paper, for first block of Encoder(E),
+ # LayerNorm is omitted. In code, we omitted Normalization.
+ self.blocks0 = nn.ModuleList(
+ [make_block(i > 0) for i in range(encoder_n_blocks)]
+ )
+
+ # Retrieval Module: R
+ self.normalization = Normalization(d_main) if mixer_normalization else None
+
+ delu = TabR.delu
+ self.label_encoder = (
+ nn.Linear(1, d_main)
+ if num_classes == 1
+ else nn.Sequential(
+ nn.Embedding(num_classes, d_main),
+ # gives depreciation warning
+ delu.nn.Lambda(lambda x: x.squeeze(-2)) # Removes the unnecessary extra dimension added by the embedding layer
+ )
+ )
+ self.K = nn.Linear(d_main, d_main) # W_k in paper
+ self.T = nn.Sequential(
+ nn.Linear(d_main, d_block),
+ activation,
+ nn.Dropout(dropout0),
+ nn.Linear(d_block, d_main, bias=False),
+ ) # T for T(k-k_i) form the TabR paper.
+ self.dropout = nn.Dropout(context_dropout)
+
+ # Predictor Module : P
+ self.blocks1 = nn.ModuleList(
+ [make_block(True) for _ in range(predictor_n_blocks)]
+ )
+ self.head = nn.Sequential(
+ Normalization(d_main),
+ activation,
+ nn.Linear(d_main, num_classes),
+ )
+
+ # >>>
+ self.search_index = None
+ self.memory_efficient = memory_efficient
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if isinstance(self.label_encoder, nn.Linear): # if num_classes==1
+ bound = 1 / math.sqrt(2.0) # He initialization (common for layers with ReLU activation)
+ nn.init.uniform_(self.label_encoder.weight, -bound, bound) # type: ignore[code] # noqa: E501
+ nn.init.uniform_(self.label_encoder.bias, -bound, bound) # type: ignore[code] # noqa: E501
+ else:
+ assert isinstance(self.label_encoder[0], nn.Embedding)
+ nn.init.uniform_(self.label_encoder[0].weight, -1.0, 1.0) # type: ignore[code] # noqa: E501
+
+ def _lazy_import_dependencies(self):
+ """Lazily import external dependencies and store them as class attributes."""
+ if TabR.delu is None:
+ try:
+ import delu
+ TabR.delu = delu
+ print("Successfully lazy imported delu dependency.")
+
+ except ImportError:
+ raise ImportError("Failed to import delu module for TabR. Ensure all dependencies are installed\n"
+ "You can install delu running 'pip install delu'.") from None
+
+ if TabR.faiss is None:
+ try:
+ import faiss
+ import faiss.contrib.torch_utils
+
+ TabR.faiss = faiss
+ TabR.faiss_torch_utils = faiss.contrib.torch_utils
+ print("Successfully lazy imported faiss dependency")
+
+ except ImportError as e:
+ raise ImportError("Failed to import faiss module for TabR. Ensure all dependencies are installed\n"
+ "You can install faiss running 'pip install faiss-cpu' for CPU and 'pip install faiss-gpu' for GPU.") from None
+
+ def _encode(
+ self,
+ a
+ ):
+ # x = x.double() # issue
+ x = a.float()
+ # x=a.clone().detach().requires_grad_(True)
+ x = x.float()
+ x = self.linear(x)
+ for block in self.blocks0:
+ x = x + block(x)
+ k = self.K(x if self.normalization is None else self.normalization(x))
+
+ return x, k
+
+ def forward(
+ self,
+ *data
+ ):
+ """
+ Standard forward pass without candidate selection (for baseline compatibility).
+ """
+ if self.hparams.use_embeddings:
+ x = self.embedding_layer(*data)
+ B, S, D = x.shape
+ x = x.reshape(B, S * D)
+ else:
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
+ x,k = self._encode(x)
+ context_k = k.unsqueeze(1).expand(-1, self.context_size, -1) # using the batch itself as context
+ similarities = (
+ -k.square().sum(-1, keepdim=True)
+ + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
+ - context_k.square().sum(-1)
+ )
+ probs = F.softmax(similarities, dim=-1)
+ context_x = torch.sum(probs.unsqueeze(-1) * context_k, dim=1)
+ t = self.T(self.dropout(context_x))
+ for block in self.blocks1:
+ x = x + block(x + t)
+ return self.head(x)
+
+ def train_with_candidates(
+ self,
+ *data,
+ targets,
+ candidate_x,
+ candidate_y
+ ):
+ """TabR-style training forward pass selecting candidates."""
+ assert targets is not None
+
+ if self.hparams.use_embeddings:
+ x = self.embedding_layer(*data)
+ B, S, D = x.shape
+ x = x.reshape(B, S * D)
+ candidate_x = self.embedding_layer(*candidate_x)
+ B, S, D = candidate_x.shape
+ candidate_x = candidate_x.reshape(B, S * D)
+ else:
+
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
+ candidate_x = torch.cat(
+ [t for tensors in candidate_x for t in tensors], dim=1
+ )
+
+ with torch.set_grad_enabled(
+ torch.is_grad_enabled() and not self.memory_efficient
+ ):
+
+ candidate_k = (
+ self._encode(candidate_x)[1] # normalized candidate_x
+ if self.candidate_encoding_batch_size ==0
+ else torch.cat(
+ [
+ self._encode(x)[1] # normalized x
+ # for x in delu.iter_batches(
+ for x in TabR.delu.iter_batches(
+ candidate_x,
+ self.candidate_encoding_batch_size
+ )
+ ]
+ )
+ )
+
+ # Encode input
+ x, k = self._encode(x)
+
+ batch_size, d_main = k.shape
+ device = k.device
+ context_size = self.context_size
+
+ with torch.no_grad():
+ # initializing the search index
+ if self.search_index is None:
+ self.search_index = (
+ TabR.faiss.GpuIndexFlatL2(
+ TabR.faiss.StandardGpuResources(),
+ d_main
+ )
+ if device.type == 'cuda'
+ else TabR.faiss.IndexFlatL2(d_main)
+ )
+ # Updating the index is much faster than creating a new one.
+ self.search_index.reset()
+ self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
+ distances: Tensor
+ context_idx: Tensor
+ distances, context_idx = self.search_index.search( # type: ignore[code]
+ k.to(torch.float32), context_size + 1
+ )
+ # NOTE: to avoid leakage, the index i must be removed from the i-th row,
+ # (because of how candidate_k is constructed).
+ distances[
+ context_idx == torch.arange(batch_size, device=device)[:, None]
+ ] = torch.inf
+ # Not the most elegant solution to remove the argmax, but anyway.
+ context_idx = context_idx.gather(-1, distances.argsort()[:, :-1])
+
+ if self.memory_efficient and torch.is_grad_enabled():
+ # Repeating the same computation,
+ # but now only for the context objects and with autograd on.
+ context_k = self._encode(
+ torch.cat([x,candidate_x])[context_idx].flatten(0,1)
+ )[1].reshape(batch_size, context_size, -1)
+ else:
+ context_k = candidate_k[context_idx]
+
+ # In theory, when autograd is off, the distances obtained during the search
+ # can be reused. However, this is not a bottleneck, so let's keep it simple
+ # and use the same code to compute `similarities` during both
+ # training and evaluation.
+ similarities = (
+ -k.square().sum(-1, keepdim=True)
+ + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
+ - context_k.square().sum(-1)
+ )
+ probs = F.softmax(similarities, dim=-1)
+ probs = self.dropout(probs)
+
+ if self.hparams.num_classes > 1: # for classification
+ context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
+ else: # for regression
+ context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
+ if len(context_y_emb.shape) == 4:
+ context_y_emb = context_y_emb[:,:,0,:]
+
+ # Combine keys and labels with a transformation T.
+ values = context_y_emb + self.T(k[:, None] - context_k)
+ context_x = (probs[:, None] @ values).squeeze(1)
+ x = x + context_x
+
+ # Predictor has LayerNorm, ReLU and Linear after the N_P number of blocks.
+ for block in self.blocks1:
+ x = x + block(x)
+ x = self.head(x)
+ return x
+
+ def validate_with_candidates(
+ self,
+ *data,
+ candidate_x,
+ candidate_y
+ ):
+ """Validation forward pass with TabR-style candidate selection."""
+ if self.hparams.use_embeddings:
+ x = self.embedding_layer(*data)
+ B, S, D = x.shape
+ x = x.reshape(B, S * D)
+ candidate_x = self.embedding_layer(*candidate_x)
+ B, S, D = candidate_x.shape
+ candidate_x = candidate_x.reshape(B, S * D)
+ else:
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
+ candidate_x = torch.cat(
+ [t for tensors in candidate_x for t in tensors], dim=1
+ )
+
+ if not self.memory_efficient:
+ candidate_k = (
+ self._encode(candidate_x)[1] # normalized candidate_x
+ if self.candidate_encoding_batch_size == 0
+ else torch.cat(
+ [
+ self._encode(x)[1] # normalized x
+ for x in TabR.delu.iter_batches(
+ candidate_x,
+ self.candidate_encoding_batch_size
+ )
+ ]
+ )
+ )
+ else:
+ candidate_x, candidate_k = self._encode(candidate_x)
+
+ x, k = self._encode(x) # encoded x and k
+ batch_size, d_main = k.shape
+ device = k.device
+ context_size = self.context_size
+
+ if self.search_index is None:
+ self.search_index = (
+ TabR.faiss.GpuIndexFlatL2(TabR.faiss.StandardGpuResources(), d_main)
+ if device.type == 'cuda'
+ else TabR.faiss.IndexFlatL2(d_main)
+ )
+
+ # Updating the index is much faster than creating a new one.
+ self.search_index.reset()
+ self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
+ distances: Tensor
+ context_idx: Tensor
+ distances, context_idx = self.search_index.search( # type: ignore[code]
+ k.to(torch.float32), context_size
+ )
+
+ context_k = candidate_k[context_idx]
+ similarities = (
+ -k.square().sum(-1, keepdim=True)
+ + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
+ - context_k.square().sum(-1)
+ )
+ probs = F.softmax(similarities, dim=-1)
+ probs = self.dropout(probs)
+
+ if self.hparams.num_classes > 1: # for classification
+ context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
+ else: # for regression
+ context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
+ if len(context_y_emb.shape) == 4:
+ context_y_emb = context_y_emb[:,:,0,:]
+
+ values = context_y_emb + self.T(k[:, None] - context_k)
+ context_x = (probs[:, None] @ values).squeeze(1)
+ x = x + context_x
+
+ # Predictor has LayerNorm, ReLU and Linear after the N_P number of blocks.
+ for block in self.blocks1:
+ x = x + block(x)
+ x = self.head(x)
+ return x
+
+
+ def predict_with_candidates(
+ self,
+ *data,
+ candidate_x,
+ candidate_y
+ ):
+ """Prediction forward pass with TabR-style candidate selection."""
+ if self.hparams.use_embeddings:
+ x = self.embedding_layer(*data)
+ B, S, D = x.shape
+ x = x.reshape(B, S * D)
+ candidate_x = self.embedding_layer(*candidate_x)
+ B, S, D = candidate_x.shape
+ candidate_x = candidate_x.reshape(B, S * D)
+ else:
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
+ candidate_x = torch.cat(
+ [t for tensors in candidate_x for t in tensors], dim=1
+ )
+
+ if not self.memory_efficient:
+ candidate_k = (
+ self._encode(candidate_x)[1] # normalized candidate_x
+ if self.candidate_encoding_batch_size == 0
+ else torch.cat(
+ [
+ self._encode(x)[1] # normalized x
+ for x in TabR.delu.iter_batches(
+ candidate_x,
+ self.candidate_encoding_batch_size
+ )
+ ]
+ )
+ )
+ else:
+ candidate_x, candidate_k = self._encode(candidate_x)
+
+ x, k = self._encode(x) # encoded x and k
+ batch_size, d_main = k.shape
+ device = k.device
+ context_size = self.context_size
+
+ if self.search_index is None:
+ self.search_index = (
+ TabR.faiss.GpuIndexFlatL2(TabR.faiss.StandardGpuResources(), d_main)
+ if device.type == 'cuda'
+ else TabR.faiss.IndexFlatL2(d_main)
+ )
+
+
+ # Updating the index is much faster than creating a new one.
+ self.search_index.reset()
+ self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
+ distances: Tensor
+ context_idx: Tensor
+ distances, context_idx = self.search_index.search( # type: ignore[code]
+ k.to(torch.float32), context_size
+ )
+
+ context_k = candidate_k[context_idx]
+ similarities = (
+ -k.square().sum(-1, keepdim=True)
+ + (2 * (k[..., None, :] @ context_k.transpose(-1, -2))).squeeze(-2)
+ - context_k.square().sum(-1)
+ )
+ probs = F.softmax(similarities, dim=-1)
+ probs = self.dropout(probs)
+
+ if self.hparams.num_classes > 1: # for classification
+ context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].long())
+ else: # for regression
+ context_y_emb = self.label_encoder(candidate_y[context_idx][..., None])
+ if len(context_y_emb.shape) == 4:
+ context_y_emb = context_y_emb[:,:,0,:]
+
+ values = context_y_emb + self.T(k[:, None] - context_k)
+ context_x = (probs[:, None] @ values).squeeze(1)
+ x = x + context_x
+
+ # Predictor has LayerNorm, ReLU and Linear after the N_P number of blocks.
+ for block in self.blocks1:
+ x = x + block(x)
+ x = self.head(x)
+ return x
diff --git a/deeptabular/base_models/tabtransformer.py b/deeptab/base_models/tabtransformer.py
similarity index 100%
rename from deeptabular/base_models/tabtransformer.py
rename to deeptab/base_models/tabtransformer.py
diff --git a/deeptabular/base_models/tabularnn.py b/deeptab/base_models/tabularnn.py
similarity index 96%
rename from deeptabular/base_models/tabularnn.py
rename to deeptab/base_models/tabularnn.py
index 7b62e5e9..893acbab 100644
--- a/deeptabular/base_models/tabularnn.py
+++ b/deeptab/base_models/tabularnn.py
@@ -1,79 +1,79 @@
-from dataclasses import replace
-import torch
-import torch.nn as nn
-
-from ..arch_utils.get_norm_fn import get_normalization_layer
-from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
-from ..arch_utils.mlp_utils import MLPhead
-from ..arch_utils.rnn_utils import ConvRNN
-from ..configs.tabularnn_config import DefaultTabulaRNNConfig
-from .utils.basemodel import BaseModel
-
-
-class TabulaRNN(BaseModel):
-
- def __init__(
- self,
- feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
- num_classes=1,
- config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(), # noqa: B008
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["feature_information"])
-
- self.returns_ensemble = False
-
- self.rnn = ConvRNN(config)
-
- self.embedding_layer = EmbeddingLayer(
- *feature_information,
- config=config,
- )
-
- self.tabular_head = MLPhead(
- input_dim=self.hparams.dim_feedforward,
- config=config,
- output_dim=num_classes,
- )
-
- self.linear = nn.Linear(
- self.hparams.d_model,
- self.hparams.dim_feedforward,
- )
-
- temp_config = replace(config, d_model=config.dim_feedforward)
- self.norm_f = get_normalization_layer(temp_config)
-
- # pooling
- n_inputs = [len(info) for info in feature_information]
- self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
-
- def forward(self, *data):
- """Defines the forward pass of the model.
-
- Parameters
- ----------
- num_features : Tensor
- Tensor containing the numerical features.
- cat_features : Tensor
- Tensor containing the categorical features.
-
- Returns
- -------
- Tensor
- The output predictions of the model.
- """
-
- x = self.embedding_layer(*data)
- # RNN forward pass
- out, _ = self.rnn(x)
- z = self.linear(torch.mean(x, dim=1))
-
- x = self.pool_sequence(out)
- x = x + z
- if self.norm_f is not None:
- x = self.norm_f(x)
- preds = self.tabular_head(x)
-
- return preds
+from dataclasses import replace
+import torch
+import torch.nn as nn
+
+from ..arch_utils.get_norm_fn import get_normalization_layer
+from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
+from ..arch_utils.mlp_utils import MLPhead
+from ..arch_utils.rnn_utils import ConvRNN
+from ..configs.tabularnn_config import DefaultTabulaRNNConfig
+from .utils.basemodel import BaseModel
+
+
+class TabulaRNN(BaseModel):
+
+ def __init__(
+ self,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
+ num_classes=1,
+ config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(), # noqa: B008
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
+ self.save_hyperparameters(ignore=["feature_information"])
+
+ self.returns_ensemble = False
+
+ self.rnn = ConvRNN(config)
+
+ self.embedding_layer = EmbeddingLayer(
+ *feature_information,
+ config=config,
+ )
+
+ self.tabular_head = MLPhead(
+ input_dim=self.hparams.dim_feedforward,
+ config=config,
+ output_dim=num_classes,
+ )
+
+ self.linear = nn.Linear(
+ self.hparams.d_model,
+ self.hparams.dim_feedforward,
+ )
+
+ temp_config = replace(config, d_model=config.dim_feedforward)
+ self.norm_f = get_normalization_layer(temp_config)
+
+ # pooling
+ n_inputs = [len(info) for info in feature_information]
+ self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
+
+ def forward(self, *data):
+ """Defines the forward pass of the model.
+
+ Parameters
+ ----------
+ num_features : Tensor
+ Tensor containing the numerical features.
+ cat_features : Tensor
+ Tensor containing the categorical features.
+
+ Returns
+ -------
+ Tensor
+ The output predictions of the model.
+ """
+
+ x = self.embedding_layer(*data)
+ # RNN forward pass
+ out, _ = self.rnn(x)
+ z = self.linear(torch.mean(x, dim=1))
+
+ x = self.pool_sequence(out)
+ x = x + z
+ if self.norm_f is not None:
+ x = self.norm_f(x)
+ preds = self.tabular_head(x)
+
+ return preds
diff --git a/deeptabular/base_models/tangos.py b/deeptab/base_models/tangos.py
similarity index 100%
rename from deeptabular/base_models/tangos.py
rename to deeptab/base_models/tangos.py
diff --git a/deeptabular/base_models/trompt.py b/deeptab/base_models/trompt.py
similarity index 100%
rename from deeptabular/base_models/trompt.py
rename to deeptab/base_models/trompt.py
diff --git a/deeptabular/base_models/utils/__init__.py b/deeptab/base_models/utils/__init__.py
similarity index 100%
rename from deeptabular/base_models/utils/__init__.py
rename to deeptab/base_models/utils/__init__.py
diff --git a/deeptabular/base_models/utils/basemodel.py b/deeptab/base_models/utils/basemodel.py
similarity index 100%
rename from deeptabular/base_models/utils/basemodel.py
rename to deeptab/base_models/utils/basemodel.py
diff --git a/deeptabular/base_models/utils/lightning_wrapper.py b/deeptab/base_models/utils/lightning_wrapper.py
similarity index 100%
rename from deeptabular/base_models/utils/lightning_wrapper.py
rename to deeptab/base_models/utils/lightning_wrapper.py
diff --git a/deeptabular/base_models/utils/pretraining.py b/deeptab/base_models/utils/pretraining.py
similarity index 100%
rename from deeptabular/base_models/utils/pretraining.py
rename to deeptab/base_models/utils/pretraining.py
diff --git a/deeptabular/configs/__init__.py b/deeptab/configs/__init__.py
similarity index 97%
rename from deeptabular/configs/__init__.py
rename to deeptab/configs/__init__.py
index 9e803ac2..4f397c1f 100644
--- a/deeptabular/configs/__init__.py
+++ b/deeptab/configs/__init__.py
@@ -1,39 +1,39 @@
-from .fttransformer_config import DefaultFTTransformerConfig
-from .mambatab_config import DefaultMambaTabConfig
-from .mambattention_config import DefaultMambAttentionConfig
-from .mambular_config import DefaultMambularConfig
-from .mlp_config import DefaultMLPConfig
-from .ndtf_config import DefaultNDTFConfig
-from .node_config import DefaultNODEConfig
-from .resnet_config import DefaultResNetConfig
-from .saint_config import DefaultSAINTConfig
-from .tabm_config import DefaultTabMConfig
-from .tabtransformer_config import DefaultTabTransformerConfig
-from .tabularnn_config import DefaultTabulaRNNConfig
-from .autoint_config import DefaultAutoIntConfig
-from .trompt_config import DefaultTromptConfig
-from .base_config import BaseConfig
-from .enode_config import DefaultENODEConfig
-from .tangos_config import DefaultTangosConfig
-from .modernnca_config import DefaultModernNCAConfig
-
-__all__ = [
- "DefaultModernNCAConfig",
- "DefaultTangosConfig",
- "DefaultENODEConfig",
- "DefaultTromptConfig",
- "DefaultAutoIntConfig",
- "DefaultFTTransformerConfig",
- "DefaultMLPConfig",
- "DefaultMambAttentionConfig",
- "DefaultMambaTabConfig",
- "DefaultMambularConfig",
- "DefaultNDTFConfig",
- "DefaultNODEConfig",
- "DefaultResNetConfig",
- "DefaultSAINTConfig",
- "DefaultTabMConfig",
- "DefaultTabTransformerConfig",
- "DefaultTabulaRNNConfig",
- "BaseConfig",
-]
+from .fttransformer_config import DefaultFTTransformerConfig
+from .mambatab_config import DefaultMambaTabConfig
+from .mambattention_config import DefaultMambAttentionConfig
+from .mambular_config import DefaultMambularConfig
+from .mlp_config import DefaultMLPConfig
+from .ndtf_config import DefaultNDTFConfig
+from .node_config import DefaultNODEConfig
+from .resnet_config import DefaultResNetConfig
+from .saint_config import DefaultSAINTConfig
+from .tabm_config import DefaultTabMConfig
+from .tabtransformer_config import DefaultTabTransformerConfig
+from .tabularnn_config import DefaultTabulaRNNConfig
+from .autoint_config import DefaultAutoIntConfig
+from .trompt_config import DefaultTromptConfig
+from .base_config import BaseConfig
+from .enode_config import DefaultENODEConfig
+from .tangos_config import DefaultTangosConfig
+from .modernnca_config import DefaultModernNCAConfig
+
+__all__ = [
+ "DefaultModernNCAConfig",
+ "DefaultTangosConfig",
+ "DefaultENODEConfig",
+ "DefaultTromptConfig",
+ "DefaultAutoIntConfig",
+ "DefaultFTTransformerConfig",
+ "DefaultMLPConfig",
+ "DefaultMambAttentionConfig",
+ "DefaultMambaTabConfig",
+ "DefaultMambularConfig",
+ "DefaultNDTFConfig",
+ "DefaultNODEConfig",
+ "DefaultResNetConfig",
+ "DefaultSAINTConfig",
+ "DefaultTabMConfig",
+ "DefaultTabTransformerConfig",
+ "DefaultTabulaRNNConfig",
+ "BaseConfig",
+]
diff --git a/deeptabular/configs/autoint_config.py b/deeptab/configs/autoint_config.py
similarity index 100%
rename from deeptabular/configs/autoint_config.py
rename to deeptab/configs/autoint_config.py
diff --git a/deeptabular/configs/base_config.py b/deeptab/configs/base_config.py
similarity index 100%
rename from deeptabular/configs/base_config.py
rename to deeptab/configs/base_config.py
diff --git a/deeptabular/configs/enode_config.py b/deeptab/configs/enode_config.py
similarity index 100%
rename from deeptabular/configs/enode_config.py
rename to deeptab/configs/enode_config.py
diff --git a/deeptabular/configs/fttransformer_config.py b/deeptab/configs/fttransformer_config.py
similarity index 100%
rename from deeptabular/configs/fttransformer_config.py
rename to deeptab/configs/fttransformer_config.py
diff --git a/deeptabular/configs/mambatab_config.py b/deeptab/configs/mambatab_config.py
similarity index 100%
rename from deeptabular/configs/mambatab_config.py
rename to deeptab/configs/mambatab_config.py
diff --git a/deeptabular/configs/mambattention_config.py b/deeptab/configs/mambattention_config.py
similarity index 97%
rename from deeptabular/configs/mambattention_config.py
rename to deeptab/configs/mambattention_config.py
index 49e596e5..22dd319f 100644
--- a/deeptabular/configs/mambattention_config.py
+++ b/deeptab/configs/mambattention_config.py
@@ -1,126 +1,126 @@
-from collections.abc import Callable
-from dataclasses import dataclass, field
-import torch.nn as nn
-from .base_config import BaseConfig
-
-
-@dataclass
-class DefaultMambAttentionConfig(BaseConfig):
- """Configuration class for the Default Mambular Attention model with predefined hyperparameters.
-
- Parameters
- ----------
- d_model : int, default=64
- Dimensionality of the model.
- n_layers : int, default=4
- Number of layers in the model.
- expand_factor : int, default=2
- Expansion factor for the feed-forward layers.
- n_heads : int, default=8
- Number of attention heads in the model.
- last_layer : str, default="attn"
- Type of the last layer (e.g., 'attn').
- n_mamba_per_attention : int, default=1
- Number of Mamba blocks per attention layer.
- bias : bool, default=False
- Whether to use bias in the linear layers.
- d_conv : int, default=4
- Dimensionality of the convolutional layers.
- conv_bias : bool, default=True
- Whether to use bias in the convolutional layers.
- dropout : float, default=0.0
- Dropout rate for regularization.
- attn_dropout : float, default=0.2
- Dropout rate for the attention mechanism.
- dt_rank : str, default="auto"
- Rank of the decision tree.
- d_state : int, default=128
- Dimensionality of the state in recurrent layers.
- dt_scale : float, default=1.0
- Scaling factor for the decision tree.
- dt_init : str, default="random"
- Initialization method for the decision tree.
- dt_max : float, default=0.1
- Maximum value for decision tree initialization.
- dt_min : float, default=1e-04
- Minimum value for decision tree initialization.
- dt_init_floor : float, default=1e-04
- Floor value for decision tree initialization.
- norm : str, default="LayerNorm"
- Type of normalization used in the model.
- activation : callable, default=nn.SiLU()
- Activation function for the model.
- head_layer_sizes : list, default=()
- Sizes of the fully connected layers in the model's head.
- head_dropout : float, default=0.5
- Dropout rate for the head layers.
- head_skip_layers : bool, default=False
- Whether to use skip connections in the head layers.
- head_activation : callable, default=nn.SELU()
- Activation function for the head layers.
- head_use_batch_norm : bool, default=False
- Whether to use batch normalization in the head layers.
- pooling_method : str, default="avg"
- Pooling method to be used ('avg', 'max', etc.).
- bidirectional : bool, default=False
- Whether to process input sequences bidirectionally.
- use_learnable_interaction : bool, default=False
- Whether to use learnable feature interactions before passing through Mamba blocks.
- use_cls : bool, default=False
- Whether to append a CLS token for sequence pooling.
- shuffle_embeddings : bool, default=False
- Whether to shuffle embeddings before passing to Mamba layers.
- cat_encoding : str, default="int"
- Encoding method for categorical features ('int', 'one-hot', etc.).
- AD_weight_decay : bool, default=True
- Whether weight decay is applied to A-D matrices.
- BC_layer_norm : bool, default=False
- Whether to apply layer normalization to B-C matrices.
- use_pscan : bool, default=False
- Whether to use PSCAN for the state-space model.
- n_attention_layers : int, default=1
- Number of attention layers in the model.
- """
-
- # Architecture Parameters
- d_model: int = 64
- n_layers: int = 4
- expand_factor: int = 2
- n_heads: int = 8
- last_layer: str = "attn"
- n_mamba_per_attention: int = 1
- bias: bool = False
- d_conv: int = 4
- conv_bias: bool = True
- dropout: float = 0.0
- attn_dropout: float = 0.2
- dt_rank: str = "auto"
- d_state: int = 128
- dt_scale: float = 1.0
- dt_init: str = "random"
- dt_max: float = 0.1
- dt_min: float = 1e-04
- dt_init_floor: float = 1e-04
- norm: str = "LayerNorm"
- activation: Callable = nn.SiLU() # noqa: RUF009
-
- # Head Parameters
- head_layer_sizes: list = field(default_factory=list)
- head_dropout: float = 0.5
- head_skip_layers: bool = False
- head_activation: Callable = nn.SELU() # noqa: RUF009
- head_use_batch_norm: bool = False
-
- # Pooling and Categorical Encoding
- pooling_method: str = "avg"
- bidirectional: bool = False
- use_learnable_interaction: bool = False
- use_cls: bool = False
- shuffle_embeddings: bool = False
- cat_encoding: str = "int"
-
- # Additional Features
- AD_weight_decay: bool = True
- BC_layer_norm: bool = False
- use_pscan: bool = False
- n_attention_layers: int = 1
+from collections.abc import Callable
+from dataclasses import dataclass, field
+import torch.nn as nn
+from .base_config import BaseConfig
+
+
+@dataclass
+class DefaultMambAttentionConfig(BaseConfig):
+ """Configuration class for the Default Mambular Attention model with predefined hyperparameters.
+
+ Parameters
+ ----------
+ d_model : int, default=64
+ Dimensionality of the model.
+ n_layers : int, default=4
+ Number of layers in the model.
+ expand_factor : int, default=2
+ Expansion factor for the feed-forward layers.
+ n_heads : int, default=8
+ Number of attention heads in the model.
+ last_layer : str, default="attn"
+ Type of the last layer (e.g., 'attn').
+ n_mamba_per_attention : int, default=1
+ Number of Mamba blocks per attention layer.
+ bias : bool, default=False
+ Whether to use bias in the linear layers.
+ d_conv : int, default=4
+ Dimensionality of the convolutional layers.
+ conv_bias : bool, default=True
+ Whether to use bias in the convolutional layers.
+ dropout : float, default=0.0
+ Dropout rate for regularization.
+ attn_dropout : float, default=0.2
+ Dropout rate for the attention mechanism.
+ dt_rank : str, default="auto"
+ Rank of the decision tree.
+ d_state : int, default=128
+ Dimensionality of the state in recurrent layers.
+ dt_scale : float, default=1.0
+ Scaling factor for the decision tree.
+ dt_init : str, default="random"
+ Initialization method for the decision tree.
+ dt_max : float, default=0.1
+ Maximum value for decision tree initialization.
+ dt_min : float, default=1e-04
+ Minimum value for decision tree initialization.
+ dt_init_floor : float, default=1e-04
+ Floor value for decision tree initialization.
+ norm : str, default="LayerNorm"
+ Type of normalization used in the model.
+ activation : callable, default=nn.SiLU()
+ Activation function for the model.
+ head_layer_sizes : list, default=()
+ Sizes of the fully connected layers in the model's head.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to use skip connections in the head layers.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ pooling_method : str, default="avg"
+ Pooling method to be used ('avg', 'max', etc.).
+ bidirectional : bool, default=False
+ Whether to process input sequences bidirectionally.
+ use_learnable_interaction : bool, default=False
+ Whether to use learnable feature interactions before passing through Mamba blocks.
+ use_cls : bool, default=False
+ Whether to append a CLS token for sequence pooling.
+ shuffle_embeddings : bool, default=False
+ Whether to shuffle embeddings before passing to Mamba layers.
+ cat_encoding : str, default="int"
+ Encoding method for categorical features ('int', 'one-hot', etc.).
+ AD_weight_decay : bool, default=True
+ Whether weight decay is applied to A-D matrices.
+ BC_layer_norm : bool, default=False
+ Whether to apply layer normalization to B-C matrices.
+ use_pscan : bool, default=False
+ Whether to use PSCAN for the state-space model.
+ n_attention_layers : int, default=1
+ Number of attention layers in the model.
+ """
+
+ # Architecture Parameters
+ d_model: int = 64
+ n_layers: int = 4
+ expand_factor: int = 2
+ n_heads: int = 8
+ last_layer: str = "attn"
+ n_mamba_per_attention: int = 1
+ bias: bool = False
+ d_conv: int = 4
+ conv_bias: bool = True
+ dropout: float = 0.0
+ attn_dropout: float = 0.2
+ dt_rank: str = "auto"
+ d_state: int = 128
+ dt_scale: float = 1.0
+ dt_init: str = "random"
+ dt_max: float = 0.1
+ dt_min: float = 1e-04
+ dt_init_floor: float = 1e-04
+ norm: str = "LayerNorm"
+ activation: Callable = nn.SiLU() # noqa: RUF009
+
+ # Head Parameters
+ head_layer_sizes: list = field(default_factory=list)
+ head_dropout: float = 0.5
+ head_skip_layers: bool = False
+ head_activation: Callable = nn.SELU() # noqa: RUF009
+ head_use_batch_norm: bool = False
+
+ # Pooling and Categorical Encoding
+ pooling_method: str = "avg"
+ bidirectional: bool = False
+ use_learnable_interaction: bool = False
+ use_cls: bool = False
+ shuffle_embeddings: bool = False
+ cat_encoding: str = "int"
+
+ # Additional Features
+ AD_weight_decay: bool = True
+ BC_layer_norm: bool = False
+ use_pscan: bool = False
+ n_attention_layers: int = 1
diff --git a/deeptabular/configs/mambular_config.py b/deeptab/configs/mambular_config.py
similarity index 100%
rename from deeptabular/configs/mambular_config.py
rename to deeptab/configs/mambular_config.py
diff --git a/deeptabular/configs/mlp_config.py b/deeptab/configs/mlp_config.py
similarity index 100%
rename from deeptabular/configs/mlp_config.py
rename to deeptab/configs/mlp_config.py
diff --git a/deeptabular/configs/modernnca_config.py b/deeptab/configs/modernnca_config.py
similarity index 100%
rename from deeptabular/configs/modernnca_config.py
rename to deeptab/configs/modernnca_config.py
diff --git a/deeptabular/configs/ndtf_config.py b/deeptab/configs/ndtf_config.py
similarity index 100%
rename from deeptabular/configs/ndtf_config.py
rename to deeptab/configs/ndtf_config.py
diff --git a/deeptabular/configs/node_config.py b/deeptab/configs/node_config.py
similarity index 100%
rename from deeptabular/configs/node_config.py
rename to deeptab/configs/node_config.py
diff --git a/deeptabular/configs/resnet_config.py b/deeptab/configs/resnet_config.py
similarity index 100%
rename from deeptabular/configs/resnet_config.py
rename to deeptab/configs/resnet_config.py
diff --git a/deeptabular/configs/saint_config.py b/deeptab/configs/saint_config.py
similarity index 100%
rename from deeptabular/configs/saint_config.py
rename to deeptab/configs/saint_config.py
diff --git a/deeptabular/configs/tabm_config.py b/deeptab/configs/tabm_config.py
similarity index 100%
rename from deeptabular/configs/tabm_config.py
rename to deeptab/configs/tabm_config.py
diff --git a/deeptabular/configs/tabr_config.py b/deeptab/configs/tabr_config.py
similarity index 96%
rename from deeptabular/configs/tabr_config.py
rename to deeptab/configs/tabr_config.py
index cde73c1c..d1ec9799 100644
--- a/deeptabular/configs/tabr_config.py
+++ b/deeptab/configs/tabr_config.py
@@ -1,38 +1,38 @@
-from collections.abc import Callable
-from dataclasses import dataclass, field
-from .base_config import BaseConfig
-import torch.nn as nn
-
-@dataclass
-class DefaultTabRConfig(BaseConfig):
- """Configuration class for the default TabR model with predefined hyperparameters.
- Parameters
- ----------
- """
-
- # Optimizer Parameters
- lr: float = 0.0003121273641315169
- weight_decay: float = 1.2260352006404615e-06
- lr_patience =10
- lr_factor: float = 0.1 # Factor for LR scheduler
-
- # Architecture Parameters
- d_main: int = 256
- context_dropout: float =0.38920071545944357
- d_multiplier : int = 2
- encoder_n_blocks : int=0
- predictor_n_blocks: int=1
- mixer_normalization: str ="auto"
- dropout0: float =0.38852797479169876
- dropout1: float=0.0
- normalization: str = "LayerNorm"
- activation:Callable = nn.ReLU()
- memory_efficient: bool = False
- candidate_encoding_batch_size:int = 0
- context_size:int=96
-
- # Embedding Parameters
- embedding_type: str = "plr"
- plr_lite: bool = True
- n_frequencies: int = 75
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from .base_config import BaseConfig
+import torch.nn as nn
+
+@dataclass
+class DefaultTabRConfig(BaseConfig):
+ """Configuration class for the default TabR model with predefined hyperparameters.
+ Parameters
+ ----------
+ """
+
+ # Optimizer Parameters
+ lr: float = 0.0003121273641315169
+ weight_decay: float = 1.2260352006404615e-06
+ lr_patience =10
+ lr_factor: float = 0.1 # Factor for LR scheduler
+
+ # Architecture Parameters
+ d_main: int = 256
+ context_dropout: float =0.38920071545944357
+ d_multiplier : int = 2
+ encoder_n_blocks : int=0
+ predictor_n_blocks: int=1
+ mixer_normalization: str ="auto"
+ dropout0: float =0.38852797479169876
+ dropout1: float=0.0
+ normalization: str = "LayerNorm"
+ activation:Callable = nn.ReLU()
+ memory_efficient: bool = False
+ candidate_encoding_batch_size:int = 0
+ context_size:int=96
+
+ # Embedding Parameters
+ embedding_type: str = "plr"
+ plr_lite: bool = True
+ n_frequencies: int = 75
frequencies_init_scale: float = 0.045
\ No newline at end of file
diff --git a/deeptabular/configs/tabtransformer_config.py b/deeptab/configs/tabtransformer_config.py
similarity index 100%
rename from deeptabular/configs/tabtransformer_config.py
rename to deeptab/configs/tabtransformer_config.py
diff --git a/deeptabular/configs/tabularnn_config.py b/deeptab/configs/tabularnn_config.py
similarity index 97%
rename from deeptabular/configs/tabularnn_config.py
rename to deeptab/configs/tabularnn_config.py
index 639e6ba3..99a74e07 100644
--- a/deeptabular/configs/tabularnn_config.py
+++ b/deeptab/configs/tabularnn_config.py
@@ -1,84 +1,84 @@
-from collections.abc import Callable
-from dataclasses import dataclass, field
-import torch.nn as nn
-from .base_config import BaseConfig
-
-
-@dataclass
-class DefaultTabulaRNNConfig(BaseConfig):
- """Configuration class for the TabulaRNN model with predefined hyperparameters.
-
- Parameters
- ----------
- model_type : str, default="RNN"
- Type of model, one of "RNN", "LSTM", "GRU", "mLSTM", "sLSTM".
- n_layers : int, default=4
- Number of layers in the RNN.
- rnn_dropout : float, default=0.2
- Dropout rate for the RNN layers.
- d_model : int, default=128
- Dimensionality of embeddings or model representations.
- norm : str, default="RMSNorm"
- Normalization method to be used.
- activation : callable, default=nn.SELU()
- Activation function for the RNN layers.
- residuals : bool, default=False
- Whether to include residual connections in the RNN.
- head_layer_sizes : list, default=()
- Sizes of the layers in the head of the model.
- head_dropout : float, default=0.5
- Dropout rate for the head layers.
- head_skip_layers : bool, default=False
- Whether to skip layers in the head.
- head_activation : callable, default=nn.SELU()
- Activation function for the head layers.
- head_use_batch_norm : bool, default=False
- Whether to use batch normalization in the head layers.
- pooling_method : str, default="avg"
- Pooling method to be used ('avg', 'cls', etc.).
- norm_first : bool, default=False
- Whether to apply normalization before other operations in each block.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization.
- bias : bool, default=True
- Whether to use bias in the linear layers.
- rnn_activation : str, default="relu"
- Activation function for the RNN layers.
- dim_feedforward : int, default=256
- Size of the feedforward network.
- d_conv : int, default=4
- Size of the convolutional layer for embedding features.
- dilation : int, default=1
- Dilation factor for the convolution.
- conv_bias : bool, default=True
- Whether to use bias in the convolutional layers.
- """
-
- # Architecture params
- model_type: str = "RNN"
- d_model: int = 128
- n_layers: int = 4
- rnn_dropout: float = 0.2
- norm: str = "RMSNorm"
- activation: Callable = nn.SELU() # noqa: RUF009
- residuals: bool = False
-
- # Head params
- head_layer_sizes: list = field(default_factory=list)
- head_dropout: float = 0.5
- head_skip_layers: bool = False
- head_activation: Callable = nn.SELU() # noqa: RUF009
- head_use_batch_norm: bool = False
-
- # Pooling and normalization
- pooling_method: str = "avg"
- norm_first: bool = False
- layer_norm_eps: float = 1e-05
-
- # Additional params
- bias: bool = True
- rnn_activation: str = "relu"
- dim_feedforward: int = 256
- d_conv: int = 4
- dilation: int = 1
- conv_bias: bool = True
+from collections.abc import Callable
+from dataclasses import dataclass, field
+import torch.nn as nn
+from .base_config import BaseConfig
+
+
+@dataclass
+class DefaultTabulaRNNConfig(BaseConfig):
+ """Configuration class for the TabulaRNN model with predefined hyperparameters.
+
+ Parameters
+ ----------
+ model_type : str, default="RNN"
+ Type of model, one of "RNN", "LSTM", "GRU", "mLSTM", "sLSTM".
+ n_layers : int, default=4
+ Number of layers in the RNN.
+ rnn_dropout : float, default=0.2
+ Dropout rate for the RNN layers.
+ d_model : int, default=128
+ Dimensionality of embeddings or model representations.
+ norm : str, default="RMSNorm"
+ Normalization method to be used.
+ activation : callable, default=nn.SELU()
+ Activation function for the RNN layers.
+ residuals : bool, default=False
+ Whether to include residual connections in the RNN.
+ head_layer_sizes : list, default=()
+ Sizes of the layers in the head of the model.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to skip layers in the head.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ pooling_method : str, default="avg"
+ Pooling method to be used ('avg', 'cls', etc.).
+ norm_first : bool, default=False
+ Whether to apply normalization before other operations in each block.
+ layer_norm_eps : float, default=1e-05
+ Epsilon value for layer normalization.
+ bias : bool, default=True
+ Whether to use bias in the linear layers.
+ rnn_activation : str, default="relu"
+ Activation function for the RNN layers.
+ dim_feedforward : int, default=256
+ Size of the feedforward network.
+ d_conv : int, default=4
+ Size of the convolutional layer for embedding features.
+ dilation : int, default=1
+ Dilation factor for the convolution.
+ conv_bias : bool, default=True
+ Whether to use bias in the convolutional layers.
+ """
+
+ # Architecture params
+ model_type: str = "RNN"
+ d_model: int = 128
+ n_layers: int = 4
+ rnn_dropout: float = 0.2
+ norm: str = "RMSNorm"
+ activation: Callable = nn.SELU() # noqa: RUF009
+ residuals: bool = False
+
+ # Head params
+ head_layer_sizes: list = field(default_factory=list)
+ head_dropout: float = 0.5
+ head_skip_layers: bool = False
+ head_activation: Callable = nn.SELU() # noqa: RUF009
+ head_use_batch_norm: bool = False
+
+ # Pooling and normalization
+ pooling_method: str = "avg"
+ norm_first: bool = False
+ layer_norm_eps: float = 1e-05
+
+ # Additional params
+ bias: bool = True
+ rnn_activation: str = "relu"
+ dim_feedforward: int = 256
+ d_conv: int = 4
+ dilation: int = 1
+ conv_bias: bool = True
diff --git a/deeptabular/configs/tangos_config.py b/deeptab/configs/tangos_config.py
similarity index 100%
rename from deeptabular/configs/tangos_config.py
rename to deeptab/configs/tangos_config.py
diff --git a/deeptabular/configs/trompt_config.py b/deeptab/configs/trompt_config.py
similarity index 100%
rename from deeptabular/configs/trompt_config.py
rename to deeptab/configs/trompt_config.py
diff --git a/deeptabular/data_utils/__init__.py b/deeptab/data_utils/__init__.py
similarity index 100%
rename from deeptabular/data_utils/__init__.py
rename to deeptab/data_utils/__init__.py
diff --git a/deeptabular/data_utils/datamodule.py b/deeptab/data_utils/datamodule.py
similarity index 100%
rename from deeptabular/data_utils/datamodule.py
rename to deeptab/data_utils/datamodule.py
diff --git a/deeptabular/data_utils/dataset.py b/deeptab/data_utils/dataset.py
similarity index 100%
rename from deeptabular/data_utils/dataset.py
rename to deeptab/data_utils/dataset.py
diff --git a/deeptabular/models/__init__.py b/deeptab/models/__init__.py
similarity index 100%
rename from deeptabular/models/__init__.py
rename to deeptab/models/__init__.py
diff --git a/deeptabular/models/autoint.py b/deeptab/models/autoint.py
similarity index 92%
rename from deeptabular/models/autoint.py
rename to deeptab/models/autoint.py
index 1b01c941..777674dd 100644
--- a/deeptabular/models/autoint.py
+++ b/deeptab/models/autoint.py
@@ -15,7 +15,7 @@ class and uses the AutoInt model with the default AutoInt
configuration.
""",
examples="""
- >>> from deeptabular.models import AutoIntRegressor
+ >>> from deeptab.models import AutoIntRegressor
>>> model = AutoIntRegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -33,7 +33,7 @@ class AutoIntClassifier(SklearnBaseClassifier):
"""AutoInt Classifier. This class extends the SklearnBaseClassifier class
and uses the AutoInt model with the default AutoInt configuration.""",
examples="""
- >>> from deeptabular.models import AutoIntClassifier
+ >>> from deeptab.models import AutoIntClassifier
>>> model = AutoIntClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -52,7 +52,7 @@ class AutoIntLSS(SklearnBaseLSS):
This class extends the SklearnBaseLSS class and uses the
AutoInt model with the default AutoInt configuration.""",
examples="""
- >>> from deeptabular.models import AutoIntLSS
+ >>> from deeptab.models import AutoIntLSS
>>> model = AutoIntLSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family="normal")
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/enode.py b/deeptab/models/enode.py
similarity index 93%
rename from deeptabular/models/enode.py
rename to deeptab/models/enode.py
index 307f0892..1bada823 100644
--- a/deeptabular/models/enode.py
+++ b/deeptab/models/enode.py
@@ -14,7 +14,7 @@ class ENODERegressor(SklearnBaseRegressor):
with the default ENODE configuration.
""",
examples="""
- >>> from deeptabular.models import ENODERegressor
+ >>> from deeptab.models import ENODERegressor
>>> model = ENODERegressor()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -35,7 +35,7 @@ class ENODEClassifier(SklearnBaseClassifier):
with the default ENODE configuration.
""",
examples="""
- >>> from deeptabular.models import ENODEClassifier
+ >>> from deeptab.models import ENODEClassifier
>>> model = ENODEClassifier()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -56,7 +56,7 @@ class ENODELSS(SklearnBaseLSS):
with the default ENODE configuration.
""",
examples="""
- >>> from deeptabular.models import ENODELSS
+ >>> from deeptab.models import ENODELSS
>>> model = ENODELSS()
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/fttransformer.py b/deeptab/models/fttransformer.py
similarity index 92%
rename from deeptabular/models/fttransformer.py
rename to deeptab/models/fttransformer.py
index 8d31dcf6..a8c4eb90 100644
--- a/deeptabular/models/fttransformer.py
+++ b/deeptab/models/fttransformer.py
@@ -15,7 +15,7 @@ class and uses the FTTransformer model with the default FTTransformer
configuration.
""",
examples="""
- >>> from deeptabular.models import FTTransformerRegressor
+ >>> from deeptab.models import FTTransformerRegressor
>>> model = FTTransformerRegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -35,7 +35,7 @@ class FTTransformerClassifier(SklearnBaseClassifier):
"""FTTransformer Classifier. This class extends the SklearnBaseClassifier class
and uses the FTTransformer model with the default FTTransformer configuration.""",
examples="""
- >>> from deeptabular.models import FTTransformerClassifier
+ >>> from deeptab.models import FTTransformerClassifier
>>> model = FTTransformerClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -56,7 +56,7 @@ class FTTransformerLSS(SklearnBaseLSS):
This class extends the SklearnBaseLSS class and uses the
FTTransformer model with the default FTTransformer configuration.""",
examples="""
- >>> from deeptabular.models import FTTransformerLSS
+ >>> from deeptab.models import FTTransformerLSS
>>> model = FTTransformerLSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family="normal")
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/mambatab.py b/deeptab/models/mambatab.py
similarity index 100%
rename from deeptabular/models/mambatab.py
rename to deeptab/models/mambatab.py
diff --git a/deeptabular/models/mambattention.py b/deeptab/models/mambattention.py
similarity index 95%
rename from deeptabular/models/mambattention.py
rename to deeptab/models/mambattention.py
index dfbb12b6..0ff9c06e 100644
--- a/deeptabular/models/mambattention.py
+++ b/deeptab/models/mambattention.py
@@ -1,72 +1,72 @@
-from ..base_models.mambattn import MambAttention
-from ..configs.mambattention_config import DefaultMambAttentionConfig
-from ..utils.docstring_generator import generate_docstring
-from .utils.sklearn_base_classifier import SklearnBaseClassifier
-from .utils.sklearn_base_lss import SklearnBaseLSS
-from .utils.sklearn_base_regressor import SklearnBaseRegressor
-
-
-class MambAttentionRegressor(SklearnBaseRegressor):
- __doc__ = generate_docstring(
- DefaultMambAttentionConfig,
- model_description="""
- MambAttention regressor. This class extends the SklearnBaseRegressor class and uses the MambAttention model
- with the default MambAttention configuration.
- """,
- examples="""
- >>> from deeptabular.models import MambAttentionRegressor
- >>> model = MambAttentionRegressor(d_model=64, n_layers=8)
- >>> model.fit(X_train, y_train)
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
-
- def __init__(self, **kwargs):
- super().__init__(
- model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
- )
-
-
-class MambAttentionClassifier(SklearnBaseClassifier):
- __doc__ = generate_docstring(
- DefaultMambAttentionConfig,
- model_description="""
- MambAttention classifier. This class extends the SklearnBaseClassifier class and uses the MambAttention model
- with the default MambAttention configuration.
- """,
- examples="""
- >>> from MambAttention.models import MambAttentionClassifier
- >>> model = MambAttentionClassifier(d_model=64, n_layers=8)
- >>> model.fit(X_train, y_train)
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
-
- def __init__(self, **kwargs):
- super().__init__(
- model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
- )
-
-
-class MambAttentionLSS(SklearnBaseLSS):
- __doc__ = generate_docstring(
- DefaultMambAttentionConfig,
- model_description="""
- MambAttention LSS for distributional regression. This class extends the SklearnBaseLSS class and uses the MambAttention model
- with the default MambAttention configuration.
- """,
- examples="""
- >>> from MambAttention.models import MambAttentionLSS
- >>> model = MambAttentionLSS(d_model=64, n_layers=8)
- >>> model.fit(X_train, y_train, family='normal')
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
-
- def __init__(self, **kwargs):
- super().__init__(
- model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
- )
+from ..base_models.mambattn import MambAttention
+from ..configs.mambattention_config import DefaultMambAttentionConfig
+from ..utils.docstring_generator import generate_docstring
+from .utils.sklearn_base_classifier import SklearnBaseClassifier
+from .utils.sklearn_base_lss import SklearnBaseLSS
+from .utils.sklearn_base_regressor import SklearnBaseRegressor
+
+
+class MambAttentionRegressor(SklearnBaseRegressor):
+ __doc__ = generate_docstring(
+ DefaultMambAttentionConfig,
+ model_description="""
+ MambAttention regressor. This class extends the SklearnBaseRegressor class and uses the MambAttention model
+ with the default MambAttention configuration.
+ """,
+ examples="""
+ >>> from deeptab.models import MambAttentionRegressor
+ >>> model = MambAttentionRegressor(d_model=64, n_layers=8)
+ >>> model.fit(X_train, y_train)
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+
+ def __init__(self, **kwargs):
+ super().__init__(
+ model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
+ )
+
+
+class MambAttentionClassifier(SklearnBaseClassifier):
+ __doc__ = generate_docstring(
+ DefaultMambAttentionConfig,
+ model_description="""
+ MambAttention classifier. This class extends the SklearnBaseClassifier class and uses the MambAttention model
+ with the default MambAttention configuration.
+ """,
+ examples="""
+ >>> from MambAttention.models import MambAttentionClassifier
+ >>> model = MambAttentionClassifier(d_model=64, n_layers=8)
+ >>> model.fit(X_train, y_train)
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+
+ def __init__(self, **kwargs):
+ super().__init__(
+ model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
+ )
+
+
+class MambAttentionLSS(SklearnBaseLSS):
+ __doc__ = generate_docstring(
+ DefaultMambAttentionConfig,
+ model_description="""
+ MambAttention LSS for distributional regression. This class extends the SklearnBaseLSS class and uses the MambAttention model
+ with the default MambAttention configuration.
+ """,
+ examples="""
+ >>> from MambAttention.models import MambAttentionLSS
+ >>> model = MambAttentionLSS(d_model=64, n_layers=8)
+ >>> model.fit(X_train, y_train, family='normal')
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+
+ def __init__(self, **kwargs):
+ super().__init__(
+ model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
+ )
diff --git a/deeptabular/models/mambular.py b/deeptab/models/mambular.py
similarity index 92%
rename from deeptabular/models/mambular.py
rename to deeptab/models/mambular.py
index ee41c07c..f51255be 100644
--- a/deeptabular/models/mambular.py
+++ b/deeptab/models/mambular.py
@@ -14,7 +14,7 @@ class MambularRegressor(SklearnBaseRegressor):
with the default Mambular configuration.
""",
examples="""
- >>> from deeptabular.models import MambularRegressor
+ >>> from deeptab.models import MambularRegressor
>>> model = MambularRegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -34,7 +34,7 @@ class MambularClassifier(SklearnBaseClassifier):
with the default Mambular configuration.
""",
examples="""
- >>> from deeptabular.models import MambularClassifier
+ >>> from deeptab.models import MambularClassifier
>>> model = MambularClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -54,7 +54,7 @@ class MambularLSS(SklearnBaseLSS):
with the default Mambular configuration.
""",
examples="""
- >>> from deeptabular.models import MambularLSS
+ >>> from deeptab.models import MambularLSS
>>> model = MambularLSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/mlp.py b/deeptab/models/mlp.py
similarity index 93%
rename from deeptabular/models/mlp.py
rename to deeptab/models/mlp.py
index a69a511a..197fd841 100644
--- a/deeptabular/models/mlp.py
+++ b/deeptab/models/mlp.py
@@ -14,7 +14,7 @@ class MLPRegressor(SklearnBaseRegressor):
with the default MLP configuration.
""",
examples="""
- >>> from deeptabular.models import MLPRegressor
+ >>> from deeptab.models import MLPRegressor
>>> model = MLPRegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -34,7 +34,7 @@ class MLPClassifier(SklearnBaseClassifier):
with the default MLP configuration.
""",
examples="""
- >>> from deeptabular.models import MLPClassifier
+ >>> from deeptab.models import MLPClassifier
>>> model = MLPClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -54,7 +54,7 @@ class MLPLSS(SklearnBaseLSS):
with the default MLP configuration.
""",
examples="""
- >>> from deeptabular.models import MLPLSS
+ >>> from deeptab.models import MLPLSS
>>> model = MLPLSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/modern_nca.py b/deeptab/models/modern_nca.py
similarity index 92%
rename from deeptabular/models/modern_nca.py
rename to deeptab/models/modern_nca.py
index dc6765f5..4b784798 100644
--- a/deeptabular/models/modern_nca.py
+++ b/deeptab/models/modern_nca.py
@@ -14,7 +14,7 @@ class ModernNCARegressor(SklearnBaseRegressor):
with the default ModernNCA configuration.
""",
examples="""
- >>> from deeptabular.models import ModernNCARegressor
+ >>> from deeptab.models import ModernNCARegressor
>>> model = ModernNCARegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -34,7 +34,7 @@ class ModernNCAClassifier(SklearnBaseClassifier):
with the default ModernNCA configuration.
""",
examples="""
- >>> from deeptabular.models import ModernNCAClassifier
+ >>> from deeptab.models import ModernNCAClassifier
>>> model = ModernNCAClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -54,7 +54,7 @@ class ModernNCALSS(SklearnBaseLSS):
with the default ModernNCA configuration.
""",
examples="""
- >>> from deeptabular.models import ModernNCALSS
+ >>> from deeptab.models import ModernNCALSS
>>> model = ModernNCALSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/ndtf.py b/deeptab/models/ndtf.py
similarity index 100%
rename from deeptabular/models/ndtf.py
rename to deeptab/models/ndtf.py
diff --git a/deeptabular/models/node.py b/deeptab/models/node.py
similarity index 93%
rename from deeptabular/models/node.py
rename to deeptab/models/node.py
index cd818c32..7275385c 100644
--- a/deeptabular/models/node.py
+++ b/deeptab/models/node.py
@@ -14,7 +14,7 @@ class NODERegressor(SklearnBaseRegressor):
with the default NODE configuration.
""",
examples="""
- >>> from deeptabular.models import NODERegressor
+ >>> from deeptab.models import NODERegressor
>>> model = NODERegressor()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -35,7 +35,7 @@ class NODEClassifier(SklearnBaseClassifier):
with the default NODE configuration.
""",
examples="""
- >>> from deeptabular.models import NODEClassifier
+ >>> from deeptab.models import NODEClassifier
>>> model = NODEClassifier()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -56,7 +56,7 @@ class NODELSS(SklearnBaseLSS):
with the default NODE configuration.
""",
examples="""
- >>> from deeptabular.models import NODELSS
+ >>> from deeptab.models import NODELSS
>>> model = NODELSS()
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/resnet.py b/deeptab/models/resnet.py
similarity index 92%
rename from deeptabular/models/resnet.py
rename to deeptab/models/resnet.py
index ed1bb2aa..8d4ecf1a 100644
--- a/deeptabular/models/resnet.py
+++ b/deeptab/models/resnet.py
@@ -14,7 +14,7 @@ class ResNetRegressor(SklearnBaseRegressor):
with the default ResNet configuration.
""",
examples="""
- >>> from deeptabular.models import ResNetRegressor
+ >>> from deeptab.models import ResNetRegressor
>>> model = ResNetRegressor()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -34,7 +34,7 @@ class ResNetClassifier(SklearnBaseClassifier):
with the default ResNet configuration.
""",
examples="""
- >>> from deeptabular.models import ResNetClassifier
+ >>> from deeptab.models import ResNetClassifier
>>> model = ResNetClassifier()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -54,7 +54,7 @@ class ResNetLSS(SklearnBaseLSS):
with the default ResNet configuration.
""",
examples="""
- >>> from deeptabular.models import ResNetLSS
+ >>> from deeptab.models import ResNetLSS
>>> model = ResNetLSS()
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/saint.py b/deeptab/models/saint.py
similarity index 92%
rename from deeptabular/models/saint.py
rename to deeptab/models/saint.py
index 44663bdc..f62ab52d 100644
--- a/deeptabular/models/saint.py
+++ b/deeptab/models/saint.py
@@ -15,7 +15,7 @@ class and uses the SAINT model with the default SAINT
configuration.
""",
examples="""
- >>> from deeptabular.models import SAINTRegressor
+ >>> from deeptab.models import SAINTRegressor
>>> model = SAINTRegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -33,7 +33,7 @@ class SAINTClassifier(SklearnBaseClassifier):
"""SAINT Classifier. This class extends the SklearnBaseClassifier class
and uses the SAINT model with the default SAINT configuration.""",
examples="""
- >>> from deeptabular.models import SAINTClassifier
+ >>> from deeptab.models import SAINTClassifier
>>> model = SAINTClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -52,7 +52,7 @@ class SAINTLSS(SklearnBaseLSS):
This class extends the SklearnBaseLSS class and uses the
SAINT model with the default SAINT configuration.""",
examples="""
- >>> from deeptabular.models import SAINTLSS
+ >>> from deeptab.models import SAINTLSS
>>> model = SAINTLSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family="normal")
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/tabm.py b/deeptab/models/tabm.py
similarity index 93%
rename from deeptabular/models/tabm.py
rename to deeptab/models/tabm.py
index da5737b1..a64d46dd 100644
--- a/deeptabular/models/tabm.py
+++ b/deeptab/models/tabm.py
@@ -14,7 +14,7 @@ class TabMRegressor(SklearnBaseRegressor):
with the default TabM configuration.
""",
examples="""
- >>> from deeptabular.models import TabMRegressor
+ >>> from deeptab.models import TabMRegressor
>>> model = TabMRegressor(ensemble_size=32, model_type='full')
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -34,7 +34,7 @@ class TabMClassifier(SklearnBaseClassifier):
with the default TabM configuration.
""",
examples="""
- >>> from deeptabular.models import TabMClassifier
+ >>> from deeptab.models import TabMClassifier
>>> model = TabMClassifier(ensemble_size=32, model_type='full')
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -54,7 +54,7 @@ class TabMLSS(SklearnBaseLSS):
with the default TabM configuration.
""",
examples="""
- >>> from deeptabular.models import TabMLSS
+ >>> from deeptab.models import TabMLSS
>>> model = TabMLSS(ensemble_size=32, model_type='full')
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/tabr.py b/deeptab/models/tabr.py
similarity index 90%
rename from deeptabular/models/tabr.py
rename to deeptab/models/tabr.py
index 9172de72..76c4fba0 100644
--- a/deeptabular/models/tabr.py
+++ b/deeptab/models/tabr.py
@@ -1,63 +1,63 @@
-from ..base_models.tabr import TabR
-from ..configs.tabr_config import DefaultTabRConfig
-from ..utils.docstring_generator import generate_docstring
-from .utils.sklearn_base_classifier import SklearnBaseClassifier
-from .utils.sklearn_base_lss import SklearnBaseLSS
-from .utils.sklearn_base_regressor import SklearnBaseRegressor
-
-
-class TabRRegressor(SklearnBaseRegressor):
- __doc__ = generate_docstring(
- DefaultTabRConfig,
- model_description="""
- TabR regressor. This class extends the SklearnBaseRegressor class and uses the TabR model
- with the default TabR configuration.
- """,
- examples="""
- >>> from deeptabular.models import TabRRegressor
- >>> model = TabRRegressor()
- >>> model.fit(X_train, y_train)
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
- def __init__(self, **kwargs):
- super().__init__(model=TabR, config=DefaultTabRConfig, **kwargs)
-
-
-class TabRClassifier(SklearnBaseClassifier):
- __doc__ = generate_docstring(
- DefaultTabRConfig,
- model_description="""
- TabR classifier. This class extends the SklearnBaseClassifier class and uses the TabR model
- with the default TabR configuration.
- """,
- examples="""
- >>> from deeptabular.models import TabRClassifier
- >>> model = TabRClassifier()
- >>> model.fit(X_train, y_train)
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
- def __init__(self, **kwargs):
- super().__init__(model=TabR, config=DefaultTabRConfig, **kwargs)
-
-
-class TabRLSS(SklearnBaseLSS):
- __doc__ = generate_docstring(
- DefaultTabRConfig,
- model_description="""
- TabR regressor. This class extends the SklearnBaseLSS class and uses the TabR model
- with the default TabR configuration.
- """,
- examples="""
- >>> from deeptabular.models import TabRLSS
- >>> model = TabRLSS(d_model=64, family='normal')
- >>> model.fit(X_train, y_train)
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
- def __init__(self, **kwargs):
- super().__init__(model=TabR, config=DefaultTabRConfig, **kwargs)
+from ..base_models.tabr import TabR
+from ..configs.tabr_config import DefaultTabRConfig
+from ..utils.docstring_generator import generate_docstring
+from .utils.sklearn_base_classifier import SklearnBaseClassifier
+from .utils.sklearn_base_lss import SklearnBaseLSS
+from .utils.sklearn_base_regressor import SklearnBaseRegressor
+
+
+class TabRRegressor(SklearnBaseRegressor):
+ __doc__ = generate_docstring(
+ DefaultTabRConfig,
+ model_description="""
+ TabR regressor. This class extends the SklearnBaseRegressor class and uses the TabR model
+ with the default TabR configuration.
+ """,
+ examples="""
+ >>> from deeptab.models import TabRRegressor
+ >>> model = TabRRegressor()
+ >>> model.fit(X_train, y_train)
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+ def __init__(self, **kwargs):
+ super().__init__(model=TabR, config=DefaultTabRConfig, **kwargs)
+
+
+class TabRClassifier(SklearnBaseClassifier):
+ __doc__ = generate_docstring(
+ DefaultTabRConfig,
+ model_description="""
+ TabR classifier. This class extends the SklearnBaseClassifier class and uses the TabR model
+ with the default TabR configuration.
+ """,
+ examples="""
+ >>> from deeptab.models import TabRClassifier
+ >>> model = TabRClassifier()
+ >>> model.fit(X_train, y_train)
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+ def __init__(self, **kwargs):
+ super().__init__(model=TabR, config=DefaultTabRConfig, **kwargs)
+
+
+class TabRLSS(SklearnBaseLSS):
+ __doc__ = generate_docstring(
+ DefaultTabRConfig,
+ model_description="""
+ TabR regressor. This class extends the SklearnBaseLSS class and uses the TabR model
+ with the default TabR configuration.
+ """,
+ examples="""
+ >>> from deeptab.models import TabRLSS
+ >>> model = TabRLSS(d_model=64, family='normal')
+ >>> model.fit(X_train, y_train)
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+ def __init__(self, **kwargs):
+ super().__init__(model=TabR, config=DefaultTabRConfig, **kwargs)
diff --git a/deeptabular/models/tabtransformer.py b/deeptab/models/tabtransformer.py
similarity index 92%
rename from deeptabular/models/tabtransformer.py
rename to deeptab/models/tabtransformer.py
index 15ddcb6d..9c01dcf8 100644
--- a/deeptabular/models/tabtransformer.py
+++ b/deeptab/models/tabtransformer.py
@@ -14,7 +14,7 @@ class TabTransformerRegressor(SklearnBaseRegressor):
with the default TabTransformer configuration.
""",
examples="""
- >>> from deeptabular.models import TabTransformerRegressor
+ >>> from deeptab.models import TabTransformerRegressor
>>> model = TabTransformerRegressor()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -36,7 +36,7 @@ class TabTransformerClassifier(SklearnBaseClassifier):
with the default TabTransformer configuration.
""",
examples="""
- >>> from deeptabular.models import TabTransformerClassifier
+ >>> from deeptab.models import TabTransformerClassifier
>>> model = TabTransformerClassifier()
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -58,7 +58,7 @@ class TabTransformerLSS(SklearnBaseLSS):
with the default TabTransformer configuration.
""",
examples="""
- >>> from deeptabular.models import TabTransformerLSS
+ >>> from deeptab.models import TabTransformerLSS
>>> model = TabTransformerLSS()
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/tabularnn.py b/deeptab/models/tabularnn.py
similarity index 95%
rename from deeptabular/models/tabularnn.py
rename to deeptab/models/tabularnn.py
index d0303832..a14910aa 100644
--- a/deeptabular/models/tabularnn.py
+++ b/deeptab/models/tabularnn.py
@@ -1,133 +1,133 @@
-from ..base_models.tabularnn import TabulaRNN
-from ..configs.tabularnn_config import DefaultTabulaRNNConfig
-from ..utils.docstring_generator import generate_docstring
-from .utils.sklearn_base_classifier import SklearnBaseClassifier
-from .utils.sklearn_base_lss import SklearnBaseLSS
-from .utils.sklearn_base_regressor import SklearnBaseRegressor
-
-
-class TabulaRNNRegressor(SklearnBaseRegressor):
- __doc__ = generate_docstring(
- DefaultTabulaRNNConfig,
- model_description="""
- TabulaRNN regressor. This class extends the SklearnBaseRegressor
- class and uses the TabulaRNN model with the default TabulaRNN
- configuration.
- """,
- examples="""
- >>> from deeptabular.models import TabulaRNNRegressor
- >>> model = TabulaRNNRegressor(d_model=64)
- >>> model.fit(X_train, y_train)
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
-
- def __init__(self, **kwargs):
- super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
-
-
-class TabulaRNNClassifier(SklearnBaseClassifier):
- __doc__ = generate_docstring(
- DefaultTabulaRNNConfig,
- model_description="""
- TabulaRNN classifier. This class extends the SklearnBaseClassifier
- class and uses the TabulaRNN model with the default TabulaRNN
- configuration.
- """,
- examples="""
- >>> from deeptabular.models import TabulaRNNClassifier
- >>> model = TabulaRNNClassifier(d_model=64)
- >>> model.fit(X_train, y_train)
- >>> preds = model.predict(X_test)
- >>> model.evaluate(X_test, y_test)
- """,
- )
-
- def __init__(self, **kwargs):
- super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
-
-
-class TabulaRNNLSS(SklearnBaseLSS):
- """RNN LSS. This class extends the SklearnBaseLSS class and uses the TabulaRNN model with the default TabulaRNN
- configuration.
-
- The accepted arguments to the TabulaRNNLSS class include both the attributes in the DefaultTabulaRNNConfig dataclass
- and the parameters for the Preprocessor class.
-
- Parameters
- ----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- model_type : str, default="RNN"
- type of model, one of "RNN", "LSTM", "GRU"
- family : str, default=None
- Distributional family to be used for the model.
- lr_patience : int, default=10
- Number of epochs with no improvement after which learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 penalty) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
- d_model : int, default=64
- Dimensionality of the model.
- n_layers : int, default=8
- Number of layers in the transformer.
- norm : str, default="RMSNorm"
- Normalization method to be used.
- activation : callable, default=nn.SELU()
- Activation function for the transformer.
- embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
- head_layer_sizes : list, default=(128, 64, 32)
- Sizes of the layers in the head of the model.
- head_dropout : float, default=0.5
- Dropout rate for the head layers.
- head_skip_layers : bool, default=False
- Whether to skip layers in the head.
- head_activation : callable, default=nn.SELU()
- Activation function for the head layers.
- head_use_batch_norm : bool, default=False
- Whether to use batch normalization in the head layers.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding.
- pooling_method : str, default="cls"
- Pooling method to be used ('cls', 'avg', etc.).
- norm_first : bool, default=False
- Whether to apply normalization before other operations in each transformer block.
- bias : bool, default=True
- Whether to use bias in the linear layers.
- rnn_activation : callable, default=nn.SELU()
- Activation function for the transformer layers.
- bidirectional : bool, default=False.
- Whether to process data bidirectionally
- cat_encoding : str, default="int"
- Encoding method for categorical features.
- n_bins : int, default=50
- The number of bins to use for numerical feature binning. This parameter is relevant
- only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
- numerical_preprocessing : str, default="ple"
- The preprocessing strategy for numerical features. Valid options are
- 'binning', 'one_hot', 'standardization', and 'normalization'.
- use_decision_tree_bins : bool, default=False
- If True, uses decision tree regression/classification to determine
- optimal bin edges for numerical feature binning. This parameter is
- relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
- binning_strategy : str, default="uniform"
- Defines the strategy for binning numerical features. Options include 'uniform',
- 'quantile', or other sklearn-compatible strategies.
- cat_cutoff : float or int, default=0.03
- Indicates the cutoff after which integer values are treated as categorical.
- If float, it's treated as a percentage. If int, it's the maximum number of
- unique values for a column to be considered categorical.
- treat_all_integers_as_numerical : bool, default=False
- If True, all integer columns will be treated as numerical, regardless
- of their unique value count or proportion.
- degree : int, default=3
- The degree of the polynomial features to be used in preprocessing.
- knots : int, default=12
- The number of knots to be used in spline transformations.
- """
-
- def __init__(self, **kwargs):
- super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
+from ..base_models.tabularnn import TabulaRNN
+from ..configs.tabularnn_config import DefaultTabulaRNNConfig
+from ..utils.docstring_generator import generate_docstring
+from .utils.sklearn_base_classifier import SklearnBaseClassifier
+from .utils.sklearn_base_lss import SklearnBaseLSS
+from .utils.sklearn_base_regressor import SklearnBaseRegressor
+
+
+class TabulaRNNRegressor(SklearnBaseRegressor):
+ __doc__ = generate_docstring(
+ DefaultTabulaRNNConfig,
+ model_description="""
+ TabulaRNN regressor. This class extends the SklearnBaseRegressor
+ class and uses the TabulaRNN model with the default TabulaRNN
+ configuration.
+ """,
+ examples="""
+ >>> from deeptab.models import TabulaRNNRegressor
+ >>> model = TabulaRNNRegressor(d_model=64)
+ >>> model.fit(X_train, y_train)
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+
+ def __init__(self, **kwargs):
+ super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
+
+
+class TabulaRNNClassifier(SklearnBaseClassifier):
+ __doc__ = generate_docstring(
+ DefaultTabulaRNNConfig,
+ model_description="""
+ TabulaRNN classifier. This class extends the SklearnBaseClassifier
+ class and uses the TabulaRNN model with the default TabulaRNN
+ configuration.
+ """,
+ examples="""
+ >>> from deeptab.models import TabulaRNNClassifier
+ >>> model = TabulaRNNClassifier(d_model=64)
+ >>> model.fit(X_train, y_train)
+ >>> preds = model.predict(X_test)
+ >>> model.evaluate(X_test, y_test)
+ """,
+ )
+
+ def __init__(self, **kwargs):
+ super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
+
+
+class TabulaRNNLSS(SklearnBaseLSS):
+ """RNN LSS. This class extends the SklearnBaseLSS class and uses the TabulaRNN model with the default TabulaRNN
+ configuration.
+
+ The accepted arguments to the TabulaRNNLSS class include both the attributes in the DefaultTabulaRNNConfig dataclass
+ and the parameters for the Preprocessor class.
+
+ Parameters
+ ----------
+ lr : float, default=1e-04
+ Learning rate for the optimizer.
+ model_type : str, default="RNN"
+ type of model, one of "RNN", "LSTM", "GRU"
+ family : str, default=None
+ Distributional family to be used for the model.
+ lr_patience : int, default=10
+ Number of epochs with no improvement after which learning rate will be reduced.
+ weight_decay : float, default=1e-06
+ Weight decay (L2 penalty) for the optimizer.
+ lr_factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ d_model : int, default=64
+ Dimensionality of the model.
+ n_layers : int, default=8
+ Number of layers in the transformer.
+ norm : str, default="RMSNorm"
+ Normalization method to be used.
+ activation : callable, default=nn.SELU()
+ Activation function for the transformer.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for numerical embeddings.
+ head_layer_sizes : list, default=(128, 64, 32)
+ Sizes of the layers in the head of the model.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to skip layers in the head.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ pooling_method : str, default="cls"
+ Pooling method to be used ('cls', 'avg', etc.).
+ norm_first : bool, default=False
+ Whether to apply normalization before other operations in each transformer block.
+ bias : bool, default=True
+ Whether to use bias in the linear layers.
+ rnn_activation : callable, default=nn.SELU()
+ Activation function for the transformer layers.
+ bidirectional : bool, default=False.
+ Whether to process data bidirectionally
+ cat_encoding : str, default="int"
+ Encoding method for categorical features.
+ n_bins : int, default=50
+ The number of bins to use for numerical feature binning. This parameter is relevant
+ only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ numerical_preprocessing : str, default="ple"
+ The preprocessing strategy for numerical features. Valid options are
+ 'binning', 'one_hot', 'standardization', and 'normalization'.
+ use_decision_tree_bins : bool, default=False
+ If True, uses decision tree regression/classification to determine
+ optimal bin edges for numerical feature binning. This parameter is
+ relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ binning_strategy : str, default="uniform"
+ Defines the strategy for binning numerical features. Options include 'uniform',
+ 'quantile', or other sklearn-compatible strategies.
+ cat_cutoff : float or int, default=0.03
+ Indicates the cutoff after which integer values are treated as categorical.
+ If float, it's treated as a percentage. If int, it's the maximum number of
+ unique values for a column to be considered categorical.
+ treat_all_integers_as_numerical : bool, default=False
+ If True, all integer columns will be treated as numerical, regardless
+ of their unique value count or proportion.
+ degree : int, default=3
+ The degree of the polynomial features to be used in preprocessing.
+ knots : int, default=12
+ The number of knots to be used in spline transformations.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
diff --git a/deeptabular/models/tangos.py b/deeptab/models/tangos.py
similarity index 92%
rename from deeptabular/models/tangos.py
rename to deeptab/models/tangos.py
index 26f2dce2..abbd437a 100644
--- a/deeptabular/models/tangos.py
+++ b/deeptab/models/tangos.py
@@ -14,7 +14,7 @@ class TangosRegressor(SklearnBaseRegressor):
with the default Tangos configuration.
""",
examples="""
- >>> from deeptabular.models import TangosRegressor
+ >>> from deeptab.models import TangosRegressor
>>> model = TangosRegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -34,7 +34,7 @@ class TangosClassifier(SklearnBaseClassifier):
with the default Tangos configuration.
""",
examples="""
- >>> from deeptabular.models import TangosClassifier
+ >>> from deeptab.models import TangosClassifier
>>> model = TangosClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -54,7 +54,7 @@ class TangosLSS(SklearnBaseLSS):
with the default Tangos configuration.
""",
examples="""
- >>> from deeptabular.models import TangosLSS
+ >>> from deeptab.models import TangosLSS
>>> model = TangosLSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family='normal')
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/trompt.py b/deeptab/models/trompt.py
similarity index 92%
rename from deeptabular/models/trompt.py
rename to deeptab/models/trompt.py
index cd96a6a8..d827a996 100644
--- a/deeptabular/models/trompt.py
+++ b/deeptab/models/trompt.py
@@ -15,7 +15,7 @@ class and uses the Trompt model with the default Trompt
configuration.
""",
examples="""
- >>> from deeptabular.models import TromptRegressor
+ >>> from deeptab.models import TromptRegressor
>>> model = TromptRegressor(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -33,7 +33,7 @@ class TromptClassifier(SklearnBaseClassifier):
"""Trompt Classifier. This class extends the SklearnBaseClassifier class
and uses the Trompt model with the default Trompt configuration.""",
examples="""
- >>> from deeptabular.models import TromptClassifier
+ >>> from deeptab.models import TromptClassifier
>>> model = TromptClassifier(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
@@ -52,7 +52,7 @@ class TromptLSS(SklearnBaseLSS):
This class extends the SklearnBaseLSS class and uses the
Trompt model with the default Trompt configuration.""",
examples="""
- >>> from deeptabular.models import TromptLSS
+ >>> from deeptab.models import TromptLSS
>>> model = TromptLSS(d_model=64, n_layers=8)
>>> model.fit(X_train, y_train, family="normal")
>>> preds = model.predict(X_test)
diff --git a/deeptabular/models/utils/__init__.py b/deeptab/models/utils/__init__.py
similarity index 100%
rename from deeptabular/models/utils/__init__.py
rename to deeptab/models/utils/__init__.py
diff --git a/deeptabular/models/utils/sklearn_base_classifier.py b/deeptab/models/utils/sklearn_base_classifier.py
similarity index 100%
rename from deeptabular/models/utils/sklearn_base_classifier.py
rename to deeptab/models/utils/sklearn_base_classifier.py
diff --git a/deeptabular/models/utils/sklearn_base_lss.py b/deeptab/models/utils/sklearn_base_lss.py
similarity index 100%
rename from deeptabular/models/utils/sklearn_base_lss.py
rename to deeptab/models/utils/sklearn_base_lss.py
diff --git a/deeptabular/models/utils/sklearn_base_regressor.py b/deeptab/models/utils/sklearn_base_regressor.py
similarity index 100%
rename from deeptabular/models/utils/sklearn_base_regressor.py
rename to deeptab/models/utils/sklearn_base_regressor.py
diff --git a/deeptabular/models/utils/sklearn_parent.py b/deeptab/models/utils/sklearn_parent.py
similarity index 100%
rename from deeptabular/models/utils/sklearn_parent.py
rename to deeptab/models/utils/sklearn_parent.py
diff --git a/deeptabular/utils/__init__.py b/deeptab/utils/__init__.py
similarity index 100%
rename from deeptabular/utils/__init__.py
rename to deeptab/utils/__init__.py
diff --git a/deeptabular/utils/config_mapper.py b/deeptab/utils/config_mapper.py
similarity index 97%
rename from deeptabular/utils/config_mapper.py
rename to deeptab/utils/config_mapper.py
index 19609e44..42df2426 100644
--- a/deeptabular/utils/config_mapper.py
+++ b/deeptab/utils/config_mapper.py
@@ -1,141 +1,141 @@
-import torch.nn as nn
-from skopt.space import Categorical, Integer, Real
-
-from ..arch_utils.transformer_utils import ReGLU
-
-
-def round_to_nearest_16(x):
- """Rounds the value to the nearest multiple of 16."""
- return int(round(x / 16) * 16)
-
-
-def get_search_space(
- config,
- fixed_params={
- "pooling_method": "avg",
- "head_skip_layers": False,
- "head_layer_size_length": 0,
- "cat_encoding": "int",
- "head_skip_layer": False,
- "use_cls": False,
- },
- custom_search_space=None,
-):
- """Given a model configuration, return the hyperparameter search space based on the config attributes.
-
- Parameters
- ----------
- config : dataclass
- The configuration object for the model.
- fixed_params : dict, optional
- Dictionary of fixed parameters and their values. Defaults to
- {"pooling_method": "avg", "head_skip_layers": False, "head_layer_size_length": 0}.
- custom_search_space : dict, optional
- Dictionary defining custom search spaces for parameters.
- Overrides the default `search_space_mapping` for the specified parameters.
-
- Returns
- -------
- param_names : list
- A list of parameter names to be optimized.
- param_space : list
- A list of hyperparameter ranges for Bayesian optimization.
- """
-
- # Handle the custom search space
- if custom_search_space is None:
- custom_search_space = {}
-
- # Base search space mapping
- search_space_mapping = {
- # Learning rate-related parameters
- "lr": Real(1e-6, 1e-2, prior="log-uniform"),
- "lr_patience": Integer(5, 20),
- "lr_factor": Real(0.1, 0.5),
- # Model architecture parameters
- "n_layers": Integer(1, 8),
- "d_model": Categorical([32, 64, 128, 256, 512, 1024]),
- "dropout": Real(0.0, 0.5),
- "expand_factor": Integer(1, 4),
- "d_state": Categorical([32, 64, 128, 256]),
- "ff_dropout": Real(0.0, 0.5),
- "rnn_dropout": Real(0.0, 0.5),
- "attn_dropout": Real(0.0, 0.5),
- "n_heads": Categorical([2, 4, 8]),
- "transformer_dim_feedforward": Integer(16, 512),
- # Convolution-related parameters
- "conv_bias": Categorical([True, False]),
- # Normalization and regularization
- "norm": Categorical(["LayerNorm", "RMSNorm"]),
- "weight_decay": Real(1e-8, 1e-2, prior="log-uniform"),
- "layer_norm_eps": Real(1e-7, 1e-4),
- "head_dropout": Real(0.0, 0.5),
- "bias": Categorical([True, False]),
- "norm_first": Categorical([True, False]),
- # Pooling, activation, and head layer settings
- "pooling_method": Categorical(["avg", "max", "cls", "sum"]),
- "activation": Categorical(["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU", "SiLU"]),
- "embedding_activation": Categorical(["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU"]),
- "rnn_activation": Categorical(["relu", "tanh"]),
- "transformer_activation": Categorical(["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU", "ReGLU"]),
- "head_skip_layers": Categorical([True, False]),
- "head_use_batch_norm": Categorical([True, False]),
- # Sequence-related settings
- "bidirectional": Categorical([True, False]),
- "use_learnable_interaction": Categorical([True, False]),
- "use_cls": Categorical([True, False]),
- # Feature encoding
- "cat_encoding": Categorical(["int", "one-hot"]),
- }
-
- # Apply custom search space overrides
- search_space_mapping.update(custom_search_space)
-
- param_names = []
- param_space = []
-
- # Iterate through config fields
- for field in config.__dataclass_fields__:
- if field in fixed_params:
- # Fix the parameter value directly in the config
- setattr(config, field, fixed_params[field])
- continue # Skip optimization for this parameter
-
- if field in search_space_mapping:
- # Add to search space if not fixed
- param_names.append(field)
- param_space.append(search_space_mapping[field])
-
- # Handle dynamic head_layer_sizes based on head_layer_size_length
- if "head_layer_sizes" in config.__dataclass_fields__:
- head_layer_size_length = fixed_params.get("head_layer_size_length", 0)
-
- # If no layers are desired, set head_layer_sizes to []
- if head_layer_size_length == 0:
- config.head_layer_sizes = []
- else:
- # Optimize the number of head layers
- max_head_layers = 5
- param_names.append("head_layer_size_length")
- param_space.append(Integer(1, max_head_layers))
-
- # Optimize individual layer sizes
- layer_size_min, layer_size_max = 16, 512
- for i in range(max_head_layers):
- layer_key = f"head_layer_size_{i+1}"
- param_names.append(layer_key)
- param_space.append(Integer(layer_size_min, layer_size_max))
-
- return param_names, param_space
-
-
-activation_mapper = {
- "ReLU": nn.ReLU(),
- "Tanh": nn.Tanh(),
- "SiLU": nn.SiLU(),
- "LeakyReLU": nn.LeakyReLU(),
- "Identity": nn.Identity(),
- "Linear": nn.Identity(),
- "SELU": nn.SELU(),
- "ReGLU": ReGLU(),
-}
+import torch.nn as nn
+from skopt.space import Categorical, Integer, Real
+
+from ..arch_utils.transformer_utils import ReGLU
+
+
+def round_to_nearest_16(x):
+ """Rounds the value to the nearest multiple of 16."""
+ return int(round(x / 16) * 16)
+
+
+def get_search_space(
+ config,
+ fixed_params={
+ "pooling_method": "avg",
+ "head_skip_layers": False,
+ "head_layer_size_length": 0,
+ "cat_encoding": "int",
+ "head_skip_layer": False,
+ "use_cls": False,
+ },
+ custom_search_space=None,
+):
+ """Given a model configuration, return the hyperparameter search space based on the config attributes.
+
+ Parameters
+ ----------
+ config : dataclass
+ The configuration object for the model.
+ fixed_params : dict, optional
+ Dictionary of fixed parameters and their values. Defaults to
+ {"pooling_method": "avg", "head_skip_layers": False, "head_layer_size_length": 0}.
+ custom_search_space : dict, optional
+ Dictionary defining custom search spaces for parameters.
+ Overrides the default `search_space_mapping` for the specified parameters.
+
+ Returns
+ -------
+ param_names : list
+ A list of parameter names to be optimized.
+ param_space : list
+ A list of hyperparameter ranges for Bayesian optimization.
+ """
+
+ # Handle the custom search space
+ if custom_search_space is None:
+ custom_search_space = {}
+
+ # Base search space mapping
+ search_space_mapping = {
+ # Learning rate-related parameters
+ "lr": Real(1e-6, 1e-2, prior="log-uniform"),
+ "lr_patience": Integer(5, 20),
+ "lr_factor": Real(0.1, 0.5),
+ # Model architecture parameters
+ "n_layers": Integer(1, 8),
+ "d_model": Categorical([32, 64, 128, 256, 512, 1024]),
+ "dropout": Real(0.0, 0.5),
+ "expand_factor": Integer(1, 4),
+ "d_state": Categorical([32, 64, 128, 256]),
+ "ff_dropout": Real(0.0, 0.5),
+ "rnn_dropout": Real(0.0, 0.5),
+ "attn_dropout": Real(0.0, 0.5),
+ "n_heads": Categorical([2, 4, 8]),
+ "transformer_dim_feedforward": Integer(16, 512),
+ # Convolution-related parameters
+ "conv_bias": Categorical([True, False]),
+ # Normalization and regularization
+ "norm": Categorical(["LayerNorm", "RMSNorm"]),
+ "weight_decay": Real(1e-8, 1e-2, prior="log-uniform"),
+ "layer_norm_eps": Real(1e-7, 1e-4),
+ "head_dropout": Real(0.0, 0.5),
+ "bias": Categorical([True, False]),
+ "norm_first": Categorical([True, False]),
+ # Pooling, activation, and head layer settings
+ "pooling_method": Categorical(["avg", "max", "cls", "sum"]),
+ "activation": Categorical(["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU", "SiLU"]),
+ "embedding_activation": Categorical(["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU"]),
+ "rnn_activation": Categorical(["relu", "tanh"]),
+ "transformer_activation": Categorical(["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU", "ReGLU"]),
+ "head_skip_layers": Categorical([True, False]),
+ "head_use_batch_norm": Categorical([True, False]),
+ # Sequence-related settings
+ "bidirectional": Categorical([True, False]),
+ "use_learnable_interaction": Categorical([True, False]),
+ "use_cls": Categorical([True, False]),
+ # Feature encoding
+ "cat_encoding": Categorical(["int", "one-hot"]),
+ }
+
+ # Apply custom search space overrides
+ search_space_mapping.update(custom_search_space)
+
+ param_names = []
+ param_space = []
+
+ # Iterate through config fields
+ for field in config.__dataclass_fields__:
+ if field in fixed_params:
+ # Fix the parameter value directly in the config
+ setattr(config, field, fixed_params[field])
+ continue # Skip optimization for this parameter
+
+ if field in search_space_mapping:
+ # Add to search space if not fixed
+ param_names.append(field)
+ param_space.append(search_space_mapping[field])
+
+ # Handle dynamic head_layer_sizes based on head_layer_size_length
+ if "head_layer_sizes" in config.__dataclass_fields__:
+ head_layer_size_length = fixed_params.get("head_layer_size_length", 0)
+
+ # If no layers are desired, set head_layer_sizes to []
+ if head_layer_size_length == 0:
+ config.head_layer_sizes = []
+ else:
+ # Optimize the number of head layers
+ max_head_layers = 5
+ param_names.append("head_layer_size_length")
+ param_space.append(Integer(1, max_head_layers))
+
+ # Optimize individual layer sizes
+ layer_size_min, layer_size_max = 16, 512
+ for i in range(max_head_layers):
+ layer_key = f"head_layer_size_{i+1}"
+ param_names.append(layer_key)
+ param_space.append(Integer(layer_size_min, layer_size_max))
+
+ return param_names, param_space
+
+
+activation_mapper = {
+ "ReLU": nn.ReLU(),
+ "Tanh": nn.Tanh(),
+ "SiLU": nn.SiLU(),
+ "LeakyReLU": nn.LeakyReLU(),
+ "Identity": nn.Identity(),
+ "Linear": nn.Identity(),
+ "SELU": nn.SELU(),
+ "ReGLU": ReGLU(),
+}
diff --git a/deeptabular/utils/distributional_metrics.py b/deeptab/utils/distributional_metrics.py
similarity index 100%
rename from deeptabular/utils/distributional_metrics.py
rename to deeptab/utils/distributional_metrics.py
diff --git a/deeptabular/utils/distributions.py b/deeptab/utils/distributions.py
similarity index 97%
rename from deeptabular/utils/distributions.py
rename to deeptab/utils/distributions.py
index 75395ce9..e79291b6 100644
--- a/deeptabular/utils/distributions.py
+++ b/deeptab/utils/distributions.py
@@ -1,670 +1,670 @@
-import numpy as np
-import torch
-import torch.distributions as dist
-
-
-class BaseDistribution(torch.nn.Module):
- """
- The base class for various statistical distributions, providing a common interface and utilities.
-
- This class defines the basic structure and methods that are inherited by specific distribution
- classes, allowing for the implementation of custom distributions with specific parameter transformations
- and loss computations.
-
- Attributes
- ----------
- _name (str): The name of the distribution.
- param_names (list of str): A list of names for the parameters of the distribution.
- param_count (int): The number of parameters for the distribution.
- predefined_transforms (dict): A dictionary of predefined transformation functions for parameters.
-
- Parameters
- ----------
- name (str): The name of the distribution.
- param_names (list of str): A list of names for the parameters of the distribution.
- """
-
- def __init__(self, name, param_names):
- super().__init__()
-
- self._name = name
- self.param_names = param_names
- self.param_count = len(param_names)
- # Predefined transformation functions accessible to all subclasses
- self.predefined_transforms = {
- "positive": torch.nn.functional.softplus,
- "none": lambda x: x,
- "square": lambda x: x**2,
- "exp": torch.exp,
- "sqrt": torch.sqrt,
- "probabilities": lambda x: torch.softmax(x, dim=-1),
- # Adding a small constant for numerical stability
- "log": lambda x: torch.log(x + 1e-6),
- }
-
- @property
- def name(self):
- return self._name
-
- @property
- def parameter_count(self):
- return self.param_count
-
- def get_transform(self, transform_name):
- """
- Retrieve a transformation function by name, or return the function if it's custom.
- """
- if callable(transform_name):
- # Custom transformation function provided
- return transform_name
- # Default to 'none'
- return self.predefined_transforms.get(transform_name, lambda x: x)
-
- def compute_loss(self, predictions, y_true):
- """
- Computes the loss (e.g., negative log likelihood) for the distribution given
- predictions and true values.
-
- This method must be implemented by subclasses.
-
- Parameters
- ----------
- predictions (torch.Tensor): The predicted parameters of the distribution.
- y_true (torch.Tensor): The true values.
-
- Raises
- ------
- NotImplementedError: If the subclass does not implement this method.
- """
- raise NotImplementedError("Subclasses must implement this method.")
-
- def evaluate_nll(self, y_true, y_pred):
- """
- Evaluates the negative log likelihood (NLL) for given true values and predictions.
-
- Parameters
- ----------
- y_true (array-like): The true values.
- y_pred (array-like): The predicted values.
-
- Returns
- -------
- dict: A dictionary containing the NLL value.
- """
-
- # Convert numpy arrays to torch tensors
- y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
- y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
-
- # Compute NLL using the provided loss function
- nll_loss_tensor = self.compute_loss(y_pred_tensor, y_true_tensor)
-
- # Convert the NLL loss tensor back to a numpy array and return
- return {
- "NLL": nll_loss_tensor.detach().numpy(),
- }
-
- def forward(self, predictions):
- """
- Apply the appropriate transformations to the predicted parameters.
-
- Parameters:
- predictions (torch.Tensor): The predicted parameters of the distribution.
-
- Returns:
- torch.Tensor: A tensor with transformed parameters.
- """
- transformed_params = []
- for idx, param_name in enumerate(self.param_names):
- transform_func = self.get_transform(
- getattr(self, f"{param_name}_transform", "none")
- )
- transformed_params.append(
- transform_func(predictions[:, idx]).unsqueeze( # type: ignore
- 1
- ) # type: ignore
- )
- return torch.cat(transformed_params, dim=1)
-
-
-class NormalDistribution(BaseDistribution):
- """
- Represents a Normal (Gaussian) distribution with parameters for mean and variance,
- including functionality for transforming these parameters and computing the loss.
-
- Inherits from BaseDistribution.
-
- Parameters
- ----------
- name (str): The name of the distribution. Defaults to "Normal".
- mean_transform (str or callable): The transformation for the mean parameter.
- Defaults to "none".
- var_transform (str or callable): The transformation for the variance parameter.
- Defaults to "positive".
- """
-
- def __init__(self, name="Normal", mean_transform="none", var_transform="positive"):
- param_names = [
- "mean",
- "variance",
- ]
- super().__init__(name, param_names)
-
- self.mean_transform = self.get_transform(mean_transform)
- self.variance_transform = self.get_transform(var_transform)
-
- def compute_loss(self, predictions, y_true):
- mean = self.mean_transform(predictions[:, self.param_names.index("mean")])
- variance = self.variance_transform(
- predictions[:, self.param_names.index("variance")]
- )
-
- normal_dist = dist.Normal(mean, variance)
-
- nll = -normal_dist.log_prob(y_true).mean()
- return nll
-
- def evaluate_nll(self, y_true, y_pred):
- metrics = super().evaluate_nll(y_true, y_pred)
-
- # Convert numpy arrays to torch tensors
- y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
- y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
-
- mse_loss = torch.nn.functional.mse_loss(
- y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]
- )
- rmse = np.sqrt(mse_loss.detach().numpy())
- mae = (
- torch.nn.functional.l1_loss(
- y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]
- )
- .detach()
- .numpy()
- )
-
- metrics["mse"] = mse_loss.detach().numpy()
- metrics["mae"] = mae
- metrics["rmse"] = rmse
-
- # Convert the NLL loss tensor back to a numpy array and return
- return metrics
-
-
-class PoissonDistribution(BaseDistribution):
- """
- Represents a Poisson distribution, typically used for modeling count data or the number of events
- occurring within a fixed interval of time or space. This class extends the BaseDistribution and
- includes parameter transformation and loss computation specific to the Poisson distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "Poisson".
- rate_transform (str or callable): Transformation to apply to the rate parameter
- to ensure it remains positive.
- """
-
- def __init__(self, name="Poisson", rate_transform="positive"):
- # Specify parameter name for Poisson distribution
- param_names = ["rate"]
- super().__init__(name, param_names)
- # Retrieve transformation function for rate
- self.rate_transform = self.get_transform(rate_transform)
-
- def compute_loss(self, predictions, y_true):
- rate = self.rate_transform(predictions[:, self.param_names.index("rate")])
-
- # Define the Poisson distribution with the transformed parameter
- poisson_dist = dist.Poisson(rate)
-
- # Compute the negative log-likelihood
- nll = -poisson_dist.log_prob(y_true).mean()
- return nll
-
- def evaluate_nll(self, y_true, y_pred):
- metrics = super().evaluate_nll(y_true, y_pred)
-
- # Convert numpy arrays to torch tensors
- y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
- y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
- rate = self.rate_transform(y_pred_tensor[:, self.param_names.index("rate")])
-
- mse_loss = torch.nn.functional.mse_loss(y_true_tensor, rate) # type: ignore
- rmse = np.sqrt(mse_loss.detach().numpy())
- mae = (
- torch.nn.functional.l1_loss(y_true_tensor, rate) # type: ignore
- .detach()
- .numpy() # type: ignore
- ) # type: ignore
- poisson_deviance = 2 * torch.sum(
- y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate)
- )
-
- metrics["mse"] = mse_loss.detach().numpy()
- metrics["mae"] = mae
- metrics["rmse"] = rmse
- metrics["poisson_deviance"] = poisson_deviance.detach().numpy()
-
- # Convert the NLL loss tensor back to a numpy array and return
- return metrics
-
-
-class InverseGammaDistribution(BaseDistribution):
- """
- Represents an Inverse Gamma distribution, often used as a prior distribution in Bayesian statistics,
- especially for scale parameters in other distributions. This class extends BaseDistribution and includes
- parameter transformation and loss computation specific to the Inverse Gamma distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "InverseGamma".
- shape_transform (str or callable): Transformation for the shape parameter to
- ensure it remains positive.
- scale_transform (str or callable): Transformation for the scale parameter to
- ensure it remains positive.
- """
-
- def __init__(
- self,
- name="InverseGamma",
- shape_transform="positive",
- scale_transform="positive",
- ):
- param_names = [
- "shape",
- "scale",
- ]
- super().__init__(name, param_names)
-
- self.shape_transform = self.get_transform(shape_transform)
- self.scale_transform = self.get_transform(scale_transform)
-
- def compute_loss(self, predictions, y_true):
- shape = self.shape_transform(predictions[:, self.param_names.index("shape")])
- scale = self.scale_transform(predictions[:, self.param_names.index("scale")])
-
- inverse_gamma_dist = dist.InverseGamma(shape, scale)
- # Compute the negative log-likelihood
- nll = -inverse_gamma_dist.log_prob(y_true).mean()
- return nll
-
-
-class BetaDistribution(BaseDistribution):
- """
- Represents a Beta distribution, a continuous distribution defined on the interval [0, 1], commonly used
- in Bayesian statistics for modeling probabilities. This class extends BaseDistribution and includes parameter
- transformation and loss computation specific to the Beta distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "Beta".
- shape_transform (str or callable): Transformation for the alpha (shape) parameter to ensure
- it remains positive.
- scale_transform (str or callable): Transformation for the beta (scale) parameter to ensure
- it remains positive.
- """
-
- def __init__(
- self,
- name="Beta",
- shape_transform="positive",
- scale_transform="positive",
- ):
- param_names = [
- "alpha",
- "beta",
- ]
- super().__init__(name, param_names)
-
- self.alpha_transform = self.get_transform(shape_transform)
- self.beta_transform = self.get_transform(scale_transform)
-
- def compute_loss(self, predictions, y_true):
- alpha = self.alpha_transform(predictions[:, self.param_names.index("alpha")])
- beta = self.beta_transform(predictions[:, self.param_names.index("beta")])
-
- beta_dist = dist.Beta(alpha, beta)
- # Compute the negative log-likelihood
- nll = -beta_dist.log_prob(y_true).mean()
- return nll
-
-
-class DirichletDistribution(BaseDistribution):
- """
- Represents a Dirichlet distribution, a multivariate generalization of the Beta distribution. It is commonly
- used in Bayesian statistics for modeling multinomial distribution probabilities. This class extends
- BaseDistribution and includes parameter transformation and loss computation
- specific to the Dirichlet distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "Dirichlet".
- concentration_transform (str or callable): Transformation to apply to
- concentration parameters to ensure they remain positive.
- """
-
- def __init__(self, name="Dirichlet", concentration_transform="positive"):
- # For Dirichlet, param_names could be dynamically set based on the dimensionality of alpha
- # For simplicity, we're not specifying individual names for each concentration parameter
- param_names = ["concentration"] # This is a simplification
- super().__init__(name, param_names)
- # Retrieve transformation function for concentration parameters
- self.concentration_transform = self.get_transform(concentration_transform)
-
- def compute_loss(self, predictions, y_true):
- # Apply the transformation to ensure all concentration parameters are positive
- # Assuming predictions is a 2D tensor where each row is a set of concentration parameters
- # for a Dirichlet distribution
- concentration = self.concentration_transform(predictions)
-
- dirichlet_dist = dist.Dirichlet(concentration)
-
- nll = -dirichlet_dist.log_prob(y_true).mean()
- return nll
-
-
-class GammaDistribution(BaseDistribution):
- """
- Represents a Gamma distribution, a two-parameter family of continuous probability distributions. It's
- widely used in various fields of science for modeling a wide range of phenomena. This class extends
- BaseDistribution and includes parameter transformation and loss computation specific to
- the Gamma distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "Gamma".
- shape_transform (str or callable): Transformation for the shape parameter to ensure it remains positive.
- rate_transform (str or callable): Transformation for the rate parameter to ensure it remains positive.
- """
-
- def __init__(
- self, name="Gamma", shape_transform="positive", rate_transform="positive"
- ):
- param_names = ["shape", "rate"]
- super().__init__(name, param_names)
-
- self.shape_transform = self.get_transform(shape_transform)
- self.rate_transform = self.get_transform(rate_transform)
-
- def compute_loss(self, predictions, y_true):
- shape = self.shape_transform(predictions[:, self.param_names.index("shape")])
- rate = self.rate_transform(predictions[:, self.param_names.index("rate")])
-
- # Define the Gamma distribution with the transformed parameters
- gamma_dist = dist.Gamma(shape, rate)
-
- # Compute the negative log-likelihood
- nll = -gamma_dist.log_prob(y_true).mean()
- return nll
-
-
-class StudentTDistribution(BaseDistribution):
- """
- Represents a Student's t-distribution, a family of continuous probability distributions that arise when
- estimating the mean of a normally distributed population in situations where the sample size is small.
- This class extends BaseDistribution and includes parameter transformation and loss computation specific
- to the Student's t-distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "StudentT".
- df_transform (str or callable): Transformation for the degrees of freedom parameter
- to ensure it remains positive.
- loc_transform (str or callable): Transformation for the location parameter.
- scale_transform (str or callable): Transformation for the scale parameter
- to ensure it remains positive.
- """
-
- def __init__(
- self,
- name="StudentT",
- df_transform="positive",
- loc_transform="none",
- scale_transform="positive",
- ):
- param_names = ["df", "loc", "scale"]
- super().__init__(name, param_names)
-
- self.df_transform = self.get_transform(df_transform)
- self.loc_transform = self.get_transform(loc_transform)
- self.scale_transform = self.get_transform(scale_transform)
-
- def compute_loss(self, predictions, y_true):
- df = self.df_transform(predictions[:, self.param_names.index("df")])
- loc = self.loc_transform(predictions[:, self.param_names.index("loc")])
- scale = self.scale_transform(predictions[:, self.param_names.index("scale")])
-
- student_t_dist = dist.StudentT(df, loc, scale) # type: ignore
-
- nll = -student_t_dist.log_prob(y_true).mean()
- return nll
-
- def evaluate_nll(self, y_true, y_pred):
- metrics = super().evaluate_nll(y_true, y_pred)
-
- # Convert numpy arrays to torch tensors
- y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
- y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
-
- mse_loss = torch.nn.functional.mse_loss(
- y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]
- )
- rmse = np.sqrt(mse_loss.detach().numpy())
- mae = (
- torch.nn.functional.l1_loss(
- y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]
- )
- .detach()
- .numpy()
- )
-
- metrics["mse"] = mse_loss.detach().numpy()
- metrics["mae"] = mae
- metrics["rmse"] = rmse
-
- # Convert the NLL loss tensor back to a numpy array and return
- return metrics
-
-
-class NegativeBinomialDistribution(BaseDistribution):
- """
- Represents a Negative Binomial distribution, often used for count data and modeling the number
- of failures before a specified number of successes occurs in a series of Bernoulli trials.
- This class extends BaseDistribution and includes parameter transformation and loss computation
- specific to the Negative Binomial distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "NegativeBinomial".
- mean_transform (str or callable): Transformation for the mean parameter to ensure it remains positive.
- dispersion_transform (str or callable): Transformation for the dispersion parameter to
- ensure it remains positive.
- """
-
- def __init__(
- self,
- name="NegativeBinomial",
- mean_transform="positive",
- dispersion_transform="positive",
- ):
- param_names = ["mean", "dispersion"]
- super().__init__(name, param_names)
-
- self.mean_transform = self.get_transform(mean_transform)
- self.dispersion_transform = self.get_transform(dispersion_transform)
-
- def compute_loss(self, predictions, y_true):
- # Apply transformations to ensure mean and dispersion parameters are positive
- mean = self.mean_transform(predictions[:, self.param_names.index("mean")])
- dispersion = self.dispersion_transform(
- predictions[:, self.param_names.index("dispersion")]
- )
-
- # Calculate the probability (p) and number of successes (r) from mean and dispersion
- # These calculations follow from the mean and variance of the negative binomial distribution
- # where variance = mean + mean^2 / dispersion
- r = torch.tensor(1.0) / dispersion
- p = r / (r + mean)
-
- # Define the Negative Binomial distribution with the transformed parameters
- negative_binomial_dist = dist.NegativeBinomial(total_count=r, probs=p)
-
- # Compute the negative log-likelihood
- nll = -negative_binomial_dist.log_prob(y_true).mean()
- return nll
-
-
-class CategoricalDistribution(BaseDistribution):
- """
- Represents a Categorical distribution, a discrete distribution that describes the possible results of a
- random variable that can take on one of K possible categories, with the probability of each category
- separately specified. This class extends BaseDistribution and includes parameter transformation and loss
- computation specific to the Categorical distribution.
-
- Parameters
- ----------
- name (str): The name of the distribution, defaulted to "Categorical".
- prob_transform (str or callable): Transformation for the probabilities to ensure
- they remain valid (i.e., non-negative and sum to 1).
- """
-
- def __init__(self, name="Categorical", prob_transform="probabilities"):
- # Specify parameter name for Poisson distribution
- param_names = ["probs"]
- super().__init__(name, param_names)
- # Retrieve transformation function for rate
- self.probs_transform = self.get_transform(prob_transform)
-
- def compute_loss(self, predictions, y_true):
- probs = self.probs_transform(predictions)
-
- # Define the Poisson distribution with the transformed parameter
- cat_dist = dist.Categorical(probs=probs)
-
- # Compute the negative log-likelihood
- nll = -cat_dist.log_prob(y_true).mean()
- return nll
-
-
-class Quantile(BaseDistribution):
- """
- Quantile Regression Loss class.
-
- This class computes the quantile loss (also known as pinball loss) for a set of quantiles.
- It is used to handle quantile regression tasks where we aim to predict a given quantile of the target distribution.
-
- Parameters
- ----------
- name : str, optional
- The name of the distribution, by default "Quantile".
- quantiles : list of float, optional
- A list of quantiles to be used for computing the loss, by default [0.25, 0.5, 0.75].
-
- Attributes
- ----------
- quantiles : list of float
- List of quantiles for which the pinball loss is computed.
-
- Methods
- -------
- compute_loss(predictions, y_true)
- Computes the quantile regression loss between the predictions and true values.
- """
-
- def __init__(self, name="Quantile", quantiles=[0.25, 0.5, 0.75]):
- # Use string representations of quantiles
- param_names = [f"q_{q}" for q in quantiles]
- super().__init__(name, param_names)
- self.quantiles = quantiles
-
- def compute_loss(self, predictions, y_true):
- if y_true.requires_grad:
- raise ValueError("y_true should not require gradients")
- if predictions.size(0) != y_true.size(0):
- raise ValueError("Batch size of predictions and y_true must match")
-
- losses = []
- for i, q in enumerate(self.quantiles):
- # Calculate errors for each quantile
- errors = y_true - predictions[:, i]
- # Compute the pinball loss
- quantile_loss = torch.max((q - 1) * errors, q * errors)
- losses.append(quantile_loss)
-
- # Sum losses across quantiles and compute mean
- loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1))
- return loss
-
-
-class JohnsonSuDistribution(BaseDistribution):
- """
- Represents a Johnson's SU distribution with parameters for skewness, shape, location, and scale.
-
- Parameters
- ----------
- name (str): The name of the distribution. Defaults to "JohnsonSu".
- skew_transform (str or callable): The transformation for the skewness parameter. Defaults to "none".
- shape_transform (str or callable): The transformation for the shape parameter. Defaults to "positive".
- loc_transform (str or callable): The transformation for the location parameter. Defaults to "none".
- scale_transform (str or callable): The transformation for the scale parameter. Defaults to "positive".
- """
-
- def __init__(
- self,
- name="JohnsonSu",
- skew_transform="none",
- shape_transform="positive",
- loc_transform="none",
- scale_transform="positive",
- ):
- param_names = ["skew", "shape", "location", "scale"]
- super().__init__(name, param_names)
-
- self.skew_transform = self.get_transform(skew_transform)
- self.shape_transform = self.get_transform(shape_transform)
- self.loc_transform = self.get_transform(loc_transform)
- self.scale_transform = self.get_transform(scale_transform)
-
- def log_prob(self, x, skew, shape, loc, scale):
- """
- Compute the log probability density of the Johnson's SU distribution.
- """
- z = skew + shape * torch.asinh((x - loc) / scale)
- log_pdf = (
- torch.log(shape / (scale * np.sqrt(2 * np.pi)))
- - 0.5 * z**2
- - 0.5 * torch.log(1 + ((x - loc) / scale) ** 2)
- )
- return log_pdf
-
- def compute_loss(self, predictions, y_true):
- skew = self.skew_transform(predictions[:, self.param_names.index("skew")])
- shape = self.shape_transform(predictions[:, self.param_names.index("shape")])
- loc = self.loc_transform(predictions[:, self.param_names.index("location")])
- scale = self.scale_transform(predictions[:, self.param_names.index("scale")])
-
- log_probs = self.log_prob(y_true, skew, shape, loc, scale)
- nll = -log_probs.mean()
- return nll
-
- def evaluate_nll(self, y_true, y_pred):
- metrics = super().evaluate_nll(y_true, y_pred)
-
- y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
- y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
-
- mse_loss = torch.nn.functional.mse_loss(
- y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]
- )
- rmse = np.sqrt(mse_loss.detach().numpy())
- mae = (
- torch.nn.functional.l1_loss(
- y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]
- )
- .detach()
- .numpy()
- )
-
- metrics.update({"mse": mse_loss.detach().numpy(), "mae": mae, "rmse": rmse})
-
- return metrics
+import numpy as np
+import torch
+import torch.distributions as dist
+
+
+class BaseDistribution(torch.nn.Module):
+ """
+ The base class for various statistical distributions, providing a common interface and utilities.
+
+ This class defines the basic structure and methods that are inherited by specific distribution
+ classes, allowing for the implementation of custom distributions with specific parameter transformations
+ and loss computations.
+
+ Attributes
+ ----------
+ _name (str): The name of the distribution.
+ param_names (list of str): A list of names for the parameters of the distribution.
+ param_count (int): The number of parameters for the distribution.
+ predefined_transforms (dict): A dictionary of predefined transformation functions for parameters.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution.
+ param_names (list of str): A list of names for the parameters of the distribution.
+ """
+
+ def __init__(self, name, param_names):
+ super().__init__()
+
+ self._name = name
+ self.param_names = param_names
+ self.param_count = len(param_names)
+ # Predefined transformation functions accessible to all subclasses
+ self.predefined_transforms = {
+ "positive": torch.nn.functional.softplus,
+ "none": lambda x: x,
+ "square": lambda x: x**2,
+ "exp": torch.exp,
+ "sqrt": torch.sqrt,
+ "probabilities": lambda x: torch.softmax(x, dim=-1),
+ # Adding a small constant for numerical stability
+ "log": lambda x: torch.log(x + 1e-6),
+ }
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def parameter_count(self):
+ return self.param_count
+
+ def get_transform(self, transform_name):
+ """
+ Retrieve a transformation function by name, or return the function if it's custom.
+ """
+ if callable(transform_name):
+ # Custom transformation function provided
+ return transform_name
+ # Default to 'none'
+ return self.predefined_transforms.get(transform_name, lambda x: x)
+
+ def compute_loss(self, predictions, y_true):
+ """
+ Computes the loss (e.g., negative log likelihood) for the distribution given
+ predictions and true values.
+
+ This method must be implemented by subclasses.
+
+ Parameters
+ ----------
+ predictions (torch.Tensor): The predicted parameters of the distribution.
+ y_true (torch.Tensor): The true values.
+
+ Raises
+ ------
+ NotImplementedError: If the subclass does not implement this method.
+ """
+ raise NotImplementedError("Subclasses must implement this method.")
+
+ def evaluate_nll(self, y_true, y_pred):
+ """
+ Evaluates the negative log likelihood (NLL) for given true values and predictions.
+
+ Parameters
+ ----------
+ y_true (array-like): The true values.
+ y_pred (array-like): The predicted values.
+
+ Returns
+ -------
+ dict: A dictionary containing the NLL value.
+ """
+
+ # Convert numpy arrays to torch tensors
+ y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
+ y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
+
+ # Compute NLL using the provided loss function
+ nll_loss_tensor = self.compute_loss(y_pred_tensor, y_true_tensor)
+
+ # Convert the NLL loss tensor back to a numpy array and return
+ return {
+ "NLL": nll_loss_tensor.detach().numpy(),
+ }
+
+ def forward(self, predictions):
+ """
+ Apply the appropriate transformations to the predicted parameters.
+
+ Parameters:
+ predictions (torch.Tensor): The predicted parameters of the distribution.
+
+ Returns:
+ torch.Tensor: A tensor with transformed parameters.
+ """
+ transformed_params = []
+ for idx, param_name in enumerate(self.param_names):
+ transform_func = self.get_transform(
+ getattr(self, f"{param_name}_transform", "none")
+ )
+ transformed_params.append(
+ transform_func(predictions[:, idx]).unsqueeze( # type: ignore
+ 1
+ ) # type: ignore
+ )
+ return torch.cat(transformed_params, dim=1)
+
+
+class NormalDistribution(BaseDistribution):
+ """
+ Represents a Normal (Gaussian) distribution with parameters for mean and variance,
+ including functionality for transforming these parameters and computing the loss.
+
+ Inherits from BaseDistribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution. Defaults to "Normal".
+ mean_transform (str or callable): The transformation for the mean parameter.
+ Defaults to "none".
+ var_transform (str or callable): The transformation for the variance parameter.
+ Defaults to "positive".
+ """
+
+ def __init__(self, name="Normal", mean_transform="none", var_transform="positive"):
+ param_names = [
+ "mean",
+ "variance",
+ ]
+ super().__init__(name, param_names)
+
+ self.mean_transform = self.get_transform(mean_transform)
+ self.variance_transform = self.get_transform(var_transform)
+
+ def compute_loss(self, predictions, y_true):
+ mean = self.mean_transform(predictions[:, self.param_names.index("mean")])
+ variance = self.variance_transform(
+ predictions[:, self.param_names.index("variance")]
+ )
+
+ normal_dist = dist.Normal(mean, variance)
+
+ nll = -normal_dist.log_prob(y_true).mean()
+ return nll
+
+ def evaluate_nll(self, y_true, y_pred):
+ metrics = super().evaluate_nll(y_true, y_pred)
+
+ # Convert numpy arrays to torch tensors
+ y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
+ y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
+
+ mse_loss = torch.nn.functional.mse_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]
+ )
+ rmse = np.sqrt(mse_loss.detach().numpy())
+ mae = (
+ torch.nn.functional.l1_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]
+ )
+ .detach()
+ .numpy()
+ )
+
+ metrics["mse"] = mse_loss.detach().numpy()
+ metrics["mae"] = mae
+ metrics["rmse"] = rmse
+
+ # Convert the NLL loss tensor back to a numpy array and return
+ return metrics
+
+
+class PoissonDistribution(BaseDistribution):
+ """
+ Represents a Poisson distribution, typically used for modeling count data or the number of events
+ occurring within a fixed interval of time or space. This class extends the BaseDistribution and
+ includes parameter transformation and loss computation specific to the Poisson distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "Poisson".
+ rate_transform (str or callable): Transformation to apply to the rate parameter
+ to ensure it remains positive.
+ """
+
+ def __init__(self, name="Poisson", rate_transform="positive"):
+ # Specify parameter name for Poisson distribution
+ param_names = ["rate"]
+ super().__init__(name, param_names)
+ # Retrieve transformation function for rate
+ self.rate_transform = self.get_transform(rate_transform)
+
+ def compute_loss(self, predictions, y_true):
+ rate = self.rate_transform(predictions[:, self.param_names.index("rate")])
+
+ # Define the Poisson distribution with the transformed parameter
+ poisson_dist = dist.Poisson(rate)
+
+ # Compute the negative log-likelihood
+ nll = -poisson_dist.log_prob(y_true).mean()
+ return nll
+
+ def evaluate_nll(self, y_true, y_pred):
+ metrics = super().evaluate_nll(y_true, y_pred)
+
+ # Convert numpy arrays to torch tensors
+ y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
+ y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
+ rate = self.rate_transform(y_pred_tensor[:, self.param_names.index("rate")])
+
+ mse_loss = torch.nn.functional.mse_loss(y_true_tensor, rate) # type: ignore
+ rmse = np.sqrt(mse_loss.detach().numpy())
+ mae = (
+ torch.nn.functional.l1_loss(y_true_tensor, rate) # type: ignore
+ .detach()
+ .numpy() # type: ignore
+ ) # type: ignore
+ poisson_deviance = 2 * torch.sum(
+ y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate)
+ )
+
+ metrics["mse"] = mse_loss.detach().numpy()
+ metrics["mae"] = mae
+ metrics["rmse"] = rmse
+ metrics["poisson_deviance"] = poisson_deviance.detach().numpy()
+
+ # Convert the NLL loss tensor back to a numpy array and return
+ return metrics
+
+
+class InverseGammaDistribution(BaseDistribution):
+ """
+ Represents an Inverse Gamma distribution, often used as a prior distribution in Bayesian statistics,
+ especially for scale parameters in other distributions. This class extends BaseDistribution and includes
+ parameter transformation and loss computation specific to the Inverse Gamma distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "InverseGamma".
+ shape_transform (str or callable): Transformation for the shape parameter to
+ ensure it remains positive.
+ scale_transform (str or callable): Transformation for the scale parameter to
+ ensure it remains positive.
+ """
+
+ def __init__(
+ self,
+ name="InverseGamma",
+ shape_transform="positive",
+ scale_transform="positive",
+ ):
+ param_names = [
+ "shape",
+ "scale",
+ ]
+ super().__init__(name, param_names)
+
+ self.shape_transform = self.get_transform(shape_transform)
+ self.scale_transform = self.get_transform(scale_transform)
+
+ def compute_loss(self, predictions, y_true):
+ shape = self.shape_transform(predictions[:, self.param_names.index("shape")])
+ scale = self.scale_transform(predictions[:, self.param_names.index("scale")])
+
+ inverse_gamma_dist = dist.InverseGamma(shape, scale)
+ # Compute the negative log-likelihood
+ nll = -inverse_gamma_dist.log_prob(y_true).mean()
+ return nll
+
+
+class BetaDistribution(BaseDistribution):
+ """
+ Represents a Beta distribution, a continuous distribution defined on the interval [0, 1], commonly used
+ in Bayesian statistics for modeling probabilities. This class extends BaseDistribution and includes parameter
+ transformation and loss computation specific to the Beta distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "Beta".
+ shape_transform (str or callable): Transformation for the alpha (shape) parameter to ensure
+ it remains positive.
+ scale_transform (str or callable): Transformation for the beta (scale) parameter to ensure
+ it remains positive.
+ """
+
+ def __init__(
+ self,
+ name="Beta",
+ shape_transform="positive",
+ scale_transform="positive",
+ ):
+ param_names = [
+ "alpha",
+ "beta",
+ ]
+ super().__init__(name, param_names)
+
+ self.alpha_transform = self.get_transform(shape_transform)
+ self.beta_transform = self.get_transform(scale_transform)
+
+ def compute_loss(self, predictions, y_true):
+ alpha = self.alpha_transform(predictions[:, self.param_names.index("alpha")])
+ beta = self.beta_transform(predictions[:, self.param_names.index("beta")])
+
+ beta_dist = dist.Beta(alpha, beta)
+ # Compute the negative log-likelihood
+ nll = -beta_dist.log_prob(y_true).mean()
+ return nll
+
+
+class DirichletDistribution(BaseDistribution):
+ """
+ Represents a Dirichlet distribution, a multivariate generalization of the Beta distribution. It is commonly
+ used in Bayesian statistics for modeling multinomial distribution probabilities. This class extends
+ BaseDistribution and includes parameter transformation and loss computation
+ specific to the Dirichlet distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "Dirichlet".
+ concentration_transform (str or callable): Transformation to apply to
+ concentration parameters to ensure they remain positive.
+ """
+
+ def __init__(self, name="Dirichlet", concentration_transform="positive"):
+ # For Dirichlet, param_names could be dynamically set based on the dimensionality of alpha
+ # For simplicity, we're not specifying individual names for each concentration parameter
+ param_names = ["concentration"] # This is a simplification
+ super().__init__(name, param_names)
+ # Retrieve transformation function for concentration parameters
+ self.concentration_transform = self.get_transform(concentration_transform)
+
+ def compute_loss(self, predictions, y_true):
+ # Apply the transformation to ensure all concentration parameters are positive
+ # Assuming predictions is a 2D tensor where each row is a set of concentration parameters
+ # for a Dirichlet distribution
+ concentration = self.concentration_transform(predictions)
+
+ dirichlet_dist = dist.Dirichlet(concentration)
+
+ nll = -dirichlet_dist.log_prob(y_true).mean()
+ return nll
+
+
+class GammaDistribution(BaseDistribution):
+ """
+ Represents a Gamma distribution, a two-parameter family of continuous probability distributions. It's
+ widely used in various fields of science for modeling a wide range of phenomena. This class extends
+ BaseDistribution and includes parameter transformation and loss computation specific to
+ the Gamma distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "Gamma".
+ shape_transform (str or callable): Transformation for the shape parameter to ensure it remains positive.
+ rate_transform (str or callable): Transformation for the rate parameter to ensure it remains positive.
+ """
+
+ def __init__(
+ self, name="Gamma", shape_transform="positive", rate_transform="positive"
+ ):
+ param_names = ["shape", "rate"]
+ super().__init__(name, param_names)
+
+ self.shape_transform = self.get_transform(shape_transform)
+ self.rate_transform = self.get_transform(rate_transform)
+
+ def compute_loss(self, predictions, y_true):
+ shape = self.shape_transform(predictions[:, self.param_names.index("shape")])
+ rate = self.rate_transform(predictions[:, self.param_names.index("rate")])
+
+ # Define the Gamma distribution with the transformed parameters
+ gamma_dist = dist.Gamma(shape, rate)
+
+ # Compute the negative log-likelihood
+ nll = -gamma_dist.log_prob(y_true).mean()
+ return nll
+
+
+class StudentTDistribution(BaseDistribution):
+ """
+ Represents a Student's t-distribution, a family of continuous probability distributions that arise when
+ estimating the mean of a normally distributed population in situations where the sample size is small.
+ This class extends BaseDistribution and includes parameter transformation and loss computation specific
+ to the Student's t-distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "StudentT".
+ df_transform (str or callable): Transformation for the degrees of freedom parameter
+ to ensure it remains positive.
+ loc_transform (str or callable): Transformation for the location parameter.
+ scale_transform (str or callable): Transformation for the scale parameter
+ to ensure it remains positive.
+ """
+
+ def __init__(
+ self,
+ name="StudentT",
+ df_transform="positive",
+ loc_transform="none",
+ scale_transform="positive",
+ ):
+ param_names = ["df", "loc", "scale"]
+ super().__init__(name, param_names)
+
+ self.df_transform = self.get_transform(df_transform)
+ self.loc_transform = self.get_transform(loc_transform)
+ self.scale_transform = self.get_transform(scale_transform)
+
+ def compute_loss(self, predictions, y_true):
+ df = self.df_transform(predictions[:, self.param_names.index("df")])
+ loc = self.loc_transform(predictions[:, self.param_names.index("loc")])
+ scale = self.scale_transform(predictions[:, self.param_names.index("scale")])
+
+ student_t_dist = dist.StudentT(df, loc, scale) # type: ignore
+
+ nll = -student_t_dist.log_prob(y_true).mean()
+ return nll
+
+ def evaluate_nll(self, y_true, y_pred):
+ metrics = super().evaluate_nll(y_true, y_pred)
+
+ # Convert numpy arrays to torch tensors
+ y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
+ y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
+
+ mse_loss = torch.nn.functional.mse_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]
+ )
+ rmse = np.sqrt(mse_loss.detach().numpy())
+ mae = (
+ torch.nn.functional.l1_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]
+ )
+ .detach()
+ .numpy()
+ )
+
+ metrics["mse"] = mse_loss.detach().numpy()
+ metrics["mae"] = mae
+ metrics["rmse"] = rmse
+
+ # Convert the NLL loss tensor back to a numpy array and return
+ return metrics
+
+
+class NegativeBinomialDistribution(BaseDistribution):
+ """
+ Represents a Negative Binomial distribution, often used for count data and modeling the number
+ of failures before a specified number of successes occurs in a series of Bernoulli trials.
+ This class extends BaseDistribution and includes parameter transformation and loss computation
+ specific to the Negative Binomial distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "NegativeBinomial".
+ mean_transform (str or callable): Transformation for the mean parameter to ensure it remains positive.
+ dispersion_transform (str or callable): Transformation for the dispersion parameter to
+ ensure it remains positive.
+ """
+
+ def __init__(
+ self,
+ name="NegativeBinomial",
+ mean_transform="positive",
+ dispersion_transform="positive",
+ ):
+ param_names = ["mean", "dispersion"]
+ super().__init__(name, param_names)
+
+ self.mean_transform = self.get_transform(mean_transform)
+ self.dispersion_transform = self.get_transform(dispersion_transform)
+
+ def compute_loss(self, predictions, y_true):
+ # Apply transformations to ensure mean and dispersion parameters are positive
+ mean = self.mean_transform(predictions[:, self.param_names.index("mean")])
+ dispersion = self.dispersion_transform(
+ predictions[:, self.param_names.index("dispersion")]
+ )
+
+ # Calculate the probability (p) and number of successes (r) from mean and dispersion
+ # These calculations follow from the mean and variance of the negative binomial distribution
+ # where variance = mean + mean^2 / dispersion
+ r = torch.tensor(1.0) / dispersion
+ p = r / (r + mean)
+
+ # Define the Negative Binomial distribution with the transformed parameters
+ negative_binomial_dist = dist.NegativeBinomial(total_count=r, probs=p)
+
+ # Compute the negative log-likelihood
+ nll = -negative_binomial_dist.log_prob(y_true).mean()
+ return nll
+
+
+class CategoricalDistribution(BaseDistribution):
+ """
+ Represents a Categorical distribution, a discrete distribution that describes the possible results of a
+ random variable that can take on one of K possible categories, with the probability of each category
+ separately specified. This class extends BaseDistribution and includes parameter transformation and loss
+ computation specific to the Categorical distribution.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution, defaulted to "Categorical".
+ prob_transform (str or callable): Transformation for the probabilities to ensure
+ they remain valid (i.e., non-negative and sum to 1).
+ """
+
+ def __init__(self, name="Categorical", prob_transform="probabilities"):
+ # Specify parameter name for Poisson distribution
+ param_names = ["probs"]
+ super().__init__(name, param_names)
+ # Retrieve transformation function for rate
+ self.probs_transform = self.get_transform(prob_transform)
+
+ def compute_loss(self, predictions, y_true):
+ probs = self.probs_transform(predictions)
+
+ # Define the Poisson distribution with the transformed parameter
+ cat_dist = dist.Categorical(probs=probs)
+
+ # Compute the negative log-likelihood
+ nll = -cat_dist.log_prob(y_true).mean()
+ return nll
+
+
+class Quantile(BaseDistribution):
+ """
+ Quantile Regression Loss class.
+
+ This class computes the quantile loss (also known as pinball loss) for a set of quantiles.
+ It is used to handle quantile regression tasks where we aim to predict a given quantile of the target distribution.
+
+ Parameters
+ ----------
+ name : str, optional
+ The name of the distribution, by default "Quantile".
+ quantiles : list of float, optional
+ A list of quantiles to be used for computing the loss, by default [0.25, 0.5, 0.75].
+
+ Attributes
+ ----------
+ quantiles : list of float
+ List of quantiles for which the pinball loss is computed.
+
+ Methods
+ -------
+ compute_loss(predictions, y_true)
+ Computes the quantile regression loss between the predictions and true values.
+ """
+
+ def __init__(self, name="Quantile", quantiles=[0.25, 0.5, 0.75]):
+ # Use string representations of quantiles
+ param_names = [f"q_{q}" for q in quantiles]
+ super().__init__(name, param_names)
+ self.quantiles = quantiles
+
+ def compute_loss(self, predictions, y_true):
+ if y_true.requires_grad:
+ raise ValueError("y_true should not require gradients")
+ if predictions.size(0) != y_true.size(0):
+ raise ValueError("Batch size of predictions and y_true must match")
+
+ losses = []
+ for i, q in enumerate(self.quantiles):
+ # Calculate errors for each quantile
+ errors = y_true - predictions[:, i]
+ # Compute the pinball loss
+ quantile_loss = torch.max((q - 1) * errors, q * errors)
+ losses.append(quantile_loss)
+
+ # Sum losses across quantiles and compute mean
+ loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1))
+ return loss
+
+
+class JohnsonSuDistribution(BaseDistribution):
+ """
+ Represents a Johnson's SU distribution with parameters for skewness, shape, location, and scale.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution. Defaults to "JohnsonSu".
+ skew_transform (str or callable): The transformation for the skewness parameter. Defaults to "none".
+ shape_transform (str or callable): The transformation for the shape parameter. Defaults to "positive".
+ loc_transform (str or callable): The transformation for the location parameter. Defaults to "none".
+ scale_transform (str or callable): The transformation for the scale parameter. Defaults to "positive".
+ """
+
+ def __init__(
+ self,
+ name="JohnsonSu",
+ skew_transform="none",
+ shape_transform="positive",
+ loc_transform="none",
+ scale_transform="positive",
+ ):
+ param_names = ["skew", "shape", "location", "scale"]
+ super().__init__(name, param_names)
+
+ self.skew_transform = self.get_transform(skew_transform)
+ self.shape_transform = self.get_transform(shape_transform)
+ self.loc_transform = self.get_transform(loc_transform)
+ self.scale_transform = self.get_transform(scale_transform)
+
+ def log_prob(self, x, skew, shape, loc, scale):
+ """
+ Compute the log probability density of the Johnson's SU distribution.
+ """
+ z = skew + shape * torch.asinh((x - loc) / scale)
+ log_pdf = (
+ torch.log(shape / (scale * np.sqrt(2 * np.pi)))
+ - 0.5 * z**2
+ - 0.5 * torch.log(1 + ((x - loc) / scale) ** 2)
+ )
+ return log_pdf
+
+ def compute_loss(self, predictions, y_true):
+ skew = self.skew_transform(predictions[:, self.param_names.index("skew")])
+ shape = self.shape_transform(predictions[:, self.param_names.index("shape")])
+ loc = self.loc_transform(predictions[:, self.param_names.index("location")])
+ scale = self.scale_transform(predictions[:, self.param_names.index("scale")])
+
+ log_probs = self.log_prob(y_true, skew, shape, loc, scale)
+ nll = -log_probs.mean()
+ return nll
+
+ def evaluate_nll(self, y_true, y_pred):
+ metrics = super().evaluate_nll(y_true, y_pred)
+
+ y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
+ y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
+
+ mse_loss = torch.nn.functional.mse_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]
+ )
+ rmse = np.sqrt(mse_loss.detach().numpy())
+ mae = (
+ torch.nn.functional.l1_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]
+ )
+ .detach()
+ .numpy()
+ )
+
+ metrics.update({"mse": mse_loss.detach().numpy(), "mae": mae, "rmse": rmse})
+
+ return metrics
diff --git a/deeptabular/utils/docstring_generator.py b/deeptab/utils/docstring_generator.py
similarity index 100%
rename from deeptabular/utils/docstring_generator.py
rename to deeptab/utils/docstring_generator.py
diff --git a/deeptabular/utils/get_feature_dimensions.py b/deeptab/utils/get_feature_dimensions.py
similarity index 100%
rename from deeptabular/utils/get_feature_dimensions.py
rename to deeptab/utils/get_feature_dimensions.py
diff --git a/docs/api/base_models/BaseModels.rst b/docs/api/base_models/BaseModels.rst
index abfaae57..ccd41d80 100644
--- a/docs/api/base_models/BaseModels.rst
+++ b/docs/api/base_models/BaseModels.rst
@@ -1,35 +1,35 @@
-deeptabular.base_models
+deeptab.base_models
=======================
-.. autoclass:: deeptabular.base_models.BaseModel
+.. autoclass:: deeptab.base_models.BaseModel
:members:
:no-inherited-members:
-.. autoclass:: deeptabular.base_models.TaskModel
+.. autoclass:: deeptab.base_models.TaskModel
:members:
:no-inherited-members:
-.. autoclass:: deeptabular.base_models.Mambular
+.. autoclass:: deeptab.base_models.Mambular
:members:
:no-inherited-members:
-.. autoclass:: deeptabular.base_models.MLP
+.. autoclass:: deeptab.base_models.MLP
:members:
:no-inherited-members:
-.. autoclass:: deeptabular.base_models.ResNet
+.. autoclass:: deeptab.base_models.ResNet
:members:
:no-inherited-members:
-.. autoclass:: deeptabular.base_models.FTTransformer
+.. autoclass:: deeptab.base_models.FTTransformer
:members:
:no-inherited-members:
-.. autoclass:: deeptabular.base_models.TabTransformer
+.. autoclass:: deeptab.base_models.TabTransformer
:members:
:no-inherited-members:
-.. autoclass:: deeptabular.base_models.TabulaRNN
+.. autoclass:: deeptab.base_models.TabulaRNN
:members:
:no-inherited-members:
diff --git a/docs/api/configs/index.rst b/docs/api/configs/index.rst
index afc38374..2f5af1c1 100644
--- a/docs/api/configs/index.rst
+++ b/docs/api/configs/index.rst
@@ -1,11 +1,11 @@
.. -*- mode: rst -*-
-.. currentmodule:: deeptabular.configs
+.. currentmodule:: deeptab.configs
Configurations
==============
-This module provides default configurations for DeepTabular models. Each configuration is implemented as a dataclass, offering a structured way to define model-specific hyperparameters.
+This module provides default configurations for deeptab models. Each configuration is implemented as a dataclass, offering a structured way to define model-specific hyperparameters.
Mambular
--------
diff --git a/docs/api/data_utils/Datautils.rst b/docs/api/data_utils/Datautils.rst
index 3ad6410b..c434a4cc 100644
--- a/docs/api/data_utils/Datautils.rst
+++ b/docs/api/data_utils/Datautils.rst
@@ -1,8 +1,8 @@
-deeptabular.data_utils
+deeptab.data_utils
======================
-.. autoclass:: deeptabular.data_utils.MambularDataset
+.. autoclass:: deeptab.data_utils.MambularDataset
:members:
-.. autoclass:: deeptabular.data_utils.MambularDataModule
+.. autoclass:: deeptab.data_utils.MambularDataModule
:members:
diff --git a/docs/api/data_utils/index.rst b/docs/api/data_utils/index.rst
index e617bc11..4edf8b67 100644
--- a/docs/api/data_utils/index.rst
+++ b/docs/api/data_utils/index.rst
@@ -1,6 +1,6 @@
.. -*- mode: rst -*-
-.. currentmodule:: deeptabular.data_utils
+.. currentmodule:: deeptab.data_utils
Data Utils
==========
diff --git a/docs/conf.py b/docs/conf.py
index 2484be0d..8a21f912 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -11,13 +11,13 @@
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../"))
-sys.path.insert(1, os.path.dirname(os.path.abspath("../")) + os.sep + "deeptabular")
+sys.path.insert(1, os.path.dirname(os.path.abspath("../")) + os.sep + "deeptab")
-project = "deeptabular"
+project = "deeptab"
project_copyright = "2024, BASF SE"
author = "Anton Frederik Thielmann, Manish Kumar, Christoph Weisser, Benjamin Saefken, Soheila Samiee"
-VERSION_PATH = "../deeptabular/__version__.py"
+VERSION_PATH = "../deeptab/__version__.py"
with open(VERSION_PATH) as f:
lines = f.readlines()
for line in lines:
diff --git a/docs/examples/classification.md b/docs/examples/classification.md
index 3ff09b2c..3ab7fea2 100644
--- a/docs/examples/classification.md
+++ b/docs/examples/classification.md
@@ -1,12 +1,12 @@
# Classification
-This example demonstrates how use Classification module from the `deeptabular` package.
+This example demonstrates how use Classification module from the `deeptab` package.
```python
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
-from deeptabular.models import MambularClassifier
+from deeptab.models import MambularClassifier
# Set random seed for reproducibility
np.random.seed(0)
```
diff --git a/docs/examples/distributional.rst b/docs/examples/distributional.rst
index 5dd25b2e..b9050eb1 100644
--- a/docs/examples/distributional.rst
+++ b/docs/examples/distributional.rst
@@ -3,7 +3,7 @@
Distributional
==============
-This example demonstrates how use Distributional from the `deeptabular` package.
+This example demonstrates how use Distributional from the `deeptab` package.
.. literalinclude:: ../../examples/example_distributional.py
diff --git a/docs/examples/regression.rst b/docs/examples/regression.rst
index 1dcdc367..e8e6dd98 100644
--- a/docs/examples/regression.rst
+++ b/docs/examples/regression.rst
@@ -3,7 +3,7 @@
Regression
==========
-This example demonstrates how use Regression module from the `deeptabular` package.
+This example demonstrates how use Regression module from the `deeptab` package.
.. literalinclude:: ../../examples/example_regression.py
diff --git a/docs/homepage.md b/docs/homepage.md
index b9a3114c..43201091 100644
--- a/docs/homepage.md
+++ b/docs/homepage.md
@@ -1,13 +1,13 @@
-# DeepTabular: Tabular Deep Learning Made Simple
+# deeptab: Tabular Deep Learning Made Simple
-DeepTabular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
+deeptab is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
# 🏃 Quickstart
-Similar to any sklearn model, DeepTabular models can be fit as easy as this:
+Similar to any sklearn model, deeptab models can be fit as easy as this:
```python
-from deeptabular.models import MambularClassifier
+from deeptab.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier()
@@ -16,7 +16,7 @@ model.fit(X, y, max_epochs=150, lr=1e-04)
```
# 📖 Introduction
-DeepTabular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, DeepTabular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using DeepTabular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
+deeptab is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, deeptab models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using deeptab models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
# 🤖 Models
@@ -44,13 +44,13 @@ Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `
# 📚 Documentation
-You can find the DeepTabular API documentation [here](https://deeptabular.readthedocs.io/en/latest/).
+You can find the deeptab API documentation [here](https://deeptab.readthedocs.io/en/latest/).
# 🛠️ Installation
-Install DeepTabular using pip:
+Install deeptab using pip:
```sh
-pip install deeptabular
+pip install deeptab
```
If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via:
@@ -70,7 +70,7 @@ pip install mamba-ssm
Preprocessing
-DeepTabular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
+deeptab simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
Data Type Detection and Transformation
@@ -90,10 +90,10 @@ DeepTabular simplifies data preprocessing with a range of tools designed for eas
Fit a Model
-Fitting a model in deeptabular is as simple as it gets. All models in deeptabular are sklearn BaseEstimators. Thus, the `fit` method is implemented for all of them. Additionally, this allows for using all other sklearn inherent methods such as their built in hyperparameter optimization tools.
+Fitting a model in deeptab is as simple as it gets. All models in deeptab are sklearn BaseEstimators. Thus, the `fit` method is implemented for all of them. Additionally, this allows for using all other sklearn inherent methods such as their built in hyperparameter optimization tools.
```python
-from deeptabular.models import MambularClassifier
+from deeptab.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier(
d_model=64,
@@ -171,12 +171,12 @@ Or use the built-in bayesian hpo simply by running:
best_params = model.optimize_hparams(X, y)
```
-This automatically sets the search space based on the default config from ``deeptabular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
+This automatically sets the search space based on the default config from ``deeptab.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
⚖️ Distributional Regression with MambularLSS
-MambularLSS allows you to model the full distribution of a response variable, not just its mean. This is crucial when understanding variability, skewness, or kurtosis is important. All DeepTabular models are available as distributional models.
+MambularLSS allows you to model the full distribution of a response variable, not just its mean. This is crucial when understanding variability, skewness, or kurtosis is important. All deeptab models are available as distributional models.
Key Features of MambularLSS:
@@ -203,10 +203,10 @@ These distribution classes make MambularLSS versatile in modeling various data t
Getting Started with MambularLSS:
-To integrate distributional regression into your workflow with `MambularLSS`, start by initializing the model with your desired configuration, similar to other DeepTabular models:
+To integrate distributional regression into your workflow with `MambularLSS`, start by initializing the model with your desired configuration, similar to other deeptab models:
```python
-from deeptabular.models import MambularLSS
+from deeptab.models import MambularLSS
# Initialize the MambularLSS model
model = MambularLSS(
@@ -231,11 +231,11 @@ model.fit(
# 💻 Implement Your Own Model
-DeepTabular allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from DeepTabular's `BaseModel`. Each DeepTabular model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
+deeptab allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from deeptab's `BaseModel`. Each deeptab model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
-One of the key advantages of using DeepTabular is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
+One of the key advantages of using deeptab is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
-Here's how you can implement a custom model with DeepTabular:
+Here's how you can implement a custom model with deeptab:
1. **First, define your config:**
The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass.
@@ -255,8 +255,8 @@ Here's how you can implement a custom model with DeepTabular:
Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.
```python
- from deeptabular.base_models import BaseModel
- from deeptabular.utils.get_feature_dimensions import get_feature_dimensions
+ from deeptab.base_models import BaseModel
+ from deeptab.utils.get_feature_dimensions import get_feature_dimensions
import torch
import torch.nn
@@ -285,11 +285,11 @@ Here's how you can implement a custom model with DeepTabular:
return output
```
-3. **Leverage the DeepTabular API:**
- You can build a regression, classification, or distributional regression model that can leverage all of DeepTabular's built-in methods by using the following:
+3. **Leverage the deeptab API:**
+ You can build a regression, classification, or distributional regression model that can leverage all of deeptab's built-in methods by using the following:
```python
- from deeptabular.models import SklearnBaseRegressor
+ from deeptab.models import SklearnBaseRegressor
class MyRegressor(SklearnBaseRegressor):
def __init__(self, **kwargs):
@@ -297,7 +297,7 @@ Here's how you can implement a custom model with DeepTabular:
```
4. **Train and evaluate your model:**
- You can now fit, evaluate, and predict with your custom model just like with any other DeepTabular model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively.
+ You can now fit, evaluate, and predict with your custom model just like with any other deeptab model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively.
```python
regressor = MyRegressor(numerical_preprocessing="ple")
@@ -306,15 +306,15 @@ Here's how you can implement a custom model with DeepTabular:
# Custom Training
-If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `deeptabular.base_models`.
+If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `deeptab.base_models`.
Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.
```python
import torch
import torch.nn as nn
import torch.optim as optim
-from deeptabular.base_models import Mambular
-from deeptabular.configs import DefaultMambularConfig
+from deeptab.base_models import Mambular
+from deeptab.configs import DefaultMambularConfig
# Dummy data and configuration
cat_feature_info = {
diff --git a/docs/installation.md b/docs/installation.md
index 5cdc67b7..43647db3 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,11 +1,11 @@
## Installation
-Please follow the steps below for installing `deeptabular`.
+Please follow the steps below for installing `deeptab`.
### Install from the source:
```bash
-cd DeepTabular
+cd deeptab
poetry install
```
@@ -16,7 +16,7 @@ Note: Make sure you in the same directory where `pyproject.toml` file resides.
The package is available on PyPi. You can install it using the following command:
```bash
-pip install -U deeptabular
+pip install -U deeptab
```
-PyPi link: [deeptabular](https://pypi.org/project/deeptabular/)
+PyPi link: [deeptab](https://pypi.org/project/deeptab/)
diff --git a/docs/release.md b/docs/release.md
index 2ea0df8c..07a8f0ea 100644
--- a/docs/release.md
+++ b/docs/release.md
@@ -1,19 +1,19 @@
# Build and release
-The document outlines the steps to build and release the `deeptabular` package. At this point, it is assumed that the development and testing of the package have been completed successfully.
+The document outlines the steps to build and release the `deeptab` package. At this point, it is assumed that the development and testing of the package have been completed successfully.
## 1. Test documentation
It is expected from the contributor to update the documentation as an when required along side the change in source code. Please use the below process to test the documentation:
```sh
-cd DeepTabular/docs/
+cd deeptab/docs/
make doctest
```
Fix any docstring related issue, then proceed with next steps.
## 2. Version update
-The package version is mantained in `deeptabular/__version__.py` and `pyproject.toml` file. Increment the version according to the changes such as patch, minor, major or all.
+The package version is mantained in `deeptab/__version__.py` and `pyproject.toml` file. Increment the version according to the changes such as patch, minor, major or all.
- The version number should be in the format `major.minor.patch`. For example, `1.0.1`.
@@ -24,7 +24,7 @@ The package version is mantained in `deeptabular/__version__.py` and `pyproject.
- Create a pull request from your `feature` branch to the `develop` branch.
- Once the pull request is approved and merged to develop. The maintainer will test the package and documentation. If everything is fine, the maintainer will proceed further to merge the changed to `master` and `release` branch.
-- Ideally content of `master` and `release` branch should be same. The `release` branch is used to publish the package to PyPi while `master` branch is used to publish the documentation to readthedocs and can be accesseed at [deeptabular.readthedocs.io](https://deeptabular.readthedocs.io/en/latest/).
+- Ideally content of `master` and `release` branch should be same. The `release` branch is used to publish the package to PyPi while `master` branch is used to publish the documentation to readthedocs and can be accesseed at [deeptab.readthedocs.io](https://deeptab.readthedocs.io/en/latest/).
## 4. Publish package to PyPi
diff --git a/examples/example_classification.py b/examples/example_classification.py
index e07f7f94..e69a6ac0 100644
--- a/examples/example_classification.py
+++ b/examples/example_classification.py
@@ -2,7 +2,7 @@
import pandas as pd
from sklearn.model_selection import train_test_split
-from deeptabular.models import MambularClassifier
+from deeptab.models import MambularClassifier
# Set random seed for reproducibility
np.random.seed(0)
diff --git a/examples/example_distributional.py b/examples/example_distributional.py
index a58e79c2..e3e226f5 100644
--- a/examples/example_distributional.py
+++ b/examples/example_distributional.py
@@ -3,7 +3,7 @@
import pandas as pd
from sklearn.model_selection import train_test_split
-from deeptabular.models import MambularLSS
+from deeptab.models import MambularLSS
# Set random seed for reproducibility
np.random.seed(0)
diff --git a/examples/example_regression.py b/examples/example_regression.py
index e3786387..49b951df 100644
--- a/examples/example_regression.py
+++ b/examples/example_regression.py
@@ -3,7 +3,7 @@
import pandas as pd
from sklearn.model_selection import train_test_split
-from deeptabular.models import MambularRegressor
+from deeptab.models import MambularRegressor
# Set random seed for reproducibility
np.random.seed(0)
diff --git a/pyproject.toml b/pyproject.toml
index ea99383d..487441c1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,13 +1,13 @@
[tool.poetry]
-name = "deeptabular"
+name = "deeptab"
version = "1.6.0"
-description = "A python package for tabular deep learning with mamba blocks."
+description = "A python package for tabular deep learning."
authors = ["Anton Thielmann", "Manish Kumar", "Christoph Weisser"]
readme = "README.md"
-packages = [{ include = "deeptabular" }]
+packages = [{ include = "deeptab" }]
[build-system]
requires = ["poetry-core"]
@@ -38,10 +38,10 @@ docformatter = "^1.4"
[tool.poetry.urls]
-homepage = "https://github.com/OpenTabular/DeepTabular"
-documentation = "https://deeptabular.readthedocs.io/"
-repository = "https://github.com/OpenTabular/DeepTabular"
-package = "https://pypi.org/project/deeptabular/"
+homepage = "https://github.com/OpenTabular/deeptab"
+documentation = "https://deeptab.readthedocs.io/"
+repository = "https://github.com/OpenTabular/deeptab"
+package = "https://pypi.org/project/deeptab/"
# code quality tools
@@ -57,7 +57,7 @@ venv = ".venv"
[tool.ruff]
line-length = 120
target-version = "py310"
-exclude = ["*.ipynb", "deeptabular/arch_utils/mamba_utils.mamba_orginal.py"]
+exclude = ["*.ipynb", "deeptab/arch_utils/mamba_utils.mamba_orginal.py"]
[tool.ruff.lint]
select = [
diff --git a/tests/test_base.py b/tests/test_base.py
index ba248d55..3af042d2 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -3,16 +3,16 @@
import torch
import os
import importlib
-from deeptabular.base_models.utils import BaseModel
+from deeptab.base_models.utils import BaseModel
# Paths for models and configs
-MODEL_MODULE_PATH = "deeptabular.base_models"
-CONFIG_MODULE_PATH = "deeptabular.configs"
+MODEL_MODULE_PATH = "deeptab.base_models"
+CONFIG_MODULE_PATH = "deeptab.configs"
EXCLUDED_CLASSES = {"TabR"}
# Discover all models
model_classes = []
-for filename in os.listdir(os.path.dirname(__file__) + "/../deeptabular/base_models"):
+for filename in os.listdir(os.path.dirname(__file__) + "/../deeptab/base_models"):
if filename.endswith(".py") and filename not in [
"__init__.py",
"basemodel.py",
diff --git a/tests/test_configs.py b/tests/test_configs.py
index 9fb557c1..6b936724 100644
--- a/tests/test_configs.py
+++ b/tests/test_configs.py
@@ -4,13 +4,13 @@
import os
import dataclasses
import typing
-from deeptabular.configs.base_config import BaseConfig # Ensure correct path
+from deeptab.configs.base_config import BaseConfig # Ensure correct path
-CONFIG_MODULE_PATH = "deeptabular.configs"
+CONFIG_MODULE_PATH = "deeptab.configs"
config_classes = []
-# Discover all config classes in deeptabular/configs/
-for filename in os.listdir(os.path.dirname(__file__) + "/../deeptabular/configs"):
+# Discover all config classes in deeptab/configs/
+for filename in os.listdir(os.path.dirname(__file__) + "/../deeptab/configs"):
if (
filename.endswith(".py")
and filename != "base_config.py"