diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py index 24635934e..41e1ddcfe 100644 --- a/src/gentropy/susie_finemapper.py +++ b/src/gentropy/susie_finemapper.py @@ -734,3 +734,124 @@ def susie_finemapper_one_studylocus_row_v2_dev( ) return out + + @staticmethod + def susie_finemapper_one_studylocus_row_v3_dev_ss_gathered( + session: Session, + study_locus_row: Row, + study_index: StudyIndex, + radius: int = 1_000_000, + max_causal_snps: int = 10, + susie_est_tausq: bool = False, + run_carma: bool = False, + run_sumstat_imputation: bool = False, + carma_time_limit: int = 600, + imputed_r2_threshold: float = 0.8, + ld_score_threshold: float = 4, + sum_pips: float = 0.99, + ) -> dict[str, Any]: + """Susie fine-mapper function that uses study-locus row with collected locus, chromosome and position as inputs. + + Args: + session (Session): Spark session + study_locus_row (Row): StudyLocus row with collected locus + study_index (StudyIndex): StudyIndex object + radius (int): Radius in base-pairs of window for fine-mapping + max_causal_snps (int): maximum number of causal variants + susie_est_tausq (bool): estimate tau squared, default is False + run_carma (bool): run CARMA, default is False + run_sumstat_imputation (bool): run summary statistics imputation, default is False + carma_time_limit (int): CARMA time limit, default is 600 seconds + imputed_r2_threshold (float): imputed R2 threshold, default is 0.8 + ld_score_threshold (float): LD score threshold ofr imputation, default is 4 + sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set) + + Returns: + dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map + """ + # PLEASE DO NOT REMOVE THIS LINE + pd.DataFrame.iteritems = pd.DataFrame.items + + chromosome = study_locus_row["chromosome"] + position = study_locus_row["position"] + studyId = study_locus_row["studyId"] + + study_index_df = study_index._df + study_index_df = study_index_df.filter(f.col("studyId") == studyId) + major_population = study_index_df.select( + "studyId", + f.array_max(f.col("ldPopulationStructure")) + .getItem("ldPopulation") + .alias("majorPopulation"), + ).collect()[0]["majorPopulation"] + + region = ( + chromosome + + ":" + + str(int(position - radius)) + + "-" + + str(int(position + radius)) + ) + + schema = StudyLocus.get_schema() + gwas_df = session.spark.createDataFrame([study_locus_row], schema=schema) + exploded_df = gwas_df.select(f.explode("locus").alias("locus")) + + result_df = exploded_df.select( + "locus.variantId", "locus.beta", "locus.standardError" + ) + gwas_df = ( + result_df.withColumn("z", f.col("beta") / f.col("standardError")) + .withColumn( + "chromosome", f.split(f.col("variantId"), "_")[0].cast("string") + ) + .withColumn("position", f.split(f.col("variantId"), "_")[1].cast("int")) + .filter(f.col("chromosome") == chromosome) + .filter(f.col("position") >= position - radius) + .filter(f.col("position") <= position + radius) + .filter(f.col("z").isNotNull()) + ) + + ld_index = ( + GnomADLDMatrix() + .get_locus_index( + study_locus_row=study_locus_row, + radius=radius, + major_population=major_population, + ) + .withColumn( + "variantId", + f.concat( + f.lit(chromosome), + f.lit("_"), + f.col("`locus.position`"), + f.lit("_"), + f.col("alleles").getItem(0), + f.lit("_"), + f.col("alleles").getItem(1), + ).cast("string"), + ) + ) + + gnomad_ld = GnomADLDMatrix.get_numpy_matrix( + ld_index, gnomad_ancestry=major_population + ) + + out = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes( + GWAS_df=gwas_df, + ld_index=ld_index, + gnomad_ld=gnomad_ld, + L=max_causal_snps, + session=session, + studyId=studyId, + region=region, + susie_est_tausq=susie_est_tausq, + run_carma=run_carma, + run_sumstat_imputation=run_sumstat_imputation, + carma_time_limit=carma_time_limit, + imputed_r2_threshold=imputed_r2_threshold, + ld_score_threshold=ld_score_threshold, + sum_pips=sum_pips, + ) + + return out