@@ -129,12 +129,13 @@ def to_disease_target_evidence(
129
129
)
130
130
131
131
def add_locus_to_gene_features (
132
- self : L2GPrediction , feature_matrix : L2GFeatureMatrix
132
+ self : L2GPrediction , feature_matrix : L2GFeatureMatrix , features_list : list [ str ]
133
133
) -> L2GPrediction :
134
- """Add features to the L2G predictions.
134
+ """Add features used to extract the L2G predictions.
135
135
136
136
Args:
137
137
feature_matrix (L2GFeatureMatrix): Feature matrix dataset
138
+ features_list (list[str]): List of features used in the model
138
139
139
140
Returns:
140
141
L2GPrediction: L2G predictions with additional features
@@ -143,38 +144,26 @@ def add_locus_to_gene_features(
143
144
if "locusToGeneFeatures" in self .df .columns :
144
145
self .df = self .df .drop ("locusToGeneFeatures" )
145
146
146
- # Columns identifying a studyLocus/gene pair
147
- prediction_id_columns = ["studyLocusId" , "geneId" ]
148
-
149
- # L2G matrix columns to build the map:
150
- columns_to_map = [
151
- column
152
- for column in feature_matrix ._df .columns
153
- if column not in prediction_id_columns
154
- ]
155
-
156
147
# Aggregating all features into a single map column:
157
148
aggregated_features = (
158
149
feature_matrix ._df .withColumn (
159
150
"locusToGeneFeatures" ,
160
151
f .create_map (
161
152
* sum (
162
- [
163
- (f .lit (colname ), f .col (colname ))
164
- for colname in columns_to_map
165
- ],
153
+ ((f .lit (feature ), f .col (feature )) for feature in features_list ),
166
154
(),
167
155
)
168
156
),
169
157
)
170
- # from the freshly created map, we filter out the null values
171
158
.withColumn (
172
159
"locusToGeneFeatures" ,
173
- f .expr ("map_filter(locusToGeneFeatures, (k, v) -> v is not null )" ),
160
+ f .expr ("map_filter(locusToGeneFeatures, (k, v) -> v != 0 )" ),
174
161
)
175
- .drop (* columns_to_map )
162
+ .drop (* features_list )
176
163
)
177
164
return L2GPrediction (
178
- _df = self .df .join (aggregated_features , on = prediction_id_columns , how = "left" ),
165
+ _df = self .df .join (
166
+ aggregated_features , on = ["studyLocusId" , "geneId" ], how = "left"
167
+ ),
179
168
_schema = self .get_schema (),
180
169
)
0 commit comments