14
14
from pyspark .ml .feature import StringIndexer , VectorAssembler
15
15
from pyspark .ml .tuning import ParamGridBuilder
16
16
from wandb .wandb_run import Run
17
+ from xgboost .spark .core import SparkXGBClassifierModel
17
18
18
19
from otg .dataset .l2g_feature_matrix import L2GFeatureMatrix
19
20
from otg .method .l2g .evaluator import WandbEvaluator
@@ -31,6 +32,7 @@ class LocusToGeneModel:
31
32
estimator : Any = None
32
33
pipeline : Pipeline = Pipeline (stages = [])
33
34
model : PipelineModel | None = None
35
+ wandb_l2g_project_name : str = "otg_l2g"
34
36
35
37
def __post_init__ (self : LocusToGeneModel ) -> None :
36
38
"""Post init that adds the model to the ML pipeline."""
@@ -98,29 +100,39 @@ def features_vector_assembler(features_cols: list[str]) -> VectorAssembler:
98
100
.setOutputCol ("features" )
99
101
)
100
102
101
- @staticmethod
102
103
def log_to_wandb (
104
+ self : LocusToGeneModel ,
103
105
results : DataFrame ,
104
- binary_evaluator : BinaryClassificationEvaluator ,
105
- multi_evaluator : MulticlassClassificationEvaluator ,
106
+ training_data : L2GFeatureMatrix ,
107
+ evaluators : list [
108
+ BinaryClassificationEvaluator | MulticlassClassificationEvaluator
109
+ ],
106
110
wandb_run : Run ,
107
111
) -> None :
108
- """Perform evaluation of the model by applying it to a test set and tracking the results with W&B.
112
+ """Log evaluation results and feature importance to W&B.
109
113
110
114
Args:
111
115
results (DataFrame): Dataframe containing the predictions
112
- binary_evaluator (BinaryClassificationEvaluator ): Binary evaluator
113
- multi_evaluator ( MulticlassClassificationEvaluator): Multiclass evaluator
116
+ training_data (L2GFeatureMatrix ): Training data used for the model. If provided, the table and the number of positive and negative labels will be logged to W&B
117
+ evaluators (list[BinaryClassificationEvaluator | MulticlassClassificationEvaluator] ): List of Spark ML evaluators to use for evaluation
114
118
wandb_run (Run): W&B run to log the results to
115
119
"""
116
- binary_wandb_evaluator = WandbEvaluator (
117
- spark_ml_evaluator = binary_evaluator , wandb_run = wandb_run
118
- )
119
- binary_wandb_evaluator .evaluate (results )
120
- multi_wandb_evaluator = WandbEvaluator (
121
- spark_ml_evaluator = multi_evaluator , wandb_run = wandb_run
122
- )
123
- multi_wandb_evaluator .evaluate (results )
120
+ ## Track evaluation metrics
121
+ for evaluator in evaluators :
122
+ wandb_evaluator = WandbEvaluator (
123
+ spark_ml_evaluator = evaluator , wandb_run = wandb_run
124
+ )
125
+ wandb_evaluator .evaluate (results )
126
+ ## Track feature importance
127
+ wandb_run .log ({"importances" : self .get_feature_importance ()})
128
+ ## Track training set metadata
129
+ gs_counts_dict = {
130
+ "goldStandard" + row ["goldStandardSet" ].capitalize (): row ["count" ]
131
+ for row in training_data .df .groupBy ("goldStandardSet" ).count ().collect ()
132
+ }
133
+ wandb_run .log (gs_counts_dict )
134
+ training_table = wandb .Table (dataframe = training_data .df .toPandas ())
135
+ wandb_run .log ({"trainingSet" : wandb .Table (dataframe = training_table )})
124
136
125
137
@classmethod
126
138
def load_from_disk (
@@ -189,13 +201,15 @@ def evaluate(
189
201
results : DataFrame ,
190
202
hyperparameters : dict [str , Any ],
191
203
wandb_run_name : str | None ,
204
+ training_data : L2GFeatureMatrix | None = None ,
192
205
) -> None :
193
206
"""Perform evaluation of the model predictions for the test set and track the results with W&B.
194
207
195
208
Args:
196
209
results (DataFrame): Dataframe containing the predictions
197
210
hyperparameters (dict[str, Any]): Hyperparameters used for the model
198
211
wandb_run_name (str | None): Descriptive name for the run to be tracked with W&B
212
+ training_data (L2GFeatureMatrix | None): Training data used for the model. If provided, the ratio of positive to negative labels will be logged to W&B
199
213
"""
200
214
binary_evaluator = BinaryClassificationEvaluator (
201
215
rawPredictionCol = "rawPrediction" , labelCol = "label"
@@ -226,20 +240,52 @@ def evaluate(
226
240
multi_evaluator .evaluate (results , {multi_evaluator .metricName : "f1" }),
227
241
)
228
242
229
- if wandb_run_name :
243
+ if wandb_run_name and training_data :
230
244
print ("Logging to W&B..." )
231
245
run = wandb .init (
232
- project = "otg_l2g" , config = hyperparameters , name = wandb_run_name
246
+ project = self .wandb_l2g_project_name ,
247
+ config = hyperparameters ,
248
+ name = wandb_run_name ,
233
249
)
234
250
if isinstance (run , Run ):
235
- LocusToGeneModel .log_to_wandb (
236
- results , binary_evaluator , multi_evaluator , run
251
+ self .log_to_wandb (
252
+ results , training_data , [ binary_evaluator , multi_evaluator ] , run
237
253
)
238
254
run .finish ()
239
255
240
- def plot_importance (self : LocusToGeneModel ) -> None :
241
- """Plot the feature importance of the model."""
242
- # xgb_plot_importance(self) # FIXME: What is the attribute that stores the model?
256
+ @property
257
+ def feature_name_map (self : LocusToGeneModel ) -> dict [str , str ]:
258
+ """Return a dictionary mapping encoded feature names to the original names.
259
+
260
+ Returns:
261
+ dict[str, str]: Feature name map of the model
262
+
263
+ Raises:
264
+ ValueError: If the model has not been fitted yet
265
+ """
266
+ if not self .model :
267
+ raise ValueError ("Model not fitted yet. `fit()` has to be called first." )
268
+ elif isinstance (self .model .stages [1 ], VectorAssembler ):
269
+ feature_names = self .model .stages [1 ].getInputCols ()
270
+ return {f"f{ i } " : feature_name for i , feature_name in enumerate (feature_names )}
271
+
272
+ def get_feature_importance (self : LocusToGeneModel ) -> dict [str , float ]:
273
+ """Return dictionary with relative importances of every feature in the model. Feature names are encoded and have to be mapped back to their original names.
274
+
275
+ Returns:
276
+ dict[str, float]: Dictionary mapping feature names to their importance
277
+
278
+ Raises:
279
+ ValueError: If the model has not been fitted yet or is not an XGBoost model
280
+ """
281
+ if not self .model or not isinstance (
282
+ self .model .stages [- 1 ], SparkXGBClassifierModel
283
+ ):
284
+ raise ValueError (
285
+ f"Model type { type (self .model )} not supported for feature importance."
286
+ )
287
+ importance_map = self .model .stages [- 1 ].get_feature_importances ()
288
+ return {self .feature_name_map [k ]: v for k , v in importance_map .items ()}
243
289
244
290
def fit (
245
291
self : LocusToGeneModel ,
0 commit comments