Skip to content

Commit d2f6092

Browse files
committed
refactored imports
1 parent 2666b6f commit d2f6092

File tree

3 files changed

+39
-42
lines changed

3 files changed

+39
-42
lines changed

docs/paper/sectflow.rst

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Data preparation
1515
1616
import numpy as np
1717
from traffic.core import Traffic
18-
18+
1919
t = (
2020
# trajectories during the opening hours of the sector
2121
Traffic.from_file("data/LFBBPT_flights_2017.pkl")
@@ -42,14 +42,14 @@ the following `github repository <https://github.com/lbasora/sectflow>`_
4242

4343
.. code:: python
4444
45-
from traffic.core.projection import Lambert93
46-
45+
from cartes.crs import Lambert93
46+
4747
# pip install git+https://github.com/lbasora/sectflow
4848
from sectflow.clustering import TrajClust
49-
49+
5050
features = ["x", "y", "latitude", "longitude", "altitude", "log_altitude"]
5151
clustering = TrajClust(features)
52-
52+
5353
# use the clustering API from traffic
5454
t_cluster = t.clustering(
5555
nb_samples=2, features=features, projection=Lambert93(), clustering=clustering
@@ -59,7 +59,7 @@ the following `github repository <https://github.com/lbasora/sectflow>`_
5959
6060
# Color distribution by cluster
6161
from itertools import cycle, islice
62-
62+
6363
n_clusters = 1 + t_cluster.data.cluster.max()
6464
color_cycle = cycle(
6565
"#fbbb35 #004cb9 #4cc700 #a50016 #510420 #01bcf5 #999999 #e60085 #ffa9c5".split()
@@ -69,29 +69,29 @@ the following `github repository <https://github.com/lbasora/sectflow>`_
6969
7070
import matplotlib.pyplot as plt
7171
from random import sample
72-
72+
7373
from traffic.data import airways, aixm_airspaces
7474
from traffic.drawing.markers import rotate_marker, atc_tower, aircraft
75-
75+
7676
with plt.style.context("traffic"):
7777
fig, ax = plt.subplots(1, figsize=(15, 10), subplot_kw=dict(projection=Lambert93()))
78-
78+
7979
aixm_airspaces["LFBBPT"].plot(
8080
ax, linewidth=3, linestyle="dashed", color="steelblue"
8181
)
8282
for name in "UN460 UN869 UM728".split():
8383
airways[name].plot(ax, linestyle="dashed", color="#aaaaaa")
84-
84+
8585
# do not plot outliers
8686
for cluster in range(n_clusters):
87-
87+
8888
current_cluster = t_cluster.query(f"cluster == {cluster}")
89-
89+
9090
# plot the centroid of each cluster
9191
centroid = current_cluster.centroid(50, projection=Lambert93())
9292
centroid.plot(ax, color=colors[cluster], alpha=0.9, linewidth=3)
9393
centroid_mark = centroid.at_ratio(0.45)
94-
94+
9595
# little aircraft
9696
centroid_mark.plot(
9797
ax,
@@ -100,22 +100,22 @@ the following `github repository <https://github.com/lbasora/sectflow>`_
100100
s=500,
101101
text_kw=dict(s=""), # no text associated
102102
)
103-
103+
104104
# plot some sample flights from each cluster
105105
sample_size = min(20, len(current_cluster))
106106
for flight_id in sample(current_cluster.flight_ids, sample_size):
107107
current_cluster[flight_id].plot(
108108
ax, color=colors[cluster], alpha=0.1, linewidth=2
109109
)
110-
110+
111111
# TODO improve this: extent with buffer
112112
ax.set_extent(
113113
tuple(
114114
x - 0.5 + (0 if i % 2 == 0 else 1)
115115
for i, x in enumerate(aixm_airspaces["LFBBPT"].extent)
116116
)
117117
)
118-
118+
119119
# Equivalent of Fig. 5
120120
121121
@@ -134,13 +134,13 @@ The anomaly detection method is based on a stacked autoencoder
134134
import torch
135135
from torch import nn, optim, from_numpy, rand
136136
from torch.autograd import Variable
137-
137+
138138
from sklearn.preprocessing import minmax_scale
139139
from tqdm.autonotebook import tqdm
140-
141-
140+
141+
142142
# Stacked autoencoder
143-
143+
144144
class Autoencoder(nn.Module):
145145
def __init__(self):
146146
super().__init__()
@@ -150,54 +150,54 @@ The anomaly detection method is based on a stacked autoencoder
150150
self.decoder = nn.Sequential(
151151
nn.Linear(12, 24), nn.ReLU(), nn.Linear(24, 50), nn.Sigmoid()
152152
)
153-
153+
154154
def forward(self, x, **kwargs):
155155
x = x + (rand(50).cuda() - 0.5) * 1e-3 # add some noise
156156
x = self.encoder(x)
157157
x = self.decoder(x)
158158
return x
159-
159+
160160
# Regularisation term introduced in IV.B.2
161-
161+
162162
def regularisation_term(X, n):
163163
samples = torch.linspace(0, X.max(), 100, requires_grad=True)
164164
mean = samples.mean()
165165
return torch.relu(
166166
(torch.histc(X) / n * 100 - 1 / mean * torch.exp(-samples / mean))
167167
).mean()
168-
168+
169169
# ML part
170-
170+
171171
def anomalies(t: Traffic, cluster_id: int, lambda_r: float, nb_it: int = 10000):
172-
172+
173173
t_id = t.query(f"cluster=={cluster_id}")
174-
174+
175175
flight_ids = list(f.flight_id for f in t_id)
176176
n = len(flight_ids)
177177
X = minmax_scale(np.vstack(f.data.track[:50] for f in t_id))
178-
178+
179179
model = Autoencoder().cuda()
180180
criterion = nn.MSELoss()
181181
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
182-
182+
183183
for epoch in tqdm(range(nb_it), leave=False):
184-
184+
185185
v = Variable(from_numpy(X.astype(np.float32))).cuda()
186-
186+
187187
output = model(v)
188188
distance = nn.MSELoss(reduction="none")(output, v).sum(1).sqrt()
189-
189+
190190
loss = criterion(output, v)
191191
# regularisation
192192
loss = (
193193
lambda_r * regularisation_term(distance.cpu().detach(), n)
194194
+ criterion(output, v).cpu()
195195
)
196-
196+
197197
optimizer.zero_grad()
198198
loss.backward()
199199
optimizer.step()
200-
200+
201201
output = model(v)
202202
return (
203203
(nn.MSELoss(reduction="none")(output, v).sum(1)).sqrt().cpu().detach().numpy()
@@ -215,16 +215,16 @@ parameter helps reducing this trend.
215215
.. code:: python
216216
217217
from scipy.stats import expon
218-
218+
219219
# Equivalent of Fig. 4
220-
220+
221221
with plt.style.context("traffic"):
222222
fig, ax = plt.subplots(1, figsize=(10, 7))
223223
hst = ax.hist(output, bins=50, density=True)
224224
mean = output.mean()
225225
x = np.arange(0, output.max(), 1e-2)
226226
e = expon.pdf(x, 0, output.mean())
227-
227+
228228
ax.plot(x, e, color="#e77074")
229229
ax.fill_between(x, e, zorder=-2, color="#e77074", alpha=0.5)
230230
ax.axvline(
@@ -234,6 +234,3 @@ parameter helps reducing this trend.
234234
235235
.. image:: images/sectflow_distribution.png
236236
:align: center
237-
238-
239-

docs/tutorial/generation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ from the north.
3838

3939
import matplotlib.pyplot as plt
4040
from traffic.data.datasets import landing_zurich_2019
41-
from traffic.core.projection import EuroPP
41+
from cartes.crs import EuroPP
4242

4343
t = (
4444
landing_zurich_2019

src/traffic/core/traffic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,7 @@ def clustering(
14921492
14931493
Example usage:
14941494
1495-
>>> from traffic.core.projection import EuroPP
1495+
>>> from cartes.crs import EuroPP
14961496
>>> from sklearn.cluster import DBSCAN
14971497
>>> from sklearn.preprocessing import StandardScaler
14981498
>>>

0 commit comments

Comments
 (0)