Skip to content

Commit

Permalink
Merge pull request #83 from MannLabs/quant-window
Browse files Browse the repository at this point in the history
FEAT improved fixed center quant, speed improvments
  • Loading branch information
GeorgWa authored Jan 24, 2024
2 parents b3ff48d + fc9963c commit 7a885d4
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 74 deletions.
20 changes: 11 additions & 9 deletions alphadia/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,20 +217,22 @@ def parse_raw_path_list(args: argparse.Namespace, config: dict) -> list:
raw_path_list += [os.path.join(directory, f) for f in os.listdir(directory)]

# filter files based on regex
#pattern = re.compile()
print('args.regex', args.regex)
# pattern = re.compile()

print("args.regex", args.regex)

for path in raw_path_list:
print('path', path)
print('os.path.basename(path)', os.path.basename(path))
#print('re.search(args.regex, os.path.basename(path))', re.search(args.regex, os.path.basename(path)))
#print('re.search(args.regex, os.path.basename(path)) is not None', re.search(args.regex, os.path.basename(path)) is not None)
#print('')
print("path", path)
print("os.path.basename(path)", os.path.basename(path))
# print('re.search(args.regex, os.path.basename(path))', re.search(args.regex, os.path.basename(path)))
# print('re.search(args.regex, os.path.basename(path)) is not None', re.search(args.regex, os.path.basename(path)) is not None)
# print('')

