Skip to content

Commit 7f94aed

Browse files
authored
Merge branch 'main' into do_patch
2 parents 6a112a6 + 243d8f6 commit 7f94aed

File tree

2 files changed

+68
-21
lines changed

2 files changed

+68
-21
lines changed

src/otg/method/l2g/model.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pyspark.ml.feature import StringIndexer, VectorAssembler
1515
from pyspark.ml.tuning import ParamGridBuilder
1616
from wandb.wandb_run import Run
17+
from xgboost.spark.core import SparkXGBClassifierModel
1718

1819
from otg.dataset.l2g_feature_matrix import L2GFeatureMatrix
1920
from otg.method.l2g.evaluator import WandbEvaluator
@@ -31,6 +32,7 @@ class LocusToGeneModel:
3132
estimator: Any = None
3233
pipeline: Pipeline = Pipeline(stages=[])
3334
model: PipelineModel | None = None
35+
wandb_l2g_project_name: str = "otg_l2g"
3436

3537
def __post_init__(self: LocusToGeneModel) -> None:
3638
"""Post init that adds the model to the ML pipeline."""
@@ -98,29 +100,39 @@ def features_vector_assembler(features_cols: list[str]) -> VectorAssembler:
98100
.setOutputCol("features")
99101
)
100102

101-
@staticmethod
102103
def log_to_wandb(
104+
self: LocusToGeneModel,
103105
results: DataFrame,
104-
binary_evaluator: BinaryClassificationEvaluator,
105-
multi_evaluator: MulticlassClassificationEvaluator,
106+
training_data: L2GFeatureMatrix,
107+
evaluators: list[
108+
BinaryClassificationEvaluator | MulticlassClassificationEvaluator
109+
],
106110
wandb_run: Run,
107111
) -> None:
108-
"""Perform evaluation of the model by applying it to a test set and tracking the results with W&B.
112+
"""Log evaluation results and feature importance to W&B.
109113
110114
Args:
111115
results (DataFrame): Dataframe containing the predictions
112-
binary_evaluator (BinaryClassificationEvaluator): Binary evaluator
113-
multi_evaluator (MulticlassClassificationEvaluator): Multiclass evaluator
116+
training_data (L2GFeatureMatrix): Training data used for the model. If provided, the table and the number of positive and negative labels will be logged to W&B
117+
evaluators (list[BinaryClassificationEvaluator | MulticlassClassificationEvaluator]): List of Spark ML evaluators to use for evaluation
114118
wandb_run (Run): W&B run to log the results to
115119
"""
116-
binary_wandb_evaluator = WandbEvaluator(
117-
spark_ml_evaluator=binary_evaluator, wandb_run=wandb_run
118-
)
119-
binary_wandb_evaluator.evaluate(results)
120-
multi_wandb_evaluator = WandbEvaluator(
121-
spark_ml_evaluator=multi_evaluator, wandb_run=wandb_run
122-
)
123-
multi_wandb_evaluator.evaluate(results)
120+
## Track evaluation metrics
121+
for evaluator in evaluators:
122+
wandb_evaluator = WandbEvaluator(
123+
spark_ml_evaluator=evaluator, wandb_run=wandb_run
124+
)
125+
wandb_evaluator.evaluate(results)
126+
## Track feature importance
127+
wandb_run.log({"importances": self.get_feature_importance()})
128+
## Track training set metadata
129+
gs_counts_dict = {
130+
"goldStandard" + row["goldStandardSet"].capitalize(): row["count"]
131+
for row in training_data.df.groupBy("goldStandardSet").count().collect()
132+
}
133+
wandb_run.log(gs_counts_dict)
134+
training_table = wandb.Table(dataframe=training_data.df.toPandas())
135+
wandb_run.log({"trainingSet": wandb.Table(dataframe=training_table)})
124136

125137
@classmethod
126138
def load_from_disk(
@@ -189,13 +201,15 @@ def evaluate(
189201
results: DataFrame,
190202
hyperparameters: dict[str, Any],
191203
wandb_run_name: str | None,
204+
training_data: L2GFeatureMatrix | None = None,
192205
) -> None:
193206
"""Perform evaluation of the model predictions for the test set and track the results with W&B.
194207
195208
Args:
196209
results (DataFrame): Dataframe containing the predictions
197210
hyperparameters (dict[str, Any]): Hyperparameters used for the model
198211
wandb_run_name (str | None): Descriptive name for the run to be tracked with W&B
212+
training_data (L2GFeatureMatrix | None): Training data used for the model. If provided, the ratio of positive to negative labels will be logged to W&B
199213
"""
200214
binary_evaluator = BinaryClassificationEvaluator(
201215
rawPredictionCol="rawPrediction", labelCol="label"
@@ -226,20 +240,52 @@ def evaluate(
226240
multi_evaluator.evaluate(results, {multi_evaluator.metricName: "f1"}),
227241
)
228242

229-
if wandb_run_name:
243+
if wandb_run_name and training_data:
230244
print("Logging to W&B...")
231245
run = wandb.init(
232-
project="otg_l2g", config=hyperparameters, name=wandb_run_name
246+
project=self.wandb_l2g_project_name,
247+
config=hyperparameters,
248+
name=wandb_run_name,
233249
)
234250
if isinstance(run, Run):
235-
LocusToGeneModel.log_to_wandb(
236-
results, binary_evaluator, multi_evaluator, run
251+
self.log_to_wandb(
252+
results, training_data, [binary_evaluator, multi_evaluator], run
237253
)
238254
run.finish()
239255

240-
def plot_importance(self: LocusToGeneModel) -> None:
241-
"""Plot the feature importance of the model."""
242-
# xgb_plot_importance(self) # FIXME: What is the attribute that stores the model?
256+
@property
257+
def feature_name_map(self: LocusToGeneModel) -> dict[str, str]:
258+
"""Return a dictionary mapping encoded feature names to the original names.
259+
260+
Returns:
261+
dict[str, str]: Feature name map of the model
262+
263+
Raises:
264+
ValueError: If the model has not been fitted yet
265+
"""
266+
if not self.model:
267+
raise ValueError("Model not fitted yet. `fit()` has to be called first.")
268+
elif isinstance(self.model.stages[1], VectorAssembler):
269+
feature_names = self.model.stages[1].getInputCols()
270+
return {f"f{i}": feature_name for i, feature_name in enumerate(feature_names)}
271+
272+
def get_feature_importance(self: LocusToGeneModel) -> dict[str, float]:
273+
"""Return dictionary with relative importances of every feature in the model. Feature names are encoded and have to be mapped back to their original names.
274+
275+
Returns:
276+
dict[str, float]: Dictionary mapping feature names to their importance
277+
278+
Raises:
279+
ValueError: If the model has not been fitted yet or is not an XGBoost model
280+
"""
281+
if not self.model or not isinstance(
282+
self.model.stages[-1], SparkXGBClassifierModel
283+
):
284+
raise ValueError(
285+
f"Model type {type(self.model)} not supported for feature importance."
286+
)
287+
importance_map = self.model.stages[-1].get_feature_importances()
288+
return {self.feature_name_map[k]: v for k, v in importance_map.items()}
243289

244290
def fit(
245291
self: LocusToGeneModel,

src/otg/method/l2g/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def train(
5353
results=model.predict(test),
5454
hyperparameters=hyperparams,
5555
wandb_run_name=wandb_run_name,
56+
training_data=train,
5657
)
5758
if model_path:
5859
l2g_model.save(model_path)

0 commit comments

Comments
 (0)