Skip to content

Commit

Permalink
Improved hungarian conflict resolution algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
luxaritas committed Feb 24, 2025
1 parent b84d1e8 commit 3e8d74f
Showing 1 changed file with 70 additions and 13 deletions.
83 changes: 70 additions & 13 deletions src/arnie/pk_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,86 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior=
bpp = _sigmoid(bpp, slope_factor=sigmoid_slope_factor)

# should think about order of above functions and possibly normalize again here

# run hungarian algorithm to find base pairs
_, row_pairs = linear_sum_assignment(-bpp)
# Hungarian/linear sum assignment operates on a bipartite graph such that each row is assigned to
# exactly one column and each column is assigned to exactly one row, however our case is not
# bipartite. That means some chosen assignments could conflict with others, either creating
# a "chain" (eg [(0,5), (5,10)]) or cycle (eg [(0,5), (5,10), (10, 0)]). We resolve these
# conflicts by solving for the maximum weight independent set. (Note that if we have
# two assignments like [(0,5) and (5,0)] we only need to deduplicate, hence the usage of set).
bp_assignments = set(
tuple(sorted((col, row)))
for col, row in enumerate(row_pairs)
if bpp_orig[col, row] > theta and col != row
)
bp_list = []
conf = {}
for col, row in enumerate(row_pairs):
# if bpp_orig[col, row] != bpp[col, row]:
# print(col, row, bpp_orig[col, row], bpp[col, row])
if bpp_orig[col, row] > theta and col < row:
p = max(conf.get(col,0), conf.get(row,0))
if p != 0:
raise ValueError('conflicting pairs')
else:
conf[col] = 1
conf[row] = 1
bp_list.append([col, row])
while len(bp_assignments):
bps = [bp_assignments.pop()]

# # Start building a chain to the "left"
check_nt = bps[0][0]
while conflict := next((bp for bp in bp_assignments if check_nt in bp), None):
bps.insert(0, conflict)
bp_assignments.remove(conflict)
check_nt = next((nt for nt in conflict if nt != check_nt), None)
# And to the "right"
check_nt = bps[-1][1]
while conflict := next((bp for bp in bp_assignments if check_nt in bp), None):
bps.append(conflict)
bp_assignments.remove(conflict)
check_nt = next((nt for nt in conflict if nt != check_nt), None)
if len(bps) == 1:
bp_list.extend(bps)
elif len(bps) > 2 and bps[0][0] in bps[-1] or bps[0][1] in bps[-1]:
# We have a cycle. We need to try both excluding the first element and excluding
# the last element (only one or the other, or neither, can be present since they conflict)
(bp_list_a,prob_a) = _max_weight_independent_set(bps[1:], bpp_orig)
(bp_list_b,prob_b) = _max_weight_independent_set(bps[:-1], bpp_orig)
if prob_a > prob_b:
bp_list.extend(bp_list_a)
else:
bp_list.extend(bp_list_b)
else:
(bp_list_,_) = _max_weight_independent_set(bps, bpp_orig)
bp_list.extend(bp_list_)

bp_list = [list(bp) for bp in bp_list]
bp_list = _check_bp_list(bp_list)
structure = convert_bp_list_to_dotbracket(bp_list, bpp.shape[0])
structure = post_process_struct(structure, allowed_buldge_len, min_len_helix)
bp_list = convert_dotbracket_to_bp_list(structure, allow_pseudoknots=True)

return structure, bp_list

def _max_weight_independent_set(pairs, probs):
max_sets = []
for bp in pairs:
bp_prob = probs[bp[0], bp[1]]

if len(max_sets) == 0:
max_sets.append({'prob': bp_prob, 'bps': [bp]})
elif len(max_sets) == 1:
if max_sets[0]['prob'] > bp_prob:
max_sets.append(max_sets[0])
elif bp_prob > max_sets[0]['prob']:
max_sets.append({'prob': bp_prob, 'bps': [bp]})
elif abs(max_sets[0]['bps'][0][0] - max_sets[0]['bps'][0][1]) <= abs(bp[0] - bp[1]):
max_sets.append(max_sets[0])
else:
max_sets.append({'prob': bp_prob, 'bps': [bp]})
else:
if max_sets[-1]['prob'] > max_sets[-2]['prob'] + bp_prob:
max_sets.append(max_sets[-1])
elif max_sets[-2]['prob'] + bp_prob > max_sets[-1]['prob']:
max_sets.append({'prob': max_sets[-2]['prob'] + bp_prob, 'bps': [*max_sets[-2]['bps'], bp]})
elif abs(max_sets[-1]['bps'][0][0] - max_sets[-1]['bps'][0][1]) <= abs(bp[0] - bp[1]):
max_sets.append(max_sets[-1])
else:
max_sets.append({'prob': max_sets[-2]['prob'] + bp_prob, 'bps': [*max_sets[-2]['bps'], bp]})

return (max_sets[-1]['bps'], max_sets[-1]['prob'])

def _sigmoid(x, slope_factor=0.5):
# normalize to [-1, 1]
Expand Down

0 comments on commit 3e8d74f

Please sign in to comment.