-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.py
100 lines (73 loc) · 3.33 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gmplot, time
import pandas as pd
from ast import literal_eval
from distances import DTW, LCS
from drawTrajectories import drawTrajectory
def findClosestNeighbors(trainSet, testSet, method):
# bruteforce to find the closest neighbors for every query.
for testIndex, testRow in testSet.iterrows():
dists = []
for trainIndex, trainRow in trainSet.iterrows():
ret = None
if method == 'DTW':
ret = DTW(testRow['Trajectory'], trainRow['Trajectory'])
elif method == 'LCS':
ret = LCS(testRow['Trajectory'], trainRow['Trajectory'])
ret = (-ret[0], ret[1]) # keep both, the value of LCS and the path
dists.append((ret, trainRow['tripId']))
dists.sort()
latitudes, longitudes = [], []
# draw testTrajectory
for elem in testRow['Trajectory']:
latitudes.append(elem[2])
longitudes.append(elem[1])
path = 'static/query2_' + method + '_' + str(testIndex) + '.html'
drawTrajectory(latitudes, longitudes, path)
# operate on five closest neighbors
for i in range(5):
tripIdList = trainSet['tripId'].tolist()
jupIdList = trainSet['journeyPatternId'].tolist()
trajectList = trainSet['Trajectory'].tolist()
neighId = dists[i][1]
pos = tripIdList.index(neighId) # position of that trajectory in the train file
latitudes, longitudes = [], []
# elem[2] = latitude, elem[1] = longitude
for elem in trajectList[pos]:
latitudes.append(elem[2])
longitudes.append(elem[1])
# merge real path with lcs path
if method == 'LCS':
lcsPath = dists[i][0][1]
redLat, redLong = [], []
# elem[2] = latitude, elem[1] = longitude
for elem in lcsPath:
redLat.append(elem[2])
redLong.append(elem[1])
latitudes = [latitudes, redLat]
longitudes = [longitudes, redLong]
print
path = 'static/query2_' + method + '_' + str(testIndex) + '_' + str(i) + '.html'
drawTrajectory(latitudes, longitudes, path, method == 'LCS')
print 'Neighbor', (i + 1)
print 'JP_ID:', jupIdList[pos]
if method == 'DTW':
print 'DTW: ' + str(dists[i][0]) + 'km'
elif method == 'LCS':
print '#Matching Points:', -dists[i][0][0]
print '----------------------------------'
def main():
trainSet = pd.read_csv('datasets/train_set.csv',
converters={'Trajectory': literal_eval})
dist = 'LCS' # Set method with this variable
if dist == 'DTW':
testSet = pd.read_csv('datasets/test_set_a1.csv',
converters={'Trajectory': literal_eval})
elif dist == 'LCS':
testSet = pd.read_csv('datasets/test_set_a2.csv',
converters={'Trajectory': literal_eval})
startTime = time.time()
findClosestNeighbors(trainSet, testSet, dist)
elapsedTime = time.time() - startTime
print 'Time:', str(round(elapsedTime, 2)) + 'sec'
if __name__ == '__main__':
main()