Skip to content

Commit

Permalink
Add stability of edges as feature selection
Browse files Browse the repository at this point in the history
  • Loading branch information
NilsWinter committed Dec 9, 2024
1 parent 547da24 commit e3d4fa3
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 65 deletions.
95 changes: 47 additions & 48 deletions cpm/cpm_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,52 @@
import pandas as pd
from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit, KFold

from cpm.logging import setup_logging
from cpm.models import LinearCPMModel
from cpm.edge_selection import UnivariateEdgeSelection, PThreshold
from cpm.utils import (
score_regression_models, regression_metrics,
train_test_split, vector_to_upper_triangular_matrix
train_test_split, vector_to_upper_triangular_matrix, check_data
)
from cpm.fold import compute_inner_fold
from cpm.models import NetworkDict, ModelDict


def setup_logging(log_file: str = "analysis_log.txt"):
# Console handler: logs all levels (DEBUG and above) to the console
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(SimpleFormatter())

# File handler: logs only INFO level logs to the file
file_handler = logging.FileHandler(log_file, mode='w')
file_handler.setLevel(logging.INFO)
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
file_handler.setFormatter(SimpleFormatter())

# Create a logger and set the base level to DEBUG so both handlers can operate independently
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) # This ensures all messages are passed to handlers
logger.addHandler(console_handler)
logger.addHandler(file_handler)


class SimpleFormatter(logging.Formatter):
def format(self, record):
log_fmt = "%(message)s"
formatter = logging.Formatter(log_fmt)
return formatter.format(record)


class CPMRegression:
"""
This class handles the process of performing CPM Regression with cross-validation and permutation testing.
"""
def __init__(self,
results_directory: str,
cv: Union[BaseCrossValidator, BaseShuffleSplit] = KFold(n_splits=10, shuffle=True, random_state=42),
cv_edge_selection: Union[BaseCrossValidator, BaseShuffleSplit] = None,
inner_cv: Union[BaseCrossValidator, BaseShuffleSplit] = None,
edge_selection: UnivariateEdgeSelection = UnivariateEdgeSelection(
edge_statistic=['pearson'],
edge_selection=[PThreshold(threshold=[0.05], correction=[None])]
),
add_edge_filter: bool = True,
select_stable_edges: bool = True,
stability_threshold: float = 0.8,
impute_missing_values: bool = True,
n_permutations: int = 0,
atlas_labels: str = None):
"""
Initialize the CPMRegression object.
:param results_directory: Directory to save results.
:param cv: Outer cross-validation strategy.
:param cv_edge_selection: Inner cross-validation strategy for edge selection.
:param inner_cv: Inner cross-validation strategy for edge selection.
:param edge_selection: Method for edge selection.
:param add_edge_filter: Whether to add an edge filter.
:param impute_missing_values: Whether to impute missing values.
:param n_permutations: Number of permutations to run for permutation testing.
:param atlas_labels: CSV file containing atlas and regions labels.
"""
self.results_directory = results_directory
self.cv = cv
self.inner_cv = cv_edge_selection
self.inner_cv = inner_cv
self.edge_selection = edge_selection
self.add_edge_filter = add_edge_filter
self.select_stable_edges = select_stable_edges
self.stability_threshold = stability_threshold
self.impute_missing_values = impute_missing_values
self.n_permutations = n_permutations
self.atlas_labels = atlas_labels

Expand All @@ -102,7 +81,9 @@ def _log_analysis_details(self):
self.logger.info(f"Outer CV strategy: {self.cv}")
self.logger.info(f"Inner CV strategy: {self.inner_cv}")
self.logger.info(f"Edge selection method: {self.edge_selection}")
self.logger.info(f"Add Edge Filter: {'Yes' if self.add_edge_filter else 'No'}")
self.logger.info(f"Select stable edges: {'Yes' if self.select_stable_edges else 'No'}")
self.logger.info(f"Stability threshold: {self.stability_threshold}")
self.logger.info(f"Impute Missing Values: {'Yes' if self.impute_missing_values else 'No'}")
self.logger.info(f"Number of Permutations: {self.n_permutations}")
self.logger.info("="*50)

