From b84d1e81afbb1020b9c167e8610e759c07d5410d Mon Sep 17 00:00:00 2001 From: Jonathan Romano Date: Mon, 24 Feb 2025 11:25:21 -0500 Subject: [PATCH 1/3] Remove misplaced row-wise normalization --- src/arnie/pk_predictors.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/arnie/pk_predictors.py b/src/arnie/pk_predictors.py index 3b8f52b..80a971a 100644 --- a/src/arnie/pk_predictors.py +++ b/src/arnie/pk_predictors.py @@ -137,11 +137,7 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior= bpp = _sigmoid(bpp, slope_factor=sigmoid_slope_factor) # should think about order of above functions and possibly normalize again here - # (normalize again we shall...) - if add_p_unpaired: - row_sums = bpp.sum(axis=1) - bpp = bpp / row_sums[:, np.newaxis] - + # run hungarian algorithm to find base pairs _, row_pairs = linear_sum_assignment(-bpp) bp_list = [] From 3e8d74f1fe5d8d4d8480a2453d21a5fb8f1995ec Mon Sep 17 00:00:00 2001 From: Jonathan Romano Date: Mon, 24 Feb 2025 18:13:16 -0500 Subject: [PATCH 2/3] Improved hungarian conflict resolution algorithm --- src/arnie/pk_predictors.py | 83 ++++++++++++++++++++++++++++++++------ 1 file changed, 70 insertions(+), 13 deletions(-) diff --git a/src/arnie/pk_predictors.py b/src/arnie/pk_predictors.py index 80a971a..c4af35a 100644 --- a/src/arnie/pk_predictors.py +++ b/src/arnie/pk_predictors.py @@ -137,29 +137,86 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior= bpp = _sigmoid(bpp, slope_factor=sigmoid_slope_factor) # should think about order of above functions and possibly normalize again here - + # run hungarian algorithm to find base pairs _, row_pairs = linear_sum_assignment(-bpp) + # Hungarian/linear sum assignment operates on a bipartite graph such that each row is assigned to + # exactly one column and each column is assigned to exactly one row, however our case is not + # bipartite. That means some chosen assignments could conflict with others, either creating + # a "chain" (eg [(0,5), (5,10)]) or cycle (eg [(0,5), (5,10), (10, 0)]). We resolve these + # conflicts by solving for the maximum weight independent set. (Note that if we have + # two assignments like [(0,5) and (5,0)] we only need to deduplicate, hence the usage of set). + bp_assignments = set( + tuple(sorted((col, row))) + for col, row in enumerate(row_pairs) + if bpp_orig[col, row] > theta and col != row + ) bp_list = [] - conf = {} - for col, row in enumerate(row_pairs): - # if bpp_orig[col, row] != bpp[col, row]: - # print(col, row, bpp_orig[col, row], bpp[col, row]) - if bpp_orig[col, row] > theta and col < row: - p = max(conf.get(col,0), conf.get(row,0)) - if p != 0: - raise ValueError('conflicting pairs') - else: - conf[col] = 1 - conf[row] = 1 - bp_list.append([col, row]) + while len(bp_assignments): + bps = [bp_assignments.pop()] + + # # Start building a chain to the "left" + check_nt = bps[0][0] + while conflict := next((bp for bp in bp_assignments if check_nt in bp), None): + bps.insert(0, conflict) + bp_assignments.remove(conflict) + check_nt = next((nt for nt in conflict if nt != check_nt), None) + # And to the "right" + check_nt = bps[-1][1] + while conflict := next((bp for bp in bp_assignments if check_nt in bp), None): + bps.append(conflict) + bp_assignments.remove(conflict) + check_nt = next((nt for nt in conflict if nt != check_nt), None) + if len(bps) == 1: + bp_list.extend(bps) + elif len(bps) > 2 and bps[0][0] in bps[-1] or bps[0][1] in bps[-1]: + # We have a cycle. We need to try both excluding the first element and excluding + # the last element (only one or the other, or neither, can be present since they conflict) + (bp_list_a,prob_a) = _max_weight_independent_set(bps[1:], bpp_orig) + (bp_list_b,prob_b) = _max_weight_independent_set(bps[:-1], bpp_orig) + if prob_a > prob_b: + bp_list.extend(bp_list_a) + else: + bp_list.extend(bp_list_b) + else: + (bp_list_,_) = _max_weight_independent_set(bps, bpp_orig) + bp_list.extend(bp_list_) + bp_list = [list(bp) for bp in bp_list] + bp_list = _check_bp_list(bp_list) structure = convert_bp_list_to_dotbracket(bp_list, bpp.shape[0]) structure = post_process_struct(structure, allowed_buldge_len, min_len_helix) bp_list = convert_dotbracket_to_bp_list(structure, allow_pseudoknots=True) return structure, bp_list +def _max_weight_independent_set(pairs, probs): + max_sets = [] + for bp in pairs: + bp_prob = probs[bp[0], bp[1]] + + if len(max_sets) == 0: + max_sets.append({'prob': bp_prob, 'bps': [bp]}) + elif len(max_sets) == 1: + if max_sets[0]['prob'] > bp_prob: + max_sets.append(max_sets[0]) + elif bp_prob > max_sets[0]['prob']: + max_sets.append({'prob': bp_prob, 'bps': [bp]}) + elif abs(max_sets[0]['bps'][0][0] - max_sets[0]['bps'][0][1]) <= abs(bp[0] - bp[1]): + max_sets.append(max_sets[0]) + else: + max_sets.append({'prob': bp_prob, 'bps': [bp]}) + else: + if max_sets[-1]['prob'] > max_sets[-2]['prob'] + bp_prob: + max_sets.append(max_sets[-1]) + elif max_sets[-2]['prob'] + bp_prob > max_sets[-1]['prob']: + max_sets.append({'prob': max_sets[-2]['prob'] + bp_prob, 'bps': [*max_sets[-2]['bps'], bp]}) + elif abs(max_sets[-1]['bps'][0][0] - max_sets[-1]['bps'][0][1]) <= abs(bp[0] - bp[1]): + max_sets.append(max_sets[-1]) + else: + max_sets.append({'prob': max_sets[-2]['prob'] + bp_prob, 'bps': [*max_sets[-2]['bps'], bp]}) + + return (max_sets[-1]['bps'], max_sets[-1]['prob']) def _sigmoid(x, slope_factor=0.5): # normalize to [-1, 1] From 5161628b8f8fba0f3107f9a6060adf743b29ccab Mon Sep 17 00:00:00 2001 From: Jonathan Romano Date: Mon, 24 Feb 2025 18:17:55 -0500 Subject: [PATCH 3/3] Spacing tweak --- src/arnie/pk_predictors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/arnie/pk_predictors.py b/src/arnie/pk_predictors.py index c4af35a..4cd8fb1 100644 --- a/src/arnie/pk_predictors.py +++ b/src/arnie/pk_predictors.py @@ -167,6 +167,7 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior= bps.append(conflict) bp_assignments.remove(conflict) check_nt = next((nt for nt in conflict if nt != check_nt), None) + if len(bps) == 1: bp_list.extend(bps) elif len(bps) > 2 and bps[0][0] in bps[-1] or bps[0][1] in bps[-1]: @@ -215,7 +216,7 @@ def _max_weight_independent_set(pairs, probs): max_sets.append(max_sets[-1]) else: max_sets.append({'prob': max_sets[-2]['prob'] + bp_prob, 'bps': [*max_sets[-2]['bps'], bp]}) - + return (max_sets[-1]['bps'], max_sets[-1]['prob']) def _sigmoid(x, slope_factor=0.5):