Skip to content

Commit 43f047a

Browse files
fix(l2g_predictions): annotate based on list of features + filter out missing annotation (#925)
* fix(prediction): do not annotate all features from matrix * fix(prediction): filter out features with 0 * chore: pre-commit auto fixes [...]
1 parent a02f9c1 commit 43f047a

File tree

2 files changed

+15
-24
lines changed

2 files changed

+15
-24
lines changed

src/gentropy/dataset/l2g_prediction.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,13 @@ def to_disease_target_evidence(
129129
)
130130

131131
def add_locus_to_gene_features(
132-
self: L2GPrediction, feature_matrix: L2GFeatureMatrix
132+
self: L2GPrediction, feature_matrix: L2GFeatureMatrix, features_list: list[str]
133133
) -> L2GPrediction:
134-
"""Add features to the L2G predictions.
134+
"""Add features used to extract the L2G predictions.
135135
136136
Args:
137137
feature_matrix (L2GFeatureMatrix): Feature matrix dataset
138+
features_list (list[str]): List of features used in the model
138139
139140
Returns:
140141
L2GPrediction: L2G predictions with additional features
@@ -143,38 +144,26 @@ def add_locus_to_gene_features(
143144
if "locusToGeneFeatures" in self.df.columns:
144145
self.df = self.df.drop("locusToGeneFeatures")
145146

146-
# Columns identifying a studyLocus/gene pair
147-
prediction_id_columns = ["studyLocusId", "geneId"]
148-
149-
# L2G matrix columns to build the map:
150-
columns_to_map = [
151-
column
152-
for column in feature_matrix._df.columns
153-
if column not in prediction_id_columns
154-
]
155-
156147
# Aggregating all features into a single map column:
157148
aggregated_features = (
158149
feature_matrix._df.withColumn(
159150
"locusToGeneFeatures",
160151
f.create_map(
161152
*sum(
162-
[
163-
(f.lit(colname), f.col(colname))
164-
for colname in columns_to_map
165-
],
153+
((f.lit(feature), f.col(feature)) for feature in features_list),
166154
(),
167155
)
168156
),
169157
)
170-
# from the freshly created map, we filter out the null values
171158
.withColumn(
172159
"locusToGeneFeatures",
173-
f.expr("map_filter(locusToGeneFeatures, (k, v) -> v is not null)"),
160+
f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"),
174161
)
175-
.drop(*columns_to_map)
162+
.drop(*features_list)
176163
)
177164
return L2GPrediction(
178-
_df=self.df.join(aggregated_features, on=prediction_id_columns, how="left"),
165+
_df=self.df.join(
166+
aggregated_features, on=["studyLocusId", "geneId"], how="left"
167+
),
179168
_schema=self.get_schema(),
180169
)

src/gentropy/l2g.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pyspark.sql.functions as f
99
from sklearn.ensemble import GradientBoostingClassifier
10-
from wandb import login as wandb_login
10+
from wandb.sdk.wandb_login import login as wandb_login
1111

1212
from gentropy.common.schemas import compare_struct_schemas
1313
from gentropy.common.session import Session
@@ -285,9 +285,11 @@ def run_predict(self) -> None:
285285
)
286286
predictions.filter(
287287
f.col("score") >= self.l2g_threshold
288-
).add_locus_to_gene_features(self.feature_matrix).df.coalesce(
289-
self.session.output_partitions
290-
).write.mode(self.session.write_mode).parquet(self.predictions_path)
288+
).add_locus_to_gene_features(
289+
self.feature_matrix, self.features_list
290+
).df.coalesce(self.session.output_partitions).write.mode(
291+
self.session.write_mode
292+
).parquet(self.predictions_path)
291293
self.session.logger.info("L2G predictions saved successfully.")
292294

293295
def run_train(self) -> None:

0 commit comments

Comments
 (0)