Expand All @@ -122,7 +103,9 @@ def save_configuration(self, config_filename: str):
'cv': self.cv,
'inner_cv': self.inner_cv,
'edge_selection': self.edge_selection,
'add_edge_filter': self.add_edge_filter,
'select_stable_edges': self.select_stable_edges,
'stability_threshold': self.stability_threshold,
'impute_missing_values': self.impute_missing_values,
'n_permutations': self.n_permutations
}
with open(config_path, 'wb') as file:
Expand All @@ -143,7 +126,9 @@ def load_configuration(self, results_directory: str, config_filename: str):
self.cv = loaded_config['cv']
self.inner_cv = loaded_config['inner_cv']
self.edge_selection = loaded_config['edge_selection']
self.add_edge_filter = loaded_config['add_edge_filter']
self.select_stable_edges = loaded_config['select_stable_edges']
self.stability_threshold = loaded_config['stability_threshold']
self.impute_missing_values = loaded_config['impute_missing_values']
self.n_permutations = loaded_config['n_permutations']
self.logger.info(f"Configuration loaded from {config_filename}")
self.logger.info(f"Results directory set to: {self.results_directory}")
Expand All @@ -161,13 +146,19 @@ def estimate(self,
"""
self.logger.info(f"Starting estimation with {self.n_permutations} permutations.")

# check data and convert to numpy
X, y, covariates = check_data(X, y, covariates, impute_missings=self.impute_missing_values)

# check missing data
# ToDo

# Estimate models on actual data
self._estimate(X=X, y=np.squeeze(y), covariates=covariates, perm_run=0)
self._estimate(X=X, y=y, covariates=covariates, perm_run=0)
self.logger.info("=" * 50)

# Estimate models on permuted data
for perm_id in range(1, self.n_permutations + 1):
self._estimate(X=X, y=np.squeeze(y), covariates=covariates, perm_run=perm_id)
self._estimate(X=X, y=y, covariates=covariates, perm_run=perm_id)

self._calculate_permutation_results()
self.logger.info("Estimation completed.")
Expand Down Expand Up @@ -198,22 +189,28 @@ def _estimate(self,
if not perm_run:
self.logger.debug(f"Running fold {outer_fold + 1}/{self.cv.get_n_splits()}")

train_test_data = train_test_split(train, test, X, y, covariates)
X_train, X_test, y_train, y_test, cov_train, cov_test = train_test_data
X_train, X_test, y_train, y_test, cov_train, cov_test = train_test_split(train, test, X, y, covariates)

if self.inner_cv:
best_params = self._run_inner_folds(X_train, y_train, cov_train, outer_fold, perm_run)
best_params, stability_edges = self._run_inner_folds(X_train, y_train, cov_train, outer_fold, perm_run)
if not perm_run:
self.logger.info(f"Best hyperparameters: {best_params}")
else:
if len(self.edge_selection.param_grid) > 1:
raise RuntimeError("Multiple hyperparameter configurations but no inner cv defined. "
"Please provide only one hyperparameter configuration or an inner cv.")
if self.select_stable_edges:
raise RuntimeError("Stable edges can only be selected when using an inner cv.")
best_params = self.edge_selection.param_grid[0]

# Use best parameters to estimate performance on outer fold test set
self.edge_selection.set_params(**best_params)
edges = self.edge_selection.fit_transform(X=X_train, y=y_train, covariates=cov_train)
if self.select_stable_edges:
edges = {'positive': np.where(stability_edges['positive'] > self.stability_threshold)[0],
'negative': np.where(stability_edges['negative'] > self.stability_threshold)[0]}
else:
self.edge_selection.set_params(**best_params)
edges = self.edge_selection.fit_transform(X=X_train, y=y_train, covariates=cov_train)

cv_edges['positive'][outer_fold, edges['positive']] = 1
cv_edges['negative'][outer_fold, edges['negative']] = 1

Expand All @@ -233,7 +230,7 @@ def _estimate(self,
self._calculate_edge_stability(cv_edges, current_results_directory)

if not perm_run:
self.logger.info(agg_results.to_string())
self.logger.info(agg_results.round(4).to_string())
self._save_predictions(cv_predictions, current_results_directory)

def _run_inner_folds(self, X, y, covariates, fold, perm_run):
Expand All @@ -250,10 +247,11 @@ def _run_inner_folds(self, X, y, covariates, fold, perm_run):
fold_dir = os.path.join(self.results_directory, "folds", f'{fold}')
os.makedirs(fold_dir, exist_ok=True)
inner_cv_results = []

stable_edges = []
for param_id, param in enumerate(self.edge_selection.param_grid):
inner_cv_results.append(
compute_inner_fold(X, y, covariates, self.inner_cv, self.edge_selection, param, param_id))
res, edges = compute_inner_fold(X, y, covariates, self.inner_cv, self.edge_selection, param, param_id)
inner_cv_results.append(res)
stable_edges.append(edges)

inner_cv_results = pd.concat(inner_cv_results)
inner_cv_results = self._calculate_model_increments(cv_results=inner_cv_results, metrics=regression_metrics)
Expand All @@ -266,7 +264,8 @@ def _run_inner_folds(self, X, y, covariates, fold, perm_run):

best_params_ids = agg_results['mean_absolute_error'].groupby(['network', 'model'])['mean'].idxmin()
best_params = inner_cv_results.loc[(0, best_params_ids.loc[('both', 'full')][1], 'both', 'full'), 'params']
return best_params
stable_edges_best_param = stable_edges[best_params_ids.loc[('both', 'full')][1]]
return best_params, stable_edges_best_param

def _calculate_final_cv_results(self, cv_results: pd.DataFrame, results_directory: str):
"""
Expand Down
17 changes: 9 additions & 8 deletions cpm/edge_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ def __init__(self, k: Union[int, list] = None):
class UnivariateEdgeSelection(BaseEstimator):
def __init__(self,
edge_statistic: Union[str, list] = 'pearson',
edge_selection: list = None):
edge_selection: list = None,
t_test_filter: bool = True):
self.edge_statistic = edge_statistic
self.edge_selection = edge_selection
self.t_test_filter = t_test_filter
self.param_grid = self._generate_config_grid()