len_before = len(raw_path_list)
raw_path_list = [
f for f in raw_path_list if re.search(args.regex, os.path.basename(f)) is not None
f
for f in raw_path_list
if re.search(args.regex, os.path.basename(f)) is not None
]
len_after = len(raw_path_list)
print(f"Removed {len_before - len_after} of {len_before} files.")
Expand Down
98 changes: 98 additions & 0 deletions alphadia/data/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,101 @@ def get_dense(
idx += 1

return dense_output, precursor_idx_list

def get_dense_intensity(
self,
frame_limits,
scan_limits,
mz_query_list,
mass_tolerance,
quadrupole_mz,
absolute_masses=False,
custom_cycle=None,
):
"""
Get a dense representation of the data for a given set of parameters.
Parameters
----------
frame_limits : np.ndarray, shape = (1,2,)
array of frame indices
scan_limits : np.ndarray, shape = (1,2,)
array of scan indices
mz_query_list : np.ndarray, shape = (n_tof_slices,)
array of query m/z values
mass_tolerance : float
mass tolerance in ppm
quadrupole_mz : np.ndarray, shape = (1,2,)
array of quadrupole m/z values
absolute_masses : bool, default = False
if True, the first slice of the dense output will contain the absolute m/z values instead of the mass error
custom_cycle : np.ndarray, shape = (1, n_precursor, 1, 2), default = None
custom cycle quadrupole mask, for example after calibration
Returns
-------
np.ndarray, shape = (1, n_tof_slices, n_precursor_indices, 2, n_precursor_cycles)
"""

# (n_tof_slices, 2) array of start, stop mz for each slice
mz_query_slices = utils.mass_range(mz_query_list, mass_tolerance)
n_tof_slices = len(mz_query_slices)

cycle_length = self.cycle.shape[1]

# (n_precursors) array of precursor indices, the precursor index refers to each scan within the cycle
precursor_idx_list = calculate_valid_scans(quadrupole_mz, self.cycle)
n_precursor_indices = len(precursor_idx_list)

precursor_cycle_start = frame_limits[0, 0] // cycle_length
precursor_cycle_stop = frame_limits[0, 1] // cycle_length
precursor_cycle_len = precursor_cycle_stop - precursor_cycle_start

dense_output = np.zeros(
(1, n_tof_slices, n_precursor_indices, 2, precursor_cycle_len),
dtype=np.float32,
)

for i, cycle_idx in enumerate(
range(precursor_cycle_start, precursor_cycle_stop)
):
for j, precursor_idx in enumerate(precursor_idx_list):
scan_idx = precursor_idx + cycle_idx * cycle_length

peak_start_idx = self.peak_start_idx_list[scan_idx]
peak_stop_idx = self.peak_stop_idx_list[scan_idx]

idx = peak_start_idx

for k, (mz_query_start, mz_query_stop) in enumerate(mz_query_slices):
rel_idx = np.searchsorted(
self.mz_values[idx:peak_stop_idx], mz_query_start, "left"
)

idx += rel_idx

while idx < peak_stop_idx and self.mz_values[idx] <= mz_query_stop:
accumulated_intensity = dense_output[0, k, j, 0, i]
# accumulated_dim1 = dense_output[1, k, j, 0, i]

new_intensity = self.intensity_values[idx]

dense_output[0, k, j, 0, i] = (
accumulated_intensity + new_intensity
)
dense_output[0, k, j, 1, i] = (
accumulated_intensity + new_intensity
)

idx += 1

return dense_output, precursor_idx_list
63 changes: 48 additions & 15 deletions alphadia/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,13 @@ def location_features(
@nb.njit()
def fragment_features(
dense_fragments: np.ndarray,
fragments_frame_profile: np.ndarray,
frame_rt: np.ndarray,
observation_importance: np.ndarray,
template: np.ndarray,
fragments: np.ndarray,
feature_array: nb_float32_array,
quant_window: nb.uint32 = 3,
):
fragment_feature_dict = nb.typed.Dict.empty(
key_type=nb.types.unicode_type, value_type=float_array
Expand Down Expand Up @@ -706,26 +709,53 @@ def fragment_features(
n_fragments, -1
)

# most intense observation across all observations
best_observation = np.argmax(observation_importance)
# (n_fragments, n_frames)
best_profile = fragments_frame_profile[:, best_observation]

# center the profile around the expected frame center
center = best_profile.shape[1] // 2
# (n_fragments, quant_window * 2 + 1)
best_profile = best_profile[:, center - quant_window : center + quant_window + 1]

# (quant_window * 2 + 1)
frame_rt_quant = frame_rt[center - quant_window : center + quant_window + 1]
quant_durarion = frame_rt_quant[-1] - frame_rt_quant[0]

# (quant_window * 2)
delta_rt = frame_rt_quant[1:] - frame_rt_quant[:-1]

# (n_fragments)
fragment_area = np.sum(
(best_profile[:, 1:] + best_profile[:, :-1]) * delta_rt.reshape(1, -1) * 0.5,
axis=-1,
)
fragment_area_norm = fragment_area / quant_durarion

observed_fragment_intensity = np.sum(best_profile, axis=-1)

# create fragment masks for filtering
fragment_profiles = np.sum(dense_fragments[0], axis=-1)
# (n_fragments, n_observations)
sum_fragment_intensity = np.sum(np.sum(dense_fragments[0], axis=-1), axis=-1)
sum_fragment_intensity = np.sum(fragment_profiles, axis=-1)

# create fragment intensity mask
fragment_intensity_mask_2d = sum_fragment_intensity > 0
fragment_intensity_weights_2d = (
fragment_intensity_mask_2d * observation_importance_reshaped
)
# fragment_intensity_mask_2d = sum_fragment_intensity > 0
# fragment_intensity_weights_2d = (
# fragment_intensity_mask_2d * observation_importance_reshaped
# )

# (n_fragments, n_observations)
# normalize rows to 1
fragment_intensity_weights_2d = fragment_intensity_weights_2d / (
np.sum(fragment_intensity_weights_2d, axis=-1).reshape(-1, 1) + 1e-20
)
# fragment_intensity_weights_2d = fragment_intensity_weights_2d / (
# np.sum(fragment_intensity_weights_2d, axis=-1).reshape(-1, 1) + 1e-20
# )

# (n_fragments)
observed_fragment_intensity = weighted_mean_a1(
sum_fragment_intensity, fragment_intensity_weights_2d
)
# observed_fragment_intensity = weighted_mean_a1(
# sum_fragment_intensity, fragment_intensity_weights_2d
# )

# (n_observations)
sum_template_intensity = np.sum(np.sum(template, axis=-1), axis=-1)
Expand Down Expand Up @@ -769,9 +799,9 @@ def fragment_features(
)

if np.sum(fragment_height_mask_1d) > 0.0:
feature_array[18] = np.corrcoef(
observed_fragment_intensity, fragment_intensity_norm
)[0, 1]
feature_array[18] = np.corrcoef(fragment_area_norm, fragment_intensity_norm)[
0, 1
]

if np.sum(observed_fragment_height) > 0.0:
feature_array[19] = np.corrcoef(
Expand Down Expand Up @@ -823,7 +853,7 @@ def fragment_features(
observed_fragment_mz_mean,
mass_error,
observed_fragment_height,
observed_fragment_intensity,
fragment_area_norm,
)


Expand Down Expand Up @@ -897,6 +927,9 @@ def profile_features(
feature_array,
):
n_observations = len(observation_importance)
# most intense observation across all observations
best_observation = np.argmax(observation_importance)

fragment_idx_sorted = np.argsort(fragment_intensity)[::-1]

# ============= FRAGMENT RT CORRELATIONS =============
Expand Down
55 changes: 38 additions & 17 deletions alphadia/hybridselection.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,14 +764,24 @@ def process(self, jit_data, fragment_container, config, kernel, debug):
self.set_status(101, "No precursor masses after grouping")
return

# shape = (2, n_fragments, n_observations, n_scans, n_frames), dtype = np.float32
_dense_precursors, _ = jit_data.get_dense(
frame_limits,
scan_limits,
precursor_mz,
config.precursor_mz_tolerance,
np.array([[-1.0, -1.0]], dtype=np.float32),
)
if jit_data.has_mobility:
# shape = (2, n_fragments, n_observations, n_scans, n_frames), dtype = np.float32
_dense_precursors, _ = jit_data.get_dense(
frame_limits,
scan_limits,
precursor_mz,
config.precursor_mz_tolerance,
np.array([[-1.0, -1.0]], dtype=np.float32),
)
else:
# shape = (2, n_fragments, n_observations, n_scans, n_frames), dtype = np.float32
_dense_precursors, _ = jit_data.get_dense_intensity(
frame_limits,
scan_limits,
precursor_mz,
config.precursor_mz_tolerance,
np.array([[-1.0, -1.0]], dtype=np.float32),
)
dense_precursors = _dense_precursors.sum(axis=2)

# FLAG: needed for debugging
Expand All @@ -784,15 +794,26 @@ def process(self, jit_data, fragment_container, config, kernel, debug):
self.set_status(102, "Unexpected quadrupole_mz.shape")
return

# shape = (2, n_fragments, n_observations, n_scans, n_frames), dtype = np.float32
_dense_fragments, _ = jit_data.get_dense(
frame_limits,
scan_limits,
fragment_mz,
config.fragment_mz_tolerance,
quadrupole_mz,
custom_cycle=jit_data.cycle,
)
if jit_data.has_mobility:
# shape = (2, n_fragments, n_observations, n_scans, n_frames), dtype = np.float32
_dense_fragments, _ = jit_data.get_dense(
frame_limits,
scan_limits,
fragment_mz,
config.fragment_mz_tolerance,
quadrupole_mz,
custom_cycle=jit_data.cycle,
)
else:
# shape = (2, n_fragments, n_observations, n_scans, n_frames), dtype = np.float32
_dense_fragments, _ = jit_data.get_dense_intensity(
frame_limits,
scan_limits,
fragment_mz,
config.fragment_mz_tolerance,
quadrupole_mz,
custom_cycle=jit_data.cycle,
)
dense_fragments = _dense_fragments.sum(axis=2)

# FLAG: needed for debugging
Expand Down
2 changes: 1 addition & 1 deletion alphadia/outputtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def lfq(
return protein_df


def prepare_df(df, psm_df, column="height"):
def prepare_df(df, psm_df, column="intensity"):
df = df[df["precursor_idx"].isin(psm_df["precursor_idx"])].copy()
df["ion"] = utils.ion_hash(
df["precursor_idx"].values,
Expand Down
22 changes: 22 additions & 0 deletions alphadia/plexscoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class CandidateConfigJIT:
top_k_fragments: nb.uint32
top_k_isotopes: nb.uint32
reference_channel: nb.int16
quant_window: nb.uint32

precursor_mz_tolerance: nb.float32
fragment_mz_tolerance: nb.float32
Expand All @@ -179,6 +180,7 @@ def __init__(
top_k_fragments: nb.uint32,
top_k_isotopes: nb.uint32,
reference_channel: nb.int16,
quant_window: nb.uint32,
precursor_mz_tolerance: nb.float32,
fragment_mz_tolerance: nb.float32,
) -> None:
Expand All @@ -194,6 +196,7 @@ def __init__(
self.top_k_fragments = top_k_fragments
self.top_k_isotopes = top_k_isotopes
self.reference_channel = reference_channel
self.quant_window = quant_window

self.precursor_mz_tolerance = precursor_mz_tolerance
self.fragment_mz_tolerance = fragment_mz_tolerance
Expand All @@ -213,6 +216,7 @@ def __init__(self):
self.top_k_fragments = 12
self.top_k_isotopes = 4
self.reference_channel = -1
self.quant_window = 3
self.precursor_mz_tolerance = 15
self.fragment_mz_tolerance = 15

Expand Down Expand Up @@ -283,6 +287,17 @@ def reference_channel(self) -> int:
def reference_channel(self, value):
self._reference_channel = value

@property
def quant_window(self) -> int:
"""The quantification window size in cycles.
the area will be calculated from `scan_center - quant_window` to `scan_center + quant_window`.
Default: `quant_window = 3`"""
return self._quant_window

@quant_window.setter
def quant_window(self, value):
self._quant_window = value

@property
def precursor_mz_tolerance(self) -> float:
"""The precursor m/z tolerance in ppm.
Expand Down Expand Up @@ -626,6 +641,10 @@ def process(
features.frame_profile_2d(dense_fragments[0])
)

cycle_len = jit_data.cycle.shape[1]

frame_rt = jit_data.rt_values[self.frame_start : self.frame_stop : cycle_len]

# (n_observations, n_frames)
template_frame_profile = features.or_envelope_1d(
features.frame_profile_1d(template)
Expand Down Expand Up @@ -680,10 +699,13 @@ def process(
# (n_valid_fragments)
mz_observed, mass_error, height, intensity = features.fragment_features(
dense_fragments,
fragments_frame_profile,
frame_rt,
observation_importance,
template,
fragments,
feature_array,
quant_window=config.quant_window,
)

# store fragment features if requested
Expand Down
Loading

0 comments on commit 7a885d4

Please sign in to comment.