forked from framazan/bwtandem
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbwt.py
More file actions
4374 lines (3648 loc) · 178 KB
/
bwt.py
File metadata and controls
4374 lines (3648 loc) · 178 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
"""
Advanced BWT-based Tandem Repeat Finder for Genomics
Implements three-tier approach:
1. Short tandem repeats (1-9bp) with FM-index
2. Medium/long repeats (10-1000bp) with LCP arrays
3. Very long repeats (kb+) with long read evidence
"""
import numpy as np
from typing import List, Tuple, Dict, Iterator, Optional, Set, Union
import argparse
from dataclasses import dataclass
from multiprocessing import Pool, cpu_count
import time
import re
import math
from collections import Counter
def _natural_sort_key(value: str):
"""Return a tuple usable for natural sorting (e.g., chr2 before chr10)."""
if value is None:
return ()
parts = re.split(r'(\d+)', str(value))
key_parts: List[Tuple[int, object]] = []
for part in parts:
if not part:
continue
if part.isdigit():
key_parts.append((0, int(part)))
else:
key_parts.append((1, part.lower()))
return tuple(key_parts)
# Optional: JIT acceleration with numba when available
HAVE_NUMBA = False
try:
import numba as _nb # type: ignore
HAVE_NUMBA = True
except Exception:
_nb = None # type: ignore
if HAVE_NUMBA:
@_nb.njit(cache=True)
def _count_equal_range(arr: np.ndarray, start: int, end: int, code: int) -> int: # type: ignore
c = 0
for i in range(start, end):
if arr[i] == code:
c += 1
return c
@_nb.njit(cache=True)
def _kasai_lcp_uint8(text_codes: np.ndarray, sa: np.ndarray) -> np.ndarray: # type: ignore
n = text_codes.size
lcp = np.zeros(n, dtype=np.int32)
rank = np.zeros(n, dtype=np.int32)
for i in range(n):
rank[sa[i]] = i
h = 0
for i in range(n):
r = rank[i]
if r > 0:
j = sa[r - 1]
while i + h < n and j + h < n and text_codes[i + h] == text_codes[j + h]:
h += 1
lcp[r] = h
if h > 0:
h -= 1
return lcp
else:
def _count_equal_range(arr: np.ndarray, start: int, end: int, code: int) -> int:
# Pure-python/numpy fallback
return int(np.count_nonzero(arr[start:end] == code))
def _kasai_lcp_uint8(text_codes: np.ndarray, sa: np.ndarray) -> np.ndarray:
# Fallback non-jitted Kasai
n = text_codes.size
lcp = np.zeros(n, dtype=np.int32)
rank = np.zeros(n, dtype=np.int32)
for i in range(n):
rank[sa[i]] = i
h = 0
for i in range(n):
r = rank[i]
if r > 0:
j = sa[r - 1]
while i + h < n and j + h < n and text_codes[i + h] == text_codes[j + h]:
h += 1
lcp[r] = h
if h > 0:
h -= 1
return lcp
class BWTCore:
"""Core BWT construction and FM-index operations.
"""
# Base encoding for bit-masking (bcftools-inspired)
BASE_TO_BITS = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 0} # 2 bits per base
BITS_TO_BASE = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
def __init__(self, text: str, sa_sample_rate: int = 32, occ_sample_rate: int = 128):
"""
Initialize BWT with FM-index.
Args:
text: Input text (should end with a single '$' sentinel not present elsewhere)
sa_sample_rate: Sample every nth suffix array position for space efficiency
occ_sample_rate: Occurrence checkpoints every nth position to reduce memory
"""
self.text: str = text
self.n = len(text)
self.sa_sample_rate = sa_sample_rate
self.occ_sample_rate = occ_sample_rate
self.text_arr = np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
# Build k-mer hash for fast lookups (performance optimization)
self._build_kmer_hash()
# Build suffix array and BWT (memory-efficient)
self.suffix_array = self._build_suffix_array()
self.bwt_arr = self._build_bwt_array()
self.alphabet = sorted(set(text))
self.char_to_code = {c: ord(c) for c in self.alphabet}
self.code_to_char = {ord(c): c for c in self.alphabet}
self.char_counts, self.char_totals = self._build_char_counts()
self.char_counts_code = {ord(k): v for k, v in self.char_counts.items()}
self.char_totals_code = {ord(k): v for k, v in self.char_totals.items()}
self.occ_checkpoints = self._build_occurrence_checkpoints()
self.sampled_sa = self._sample_suffix_array()
def _build_kmer_hash(self, k: int = 8):
"""Build hash table for k-mer positions (bcftools-inspired optimization).
Uses bit-masking for fast k-mer encoding (2 bits per base).
"""
self.kmer_hash = {} # hash -> list of positions
if self.n < k:
return
# Encode first k-mer
mask = (1 << (2 * k)) - 1 # k bases × 2 bits
w = 0
valid_count = 0
for i in range(min(k, self.n)):
base = self.text[i].upper()
if base in self.BASE_TO_BITS:
w = ((w << 2) | self.BASE_TO_BITS[base]) & mask
valid_count += 1
if valid_count == k:
if w not in self.kmer_hash:
self.kmer_hash[w] = []
self.kmer_hash[w].append(0)
# Rolling window
for i in range(k, self.n):
base = self.text[i].upper()
if base in self.BASE_TO_BITS:
w = ((w << 2) | self.BASE_TO_BITS[base]) & mask
if w not in self.kmer_hash:
self.kmer_hash[w] = []
self.kmer_hash[w].append(i - k + 1)
def get_kmer_positions(self, kmer: str) -> List[int]:
"""Get positions of k-mer using hash table (O(1) lookup).
Args:
kmer: k-mer sequence (must be valid DNA bases)
Returns:
List of positions where k-mer occurs
"""
if len(kmer) > 8 or not self.kmer_hash:
# Fall back to FM-index for longer k-mers
return self.locate_positions(kmer)
# Encode k-mer to hash
w = 0
for base in kmer.upper():
if base not in self.BASE_TO_BITS:
return []
w = (w << 2) | self.BASE_TO_BITS[base]
return self.kmer_hash.get(w, [])
def clear(self):
"""Release heavy memory structures to let GC reclaim memory."""
# Replace large attributes with minimal stubs
self.text = ""
self.text_arr = np.array([], dtype=np.uint8)
self.bwt_arr = np.array([], dtype=np.uint8)
self.suffix_array = np.array([], dtype=np.int32)
self.sampled_sa = {}
self.occ_checkpoints = {}
self.char_counts = {}
self.char_totals = {}
self.alphabet = []
self.char_to_code = {}
self.code_to_char = {}
self.char_counts_code = {}
self.char_totals_code = {}
def _build_suffix_array(self) -> np.ndarray:
"""Build suffix array, preferring pydivsufsort (C backend) with a NumPy fallback.
Fallback uses prefix-doubling with NumPy lexsort (significantly faster than
Python list.sort + lambdas). Complexity ~O(n log n) sorts.
"""
# Prefer fast C implementation when available
s = self.text
try:
import pydivsufsort # type: ignore
sa_list = pydivsufsort.divsufsort(s)
return np.array(sa_list, dtype=np.int32)
except (ImportError, Exception):
# Silently fall back to NumPy implementation
pass
n = self.n
if n == 0:
return np.array([], dtype=np.int32)
# Initial rank from character codes
codes = self.text_arr.astype(np.int32, copy=False)
# Compress codes to 0..sigma-1 for stability
uniq_codes, inv = np.unique(codes, return_inverse=True)
rank = inv.astype(np.int32, copy=False)
sa = np.arange(n, dtype=np.int32)
k = 1
tmp_rank = np.empty(n, dtype=np.int32)
idx = np.arange(n, dtype=np.int32)
while k < n:
# secondary key is rank[i+k] else -1
ipk = idx + k
# Use safe indexing: clip ipk to valid range, then apply condition
ipk_safe = np.clip(ipk, 0, n - 1)
key2 = np.where(ipk < n, rank[ipk_safe], -1)
# Sort by (rank[i], key2[i]) using lexsort with primary last
sa = np.lexsort((key2, rank))
# Compute new ranks
r_sa = rank[sa]
k2_sa = key2[sa]
# mark changes
change = np.empty(n, dtype=np.int32)
change[0] = 0
change[1:] = (r_sa[1:] != r_sa[:-1]) | (k2_sa[1:] != k2_sa[:-1])
new_rank_ordered = np.cumsum(change, dtype=np.int32)
# remap to original index order
tmp_rank[sa] = new_rank_ordered
rank, tmp_rank = tmp_rank, rank
if rank[sa[-1]] == n - 1:
break
k <<= 1
return sa.astype(np.int32, copy=False)
def _build_bwt_array(self) -> np.ndarray:
"""Build BWT from suffix array as uint8 NumPy array (ASCII codes)."""
if self.n == 0:
return np.array([], dtype=np.uint8)
sa = self.suffix_array.astype(np.int64, copy=False)
# previous index (sa-1) % n
prev_idx = (sa - 1) % self.n
# Gather from numeric text array
return self.text_arr[prev_idx]
def _build_char_counts(self) -> Tuple[Dict[str, int], Dict[str, int]]:
"""Count character frequencies and compute cumulative counts C[char]."""
totals: Dict[str, int] = {c: 0 for c in self.alphabet}
for ch in self.text:
totals[ch] += 1
counts: Dict[str, int] = {}
cumulative = 0
for char in self.alphabet:
counts[char] = cumulative
cumulative += totals[char]
return counts, totals
def _build_occurrence_checkpoints(self) -> Dict[int, np.ndarray]:
"""Build checkpointed occurrence counts for efficient rank queries with low memory.
Returns a mapping from ASCII code -> np.ndarray of counts at positions m*k
(prefix length), with cp[0] = 0. If the last block is partial, a final
checkpoint with the total count at n is appended to mirror previous behavior.
"""
bwt = self.bwt_arr
n = bwt.size
k = int(self.occ_sample_rate)
if n == 0:
return {}
checkpoints: Dict[int, np.ndarray] = {}
# Precompute full cumsum once per distinct code as we have small alphabets
distinct_codes = np.unique(bwt)
# indices where boundaries end (1-based length m*k corresponds to index m*k-1)
block_ends = np.arange(k - 1, n, k, dtype=np.int64)
for code in distinct_codes.tolist():
mask = (bwt == code)
csum = np.cumsum(mask, dtype=np.int32)
# cp[0]=0, then take counts at each block end
cp_list = [0]
if block_ends.size:
cp_list.extend(csum[block_ends].tolist())
# Optionally append final count for partial block remainder
if n % k != 0:
cp_list.append(int(csum[-1]))
checkpoints[int(code)] = np.asarray(cp_list, dtype=np.int32)
# Ensure every alphabet character has a checkpoint array (even if absent)
for c in self.alphabet:
code = ord(c)
if code not in checkpoints:
# Build an all-zeros checkpoint array of same length as others
# Determine representative length from any existing array
any_cp = next(iter(checkpoints.values())) if checkpoints else np.array([0], dtype=np.int32)
checkpoints[code] = np.zeros_like(any_cp)
return checkpoints
def _sample_suffix_array(self) -> Dict[int, int]:
"""Sample suffix array positions for space-efficient locating."""
sampled = {}
for i in range(0, self.n, self.sa_sample_rate):
sampled[i] = self.suffix_array[i]
return sampled
def rank(self, char: Union[str, int], pos: int) -> int:
"""Count occurrences of `char` in bwt[0:pos]. Vectorized with checkpoints.
Args:
char: character (str) or ASCII code (int) to count
pos: count occurrences in bwt[0:pos] (pos can be 0..n)
"""
if pos <= 0:
return 0
if pos > self.n:
pos = self.n
code = ord(char) if isinstance(char, str) else int(char)
cp = self.occ_checkpoints.get(code)
if cp is None:
return 0
k = int(self.occ_sample_rate)
cp_idx = pos // k
cp_pos = cp_idx * k
base = int(cp[cp_idx])
# Fast remainder scan (Numba-accelerated if available)
if pos > cp_pos:
base += int(_count_equal_range(self.bwt_arr, cp_pos, pos, code))
return base
def backward_search(self, pattern: str) -> Tuple[int, int]:
"""
Find suffix array interval for pattern using backward search.
Returns:
(start, end) interval in suffix array, or (-1, -1) if not found
"""
if not pattern:
return (0, self.n - 1)
# Initialize with character range
char = pattern[-1]
if char not in self.char_counts:
return (-1, -1)
# sp inclusive, ep inclusive
sp = self.char_counts[char]
ep = sp + self.char_totals[char] - 1
# Process pattern right to left
for i in range(len(pattern) - 2, -1, -1):
char = pattern[i]
if char not in self.char_counts:
return (-1, -1)
sp = self.char_counts[char] + self.rank(char, sp)
ep = self.char_counts[char] + self.rank(char, ep + 1) - 1
if sp > ep:
return (-1, -1)
return (sp, ep)
def count_occurrences(self, pattern: str) -> int:
"""Count pattern occurrences in text."""
sp, ep = self.backward_search(pattern)
if sp == -1:
return 0
return ep - sp + 1
def locate_positions(self, pattern: str) -> List[int]:
"""
Locate all positions of pattern in text.
Uses sampled suffix array for efficiency.
"""
sp, ep = self.backward_search(pattern)
if sp == -1:
return []
# Directly read positions from the suffix array (much faster than LF walking)
positions = self.suffix_array[sp:ep + 1].tolist()
positions.sort()
return positions
def _get_suffix_position(self, sa_index: int) -> int:
"""Recover original text position from SA index using sampling."""
if sa_index in self.sampled_sa:
return self.sampled_sa[sa_index]
# Walk using LF mapping until we hit a sampled position
steps = 0
current_idx = sa_index
while current_idx not in self.sampled_sa:
code = int(self.bwt_arr[current_idx])
current_idx = self.char_counts_code[code] + self.rank(code, current_idx)
steps += 1
return (self.sampled_sa[current_idx] + steps) % self.n
@dataclass
class TandemRepeat:
"""Represents a tandem repeat finding."""
chrom: str
start: int
end: int
motif: str
copies: float
length: int
tier: int
confidence: float = 1.0
consensus_motif: Optional[str] = None # Consensus motif from all copies
mismatch_rate: float = 0.0 # Overall mismatch rate across all copies
max_mismatches_per_copy: int = 0 # Maximum mismatches in any single copy
n_copies_evaluated: int = 0 # Number of copies used in consensus
strand: str = "+" # Strand information
# TRF-compatible fields
percent_matches: float = 0.0 # Percent matches (100 - mismatch_rate*100)
percent_indels: float = 0.0 # Percent indels (we use 0 for Hamming-based)
score: int = 0 # Alignment score (calculated from matches/mismatches)
composition: Optional[Dict[str, float]] = None # A, C, G, T percentages
entropy: float = 0.0 # Shannon entropy (0-2 bits)
actual_sequence: Optional[str] = None # The actual repeat sequence from genome
variations: Optional[List[str]] = None # Per-copy variation annotations
def to_bed(self) -> str:
"""Convert to BED format."""
cons = self.consensus_motif or self.motif
return f"{self.chrom}\t{self.start}\t{self.end}\t{cons}\t{self.copies:.1f}\t{self.tier}\t{self.mismatch_rate:.3f}\t{self.strand}"
def to_vcf_info(self) -> str:
"""Convert to VCF INFO field."""
cons = self.consensus_motif or self.motif
info_parts = [
f"MOTIF={self.motif}",
f"CONS_MOTIF={cons}",
f"COPIES={self.copies:.1f}",
f"TIER={self.tier}",
f"CONF={self.confidence:.2f}",
f"MM_RATE={self.mismatch_rate:.3f}",
f"MAX_MM_PER_COPY={self.max_mismatches_per_copy}",
f"N_COPIES_EVAL={self.n_copies_evaluated}",
f"STRAND={self.strand}"
]
return ";".join(info_parts)
def to_trf_table(self) -> str:
"""Convert to TRF table format (tab-delimited).
Format: Indices Period CopyNumber ConsensusSize PercentMatches PercentIndels Score A C G T Entropy
"""
cons = self.consensus_motif or self.motif
period = len(cons)
consensus_size = len(cons)
# Get composition
comp = self.composition or {'A': 25.0, 'C': 25.0, 'G': 25.0, 'T': 25.0}
indices = f"{self.start}--{self.end}"
return (f"{indices}\t{period}\t{self.copies:.1f}\t{consensus_size}\t"
f"{self.percent_matches:.0f}\t{self.percent_indels:.0f}\t{self.score}\t"
f"{comp['A']:.0f}\t{comp['C']:.0f}\t{comp['G']:.0f}\t{comp['T']:.0f}\t"
f"{self.entropy:.2f}")
def to_trf_dat(self) -> str:
"""Convert to TRF DAT format (space-delimited, includes consensus and sequence).
Format: Start End Period CopyNumber ConsensusSize PercentMatches PercentIndels Score
A C G T Entropy ConsensusPattern Sequence
"""
cons = self.consensus_motif or self.motif
period = len(cons)
consensus_size = len(cons)
# Get composition
comp = self.composition or {'A': 25.0, 'C': 25.0, 'G': 25.0, 'T': 25.0}
# Get actual sequence (or use consensus repeated)
sequence = self.actual_sequence or (cons * int(self.copies))
return (f"{self.start} {self.end} {period} {self.copies:.1f} {consensus_size} "
f"{self.percent_matches:.0f} {self.percent_indels:.0f} {self.score} "
f"{comp['A']:.0f} {comp['C']:.0f} {comp['G']:.0f} {comp['T']:.0f} "
f"{self.entropy:.2f} {cons} {sequence}")
def to_strfinder(self, marker_name: Optional[str] = None,
flanking_left: str = "", flanking_right: str = "") -> str:
"""Convert to STRfinder-compatible CSV format (includes variation summary).
Follows the STRfinder format specification:
STR_marker, STR_position, STR_motif, STR_genotype_structure, STR_genotype,
STR_core_seq, Allele_coverage, Alleles_ratio, Reads_Distribution, STR_depth, Full_seq, Variations
"""
# Check if this is a compound repeat
is_compound = hasattr(self, 'is_compound') and self.is_compound and hasattr(self, 'compound_partner')
if is_compound:
partner = self.compound_partner
cons1 = self.consensus_motif or self.motif
cons2 = partner.consensus_motif or partner.motif
# Compound repeat formatting
marker = marker_name or f"STR_{self.chrom}_{self.start}"
position = f"{self.chrom}:{self.start + 1}-{partner.end}"
str_motif = f"[{cons1}]n+[{cons2}]n"
copies1 = int(round(self.copies))
copies2 = int(round(partner.copies))
motif_len1 = len(cons1)
motif_len2 = len(cons2)
genotype_struct = f"{motif_len1}[{cons1}]{copies1};{motif_len2}[{cons2}]{copies2},0"
genotype = f"{copies1}/{copies2}"
core_seq1 = self.actual_sequence or (cons1 * copies1)
core_seq2 = partner.actual_sequence or (cons2 * copies2)
core_seq = core_seq1 + core_seq2
allele_coverage = "100%"
alleles_ratio = "-"
reads_dist = f"{copies1}:{copies2}"
str_depth = str(copies1 + copies2)
if flanking_left or flanking_right:
full_seq = flanking_left + core_seq + flanking_right
else:
full_seq = core_seq
variation_str = "-"
return (f"{marker}\t{position}\t{str_motif}\t{genotype_struct}\t{genotype}\t"
f"{core_seq}\t{allele_coverage}\t{alleles_ratio}\t{reads_dist}\t"
f"{str_depth}\t{full_seq}\t{variation_str}")
# Regular (non-compound) repeat handling
cons = self.consensus_motif or self.motif
# STR_marker - use provided name or generate from position
marker = marker_name or f"STR_{self.chrom}_{self.start}"
# STR_position - chr:start-end format (1-BASED COORDINATES)
# Convert from 0-based internal to 1-based output
position = f"{self.chrom}:{self.start + 1}-{self.end}"
# STR_motif - [MOTIF]n format
str_motif = f"[{cons}]n"
# STR_genotype_structure - format as motif_length[MOTIF]copies,truncated
# Calculate truncated bases (remainder after complete copies)
motif_len = len(cons)
total_length = self.end - self.start
complete_copies = int(math.floor(self.copies + 1e-6))
complete_length = motif_len * complete_copies
truncated = total_length - complete_length
genotype_struct = f"{motif_len}[{cons}]{complete_copies},{truncated}"
# STR_genotype - repeat number(s)
if abs(self.copies - round(self.copies)) < 1e-6:
genotype = str(int(round(self.copies)))
else:
genotype = f"{self.copies:.2f}".rstrip('0').rstrip('.')
# STR_core_seq - the actual core sequence
# Use actual sequence if available, otherwise reconstruct
if self.actual_sequence:
core_seq_full = self.actual_sequence
else:
core_seq_full = cons * int(self.copies)
# Truncate long sequences with ellipsis notation
if len(core_seq_full) > 150:
# Show first ~70 chars + " ... (xN)" where N is the number of copies
truncate_len = 70
core_seq = f"{core_seq_full[:truncate_len]}... (x{complete_copies})"
else:
core_seq = core_seq_full
# Allele_coverage - percentage (use percent_matches if available, else confidence)
if hasattr(self, 'percent_matches') and self.percent_matches is not None:
allele_coverage = f"{self.percent_matches:.0f}%"
else:
allele_coverage = f"{self.confidence * 100:.0f}%"
# Alleles_ratio - for diploid; use "-" for haploid
alleles_ratio = "-"
# Reads_Distribution - simplified format showing copy numbers
# Format: 7:0,8:150,9:0,10:0,11:200,12:0 (copy_number:read_count)
reads_dist = f"{complete_copies}:{self.n_copies_evaluated}"
# STR_depth - use n_copies_evaluated as proxy
str_depth = str(self.n_copies_evaluated)
# Variation summary (list variants differing from consensus)
variation_str = ";".join(self.variations) if self.variations else "-"
# Full_seq - flanking + CORE + flanking (simple concatenation)
# Use full sequence (not truncated) for Full_seq, but truncate if too long
if flanking_left or flanking_right:
full_seq_complete = flanking_left + core_seq_full + flanking_right
else:
full_seq_complete = core_seq_full
# Truncate Full_seq if extremely long (keep reasonable size)
if len(full_seq_complete) > 500:
full_seq = f"{full_seq_complete[:250]}...{full_seq_complete[-200:]}"
else:
full_seq = full_seq_complete
return (f"{marker}\t{position}\t{str_motif}\t{genotype_struct}\t{genotype}\t"
f"{core_seq}\t{allele_coverage}\t{alleles_ratio}\t{reads_dist}\t"
f"{str_depth}\t{full_seq}\t{variation_str}")
@dataclass
class AlignmentResult:
"""Per-copy alignment outcome against the consensus motif template."""
consumed: int
unit_sequence: str
mismatch_count: int
insertion_length: int
deletion_length: int
operations: List[Tuple] # ('sub', pos, ref, alt) | ('ins', pos, seq) | ('del', pos, length)
observed_bases: List[Tuple[int, str]] # (motif_index, base) observations for consensus tally
edit_distance: int
@property
def error_count(self) -> int:
return self.mismatch_count + self.insertion_length + self.deletion_length
@dataclass
class RepeatAlignmentSummary:
"""Aggregate result for aligning a tandem repeat block."""
consensus: str
motif_len: int
copies: int
consumed_length: int
mismatch_rate: float
max_errors_per_copy: int
variations: List[str]
copy_sequences: List[str]
total_insertions: int
total_deletions: int
error_counts: List[int]
class MotifUtils:
"""Utilities for canonical motif handling."""
@staticmethod
def get_canonical_motif(motif: str) -> str:
"""Get lexicographically smallest rotation of motif."""
if not motif:
return motif
rotations = [motif[i:] + motif[:i] for i in range(len(motif))]
return min(rotations)
@staticmethod
def reverse_complement(seq: str) -> str:
"""Get reverse complement of DNA sequence."""
complement_map = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N': 'N'}
return ''.join(complement_map.get(b, b) for b in reversed(seq))
@staticmethod
def get_canonical_motif_stranded(motif: str) -> Tuple[str, str]:
"""Get canonical motif considering both strands.
Returns:
(canonical_motif, strand) where strand is '+' or '-'
"""
if not motif:
return motif, '+'
# Get all rotations of forward strand
forward_rotations = [motif[i:] + motif[:i] for i in range(len(motif))]
forward_canonical = min(forward_rotations)
# Get all rotations of reverse complement
rc = MotifUtils.reverse_complement(motif)
rc_rotations = [rc[i:] + rc[:i] for i in range(len(rc))]
rc_canonical = min(rc_rotations)
# Return lexicographically smallest
if forward_canonical <= rc_canonical:
return forward_canonical, '+'
else:
return rc_canonical, '-'
@staticmethod
def is_primitive_motif(motif: str) -> bool:
"""Check if motif is not a repetition of a shorter motif."""
n = len(motif)
for i in range(1, n):
if n % i == 0:
period = motif[:i]
if period * (n // i) == motif:
return False
return True
@staticmethod
def calculate_entropy(seq: str) -> float:
"""Calculate Shannon entropy of sequence (bits per base)."""
if not seq:
return 0.0
from collections import Counter
counts = Counter(seq)
n = len(seq)
entropy = 0.0
for count in counts.values():
if count > 0:
p = count / n
entropy -= p * np.log2(p)
return entropy
@staticmethod
def is_transition(base1: str, base2: str) -> bool:
"""Check if a base change is a transition (A↔G or C↔T).
Transitions are more common than transversions in biology.
Purines: A, G (transitions within purines: A↔G)
Pyrimidines: C, T (transitions within pyrimidines: C↔T)
"""
if base1 == base2:
return True # No change
transitions = {
('A', 'G'), ('G', 'A'), # Purine transitions
('C', 'T'), ('T', 'C') # Pyrimidine transitions
}
return (base1, base2) in transitions
@staticmethod
def hamming_distance(s1: str, s2: str) -> int:
"""Calculate Hamming distance between two strings of equal length."""
if len(s1) != len(s2):
return max(len(s1), len(s2))
return sum(c1 != c2 for c1, c2 in zip(s1, s2))
@staticmethod
def hamming_distance_array(arr1: np.ndarray, arr2: np.ndarray) -> int:
"""Calculate Hamming distance between two uint8 arrays."""
if arr1.size != arr2.size:
return max(arr1.size, arr2.size)
return int(np.count_nonzero(arr1 != arr2))
@staticmethod
def count_transversions_array(arr1: np.ndarray, arr2: np.ndarray) -> int:
"""Count transversions (non-transition mismatches) between two uint8 arrays.
Returns number of transversion changes (A↔C, A↔T, G↔C, G↔T).
"""
if arr1.size != arr2.size:
return max(arr1.size, arr2.size)
# ASCII codes: A=65, C=67, G=71, T=84
transversions = 0
for i in range(arr1.size):
b1, b2 = arr1[i], arr2[i]
if b1 != b2:
# Convert to characters for transition check
c1 = chr(b1) if 65 <= b1 <= 84 else 'N'
c2 = chr(b2) if 65 <= b2 <= 84 else 'N'
if not MotifUtils.is_transition(c1, c2):
transversions += 1
return transversions
@staticmethod
def edit_distance(a: str, b: str) -> int:
"""Compute Levenshtein edit distance between two short strings."""
la, lb = len(a), len(b)
if la == 0:
return lb
if lb == 0:
return la
prev = list(range(lb + 1))
curr = [0] * (lb + 1)
for i in range(1, la + 1):
curr[0] = i
ai = a[i - 1]
for j in range(1, lb + 1):
cost = 0 if ai == b[j - 1] else 1
curr[j] = min(
prev[j] + 1, # deletion
curr[j - 1] + 1, # insertion
prev[j - 1] + cost # substitution
)
prev, curr = curr, prev
return prev[lb]
@staticmethod
def _align_unit_to_window(motif: str, window: str, max_indel: int,
mismatch_tolerance: int) -> Optional[AlignmentResult]:
"""Align motif to a window allowing mismatches and small indels."""
m = len(motif)
n = len(window)
if m == 0 or n == 0:
return None
max_indel = max(0, max_indel)
mismatch_tolerance = max(0, mismatch_tolerance)
lower = max(0, m - max_indel)
upper = min(n, m + max_indel)
if lower > upper:
return None
inf = m + n + 10
dp = [[inf] * (n + 1) for _ in range(m + 1)]
ptr = [[''] * (n + 1) for _ in range(m + 1)]
dp[0][0] = 0
for j in range(1, n + 1):
dp[0][j] = j
ptr[0][j] = 'I'
for i in range(1, m + 1):
dp[i][0] = i
ptr[i][0] = 'D'
band_extra = max_indel + 2
for i in range(1, m + 1):
j_min = max(1, i - band_extra)
j_max = min(n, i + band_extra)
for j in range(j_min, j_max + 1):
sub_cost = dp[i - 1][j - 1] + (motif[i - 1] != window[j - 1])
del_cost = dp[i - 1][j] + 1
ins_cost = dp[i][j - 1] + 1
best_cost = sub_cost
best_ptr = 'M' if motif[i - 1] == window[j - 1] else 'S'
if del_cost < best_cost:
best_cost = del_cost
best_ptr = 'D'
if ins_cost < best_cost:
best_cost = ins_cost
best_ptr = 'I'
dp[i][j] = best_cost
ptr[i][j] = best_ptr
best_j = -1
best_cost = inf
for j in range(lower, upper + 1):
cost = dp[m][j]
if cost < best_cost:
best_cost = cost
best_j = j
if best_j <= 0 or best_cost >= inf:
return None
aligned_ref = []
aligned_query = []
i, j = m, best_j
while i > 0 or j > 0:
op = ptr[i][j]
if op in ('M', 'S'):
aligned_ref.append(motif[i - 1])
aligned_query.append(window[j - 1])
i -= 1
j -= 1
elif op == 'D':
aligned_ref.append(motif[i - 1])
aligned_query.append('-')
i -= 1
elif op == 'I':
aligned_ref.append('-')
aligned_query.append(window[j - 1])
j -= 1
else: # Should only occur at origin
break
aligned_ref.reverse()
aligned_query.reverse()
operations: List[Tuple] = []
observed_bases: List[Tuple[int, str]] = []
mismatch_count = 0
insertion_len = 0
deletion_len = 0
ref_pos = 0
pending_ins: List[str] = []
pending_ins_pos = 0
pending_del_len = 0
pending_del_pos = 0
for r, q in zip(aligned_ref, aligned_query):
if r == '-':
if not pending_ins:
pending_ins_pos = ref_pos
pending_ins.append(q)
continue
if pending_ins:
ins_seq = ''.join(pending_ins)
operations.append(('ins', pending_ins_pos, ins_seq))
insertion_len += len(ins_seq)
pending_ins = []
pending_ins_pos = 0
ref_pos += 1
if q == '-':
if pending_del_len == 0:
pending_del_pos = ref_pos
pending_del_len += 1
continue
if pending_del_len:
operations.append(('del', pending_del_pos, pending_del_len))
deletion_len += pending_del_len
pending_del_len = 0
observed_bases.append((ref_pos - 1, q))
if r != q:
operations.append(('sub', ref_pos, r, q))
mismatch_count += 1
if pending_ins:
ins_seq = ''.join(pending_ins)
operations.append(('ins', pending_ins_pos, ins_seq))
insertion_len += len(ins_seq)
if pending_del_len:
operations.append(('del', pending_del_pos, pending_del_len))
deletion_len += pending_del_len
if mismatch_count > mismatch_tolerance:
return None
if insertion_len > max_indel or deletion_len > max_indel:
return None
return AlignmentResult(
consumed=best_j,
unit_sequence=window[:best_j],
mismatch_count=mismatch_count,
insertion_length=insertion_len,
deletion_length=deletion_len,
operations=operations,
observed_bases=observed_bases,
edit_distance=best_cost
)
@staticmethod
def _consensus_from_counts(counts: List[Counter], fallback: str) -> str:
"""Build consensus string from per-position base counts."""
consensus = []
for idx, counter in enumerate(counts):
if counter:
base, _ = counter.most_common(1)[0]
consensus.append(base)
else:
consensus.append(fallback[idx] if idx < len(fallback) else 'N')
return ''.join(consensus)
@staticmethod
def align_repeat_region(sequence: str, start: int, end: int, motif_template: str,
mismatch_fraction: float = 0.1,
max_indel: Optional[int] = None,