diff --git a/mambular/__version__.py b/mambular/__version__.py index fcfe670f..43b979e7 100644 --- a/mambular/__version__.py +++ b/mambular/__version__.py @@ -17,5 +17,4 @@ # The following line *must* be the last in the module, exactly as formatted: -__version__ = "1.3.1" - +__version__ = "1.3.2" diff --git a/mambular/models/utils/sklearn_base_classifier.py b/mambular/models/utils/sklearn_base_classifier.py index 1c021a26..a92d906d 100644 --- a/mambular/models/utils/sklearn_base_classifier.py +++ b/mambular/models/utils/sklearn_base_classifier.py @@ -4,6 +4,7 @@ import torch from sklearn.metrics import accuracy_score, log_loss from .sklearn_parent import SklearnBase +import numpy as np class SklearnBaseClassifier(SklearnBase): @@ -85,6 +86,8 @@ def build_model( The built classifier. """ + num_classes = len(np.unique(y)) + return super()._build_model( X, y, @@ -94,6 +97,7 @@ def build_model( y_val=y_val, embeddings=embeddings, embeddings_val=embeddings_val, + num_classes=num_classes, random_state=random_state, batch_size=batch_size, shuffle=shuffle, @@ -190,6 +194,7 @@ def fit( The fitted classifier. """ + num_classes = len(np.unique(y)) return super().fit( X=X, y=y, @@ -215,6 +220,7 @@ def fit( train_metrics=train_metrics, val_metrics=val_metrics, rebuild=rebuild, + num_classes=num_classes, **trainer_kwargs, ) diff --git a/mambular/models/utils/sklearn_base_regressor.py b/mambular/models/utils/sklearn_base_regressor.py index 963bb3a9..426ff5fa 100644 --- a/mambular/models/utils/sklearn_base_regressor.py +++ b/mambular/models/utils/sklearn_base_regressor.py @@ -93,6 +93,7 @@ def build_model( y_val=y_val, embeddings=embeddings, embeddings_val=embeddings_val, + num_classes=1, random_state=random_state, batch_size=batch_size, shuffle=shuffle, @@ -198,6 +199,7 @@ def fit( y_val=y_val, embeddings=embeddings, embeddings_val=embeddings_val, + num_classes=1, max_epochs=max_epochs, random_state=random_state, batch_size=batch_size, diff --git a/mambular/models/utils/sklearn_parent.py b/mambular/models/utils/sklearn_parent.py index 812db286..797a8259 100644 --- a/mambular/models/utils/sklearn_parent.py +++ b/mambular/models/utils/sklearn_parent.py @@ -120,6 +120,7 @@ def _build_model( y_val=None, embeddings=None, embeddings_val=None, + num_classes: int = None, random_state: int = 101, batch_size: int = 128, shuffle: bool = True, @@ -223,6 +224,7 @@ def _build_model( weight_decay=( weight_decay if weight_decay is not None else self.config.weight_decay ), + num_classes=num_classes, train_metrics=train_metrics, val_metrics=val_metrics, optimizer_type=self.optimizer_type, @@ -273,6 +275,7 @@ def fit( y_val=None, embeddings=None, embeddings_val=None, + num_classes: int = None, max_epochs: int = 100, random_state: int = 101, batch_size: int = 128, @@ -357,6 +360,7 @@ def fit( y_val=y_val, embeddings=embeddings, embeddings_val=embeddings_val, + num_classes=num_classes, random_state=random_state, batch_size=batch_size, shuffle=shuffle, diff --git a/pyproject.toml b/pyproject.toml index 99186db9..d1fdb8a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "mambular" -version = "1.3.1" +version = "1.3.2" description = "A python package for tabular deep learning with mamba blocks." authors = ["Anton Thielmann", "Manish Kumar", "Christoph Weisser"]