def _generate_config_grid(self):
Expand All @@ -178,19 +180,18 @@ def _generate_config_grid(self):
return ParameterGrid(grid_elements)

def fit_transform(self, X, y=None, covariates=None):
t_test_filter = False
if t_test_filter:
if self.t_test_filter:
_, p_values = one_sample_t_test(X, 0)
valid_edges = p_values < 0.05
else:
valid_edges = np.bool(np.ones(X.shape[1]))

#r_edges, p_edges = np.zeros(X.shape[1]), np.ones(X.shape[1])
#r_edges_masked, p_edges_masked = self.compute_edge_statistics(X=X[:, valid_edges], y=y, covariates=covariates)
#r_edges[valid_edges] = r_edges_masked
#p_edges[valid_edges] = p_edges_masked
r_edges, p_edges = np.zeros(X.shape[1]), np.ones(X.shape[1])
r_edges_masked, p_edges_masked = self.compute_edge_statistics(X=X[:, valid_edges], y=y, covariates=covariates)
r_edges[valid_edges] = r_edges_masked
p_edges[valid_edges] = p_edges_masked

r_edges, p_edges = self.compute_edge_statistics(X=X, y=y, covariates=covariates)
#r_edges, p_edges = self.compute_edge_statistics(X=X, y=y, covariates=covariates)

edges = self.edge_selection.select(r=r_edges, p=p_edges)
return edges
Expand Down
17 changes: 13 additions & 4 deletions cpm/fold.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pandas as pd

from cpm.models import LinearCPMModel, NetworkDict, ModelDict
Expand All @@ -6,17 +7,25 @@

def compute_inner_fold(X, y, covariates, cv, edge_selection, param, param_id):
cv_results = pd.DataFrame()
n_folds = cv.get_n_splits()
n_features = X.shape[1]
cv_edges = {'positive': np.zeros((n_folds, n_features)), 'negative': np.zeros((n_folds, n_features))}
edge_selection.set_params(**param)
for fold_id, (nested_train, nested_test) in enumerate(cv.split(X, y)):
(X_train, X_test, y_train,
y_test, cov_train, cov_test) = train_test_split(nested_train, nested_test, X, y, covariates)

res = compute_fold(X_train, X_test, y_train, y_test, cov_train, cov_test, edge_selection, param, param_id, fold_id)
res, edges = compute_fold(X_train, X_test, y_train, y_test, cov_train, cov_test, edge_selection, param, param_id, fold_id)
cv_results = pd.concat([cv_results, pd.DataFrame(res)], ignore_index=True)

cv_edges['positive'][fold_id, edges['positive']] = 1
cv_edges['negative'][fold_id, edges['negative']] = 1
cv_results.set_index(['fold', 'param_id', 'network', 'model'], inplace=True)
cv_results.sort_index(inplace=True)
return cv_results

