forked from aertslab/arboreto
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
508 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,3 +99,4 @@ ENV/ | |
|
||
# mypy | ||
.mypy_cache/ | ||
canopy_walk_whitsundays002 copy.jpg |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,12 @@ | ||
# arboretum | ||
A scalable framework for gene regulatory network inference using tree-based ensemble regressors. | ||
 | ||
|
||
ARBORETUM: a scalable framework for gene regulatory network inference using tree-based ensemble regressors. | ||
|
||
## Introduction | ||
|
||
## Getting Started | ||
|
||
## License | ||
|
||
## References | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
""" | ||
Core functional building blocks, composed in a Dask graph for distributed computation. | ||
""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor, ExtraTreesRegressor | ||
from dask import delayed | ||
from dask.dataframe import from_delayed | ||
from dask.dataframe.utils import make_meta | ||
|
||
THE_DEMON_SEED = 666 | ||
|
||
SKL_REGRESSOR_FACTORY = { | ||
'RF': RandomForestRegressor, | ||
'ET': ExtraTreesRegressor, | ||
'GBM': GradientBoostingRegressor | ||
} | ||
|
||
DEFAULT_KWARGS = { | ||
|
||
'RF': { | ||
'n_jobs': 1, | ||
'n_estimators': 1000, | ||
'max_features': 'sqrt' | ||
}, | ||
|
||
'ET': { | ||
'n_jobs': 1, | ||
'n_estimators': 1000, | ||
'max_features': 'sqrt' | ||
}, | ||
|
||
'GBM': { | ||
'learning_rate': 0.001, | ||
'n_estimators': 500, | ||
'max_features': 0.1 | ||
}, | ||
|
||
'XGB': { | ||
|
||
} | ||
|
||
} | ||
|
||
|
||
def __is_skl_regressor(regressor_type): | ||
""" | ||
:param regressor_type: the regressor type to consider. Case insensitive. | ||
:return: boolean indicating whether the regressor type is a scikit-learn regressor. | ||
""" | ||
return regressor_type.upper() in SKL_REGRESSOR_FACTORY.keys() | ||
|
||
|
||
def __is_xgboost_regressor(regressor_type): | ||
""" | ||
:param regressor_type: the regressor type to consider. Case insensitive. | ||
:return: boolean indicating whether the regressor type is the xgboost regressor. | ||
""" | ||
return regressor_type.upper() == 'XGB' | ||
|
||
|
||
def train_model(regressor_type, | ||
regressor_kwargs, | ||
tf_matrix, | ||
target_gene_expression, | ||
seed=THE_DEMON_SEED): | ||
""" | ||
:param regressor_type: one of ['RF', 'ET', 'GBM', 'XGB']. Case insensitive. | ||
:param regressor_kwargs: a dict of key-value pairs that configures the regressor. | ||
:param tf_matrix: the predictor matrix (transcription factor matrix) as a numpy array. | ||
:param target_gene_expression: the target (y) gene expression to predict in function of the tf_matrix (X). | ||
:param seed: (optional) random seed for the regressors. | ||
:return: a trained regression model. | ||
""" | ||
|
||
assert tf_matrix.shape[0] == len(target_gene_expression) | ||
|
||
if __is_skl_regressor(regressor_type): | ||
regressor = SKL_REGRESSOR_FACTORY[regressor_type](random_state=seed, **regressor_kwargs) | ||
|
||
regressor.fit(tf_matrix, target_gene_expression) | ||
|
||
return regressor | ||
|
||
elif __is_xgboost_regressor(regressor_type): | ||
raise ValueError('XGB regressor not yet supported') # TODO | ||
|
||
else: | ||
raise ValueError('Unsupported regressor type: {0}'.format(regressor_type)) | ||
|
||
|
||
def to_tf_matrix(expression_matrix, | ||
gene_names, | ||
tf_names): | ||
""" | ||
:param expression_matrix: numpy matrix. Rows are observations and columns are genes. | ||
:param gene_names: a list of gene names. Each entry corresponds to the expression_matrix column with same index. | ||
:param tf_names: a list of transcription factor names. Should be a subset of gene_names. | ||
:return: a numpy matrix representing the predictor matrix for the regressions. | ||
""" | ||
|
||
assert expression_matrix.shape[1] == len(gene_names) | ||
|
||
tf_indices = [index for index, gene in enumerate(gene_names) if gene in tf_names] | ||
|
||
return expression_matrix[: tf_indices] | ||
|
||
|
||
def to_links_df(regressor_type, | ||
trained_model, | ||
tf_names, | ||
target_gene_name): | ||
""" | ||
:param regressor_type: one of ['RF', 'ET', 'GBM', 'XGB']. Case insensitive. | ||
:param trained_model: the trained model from which to extract the feature importances. | ||
:param tf_names: the list of names corresponding to the columns of the tf_matrix used to train the model. | ||
:param target_gene_name: the name of the target gene. | ||
:return: a Pandas DataFrame['TF', 'target', 'importance'] representing inferred regulatory links and their | ||
connection strength. | ||
""" | ||
|
||
if __is_skl_regressor(regressor_type): | ||
importances = trained_model.feature_importances_ | ||
|
||
links_df = pd.DataFrame({'TF': tf_names, 'importance': importances}) | ||
links_df['target'] = target_gene_name | ||
|
||
return links_df[links_df.importance > 0].sort_values(by='importance', ascending=False) | ||
|
||
elif __is_xgboost_regressor(regressor_type): | ||
raise ValueError('XGB regressor not yet supported') # TODO | ||
|
||
else: | ||
raise ValueError('Unsupported regressor type: ' + regressor_type) | ||
|
||
|
||
def clean(tf_matrix, | ||
tf_names, | ||
target_gene_name): | ||
""" | ||
:param tf_matrix: numpy array. The full transcription factor matrix. | ||
:param tf_names: the full list of transcriptor factor names, corresponding to the tf_matrix columns. | ||
:param target_gene_name: the target gene to remove from the th_matrix and tf_names. | ||
:return: a tuple of (matrix, names) equal to the specified ones minus the target_gene_name if the target happens | ||
to be one of the transcription factors. If not, the specified (tf_matrix, tf_names) is returned verbatim. | ||
""" | ||
|
||
clean_tf_matrix = tf_matrix if target_gene_name not in tf_names else np.delete(tf_matrix, tf_names.index(target_gene_name), 1) | ||
clean_tf_names = [tf for tf in tf_names if tf != target_gene_name] | ||
|
||
assert clean_tf_matrix.shape[1] == len(clean_tf_names) # sanity check | ||
|
||
return clean_tf_matrix, clean_tf_names | ||
|
||
|
||
def infer_links(regressor_type, | ||
regressor_kwargs, | ||
tf_matrix, | ||
tf_names, | ||
target_gene_name, | ||
target_gene_expression, | ||
seed=THE_DEMON_SEED): | ||
""" | ||
Top-level function. Ties together model training and feature importance extraction. | ||
:param regressor_type: one of ['RF', 'ET', 'GBM', 'XGB']. Case insensitive. | ||
:param regressor_kwargs: dict of key-value pairs that configures the regressor. | ||
:param tf_matrix: numpy matrix. The feature matrix X to use for the regression. | ||
:param tf_names: list of transcription factor names corresponding to the columns of the tf_matrix used to train the model. | ||
:param target_gene_name: the name of the target gene to infer the regulatory links for. | ||
:param target_gene_expression: the expression profile of the target gene. Numpy array. | ||
:param seed: (optional) random seed for the regressors. | ||
:return: a Pandas DataFrame['TF', 'target', 'importance'] representing inferred regulatory links and their | ||
connection strength. | ||
""" | ||
|
||
(clean_tf_matrix, clean_tf_names) = clean(tf_matrix, tf_names, target_gene_name) | ||
|
||
model = train_model(regressor_type, regressor_kwargs, clean_tf_matrix, target_gene_expression, seed) | ||
|
||
return to_links_df(regressor_type, model, clean_tf_names, target_gene_name) | ||
|
||
|
||
def __target_gene_indices(gene_names, | ||
target_genes): | ||
""" | ||
:param gene_names: list of gene names. | ||
:param target_genes: either int (the top n), 'all', or a collection (subset of gene_names). | ||
:return: the (column) indices of the target genes in the expression_matrix. | ||
""" | ||
|
||
if target_genes.upper() == 'ALL': | ||
return list(range(len(gene_names))) | ||
|
||
elif isinstance(target_genes, int): | ||
top_n = target_genes # rename for clarity | ||
assert top_n > 0 | ||
assert top_n <= len(gene_names) | ||
|
||
return list(range(top_n)) | ||
|
||
elif isinstance(target_genes, list): | ||
return [index for index, gene in enumerate(gene_names) if gene in target_genes] | ||
|
||
else: | ||
raise ValueError("Unable to interpret target_genes: " + target_genes) | ||
|
||
|
||
def create_graph(expression_matrix, | ||
gene_names, | ||
tf_names, | ||
regressor_type, | ||
regressor_kwargs, | ||
target_genes='all', | ||
limit=100000, | ||
seed=THE_DEMON_SEED): | ||
""" | ||
Main API function. Create a Dask computation graph. | ||
:param expression_matrix: numpy matrix. Rows are observations and columns are genes. | ||
:param gene_names: list of gene names. Each entry corresponds to the expression_matrix column with same index. | ||
:param tf_names: list of transcription factor names. Should have a non-empty intersection with gene_names. | ||
:param regressor_type: one of ['RF', 'ET', 'GBM', 'XGB']. Case insensitive. | ||
:param regressor_kwargs: dict of key-value pairs that configures the regressor. | ||
:param target_genes: either int, 'all' or a collection that is a subset of gene_names. | ||
:param limit: int or None. Default 100k. The number of top regulatory links to return. | ||
:param seed: (optional) random seed for the regressors. | ||
:return: a dask computation graph instance. | ||
""" | ||
|
||
tf_matrix = to_tf_matrix(expression_matrix, gene_names, tf_names) | ||
|
||
delayed_tf_matrix = delayed(tf_matrix) | ||
delayed_tf_names = delayed(tf_names) | ||
|
||
delayed_link_dfs = [] # collection of delayed link DataFrames | ||
|
||
for target_gene_index in __target_gene_indices(gene_names, target_genes): | ||
|
||
target_gene_name = gene_names[target_gene_index] | ||
target_gene_expression = expression_matrix[:, target_gene_index] | ||
|
||
delayed_link_df = delayed(infer_links)( | ||
regressor_type, regressor_kwargs, | ||
delayed_tf_matrix, delayed_tf_names, | ||
target_gene_name, target_gene_expression, | ||
seed) | ||
|
||
delayed_link_dfs.append(delayed_link_df) | ||
|
||
# provide the schema of the delayed DataFrames | ||
link_df_meta = make_meta({'TF': str, 'target': str, 'importance': float}) | ||
|
||
# gather the regulatory link DataFrames into one distributed DataFrame | ||
all_links_df = from_delayed(delayed_link_dfs, meta=link_df_meta) | ||
|
||
# optionally limit the number of resulting regulatory links | ||
result = all_links_df.nlargest(limit, columns=['importance']) if isinstance(limit, int) else all_links_df | ||
|
||
return result['TF', 'target', 'importance'] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
Utility functions for reading DREAM5 data. | ||
""" | ||
|
||
import numpy as np | ||
|
||
|
||
def load_expression_matrix(path, delimiter='\t'): | ||
""" | ||
:param path: the path of the dream challenge expression data file. | ||
:param delimiter: the delimiter used in the file. | ||
:return: a numpy matrix. | ||
""" | ||
|
||
return np.genfromtxt(path, delimiter=delimiter, skip_header=1) | ||
|
||
|
||
def load_gene_names(path, delimiter='\t'): | ||
""" | ||
:param path: the path of the dream challenge expression data file. | ||
:param delimiter: the delimiter used in the file. | ||
:return: a list of gene names. | ||
""" | ||
|
||
with open(path) as file: | ||
gene_names = [gene.strip() for gene in file.readline().split(delimiter)] | ||
|
||
return gene_names | ||
|
||
|
||
def load_tf_names(path, gene_names): | ||
""" | ||
:param path: the path of the transcription factor list file. | ||
:param gene_names: the list of gene names in the expression data. | ||
:return: a list of transcription factor names read from the file, intersected with the gene_names list. | ||
""" | ||
|
||
with open(path) as file: | ||
tfs_in_file = [line.strip() for line in file.readlines()] | ||
|
||
return [tf for tf in tfs_in_file if tf in gene_names] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# GENIE3 results on DREAM5, run with dreamtools | ||
|
||
``` | ||
Overall Score 4.027898e+01 | ||
AUPR (pvalue) 4.129532e+01 | ||
AUROC (pvalue) 3.926265e+01 | ||
Net1 AUPR 2.909863e-01 | ||
Net3 AUPR 9.302545e-02 | ||
Net4 AUPR 2.065091e-02 | ||
Net1 AUROC 8.148355e-01 | ||
Net3 AUROC 6.170165e-01 | ||
Net4 AUROC 5.176411e-01 | ||
Net1 p-aupr 1.596591e-104 | ||
Net3 p-aupr 5.149347e-20 | ||
Net4 p-aupr 1.581591e-01 | ||
Net1 p-auroc 3.060302e-106 | ||
Net3 p-auroc 5.003620e-11 | ||
Net4 p-auroc 1.064168e-02 | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# DREAM5 : scikit-learn extra-trees (ET)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.