Skip to content

Commit ba95d04

Browse files
committed
refactor: set ID specification in fixtures with expression to avoid changing nullability status
1 parent 78bda18 commit ba95d04

File tree

1 file changed

+73
-25
lines changed

1 file changed

+73
-25
lines changed

tests/gentropy/conftest.py

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,14 @@ def mock_colocalisation(spark: SparkSession) -> Colocalisation:
7676
randomSeedMethod="hash_fieldname",
7777
)
7878
.withSchema(coloc_schema)
79-
.withColumnSpec("leftStudyLocusId", minValue=1, maxValue=400)
80-
.withColumnSpec("rightStudyLocusId", minValue=1, maxValue=400)
79+
.withColumnSpec(
80+
"leftStudyLocusId",
81+
expr="cast(id as string)",
82+
)
83+
.withColumnSpec(
84+
"rightStudyLocusId",
85+
expr="cast(id as string)",
86+
)
8187
.withColumnSpec("h0", percentNulls=0.1)
8288
.withColumnSpec("h1", percentNulls=0.1)
8389
.withColumnSpec("h2", percentNulls=0.1)
@@ -105,7 +111,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
105111
randomSeedMethod="hash_fieldname",
106112
)
107113
.withSchema(si_schema)
108-
.withColumnSpec("studyId", minValue=1, maxValue=400)
114+
.withColumnSpec(
115+
"studyId",
116+
expr="cast(id as string)",
117+
)
109118
.withColumnSpec(
110119
"traitFromSourceMappedIds",
111120
expr="array(cast(rand() AS string))",
@@ -126,7 +135,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
126135
expr='array(named_struct("sampleSize", cast(rand() as string), "ancestry", cast(rand() as string)))',
127136
percentNulls=0.1,
128137
)
129-
.withColumnSpec("geneId", minValue=1, maxValue=400, percentNulls=0.1)
138+
.withColumnSpec(
139+
"geneId",
140+
expr="cast(id as string)",
141+
)
130142
.withColumnSpec("pubmedId", percentNulls=0.1)
131143
.withColumnSpec("publicationFirstAuthor", percentNulls=0.1)
132144
.withColumnSpec("publicationDate", percentNulls=0.1)
@@ -175,18 +187,15 @@ def mock_study_locus_overlap(spark: SparkSession) -> StudyLocusOverlap:
175187
.withSchema(overlap_schema)
176188
.withColumnSpec(
177189
"leftStudyLocusId",
178-
minValue=1,
179-
maxValue=400,
190+
expr="cast(id as string)",
180191
)
181192
.withColumnSpec(
182193
"rightStudyLocusId",
183-
minValue=1,
184-
maxValue=400,
194+
expr="cast(id as string)",
185195
)
186196
.withColumnSpec(
187197
"tagVariantId",
188-
minValue=1,
189-
maxValue=400,
198+
expr="cast(id as string)",
190199
)
191200
.withColumnSpec(
192201
"rightStudyType", percentNulls=0.0, values=StudyIndex.VALID_TYPES
@@ -211,7 +220,10 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame:
211220
randomSeedMethod="hash_fieldname",
212221
)
213222
.withSchema(sl_schema)
214-
.withColumnSpec("variantId", minValue=1, maxValue=400)
223+
.withColumnSpec(
224+
"variantId",
225+
expr="cast(id as string)",
226+
)
215227
.withColumnSpec("chromosome", percentNulls=0.1)
216228
.withColumnSpec("position", minValue=100, percentNulls=0.1)
217229
.withColumnSpec("beta", percentNulls=0.1)
@@ -288,7 +300,10 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
288300
randomSeedMethod="hash_fieldname",
289301
)
290302
.withSchema(vi_schema)
291-
.withColumnSpec("variantId", minValue=1, maxValue=400)
303+
.withColumnSpec(
304+
"variantId",
305+
expr="cast(id as string)",
306+
)
292307
.withColumnSpec("mostSevereConsequenceId", percentNulls=0.1)
293308
# Nested column handling workaround
294309
# https://github.com/databrickslabs/dbldatagen/issues/135
@@ -382,8 +397,14 @@ def mock_summary_statistics_data(spark: SparkSession) -> DataFrame:
382397
name="summaryStats",
383398
)
384399
.withSchema(ss_schema)
385-
.withColumnSpec("studyId", minValue=1, maxValue=400)
386-
.withColumnSpec("variantId", minValue=1, maxValue=400)
400+
.withColumnSpec(
401+
"studyId",
402+
expr="cast(id as string)",
403+
)
404+
.withColumnSpec(
405+
"variantId",
406+
expr="cast(id as string)",
407+
)
387408
# Allowing missingness in effect allele frequency and enforce upper limit:
388409
.withColumnSpec(
389410
"effectAlleleFrequencyFromSource", percentNulls=0.1, maxValue=1.0
@@ -418,7 +439,10 @@ def mock_ld_index(spark: SparkSession) -> LDIndex:
418439
randomSeedMethod="hash_fieldname",
419440
)
420441
.withSchema(ld_schema)
421-
.withColumn("variantId", minValue=1, maxValue=400)
442+
.withColumnSpec(
443+
"variantId",
444+
expr="cast(id as string)",
445+
)
422446
.withColumnSpec(
423447
"ldSet",
424448
expr="array(named_struct('tagVariantId', cast(floor(rand() * 400) + 1 as string), 'rValues', array(named_struct('population', cast(rand() as string), 'r', cast(rand() as double)))))",
@@ -555,12 +579,15 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex:
555579
data_spec = (
556580
dg.DataGenerator(
557581
spark,
558-
rows=400,
582+
rows=30,
559583
partitions=4,
560584
randomSeedMethod="hash_fieldname",
561585
)
562586
.withSchema(gi_schema)
563-
.withColumnSpec("geneId", minValue=1, maxValue=400)
587+
.withColumnSpec(
588+
"geneId",
589+
expr="cast(id as string)",
590+
)
564591
.withColumnSpec("approvedSymbol", percentNulls=0.1)
565592
.withColumnSpec(
566593
"biotype", percentNulls=0.1, values=["protein_coding", "lncRNA"]
@@ -570,7 +597,7 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex:
570597
.withColumnSpec("start", percentNulls=0.1)
571598
.withColumnSpec("end", percentNulls=0.1)
572599
.withColumnSpec("strand", percentNulls=0.1, values=[1, -1])
573-
)
600+
).build()
574601

575602
return GeneIndex(_df=data_spec.build(), _schema=gi_schema)
576603

@@ -591,7 +618,10 @@ def mock_biosample_index(spark: SparkSession) -> BiosampleIndex:
591618
randomSeedMethod="hash_fieldname",
592619
)
593620
.withSchema(bi_schema)
594-
.withColumnSpec("biosampleId", minValue=1, maxValue=400)
621+
.withColumnSpec(
622+
"biosampleId",
623+
expr="cast(id as string)",
624+
)
595625
.withColumnSpec("biosampleName", percentNulls=0.1)
596626
.withColumnSpec("description", percentNulls=0.1)
597627
.withColumnSpec("xrefs", expr=array_expression, percentNulls=0.1)
@@ -652,10 +682,22 @@ def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard:
652682
)
653683
.withSchema(schema)
654684
.withColumnSpec("studyLocusId", minValue=1, maxValue=400)
655-
.withColumnSpec("variantId", minValue=1, maxValue=400)
656-
.withColumnSpec("studyId", minValue=1, maxValue=400)
657-
.withColumnSpec("geneId", minValue=1, maxValue=400)
658-
.withColumnSpec("traitFromSourceMappedId", minValue=1, maxValue=400)
685+
.withColumnSpec(
686+
"studyLocusId",
687+
expr="cast(id as string)",
688+
)
689+
.withColumnSpec(
690+
"variantId",
691+
expr="cast(id as string)",
692+
)
693+
.withColumnSpec(
694+
"geneId",
695+
expr="cast(id as string)",
696+
)
697+
.withColumnSpec(
698+
"traitFromSourceMappedId",
699+
expr="cast(id as string)",
700+
)
659701
.withColumnSpec(
660702
"goldStandardSet",
661703
values=[
@@ -677,8 +719,14 @@ def mock_l2g_predictions(spark: SparkSession) -> L2GPrediction:
677719
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
678720
)
679721
.withSchema(schema)
680-
.withColumnSpec("studyId", minValue=1, maxValue=400)
681-
.withColumnSpec("geneId", minValue=1, maxValue=400)
722+
.withColumnSpec(
723+
"studyId",
724+
expr="cast(id as string)",
725+
)
726+
.withColumnSpec(
727+
"geneId",
728+
expr="cast(id as string)",
729+
)
682730
)
683731

684732
return L2GPrediction(_df=data_spec.build(), _schema=schema)

0 commit comments

Comments
 (0)