Skip to content

Commit

Permalink
Merge pull request #455 from MannLabs/fix-update-of-first-classifier
Browse files Browse the repository at this point in the history
Fix updating of the `TwoStepClassifier.first_classifier` and add logging
  • Loading branch information
anna-charlotte authored Jan 30, 2025
2 parents b0af1aa + fd23254 commit 56a6d42
Showing 1 changed file with 85 additions and 50 deletions.
135 changes: 85 additions & 50 deletions alphadia/fdrx/models/two_step_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Implements the Two Step Classifier for use within the Alphadia framework."""

import copy
import logging

import numpy as np
Expand Down Expand Up @@ -53,6 +54,12 @@ def __init__( # noqa: PLR0913 Too many arguments in function definition (> 5)
self._max_iterations = max_iterations
self._train_on_top_n = train_on_top_n

logger.info(
f"Initialized TwoStepClassifier with "
f"first_classifier: {first_classifier.__class__.__name__}, "
f"second_classifier: {second_classifier.__class__.__name__}"
)

def fit_predict(
self,
df: pd.DataFrame,
Expand Down Expand Up @@ -83,52 +90,90 @@ def fit_predict(
DataFrame containing predictions and q-values
"""
logger.info("=== Starting training of TwoStepClassifier ===")

df = self._preprocess_data(df, x_cols)
best_result = None
best_precursor_count = -1

# tracking precursors identified at fdr cutoffs `self.first_fdr_cutoff` and `self.second_fdr_cutoff``
previous_target_count_after_first_clf = -1
previous_target_count_after_second_clf = -1

for i in range(self._max_iterations):
logger.info(f"Starting iteration {i + 1} / {self._max_iterations}.")

