From 357add295d7cc39b138deb7f309422d58b97cd13 Mon Sep 17 00:00:00 2001 From: Omer Weissbrod Date: Thu, 30 May 2024 00:41:30 +0300 Subject: [PATCH] update polypred to take a weighted sum of PRS instead of a weighted sum of betas, which loses accuracy for some reason we don't understand involving plink --- polypred.py | 275 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 172 insertions(+), 103 deletions(-) diff --git a/polypred.py b/polypred.py index a80837b..cfe680f 100644 --- a/polypred.py +++ b/polypred.py @@ -105,6 +105,8 @@ def compute_prs_for_file(args, plink_cmd += ' --bfile %s --score %s sum'%(plink_file_prefix, betas_file) else: raise ValueError('neither --bed nor --pgen specified') + if args.center: + plink_cmd += ' center' if ranges_file is not None: scores_file = os.path.join(temp_dir, next(tempfile._get_candidate_names())) df_betas[['SNP_bim', 'score']].drop_duplicates('SNP_bim').to_csv(scores_file, sep='\t', header=False, index=False) @@ -175,6 +177,7 @@ def load_betas_files(betas_file, verbose=True): #rename columns if needed df_betas.rename(columns={'sid':'SNP', 'nt1':'A1', 'nt2':'A2', 'BETA_MEAN':'BETA', 'ldpred_inf_beta':'BETA', 'chrom':'CHR', 'Chrom':'CHR', 'pos':'BP'}, inplace=True, errors='ignore') + if not is_numeric_dtype(df_betas['CHR']): if df_betas['CHR'].str.startswith('chrom_').all(): df_betas['CHR'] = df_betas['CHR'].str[6:].astype(np.int64) @@ -263,124 +266,191 @@ def computs_prs_all_files(args, betas_file, disable_jackknife=False, keep_file=N -def estimate_mixing_weights(args): +def compute_prs(args): + + #if we need to perform predictions, make sure the mixweights file is found + if args.predict and args.betas.count(',') > 0: + mixweights_file = args.mixweights_prefix +'.mixweights' + if not os.path.exists(mixweights_file): + raise ValueError('mixweights file %s not found'%(mixweights_file)) - #read phenotypes - df_pheno = pd.read_csv(args.pheno, names=['FID', 'IID', 'PHENO'], index_col='IID', delim_whitespace=True) - - #make sure that we didn't include a header line - try: - float(df_pheno['PHENO'].iloc[0]) - except: - df_pheno = df_pheno.iloc[1:] - df_pheno['PHENO'] = df_pheno['PHENO'].astype(np.float64) - if np.any(df_pheno.index.duplicated()): - raise ValueError('duplicate ids found in %s'%(args.pheno)) #compute a PRS for each beta file beta_files = args.betas.split(',') - df_prs_sum_list = [] + df_prs_list = [] for betas_file in beta_files: - df_prs_sum = computs_prs_all_files(args, betas_file, disable_jackknife=True, keep_file=args.pheno) - df_prs_sum_list.append(df_prs_sum[['SCORESUM']]) - for df_prs_sum in df_prs_sum_list: - assert np.all(df_prs_sum.index == df_prs_sum_list[0].index) - df_prs_sum_all = pd.concat(df_prs_sum_list, axis=1) - - #sync df_pheno and df_prs_sum_all - df_prs_sum_all.index = df_prs_sum_all.index.astype(str) - df_pheno.index = df_pheno.index.astype(str) - index_shared = df_prs_sum_all.index.intersection(df_pheno.index) - assert len(index_shared)>0 - if len(index_shared) < df_prs_sum_all.shape[0]: - df_prs_sum_all = df_prs_sum_all.loc[index_shared] - if df_pheno.shape[0] != df_prs_sum_all.shape[0] or np.any(df_prs_sum_all.index != df_pheno.index): - df_pheno = df_pheno.loc[df_prs_sum_all.index] + df_prs = computs_prs_all_files(args, betas_file, disable_jackknife=not args.predict, keep_file=args.pheno) + df_prs_list.append(df_prs) + for df_prs in df_prs_list: + assert np.all(df_prs.index == df_prs_list[0].index) + df_prs_all = pd.concat(df_prs_list, axis=1) + + + #compute mixing weights if needed + if args.estimate_mixweights: + + #read phenotypes + df_pheno = pd.read_csv(args.pheno, names=['FID', 'IID', 'PHENO'], index_col='IID', delim_whitespace=True) + + #make sure that we didn't include a header line + try: + float(df_pheno['PHENO'].iloc[0]) + except: + df_pheno = df_pheno.iloc[1:] + df_pheno['PHENO'] = df_pheno['PHENO'].astype(np.float64) + if np.any(df_pheno.index.duplicated()): + raise ValueError('duplicate ids found in %s'%(args.pheno)) + + #sync df_pheno and df_prs_all + df_prs_all.index = df_prs_all.index.astype(str) + df_pheno.index = df_pheno.index.astype(str) + index_shared = df_prs_all.index.intersection(df_pheno.index) + assert len(index_shared)>0 + if len(index_shared) < df_prs_all.shape[0]: + df_prs_all = df_prs_all.loc[index_shared] + if df_pheno.shape[0] != df_prs_all.shape[0] or np.any(df_prs_all.index != df_pheno.index): + df_pheno = df_pheno.loc[df_prs_all.index] + + #extract just the SCORESUM columns + df_prs_sum_all = df_prs_all['SCORESUM'].copy() - #flip PRS that are negatively correlated with the phenotype - is_flipped = np.zeros(df_prs_sum_all.shape[1], dtype=bool) - linreg_univariate = LinearRegression() - for c_i in range(df_prs_sum_all.shape[1]): - linreg_univariate.fit(df_prs_sum_all.iloc[:, [c_i]], df_pheno['PHENO']) - is_flipped[c_i] = linreg_univariate.coef_[0] < 0 - df_prs_sum_all.loc[:, is_flipped] *= -1 - - #compute mixing weights - linreg = LinearRegression(positive = not args.allow_neg_mixweights) - linreg.fit(df_prs_sum_all, df_pheno['PHENO']) - mix_weights, intercept = linreg.coef_, linreg.intercept_ - r2_score = metrics.r2_score(df_pheno['PHENO'], linreg.predict(df_prs_sum_all)) - logging.info('In-sample R2: %0.3f'%(r2_score)) - - #create and print df_coef, and save it to disk - df_coef = pd.Series(mix_weights, index=beta_files) - df_coef.loc['intercept'] = intercept - mix_weights_file = args.output_prefix+'.mixweights' - df_coef.to_frame(name='mix_weight').to_csv(mix_weights_file, sep='\t') - logging.info('Writing mixing weights to %s'%(mix_weights_file)) - - #compute weighted betas - df_betas_weighted = None - for is_flipped_beta, betas_file, mix_weight in zip(is_flipped, beta_files, mix_weights): - df_betas = load_betas_files(betas_file) - df_betas = df_betas[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']] - df_betas['BETA'] *= mix_weight - if is_flipped_beta: df_betas['BETA'] = -df_betas['BETA'] - if df_betas_weighted is None: - df_betas_weighted = df_betas - continue + #flip PRS that are negatively correlated with the phenotype + is_flipped = np.zeros(df_prs_sum_all.shape[1], dtype=bool) + linreg_univariate = LinearRegression() + for c_i in range(df_prs_sum_all.shape[1]): + linreg_univariate.fit(df_prs_sum_all.iloc[:, [c_i]], df_pheno['PHENO']) + is_flipped[c_i] = linreg_univariate.coef_[0] < 0 + df_prs_sum_all.loc[:, is_flipped] *= -1 + + #estimate mixing weights + linreg = LinearRegression(positive = not args.allow_neg_mixweights) + linreg.fit(df_prs_sum_all, df_pheno['PHENO']) + mix_weights, intercept = linreg.coef_, linreg.intercept_ + r2_score = metrics.r2_score(df_pheno['PHENO'], linreg.predict(df_prs_sum_all)) + logging.info('In-sample R2: %0.3f'%(r2_score)) - index_shared = df_betas.index.intersection(df_betas_weighted.index) - df_betas['BETA2'] = df_betas['BETA'] - df_new = df_betas_weighted.loc[index_shared].merge(df_betas.loc[index_shared, ['BETA2']], left_index=True, right_index=True) - df_new['BETA'] += df_new['BETA2'] - del df_new['BETA2'] - del df_betas['BETA2'] - df_list = [df_new, df_betas.loc[~df_betas.index.isin(index_shared)], df_betas_weighted.loc[~df_betas_weighted.index.isin(index_shared)]] - df_betas_weighted = pd.concat(df_list, axis=0) - df_betas_weighted.sort_values(['CHR', 'BP', 'A1'], inplace=True) - - #save output to file - df_betas_weighted[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']].to_csv(args.output_prefix+'.betas', sep='\t', index=False, float_format='%0.6e') - logging.info('Saving weighted betas to %s'%(args.output_prefix+'.betas')) + #create and print df_coef, and save it to disk + df_coef = pd.Series(mix_weights, index=beta_files) + df_coef.loc[is_flipped] *= -1 + df_coef.loc['intercept'] = intercept + mix_weights_file = args.output_prefix+'.mixweights' + df_coef.to_frame(name='mix_weight').to_csv(mix_weights_file, sep='\t') + logging.info('Writing mixing weights to %s'%(mix_weights_file)) + + #flip the PRS back + df_prs_sum_all.loc[:, is_flipped] *= -1 - -def compute_prs(args): + #perform predictions + if args.predict: + + #extract just the SCORESUM columns + df_prs_sum_all = df_prs_all['SCORESUM'] + + #just take the PRS if there's only a single beta + if args.betas.count(',') == 0: + assert (df_prs_all.columns=='SCORESUM').sum() == 1 + s_combined_prs = df_prs_sum_all + + #if there's more than one beta, take the linear combination + else: + mixweights_file = args.mixweights_prefix +'.mixweights' + s_mixweights = pd.read_csv(mixweights_file, delim_whitespace=True, squeeze=True) + if np.any(s_mixweights.index[:-1] != args.betas.split(',')): + raise ValueError('The provided betas file do not match the mix weights file') + assert s_mixweights.index[-1] == 'intercept' + s_combined_prs = df_prs_sum_all.dot(s_mixweights.iloc[:-1].values) + s_mixweights.loc['intercept'] + + #save the PRS to disk + df_prs_sum = s_combined_prs.reset_index(drop=False) + df_prs_sum.columns = ['IID', 'PRS'] + df_prs_sum['FID'] = df_prs_sum['IID'] + df_prs_sum = df_prs_sum[['FID', 'IID', 'PRS']] + df_prs_sum.to_csv(args.output_prefix+'.prs', sep='\t', index=False, float_format='%0.5f') + + #handle jackknife + set_jk_columns = set([c for c in df_prs_all.columns if '.jk' in c]) + df_prs_sum_jk = pd.DataFrame(index=df_prs_all.index, columns=set_jk_columns) + if df_prs_sum_jk.shape[1] > 1: + for jk_column in set_jk_columns: + if args.betas.count(',') == 0: + assert (df_prs_all.columns==jk_column).sum() == 1 + df_prs_sum_jk[jk_column] = df_prs_all[jk_column] + else: + #import ipdb; ipdb.set_trace() + df_prs_sum_jk[jk_column] = df_prs_all[jk_column].dot(s_mixweights.iloc[:-1].values) + s_mixweights.loc['intercept'] + + df_prs_sum_jk.reset_index().to_csv(args.output_prefix+'.prs_jk', sep='\t', index=False, float_format='%0.5f') + + logging.info('Saving PRS to %s'%(args.output_prefix+'.prs')) + + + + + # #compute weighted betas + # df_betas_weighted = None + # for is_flipped_beta, betas_file, mix_weight in zip(is_flipped, beta_files, mix_weights): + # df_betas = load_betas_files(betas_file) + # df_betas = df_betas[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']] + # df_betas['BETA'] *= mix_weight + # if is_flipped_beta: df_betas['BETA'] = -df_betas['BETA'] + # if df_betas_weighted is None: + # df_betas_weighted = df_betas + # continue + + # index_shared = df_betas.index.intersection(df_betas_weighted.index) + # df_betas['BETA2'] = df_betas['BETA'] + # df_new = df_betas_weighted.loc[index_shared].merge(df_betas.loc[index_shared, ['BETA2']], left_index=True, right_index=True) + # df_new['BETA'] += df_new['BETA2'] + # del df_new['BETA2'] + # del df_betas['BETA2'] + # df_list = [df_new, df_betas.loc[~df_betas.index.isin(index_shared)], df_betas_weighted.loc[~df_betas_weighted.index.isin(index_shared)]] + # df_betas_weighted = pd.concat(df_list, axis=0) + # df_betas_weighted.sort_values(['CHR', 'BP', 'A1'], inplace=True) + + # #save weighted betas to file + # df_betas_weighted[['SNP', 'CHR', 'BP', 'A1', 'A2', 'BETA']].to_csv(args.output_prefix+'.betas', sep='\t', index=False, float_format='%0.6e') + # logging.info('Saving weighted betas to %s'%(args.output_prefix+'.betas')) - if args.betas.count(',') > 0: - raise ValueError('--predict can only be used with a single betas file') - df_prs_sum = computs_prs_all_files(args, args.betas, disable_jackknife=False, keep_file=args.keep) - df_prs_sum.reset_index(inplace=True, drop=False) - df_prs_sum.columns = df_prs_sum.columns.str.replace('SCORESUM', 'PRS') - df_prs_sum_main = df_prs_sum[['FID', 'IID', 'PRS']] - df_prs_sum_jk = df_prs_sum[['FID', 'IID'] + [c for c in df_prs_sum.columns if c.startswith('PRS.')]] - - df_prs_sum_main.to_csv(args.output_prefix+'.prs', sep='\t', index=False, float_format='%0.5f') - if df_prs_sum_jk.shape[1]>1: - df_prs_sum_jk.to_csv(args.output_prefix+'.prs_jk', sep='\t', index=False, float_format='%0.5f') - logging.info('Saving PRS to %s'%(args.output_prefix+'.prs')) + +# def compute_prs(args): + + # if args.betas.count(',') > 0: + # raise ValueError('--predict can only be used with a single betas file') + # df_prs_sum = computs_prs_all_files(args, args.betas, disable_jackknife=False, keep_file=args.keep) + # df_prs_sum.reset_index(inplace=True, drop=False) + # df_prs_sum.columns = df_prs_sum.columns.str.replace('SCORESUM', 'PRS') + # df_prs_sum_main = df_prs_sum[['FID', 'IID', 'PRS']] + # df_prs_sum_jk = df_prs_sum[['FID', 'IID'] + [c for c in df_prs_sum.columns if c.startswith('PRS.')]] + + # df_prs_sum_main.to_csv(args.output_prefix+'.prs', sep='\t', index=False, float_format='%0.5f') + # if df_prs_sum_jk.shape[1]>1: + # df_prs_sum_jk.to_csv(args.output_prefix+'.prs_jk', sep='\t', index=False, float_format='%0.5f') + # logging.info('Saving PRS to %s'%(args.output_prefix+'.prs')) def check_args(args): - if int(args.predict) + int(args.combine_betas) != 1: - raise ValueError('you must specify either --predict or --combine-betas (but not both)') + if int(args.predict) + int(args.estimate_mixweights) != 1: + raise ValueError('you must specify either --predict or --estimate-mixweights (but not both)') if args.plink_exe is None and args.plink2_exe is None: raise ValueError('you must specify either --plink-exe or --plink2-exe') if args.plink_exe is not None and not os.path.exists(args.plink_exe): raise ValueError('%s not found'%(args.plink_exe)) if args.plink2_exe is not None and not os.path.exists(args.plink2_exe): raise ValueError('%s not found'%(args.plink2_exe)) - if args.combine_betas: + if args.estimate_mixweights: if args.keep is not None: - raise ValueError('you cannot provide both --combine-betas and --keep') + raise ValueError('you cannot provide both --estimate-mixweights and --keep') if args.pheno is None: - raise ValueError('you must provide --pheno if you specify --combine-betas') + raise ValueError('you must provide --pheno if you specify --estimate-mixweights') if args.betas.count(',')==0: - raise ValueError('you must provide multiple files in --betas if you specify --combine-betas') + raise ValueError('you must provide multiple files in --betas if you specify --estimate-mixweights') + if args.predict: + if args.mixweights_prefix is None and args.betas.count(',') > 0: + raise ValueError('you must provide --mixweights-prefix together with --predict if you have more than one beta file') if args.num_jk<0: raise ValueError('--num-jk must be >=0') if args.pheno is not None and args.predict: - raise ValueError('--pheno can only be used with --combine-betas') + raise ValueError('--pheno can only be used with --estimate-mixweights') if len(list(args.files)) == 0: raise ValueError('no input files specified') @@ -391,9 +461,10 @@ def check_args(args): parser = argparse.ArgumentParser() parser.add_argument('--betas', required=True, help='files with SNP effect sizes (comma separated). A1 is the effect allele.') + parser.add_argument('--mixweights-prefix', help='Prefix of files with mixing weights (required if you use --predict with more than one betas file') parser.add_argument('--output-prefix', required=True, help='Prefix of output file') - parser.add_argument('--combine-betas', default=False, action='store_true', help='If specified, PolyPred will estimate mixing weights') + parser.add_argument('--estimate-mixweights', default=False, action='store_true', help='If specified, PolyPred will estimate mixing weights') parser.add_argument('--allow-neg-mixweights', default=False, action='store_true', help='If specified, PolyPred will not enforce non-negative mixing weights') parser.add_argument('--predict', default=False, action='store_true', help='If specified, PolyPred will compute PRS') parser.add_argument('--pheno', default=None, help='Phenotype file (required for estimating mixing weights)') @@ -403,6 +474,7 @@ def check_args(args): parser.add_argument('--extract', default=None, help='A text file with rsids of SNPs to use (one per line)') parser.add_argument('--keep', default=None, help='A text file with ids of individuals to use (two columns per line, each containing FID,IID)') parser.add_argument('--num-jk', type=int, default=200, help='number of genomic jackknife blocks') + parser.add_argument('--center', default=False, action='store_true', help='If specified, the PRS will be centered') parser.add_argument('--memory', type=int, default=2, help='Maximum memory usage (in GB)') parser.add_argument('--threads', type=int, default=1, help='Number of CPU threads') @@ -420,7 +492,9 @@ def check_args(args): #check that the output directory exists if len(os.path.dirname(args.output_prefix))>0 and not os.path.exists(os.path.dirname(args.output_prefix)): - raise ValueError('output directory %s doesn\'t exist'%(os.path.dirname(args.output_prefix))) + raise ValueError('output directory %s doesn\'t exist'%(os.path.dirname(args.output_prefix))) + + #configure logger configure_logger(args.output_prefix) @@ -428,13 +502,8 @@ def check_args(args): #check arguments check_args(args) - #estimate mixing weights if needed - if args.combine_betas: - estimate_mixing_weights(args) - - #compute PRS if needed - if args.predict: - compute_prs(args) + #Estimate mixiwing weights and/or compute PRS + compute_prs(args) print()