stability_edges = {'positive': np.sum(cv_edges['positive'], axis=0) / cv_edges['positive'].shape[0],
'negative': np.sum(cv_edges['negative'], axis=0) / cv_edges['negative'].shape[0]}

return cv_results, stability_edges


def compute_fold(X_train, X_test, y_train, y_test, cov_train, cov_test, edge_selection, param, param_id, fold_id):
Expand All @@ -34,4 +43,4 @@ def compute_fold(X_train, X_test, y_train, y_test, cov_train, cov_test, edge_sel
res['param_id'] = param_id
res['params'] = [param]
cv_results = pd.concat([cv_results, pd.DataFrame(res, index=[0])], ignore_index=True)
return cv_results
return cv_results, edges
33 changes: 33 additions & 0 deletions cpm/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging


def setup_logging(log_file: str = "analysis_log.txt"):
# Get the root logger
logger = logging.getLogger()

# Check if handlers already exist and remove them to avoid duplication
if logger.hasHandlers():
logger.handlers.clear()

# Console handler: logs all levels (DEBUG and above) to the console
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(SimpleFormatter())

# File handler: logs only INFO level logs to the file
file_handler = logging.FileHandler(log_file, mode='w')
file_handler.setLevel(logging.INFO)
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
file_handler.setFormatter(SimpleFormatter())

# Create a logger and set the base level to DEBUG so both handlers can operate independently
logger.setLevel(logging.DEBUG) # This ensures all messages are passed to handlers
logger.addHandler(console_handler)
logger.addHandler(file_handler)


class SimpleFormatter(logging.Formatter):
def format(self, record):
log_fmt = "%(message)s"
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
8 changes: 4 additions & 4 deletions cpm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def fit(self, X, y, covariates):
self.models_residuals[network] = LinearRegression().fit(covariates, connectome[network])
residuals[network] = connectome[network] - self.models_residuals[network].predict(covariates)

residuals['both'] = residuals['positive'] - residuals['negative']
connectome['both'] = connectome['positive'] - connectome['negative']
residuals['both'] = np.hstack((residuals['positive'], residuals['negative']))
connectome['both'] = np.hstack((connectome['positive'], connectome['negative']))

for network in NetworkDict().keys():
self.models['connectome'][network] = LinearRegression().fit(connectome[network], y)
Expand All @@ -56,8 +56,8 @@ def predict(self, X, covariates):
connectome[network] = np.sum(X[:, self.edges[network]], axis=1).reshape(-1, 1)
residuals[network] = connectome[network] - self.models_residuals[network].predict(covariates)

residuals['both'] = residuals['positive'] - residuals['negative']
connectome['both'] = connectome['positive'] - connectome['negative']
residuals['both'] = np.hstack((residuals['positive'], residuals['negative']))
connectome['both'] = np.hstack((connectome['positive'], connectome['negative']))

predictions = ModelDict()
for network in ['positive', 'negative', 'both']:
Expand Down
25 changes: 24 additions & 1 deletion cpm/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import numpy as np

from sklearn.metrics import (mean_squared_error, mean_absolute_error, explained_variance_score)
from sklearn.utils import check_X_y
from scipy.stats import pearsonr, spearmanr
from scipy.stats import ConstantInputWarning, NearConstantInputWarning
import matplotlib.pyplot as plt
import warnings

import logging


logger = logging.getLogger(__name__)

warnings.filterwarnings("ignore", category=ConstantInputWarning)
warnings.filterwarnings("ignore", category=NearConstantInputWarning)

Expand Down Expand Up @@ -107,4 +113,21 @@ def get_colors_from_colormap(n_colors, colormap_name='tab10'):
"""
cmap = plt.get_cmap(colormap_name)
colors = [cmap(i / (n_colors - 1)) for i in range(n_colors)]
return colors
return colors


def check_data(X, y, covariates, impute_missings: bool = False):
logger.info("Checking data...")
if impute_missings:
try:
X, y = check_X_y(X, y, force_all_finite='allow-nan', allow_nd=True, y_numeric=True)
except ValueError as e:
logger.info("y contains NaN values. Only missing values in X and covariates can be imputed.")
raise e
else:
try:
X, y = check_X_y(X, y, force_all_finite=True, allow_nd=True, y_numeric=True)
except ValueError as e:
logger.info("Your input contains NaN values. Fix NaNs or use impute_missing_values=True.")
raise e
return X, y, covariates
Loading

0 comments on commit e3d4fa3

Please sign in to comment.