# extract preselction using first classifier if it is fitted
if self.first_classifier.fitted and i > 0:
df_train, df_predict = self._apply_filtering_with_first_classifier(
df_train = self._apply_filtering_with_first_classifier(
df, x_cols, group_columns
)
df_predict = df_train # using the same df for training and predicting, unlike in the following else block.
logger.info(
f"Application of first classifier at fdr={self.first_fdr_cutoff} results in "
f"{len(df_train):,} samples ({get_target_count(df_train):,} precursors)"
)

previous_target_count_after_first_clf = get_target_count(df_train)
self.second_classifier.epochs = 50
else:
logger.debug("First classifier not fitted yet. Proceeding without it.")
df_train = df[df["rank"] < self._train_on_top_n]
df_predict = df

self.second_classifier.epochs = 10

predictions = self._train_and_apply_second_classifier(
# train and apply second classifier
df_after_second_clf = self._train_and_apply_second_classifier(
df_train, df_predict, x_cols, y_col, group_columns
)

# Filter results and check for improvement
df_filtered = filter_by_qval(predictions, self.second_fdr_cutoff)
current_target_count = len(df_filtered[df_filtered["decoy"] == 0])
df_filtered = filter_by_qval(df_after_second_clf, self.second_fdr_cutoff)
current_target_count = get_target_count(df_filtered)

if current_target_count < best_precursor_count:
if current_target_count < previous_target_count_after_second_clf:
logger.info(
f"Stopping training after iteration {i}, "
f"due to decreased target count ({current_target_count} < {best_precursor_count})"
f"Training stopped on iteration {i + 1}. Decrease in precursor count from "
f"{previous_target_count_after_second_clf:,} to {current_target_count:,}."
)
return best_result

best_precursor_count = current_target_count
best_result = predictions
previous_target_count_after_second_clf = current_target_count
best_result = df_after_second_clf # TODO: Remove if multiple iterations are dropped to save memory.

logger.info(
f"Application of second classifier at fdr={self.second_fdr_cutoff} results in "
f"{get_target_count(df_train):,} precursors."
)

# Update first classifier if enough confident predictions
# update first classifier if enough confident predictions
if current_target_count > self._min_precursors_for_update:
self._update_first_classifier(
df_filtered, df, x_cols, y_col, group_columns
target_count_after_first_clf, new_classifier = (
self._fit_and_eval_first_classifier(
df_filtered, df, x_cols, y_col, group_columns
)
)
if target_count_after_first_clf > previous_target_count_after_first_clf:
logger.debug(
f"Update of first classifier initiated: previous version had {previous_target_count_after_first_clf:,} "
f"precursors, current version has {target_count_after_first_clf:,} precursors."
)
self.first_classifier = new_classifier
previous_target_count_after_first_clf = target_count_after_first_clf

else:
logger.debug(
f"Update of first classifier skipped: previous version had {previous_target_count_after_first_clf:,} "
f"precursors, current version has {target_count_after_first_clf:,} precursors."
)
else:
logger.info(
f"Stopping fitting after {i+1} / {self._max_iterations} iterations due to insufficient detected precursors to update the first classifier."
f"=== Insufficient precursors detected; ending after {i + 1} iterations ==="
)
break
else:
logger.info(
f"Stopping fitting after reaching the maximum number of iterations: {self._max_iterations} / {self._max_iterations}."
f"=== Stopping fitting after reaching the maximum number of iterations: "
f"{self._max_iterations} / {self._max_iterations} ==="
)

return best_result
Expand All @@ -140,16 +185,14 @@ def _preprocess_data(self, df: pd.DataFrame, x_cols: list[str]) -> pd.DataFrame:

def _apply_filtering_with_first_classifier(
self, df: pd.DataFrame, x_cols: list[str], group_columns: list[str]
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> pd.DataFrame:
"""Apply first classifier to filter data for the training of the second classifier."""
df["proba"] = self.first_classifier.predict_proba(df[x_cols].to_numpy())[:, 1]

filtered_df = compute_and_filter_q_values(
return compute_and_filter_q_values(
df, self.first_fdr_cutoff, group_columns, remove_decoys=False
)

return filtered_df, filtered_df

def _train_and_apply_second_classifier(
self,
train_df: pd.DataFrame,
Expand All @@ -169,50 +212,37 @@ def _train_and_apply_second_classifier(

return compute_q_values(predict_df, group_columns)

def _update_first_classifier(
def _fit_and_eval_first_classifier(
self,
subset_df: pd.DataFrame,
full_df: pd.DataFrame,
x_cols: list[str],
y_col: str,
group_columns: list[str],
) -> None:
"""Update first classifier by finding and using target/decoy pairs.
) -> tuple[int, Classifier]:
"""Fits a copy of the first classifier on a given subset and applies it to the full dataset.
First extracts the corresponding target/decoy partners from the full dataset
for each entry in the subset, then uses these pairs to retrain the classifier.
Returns the number of targets found and the trained classifier.
"""
df = get_target_decoy_partners(subset_df, full_df)
df_train = get_target_decoy_partners(subset_df, full_df)
x_train = df_train[x_cols].to_numpy()
y_train = df_train[y_col].to_numpy()

x = df[x_cols].to_numpy()
y = df[y_col].to_numpy()

previous_n_precursors = -1

if self.first_classifier.fitted:
df["proba"] = self.first_classifier.predict_proba(x)[:, 1]
df_targets = compute_and_filter_q_values(
df, self.first_fdr_cutoff, group_columns
)
previous_n_precursors = len(df_targets)
previous_state_dict = self.first_classifier.to_state_dict()
x_all = full_df[x_cols].to_numpy()
reduced_df = full_df[[*group_columns, "decoy"]]

self.first_classifier.fit(x, y)
logger.info(f"Fitting first classifier on {len(df_train):,} samples.")
new_classifier = copy.deepcopy(self.first_classifier)
new_classifier.fit(x_train, y_train)

df["proba"] = self.first_classifier.predict_proba(x)[:, 1]
logger.info(f"Applying first classifier to {len(x_all):,} samples.")
reduced_df["proba"] = new_classifier.predict_proba(x_all)[:, 1]
df_targets = compute_and_filter_q_values(
df, self.first_fdr_cutoff, group_columns
reduced_df, self.first_fdr_cutoff, group_columns
)
current_n_precursors = len(df_targets)
n_targets = get_target_count(df_targets)

if previous_n_precursors > current_n_precursors:
logger.info(
f"Reverted the first classifier back to the previous version "
f"(prev: {previous_n_precursors}, curr: {current_n_precursors})"
)
self.first_classifier.from_state_dict(previous_state_dict)
else:
logger.info("Fitted the second classifier")
return n_targets, new_classifier

@property
def fitted(self) -> bool:
Expand Down Expand Up @@ -252,6 +282,11 @@ def from_state_dict(self, state_dict: dict) -> None:
self._train_on_top_n = state_dict["train_on_top_n"]


def get_target_count(df: pd.DataFrame) -> int:
"""Counts the number of target (non-decoy) entries in a DataFrame."""
return len(df[(df["decoy"] == 0)])


def compute_q_values(
df: pd.DataFrame, group_columns: list[str] | None = None
) -> pd.DataFrame:
Expand Down

0 comments on commit 56a6d42

Please sign in to comment.