-
Notifications
You must be signed in to change notification settings - Fork 19
Description
Describe the bug
A clear and concise description of what the bug is.
To Reproduce
from mambular.models import MambaTabClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2 )
print( X.shape )
X_tr, X_val, y_tr, y_val = train_test_split( X, y, test_size=0.25 )
print( X_tr.shape, X_val.shape )
model = MambaTabClassifier()
model.fit(X_tr,y_tr, X_val=X_val, y_val=y_val)
Expected behavior
The model train explicitely with the train and valid datasets i provide and do not split the data provided
Screenshots
ValueError Traceback (most recent call last)
Cell In[84], line 12
9 print( X_tr.shape, X_val.shape )
11 model = MambaTabClassifier()
---> 12 model.fit(X_tr,y_tr, X_val=X_val, y_val=y_val)
File /opt/conda/lib/python3.10/site-packages/mambular/models/sklearn_base_classifier.py:324, in SklearnBaseClassifier.fit(self, X, y, val_size, X_val, y_val, max_epochs, random_state, batch_size, shuffle, patience, monitor, mode, lr, lr_patience, factor, weight_decay, checkpoint_path, dataloader_kwargs, rebuild, **trainer_kwargs)
322 if isinstance(y, pd.Series):
323 y = y.values
--> 324 if X_val:
325 if not isinstance(X_val, pd.DataFrame):
326 X_val = pd.DataFrame(X_val)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Desktop (please complete the following information):
Kaggle notebook
Additional context
Same thing for others mambular models using the sklearn wrapper