-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathedit_distance.py
149 lines (117 loc) · 4.28 KB
/
edit_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/local/env python
"""
===================================
Dynamic programming - Edit distance
===================================
Definition:
Given two strings, a and b, the edit distance d(a, b) is the min weight series
of edit operation to transform a into b.
Insertion: insert a char in to a, and then equals to b
Deletion: delete a char in a, and the equal to b
Substitution: substitute a char in a with char in b, and the get b
Suppose we have two string: a = a1, a2, ..., an and b = b1, b2, ..., bm,
Wdel: Weight for deletion
Wins: Weight for insertion
Wsub: Weight for substitution
| d(i, 0) = ∑ Wdel(ak), k = [1, i], where 1 <= i <= n
| d(0, j) = ∑ Wins(bk), k = [1, j], where 1 <= j <= m
|
| | d(i - 1, j - 1), for a(i) = b(j)
| d(i, j) = | | d(i - 1, j) + Wins(ai)
| | min <| d(i, j - 1) + Wdel(bj) for a(i) <> b(j) for 1 <= i <= n, 1 <= j <= m
| | | d(i - 1, j - 1) + Wsub(ai, bj)
|
Refs:
[1] https://en.wikipedia.org/wiki/Edit_distance
[2] https://en.wikipedia.org/wiki/Wagner%E2%80%93Fischer_algorithm
"""
import sys
import time
import argparse
import functools
def weight_insertion(ch):
"""
Weight for insert char `ch`
"""
return 1
def weight_deletion(ch):
"""
Weight for delete char `ch`
"""
return 1
def weight_substitution(ch1, ch2):
"""
Weight for substitude char `ch1` with `ch2`
"""
return 2
@functools.lru_cache(maxsize=20)
def edit_distance(seq1, seq2):
"""
Calculate edit distance between string `seq1` and `seq2`
:param str seq1: String 1
:param str seq2: String 2
:return int: The edit distance
"""
if seq1 == seq2:
return 0
if len(seq2) == 0:
return sum([weight_insertion(seq1[i]) for i in range(len(seq1))])
if len(seq1) == 0:
return sum([weight_deletion(seq2[i]) for i in range(len(seq2))])
return min(
weight_insertion(seq1[-1]) + edit_distance(seq1[:-1], seq2),
weight_deletion(seq2[-1]) + edit_distance(seq1, seq2[:-1]),
weight_substitution(seq1[-1], seq2[-1]) + edit_distance(seq1[:-1], seq2[:-1]) if seq1[:-1] != seq2[:-1] else 0
)
def edit_distance_with_solution(cache):
"""
A function to calculate edit distance between string `seq1` and `seq2`
:param str seq1: String 1
:param str seq2: String 2
:return int: The edit distance calculation function
"""
@functools.lru_cache(maxsize=20)
def wrapper(seq1, seq2):
key = '{}:{}'.format(seq1, seq2)
if seq1 == seq2:
cache[key] = []
return 0
if len(seq2) == 0:
cache[key] = [{'op': 'I', 'ch': seq1[i]} for i in range(len(seq1))]
return sum([weight_insertion(seq1[i]) for i in range(len(seq1))])
if len(seq1) == 0:
cache[key] = [{'op': 'D', 'ch': seq2[i]} for i in range(len(seq2))]
return sum([weight_deletion(seq2[i]) for i in range(len(seq2))])
op, distance = min(
({'key': '{}:{}'.format(seq1[:-1], seq2), 'op': 'I', 'ch': seq1[-1]}, weight_insertion(seq1[-1]) + wrapper(seq1[:-1], seq2)),
({'key': '{}:{}'.format(seq1, seq2[:-1]), 'op': 'D', 'ch': seq2[-1]}, weight_deletion(seq2[-1]) + wrapper(seq1, seq2[:-1])),
({'key': '{}:{}'.format(seq1[:-1], seq2[:-1]), 'op': 'S', 'ch': seq1[-1]}, weight_substitution(seq1[-1], seq2[-1]) + wrapper(seq1[:-1], seq2[:-1])) if seq1[:-1] != seq2[:-1] else ([], 0),
key=lambda x: x[1]
)
# if 'key' in op:
# key = op.pop('key')
cache[key] = op
return distance
return wrapper
# def build_solution(cache):
# def wrapper(seq):
# if seq
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('seq1', type=str, help='string1')
parser.add_argument('seq2', type=str, help='string2')
args = parser.parse_args(sys.argv[1:])
dist = edit_distance(args.seq1, args.seq2)
print('Min edit distance between "{}" and "{}" is {}'.format(
args.seq1,
args.seq2,
dist
))
cache = {}
dist = edit_distance_with_solution(cache)(args.seq1, args.seq2)
print('Min edit distance between "{}" and "{}" is {}'.format(
args.seq1,
args.seq2,
dist
))
print(cache)