Skip to content

Commit

Permalink
update example multiclass
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 28, 2024
1 parent 86c3047 commit f439bac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/multiclass_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,6 @@
# Evaluate models
print("\nClassification Report for Standard:")
print(classification_report(y_test, y_pred_standard_label))

print("\nClassification Report for Imbalanced:")
print(classification_report(y_test, y_pred_focal_label))
42 changes: 42 additions & 0 deletions examples/multiclass_sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

import imlightgbm as imlgb

# Generate dataset
X, y = make_classification(
n_samples=5000,
n_features=10,
n_classes=3,
n_informative=5,
weights=[0.05, 0.15, 0.8],
flip_y=0,
random_state=42,
)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)

# Initialize the ImbalancedLGBMClassifier using binary focal loss
clf = imlgb.ImbalancedLGBMClassifier(
objective="multiclass_focal", # multiclass_weighted
gamma=2.0, # alpha with multiclass_weighted
num_class=3,
learning_rate=0.05,
num_leaves=31,
)

# Train the classifier on the training data
clf.fit(X=X_train, y=y_train)

# Make predictions on the test data
y_pred_focal = clf.predict(X_test)


# Evaluate the model performance using accuracy, log loss, and ROC AUC
# Evaluate models
print("\nClassification Report:")
print(classification_report(y_test, y_pred_focal))

0 comments on commit f439bac

Please sign in to comment.