@@ -76,8 +76,14 @@ def mock_colocalisation(spark: SparkSession) -> Colocalisation:
76
76
randomSeedMethod = "hash_fieldname" ,
77
77
)
78
78
.withSchema (coloc_schema )
79
- .withColumnSpec ("leftStudyLocusId" , minValue = 1 , maxValue = 400 )
80
- .withColumnSpec ("rightStudyLocusId" , minValue = 1 , maxValue = 400 )
79
+ .withColumnSpec (
80
+ "leftStudyLocusId" ,
81
+ expr = "cast(id as string)" ,
82
+ )
83
+ .withColumnSpec (
84
+ "rightStudyLocusId" ,
85
+ expr = "cast(id as string)" ,
86
+ )
81
87
.withColumnSpec ("h0" , percentNulls = 0.1 )
82
88
.withColumnSpec ("h1" , percentNulls = 0.1 )
83
89
.withColumnSpec ("h2" , percentNulls = 0.1 )
@@ -105,7 +111,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
105
111
randomSeedMethod = "hash_fieldname" ,
106
112
)
107
113
.withSchema (si_schema )
108
- .withColumnSpec ("studyId" , minValue = 1 , maxValue = 400 )
114
+ .withColumnSpec (
115
+ "studyId" ,
116
+ expr = "cast(id as string)" ,
117
+ )
109
118
.withColumnSpec (
110
119
"traitFromSourceMappedIds" ,
111
120
expr = "array(cast(rand() AS string))" ,
@@ -126,7 +135,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
126
135
expr = 'array(named_struct("sampleSize", cast(rand() as string), "ancestry", cast(rand() as string)))' ,
127
136
percentNulls = 0.1 ,
128
137
)
129
- .withColumnSpec ("geneId" , minValue = 1 , maxValue = 400 , percentNulls = 0.1 )
138
+ .withColumnSpec (
139
+ "geneId" ,
140
+ expr = "cast(id as string)" ,
141
+ )
130
142
.withColumnSpec ("pubmedId" , percentNulls = 0.1 )
131
143
.withColumnSpec ("publicationFirstAuthor" , percentNulls = 0.1 )
132
144
.withColumnSpec ("publicationDate" , percentNulls = 0.1 )
@@ -175,18 +187,15 @@ def mock_study_locus_overlap(spark: SparkSession) -> StudyLocusOverlap:
175
187
.withSchema (overlap_schema )
176
188
.withColumnSpec (
177
189
"leftStudyLocusId" ,
178
- minValue = 1 ,
179
- maxValue = 400 ,
190
+ expr = "cast(id as string)" ,
180
191
)
181
192
.withColumnSpec (
182
193
"rightStudyLocusId" ,
183
- minValue = 1 ,
184
- maxValue = 400 ,
194
+ expr = "cast(id as string)" ,
185
195
)
186
196
.withColumnSpec (
187
197
"tagVariantId" ,
188
- minValue = 1 ,
189
- maxValue = 400 ,
198
+ expr = "cast(id as string)" ,
190
199
)
191
200
.withColumnSpec (
192
201
"rightStudyType" , percentNulls = 0.0 , values = StudyIndex .VALID_TYPES
@@ -211,7 +220,10 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame:
211
220
randomSeedMethod = "hash_fieldname" ,
212
221
)
213
222
.withSchema (sl_schema )
214
- .withColumnSpec ("variantId" , minValue = 1 , maxValue = 400 )
223
+ .withColumnSpec (
224
+ "variantId" ,
225
+ expr = "cast(id as string)" ,
226
+ )
215
227
.withColumnSpec ("chromosome" , percentNulls = 0.1 )
216
228
.withColumnSpec ("position" , minValue = 100 , percentNulls = 0.1 )
217
229
.withColumnSpec ("beta" , percentNulls = 0.1 )
@@ -288,7 +300,10 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
288
300
randomSeedMethod = "hash_fieldname" ,
289
301
)
290
302
.withSchema (vi_schema )
291
- .withColumnSpec ("variantId" , minValue = 1 , maxValue = 400 )
303
+ .withColumnSpec (
304
+ "variantId" ,
305
+ expr = "cast(id as string)" ,
306
+ )
292
307
.withColumnSpec ("mostSevereConsequenceId" , percentNulls = 0.1 )
293
308
# Nested column handling workaround
294
309
# https://github.com/databrickslabs/dbldatagen/issues/135
@@ -382,8 +397,14 @@ def mock_summary_statistics_data(spark: SparkSession) -> DataFrame:
382
397
name = "summaryStats" ,
383
398
)
384
399
.withSchema (ss_schema )
385
- .withColumnSpec ("studyId" , minValue = 1 , maxValue = 400 )
386
- .withColumnSpec ("variantId" , minValue = 1 , maxValue = 400 )
400
+ .withColumnSpec (
401
+ "studyId" ,
402
+ expr = "cast(id as string)" ,
403
+ )
404
+ .withColumnSpec (
405
+ "variantId" ,
406
+ expr = "cast(id as string)" ,
407
+ )
387
408
# Allowing missingness in effect allele frequency and enforce upper limit:
388
409
.withColumnSpec (
389
410
"effectAlleleFrequencyFromSource" , percentNulls = 0.1 , maxValue = 1.0
@@ -418,7 +439,10 @@ def mock_ld_index(spark: SparkSession) -> LDIndex:
418
439
randomSeedMethod = "hash_fieldname" ,
419
440
)
420
441
.withSchema (ld_schema )
421
- .withColumn ("variantId" , minValue = 1 , maxValue = 400 )
442
+ .withColumnSpec (
443
+ "variantId" ,
444
+ expr = "cast(id as string)" ,
445
+ )
422
446
.withColumnSpec (
423
447
"ldSet" ,
424
448
expr = "array(named_struct('tagVariantId', cast(floor(rand() * 400) + 1 as string), 'rValues', array(named_struct('population', cast(rand() as string), 'r', cast(rand() as double)))))" ,
@@ -555,12 +579,15 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex:
555
579
data_spec = (
556
580
dg .DataGenerator (
557
581
spark ,
558
- rows = 400 ,
582
+ rows = 30 ,
559
583
partitions = 4 ,
560
584
randomSeedMethod = "hash_fieldname" ,
561
585
)
562
586
.withSchema (gi_schema )
563
- .withColumnSpec ("geneId" , minValue = 1 , maxValue = 400 )
587
+ .withColumnSpec (
588
+ "geneId" ,
589
+ expr = "cast(id as string)" ,
590
+ )
564
591
.withColumnSpec ("approvedSymbol" , percentNulls = 0.1 )
565
592
.withColumnSpec (
566
593
"biotype" , percentNulls = 0.1 , values = ["protein_coding" , "lncRNA" ]
@@ -570,7 +597,7 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex:
570
597
.withColumnSpec ("start" , percentNulls = 0.1 )
571
598
.withColumnSpec ("end" , percentNulls = 0.1 )
572
599
.withColumnSpec ("strand" , percentNulls = 0.1 , values = [1 , - 1 ])
573
- )
600
+ ). build ()
574
601
575
602
return GeneIndex (_df = data_spec .build (), _schema = gi_schema )
576
603
@@ -591,7 +618,10 @@ def mock_biosample_index(spark: SparkSession) -> BiosampleIndex:
591
618
randomSeedMethod = "hash_fieldname" ,
592
619
)
593
620
.withSchema (bi_schema )
594
- .withColumnSpec ("biosampleId" , minValue = 1 , maxValue = 400 )
621
+ .withColumnSpec (
622
+ "biosampleId" ,
623
+ expr = "cast(id as string)" ,
624
+ )
595
625
.withColumnSpec ("biosampleName" , percentNulls = 0.1 )
596
626
.withColumnSpec ("description" , percentNulls = 0.1 )
597
627
.withColumnSpec ("xrefs" , expr = array_expression , percentNulls = 0.1 )
@@ -652,10 +682,22 @@ def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard:
652
682
)
653
683
.withSchema (schema )
654
684
.withColumnSpec ("studyLocusId" , minValue = 1 , maxValue = 400 )
655
- .withColumnSpec ("variantId" , minValue = 1 , maxValue = 400 )
656
- .withColumnSpec ("studyId" , minValue = 1 , maxValue = 400 )
657
- .withColumnSpec ("geneId" , minValue = 1 , maxValue = 400 )
658
- .withColumnSpec ("traitFromSourceMappedId" , minValue = 1 , maxValue = 400 )
685
+ .withColumnSpec (
686
+ "studyLocusId" ,
687
+ expr = "cast(id as string)" ,
688
+ )
689
+ .withColumnSpec (
690
+ "variantId" ,
691
+ expr = "cast(id as string)" ,
692
+ )
693
+ .withColumnSpec (
694
+ "geneId" ,
695
+ expr = "cast(id as string)" ,
696
+ )
697
+ .withColumnSpec (
698
+ "traitFromSourceMappedId" ,
699
+ expr = "cast(id as string)" ,
700
+ )
659
701
.withColumnSpec (
660
702
"goldStandardSet" ,
661
703
values = [
@@ -677,8 +719,14 @@ def mock_l2g_predictions(spark: SparkSession) -> L2GPrediction:
677
719
spark , rows = 400 , partitions = 4 , randomSeedMethod = "hash_fieldname"
678
720
)
679
721
.withSchema (schema )
680
- .withColumnSpec ("studyId" , minValue = 1 , maxValue = 400 )
681
- .withColumnSpec ("geneId" , minValue = 1 , maxValue = 400 )
722
+ .withColumnSpec (
723
+ "studyId" ,
724
+ expr = "cast(id as string)" ,
725
+ )
726
+ .withColumnSpec (
727
+ "geneId" ,
728
+ expr = "cast(id as string)" ,
729
+ )
682
730
)
683
731
684
732
return L2GPrediction (_df = data_spec .build (), _schema = schema )
0 commit comments