Skip to content

Commit 34e75c1

Browse files
author
hiruna534
committed
initial move table refine
1 parent 822f852 commit 34e75c1

File tree

4 files changed

+340
-440
lines changed

4 files changed

+340
-440
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
bokeh==3.1.1
2-
numpy
2+
numpy<2
33
pyslow5
44
pyfaidx
55
pyfastx

src/move_refine.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
from src import readpaf_local
2+
import argparse
3+
import numpy as np
4+
import os
5+
6+
fout2 = open("delete", "w")
7+
8+
9+
def get_closest_signal_indices(Q, R):
10+
signal_dict = {}
11+
12+
n = len(Q)
13+
14+
for i in range(n):
15+
left_bound = Q[i - 1] if i > 0 else float('-inf') # Previous move or negative infinity
16+
right_bound = Q[i + 1] if i < n - 1 else float('inf') # Next move or positive infinity
17+
18+
# Collect indices of relevant R[j] values for the current Q[i]
19+
relevant_indices = []
20+
# for j, r in enumerate(R):
21+
# if left_bound <= r <= right_bound:
22+
# relevant_indices.append(j)
23+
24+
# If no relevant indices are found, add the nearest left and right segmentations
25+
if not relevant_indices:
26+
# Find nearest left segmentation
27+
left_nearest = max((j for j, r in enumerate(R) if r < Q[i]), default=None)
28+
# Find nearest right segmentation
29+
right_nearest = min((j for j, r in enumerate(R) if r > Q[i]), default=None)
30+
31+
# Find overlap segmentation
32+
center = min((j for j, r in enumerate(R) if r == Q[i]), default=None)
33+
34+
if left_nearest is not None:
35+
relevant_indices.append(left_nearest)
36+
if right_nearest is not None:
37+
relevant_indices.append(right_nearest)
38+
if center is not None:
39+
relevant_indices.append(center)
40+
41+
# Add to the dictionary
42+
signal_dict[i] = relevant_indices
43+
print(signal_dict)
44+
return signal_dict
45+
46+
def print_dp_matrix(dp_matrix,i):
47+
"""
48+
Prints a 2D matrix in a readable format.
49+
50+
Args:
51+
dp_matrix (list of lists): The 2D DP matrix to print.
52+
"""
53+
for row in dp_matrix[i:i+1]:
54+
print(" ".join(f"{val:6}" for val in row))
55+
56+
def refine_moves(Q, R):
57+
58+
q_i_dict = get_closest_signal_indices(Q=Q, R=R)
59+
# print(q_i_dict)
60+
N, M = len(Q), len(R)
61+
62+
print("N:{} M:{}".format(N,M))
63+
64+
# Initialize DP matrix with infinity
65+
D = np.full((N, M), float('inf'))
66+
D[0][0] = 0 # start of the signal
67+
print("101")
68+
69+
# Fill DP matrix
70+
for i in range(1, N):
71+
# for j in range(1, M+1):
72+
j_indices = q_i_dict[i]
73+
# print(i)
74+
# print("{}:{}".format(i,j_indices))
75+
for j in j_indices:
76+
if j == 0:
77+
continue
78+
cost = abs(Q[i] - R[j]) # Cost function: absolute difference
79+
D[i][j] = cost + min(
80+
D[i-1][j], # Vertical step
81+
D[i][j-1], # Horizontal step
82+
D[i-1][j-1] # Diagonal step
83+
)
84+
fout2.write("D[{}][{}] = {}\n".format(i,j,D[i][j]))
85+
print("102")
86+
# print_dp_matrix(D,N+1)
87+
# print(min(D[N-1]))
88+
# D[N][M] = min(D[N-1])
89+
90+
# Backtracking to find the optimal alignment
91+
i, j = N-1, M-1
92+
alignment = []
93+
alignment.append((i,j,Q[i], R[j])) # Append the aligned pair
94+
prev_j = M
95+
while i > 0 and j > 0:
96+
if i == j:
97+
i -= 1
98+
j -= 1
99+
alignment.append((i,j,Q[i], R[j])) # Append the aligned pair
100+
prev_j = j
101+
elif D[i][j] == D[i-1][j-1] + abs(Q[i] - R[j]):
102+
i -= 1
103+
j -= 1
104+
alignment.append((i,j,Q[i], R[j])) # Append the aligned pair
105+
prev_j = j
106+
elif D[i][j] == D[i-1][j] + abs(Q[i] - R[j]) and prev_j != j:
107+
i -= 1
108+
alignment.append((i,j,Q[i], R[j])) # Append the aligned pair
109+
prev_j = j
110+
elif D[i][j] == D[i-1][j] + abs(Q[i] - R[j]) and prev_j == j:
111+
print("here???")
112+
i -= 1
113+
j -= 1
114+
alignment.append((i,j,Q[i], R[j])) # Append the aligned pair
115+
prev_j = j
116+
else:
117+
j -= 1
118+
alignment.reverse() # Reverse the path to get the correct order
119+
120+
print(len(alignment))
121+
print(alignment)
122+
123+
return [j for _,_,_, j in alignment]
124+
125+
def get_refined_moves_from_alignment(alignment, Q, first_occurrence=True):
126+
# Dictionary to store the first or last R[j] match for each Q[i]
127+
matches = {}
128+
129+
# Iterate over the alignment and update the dictionary based on the flag
130+
for q, r in alignment:
131+
if first_occurrence:
132+
# Store the first occurrence of R[j] for each Q[i]
133+
if q not in matches:
134+
matches[q] = r
135+
else:
136+
# Store the last occurrence of R[j] for each Q[i]
137+
matches[q] = r # This will overwrite with the last match
138+
139+
# Construct the refined moves array using the selected matches
140+
refined_moves = [matches[q] for q in Q]
141+
142+
return refined_moves
143+
144+
145+
def calculate_similarity(Q, refined_moves):
146+
if len(Q) != len(refined_moves):
147+
raise ValueError("The lengths of Q and refined_moves must be the same!")
148+
149+
# Calculate the mean absolute difference
150+
differences = [abs(q - r) for q, r in zip(Q, refined_moves)]
151+
mad = sum(differences) / len(differences)
152+
153+
return mad
154+
155+
def make_ss_string(moves, start_index):
156+
ss_string = "ss:Z:"
157+
prev_index = start_index
158+
for move in moves[1:]:
159+
ss_string += "{},".format(move-prev_index)
160+
prev_index = move
161+
return ss_string
162+
163+
def write_to_file(fout, paf_record, ss_string, end_index):
164+
fout.write("{}\t".format(paf_record.query_name))
165+
fout.write("{}\t".format(end_index-paf_record.query_start))
166+
fout.write("{}\t".format(paf_record.query_start))
167+
fout.write("{}\t".format(end_index))
168+
fout.write("{}\t".format(paf_record.strand))
169+
fout.write("{}\t".format(paf_record.target_name))
170+
fout.write("{}\t".format(paf_record.target_length))
171+
fout.write("{}\t".format(paf_record.target_start))
172+
fout.write("{}\t".format(paf_record.target_end))
173+
fout.write("{}\t".format(paf_record.residue_matches))
174+
fout.write("{}\t".format(paf_record.alignment_block_length))
175+
fout.write("{}\t".format(paf_record.mapping_quality))
176+
fout.write("{}\t".format(ss_string))
177+
178+
def refine_moves_greedy(Q,R):
179+
alignment = []
180+
j = 0 # Pointer for R
181+
182+
for i, q in enumerate(Q):
183+
# Advance j to find the two R[j] values surrounding Q[i]
184+
while j < len(R) - 1 and R[j+1] <= q:
185+
j += 1
186+
187+
# Determine the closest R[j] or R[j+1] to Q[i]
188+
if j < len(R) - 1:
189+
left_dist = abs(q - R[j])
190+
right_dist = abs(q - R[j + 1])
191+
if right_dist < left_dist:
192+
j += 1
193+
194+
# Align Q[i] to the closest R[j]
195+
alignment.append((i, j, Q[i], R[j]))
196+
j += 1 # Move to the next R[j] to ensure no reuse
197+
198+
print(len(alignment))
199+
print(alignment)
200+
201+
return [j for _,_,_, j in alignment]
202+
203+
204+
def run(args):
205+
fout = open(args.output, "w")
206+
reads = []
207+
sigann_dict = {}
208+
if args.sigann:
209+
print(f'Signal point annotation file: {args.sigann}')
210+
with open(args.sigann, 'r') as file:
211+
for line in file:
212+
line = line.strip()
213+
if line: # Skip empty lines
214+
read_id, array_str = line.split(' ', 1) # Split only on the first space
215+
array = [int(value) for value in array_str.split(',') if value] # Parse the array
216+
sigann_dict[read_id] = array
217+
218+
moves_dict = {}
219+
with open(args.alignment, "r") as handle:
220+
for paf_record in readpaf_local.parse_paf(handle):
221+
read_id = paf_record.query_name
222+
start_index = paf_record.query_start
223+
moves_string = paf_record.tags['ss'][2]
224+
moves = [int(value) for value in moves_string.split(',') if value]
225+
moves_array = []
226+
prev_bound = start_index
227+
for move in moves:
228+
prev_bound = prev_bound + move
229+
moves_array.append(prev_bound)
230+
moves_dict[read_id] = moves_array
231+
reads.append(read_id)
232+
233+
print(read_id)
234+
moves = moves_dict[read_id]
235+
sigann = sigann_dict[read_id]
236+
moves = [0] + moves
237+
sigann = [0] + sigann
238+
if moves[-1] != sigann[-1]:
239+
sigann[-1] = moves[-1]
240+
# print(moves)
241+
# print(sigann)
242+
print("len sigann: {} len moves: {}".format(len(sigann), len(moves)))
243+
244+
refined_moves = refine_moves_greedy(Q=moves, R=sigann)
245+
246+
# refined_moves = refine_moves(Q=moves, R=sigann)
247+
248+
print(calculate_similarity(Q=moves, refined_moves=refined_moves))
249+
250+
ss_string = make_ss_string(refined_moves, start_index)
251+
252+
write_to_file(fout, paf_record, ss_string=ss_string, end_index=refined_moves[-1])
253+
254+
fout.close()
255+
fout2.close()
256+
257+
258+
def argparser():
259+
parser = argparse.ArgumentParser(
260+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
261+
add_help=False
262+
)
263+
264+
parser.add_argument('-r', '--read_id', required=False, type=str, default="", help="plot the read with read_id")
265+
parser.add_argument('-l', '--read_list', required=False, type=str, default="", help="a file with read_ids to plot")
266+
parser.add_argument('-a', '--alignment', required=True, type=str, default="", help="read-signal alignment in PAF")
267+
parser.add_argument('--sigann', required=False, help="file with signal point annotations (0-based)")
268+
parser.add_argument('-o', '--output', required=True, type=str, default="", help="output file")
269+
return parser
270+
271+
272+
if __name__ == "__main__":
273+
parser = argparser()
274+
args = parser.parse_args()
275+
try:
276+
run(args)
277+
except Exception as e:
278+
print(str(e))
279+
exit(1)
280+
281+

0 commit comments

Comments
 (0)