From f439bac315ef0bc2efa5b744c4c4d7f35e9a6945 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sat, 28 Sep 2024 09:14:14 +0900 Subject: [PATCH] update example multiclass --- examples/multiclass_engine.py | 1 + examples/multiclass_sklearn.py | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 examples/multiclass_sklearn.py diff --git a/examples/multiclass_engine.py b/examples/multiclass_engine.py index 2c90601..02c74b3 100644 --- a/examples/multiclass_engine.py +++ b/examples/multiclass_engine.py @@ -77,5 +77,6 @@ # Evaluate models print("\nClassification Report for Standard:") print(classification_report(y_test, y_pred_standard_label)) + print("\nClassification Report for Imbalanced:") print(classification_report(y_test, y_pred_focal_label)) diff --git a/examples/multiclass_sklearn.py b/examples/multiclass_sklearn.py new file mode 100644 index 0000000..9a37ea9 --- /dev/null +++ b/examples/multiclass_sklearn.py @@ -0,0 +1,42 @@ +from sklearn.datasets import make_classification +from sklearn.metrics import classification_report +from sklearn.model_selection import train_test_split + +import imlightgbm as imlgb + +# Generate dataset +X, y = make_classification( + n_samples=5000, + n_features=10, + n_classes=3, + n_informative=5, + weights=[0.05, 0.15, 0.8], + flip_y=0, + random_state=42, +) + +# Split the data into training and testing sets +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +# Initialize the ImbalancedLGBMClassifier using binary focal loss +clf = imlgb.ImbalancedLGBMClassifier( + objective="multiclass_focal", # multiclass_weighted + gamma=2.0, # alpha with multiclass_weighted + num_class=3, + learning_rate=0.05, + num_leaves=31, +) + +# Train the classifier on the training data +clf.fit(X=X_train, y=y_train) + +# Make predictions on the test data +y_pred_focal = clf.predict(X_test) + + +# Evaluate the model performance using accuracy, log loss, and ROC AUC +# Evaluate models +print("\nClassification Report:") +print(classification_report(y_test, y_pred_focal))