Skip to content

Commit 6eabfb5

Browse files
authored
Merge pull request #515 from GIScience/refactor-celery-workflow
refactor(celery-workflow): avoid doing clip twice
2 parents d99378f + 77da6c6 commit 6eabfb5

19 files changed

+88955
-166319
lines changed

sketch_map_tool/database/client_flask.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from uuid import UUID
32

43
import psycopg2
@@ -34,27 +33,7 @@ def close_connection(e=None):
3433
db_conn.close()
3534

3635

37-
def _insert_id_map(uuid: str, map_: dict):
38-
create_query = """
39-
CREATE TABLE IF NOT EXISTS uuid_map(
40-
uuid uuid PRIMARY KEY,
41-
map json NOT NULL
42-
)
43-
"""
44-
insert_query = "INSERT INTO uuid_map(uuid, map) VALUES (%s, %s)"
45-
db_conn = open_connection()
46-
with db_conn.cursor() as curs:
47-
curs.execute(create_query)
48-
curs.execute(insert_query, [uuid, json.dumps(map_)])
49-
50-
51-
def _delete_id_map(uuid: str):
52-
query = "DELETE FROM uuid_map WHERE uuid = %s"
53-
db_conn = open_connection()
54-
with db_conn.cursor() as curs:
55-
curs.execute(query, [uuid])
56-
57-
36+
# TODO: Legacy support: Delete this function after PR 515 has been deployed for 1 day
5837
def _select_id_map(uuid) -> dict:
5938
query = "SELECT map FROM uuid_map WHERE uuid = %s"
6039
db_conn = open_connection()
@@ -69,6 +48,7 @@ def _select_id_map(uuid) -> dict:
6948
)
7049

7150

51+
# TODO: Legacy support: Delete this function after PR 515 has been deployed for 1 day
7252
def get_async_result_id(request_uuid: str, request_type: REQUEST_TYPES) -> str:
7353
"""Get the Celery Async Result IDs for a request."""
7454
map_ = _select_id_map(request_uuid)
@@ -84,11 +64,6 @@ def get_async_result_id(request_uuid: str, request_type: REQUEST_TYPES) -> str:
8464
) from error
8565

8666

87-
def set_async_result_ids(request_uuid, map_: dict[REQUEST_TYPES, str]):
88-
"""Set the Celery Result IDs for a request."""
89-
_insert_id_map(request_uuid, map_)
90-
91-
9267
def insert_files(
9368
files, consent: bool
9469
) -> tuple[list[int], list[str], list[str], list[Bbox], list[Layer]]:
@@ -215,11 +190,3 @@ def select_map_frame(uuid: UUID) -> tuple[bytes, str, str]:
215190
),
216191
{"UUID": uuid},
217192
)
218-
219-
220-
def delete_map_frame(uuid: UUID):
221-
"""Delete map frame of the associated UUID from the database."""
222-
query = "DELETE FROM map_frame WHERE uuid = %s"
223-
db_conn = open_connection()
224-
with db_conn.cursor() as curs:
225-
curs.execute(query, [str(uuid)])

sketch_map_tool/helpers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,17 @@ def merge(fcs: list[FeatureCollection]) -> FeatureCollection:
5757

5858

5959
def zip_(results: list[tuple[str, str, BytesIO]]) -> BytesIO:
60-
"""ZIP the results of the Celery group of `georeference_sketch_map` tasks."""
60+
"""ZIP the raster results of the Celery group of `upload_processing` tasks."""
6161
buffer = BytesIO()
62-
raw = set([r[1].replace("<br />", "\n") for r in results])
63-
attributions = BytesIO("\n".join(raw).encode())
62+
attributions = []
6463
with ZipFile(buffer, "a") as zip_file:
65-
for file_name, _, file in results:
64+
for file_name, attribution, file in results:
6665
stem = Path(file_name).stem
6766
name = Path(stem).with_suffix(".geotiff")
6867
zip_file.writestr(str(name), file.read())
69-
zip_file.writestr("attributions.txt", attributions.read())
68+
attributions.append(attribution.replace("<br />", "\n"))
69+
file = BytesIO("\n".join(set(attributions)).encode())
70+
zip_file.writestr("attributions.txt", file.read())
7071
buffer.seek(0)
7172
return buffer
7273

sketch_map_tool/routes.py

Lines changed: 72 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from io import BytesIO
33
from pathlib import Path
4-
from uuid import UUID, uuid4
4+
from uuid import UUID
55

66
import geojson
77
from celery import chord, group
@@ -15,6 +15,7 @@
1515
send_from_directory,
1616
url_for,
1717
)
18+
from psycopg2.errors import UndefinedTable
1819
from werkzeug import Response
1920

2021
from sketch_map_tool import celery_app, config, definitions, tasks
@@ -28,12 +29,11 @@
2829
UploadLimitsExceededError,
2930
UUIDNotFoundError,
3031
)
31-
from sketch_map_tool.helpers import extract_errors, merge, to_array, zip_
32+
from sketch_map_tool.helpers import N_, extract_errors, merge, to_array, zip_
3233
from sketch_map_tool.models import Bbox, Layer, PaperFormat, Size
3334
from sketch_map_tool.tasks import (
3435
cleanup_blobs,
35-
digitize_sketches,
36-
georeference_sketch_map,
36+
upload_processing,
3737
)
3838
from sketch_map_tool.validators import (
3939
validate_type,
@@ -114,9 +114,7 @@ def create_results_post(lang="en") -> Response:
114114
"""Create the sketch map"""
115115
# Request parameters
116116
bbox_raw = json.loads(request.form["bbox"])
117-
bbox_wgs84_raw = json.loads(request.form["bboxWGS84"])
118117
bbox = Bbox(*bbox_raw)
119-
bbox_wgs84 = Bbox(*bbox_wgs84_raw)
120118
format_raw = request.form["format"]
121119
format_: PaperFormat = getattr(definitions, format_raw.upper())
122120
orientation = request.form["orientation"]
@@ -125,30 +123,33 @@ def create_results_post(lang="en") -> Response:
125123
scale = float(request.form["scale"])
126124
layer = Layer(request.form["layer"].replace(":", "-").replace("_", "-").lower())
127125

128-
# feature flag for enabling aruco markers
126+
# Feature flag for enabling aruco markers
129127
if request.args.get("aruco") is None:
130128
aruco = False
131129
else:
132130
aruco = True
133131

134-
# Unique id for current request
135-
uuid = str(uuid4())
136-
137132
# Tasks
138133
task_sketch_map = tasks.generate_sketch_map.apply_async(
139-
args=(uuid, bbox, format_, orientation, size, scale, layer, aruco)
140-
)
141-
task_quality_report = tasks.generate_quality_report.apply_async(
142-
args=tuple([bbox_wgs84])
134+
args=(bbox, format_, orientation, size, scale, layer, aruco)
143135
)
136+
return redirect(url_for("create_results_get", lang=lang, uuid=task_sketch_map.id))
144137

145-
# Map of request type to multiple Async Result IDs
146-
map_ = {
147-
"sketch-map": str(task_sketch_map.id),
148-
"quality-report": str(task_quality_report.id),
149-
}
150-
db_client_flask.set_async_result_ids(uuid, map_)
151-
return redirect(url_for("create_results_get", lang=lang, uuid=uuid))
138+
139+
def get_async_result_id(uuid: str, type_: REQUEST_TYPES):
140+
"""Get Celery Async or Group Result UUID for given request UUID.
141+
142+
Try to get Celery UUID for given request from datastore.
143+
If no Celery UUID has been found the request UUID is the same as the Celery UUID.
144+
145+
This function exists only for legacy support.
146+
"""
147+
# TODO: Legacy support: Delete this function after PR 515 has been deployed
148+
# for 1 day
149+
try:
150+
return db_client_flask.get_async_result_id(uuid, type_)
151+
except (UUIDNotFoundError, UndefinedTable):
152+
return uuid
152153

153154

154155
@app.get("/create/results")
@@ -160,8 +161,8 @@ def create_results_get(lang="en", uuid: str | None = None) -> Response | str:
160161
return redirect(url_for("create", lang=lang))
161162
validate_uuid(uuid)
162163
# Check if celery tasks for UUID exists
163-
_ = db_client_flask.get_async_result_id(uuid, "sketch-map")
164-
_ = db_client_flask.get_async_result_id(uuid, "quality-report")
164+
id_ = get_async_result_id(uuid, "sketch-map")
165+
_ = get_async_result(id_, "sketch-map")
165166
return render_template("create-results.html", lang=lang)
166167

167168

@@ -205,11 +206,10 @@ def digitize_results_post(lang="en") -> Response:
205206
bboxes_[uuid] = bbox
206207
layers_[uuid] = layer
207208

208-
tasks_vector = []
209-
tasks_raster = []
209+
tasks = []
210210
for file_id, file_name, uuid in zip(file_ids, file_names, uuids):
211-
tasks_vector.append(
212-
digitize_sketches.signature(
211+
tasks.append(
212+
upload_processing.signature(
213213
(
214214
file_id,
215215
file_name,
@@ -219,40 +219,24 @@ def digitize_results_post(lang="en") -> Response:
219219
)
220220
)
221221
)
222-
tasks_raster.append(
223-
georeference_sketch_map.signature(
224-
(
225-
file_id,
226-
file_name,
227-
map_frames[uuid],
228-
layers_[uuid],
229-
bboxes_[uuid],
230-
)
231-
)
232-
)
233-
async_result_raster = group(tasks_raster).apply_async()
234-
c = chord(
235-
group(tasks_vector),
222+
chord_ = chord(
223+
group(tasks),
236224
cleanup_blobs.signature(
237225
kwargs={"file_ids": list(set(file_ids))},
238226
immutable=True,
239227
),
240228
).apply_async()
241-
async_result_vector = c.parent
229+
async_group_result = chord_.parent
242230

243231
# group results have to be saved for them to be able to be restored later
244-
async_result_raster.save()
245-
async_result_vector.save()
246-
247-
# Unique id for current request
248-
uuid = str(uuid4())
249-
# Mapping of request id to multiple tasks id's
250-
map_ = {
251-
"raster-results": str(async_result_raster.id),
252-
"vector-results": str(async_result_vector.id),
253-
}
254-
db_client_flask.set_async_result_ids(uuid, map_)
255-
return redirect(url_for("digitize_results_get", lang=lang, uuid=uuid))
232+
async_group_result.save()
233+
return redirect(
234+
url_for(
235+
"digitize_results_get",
236+
lang=lang,
237+
uuid=async_group_result.id,
238+
)
239+
)
256240

257241

258242
@app.get("/digitize/results")
@@ -266,19 +250,32 @@ def digitize_results_get(lang="en", uuid: str | None = None) -> Response | str:
266250
return render_template("digitize-results.html", lang=lang)
267251

268252

253+
def get_async_result(uuid: str, type_: REQUEST_TYPES) -> AsyncResult | GroupResult:
254+
"""Get Celery `AsyncResult` or restore `GroupResult` for given Celery UUID."""
255+
if type_ in ("sketch-map", "quality-report"):
256+
async_result = celery_app.AsyncResult(uuid)
257+
elif type_ in ("vector-results", "raster-results"):
258+
async_result = celery_app.GroupResult.restore(uuid)
259+
else:
260+
raise TypeError()
261+
262+
if async_result is None:
263+
raise UUIDNotFoundError(
264+
N_("There are no tasks for UUID {UUID}"),
265+
{"UUID": uuid},
266+
)
267+
else:
268+
return async_result
269+
270+
269271
@app.get("/api/status/<uuid>/<type_>")
270272
@app.get("/<lang>/api/status/<uuid>/<type_>")
271273
def status(uuid: str, type_: REQUEST_TYPES, lang="en") -> Response:
272274
validate_uuid(uuid)
273275
validate_type(type_)
274276

275-
id_ = db_client_flask.get_async_result_id(uuid, type_)
276-
277-
# due to legacy support it is not possible to check only `type_`
278-
# (in the past every Celery result was of type `AsyncResult`)
279-
async_result = celery_app.GroupResult.restore(id_)
280-
if async_result is None:
281-
async_result = celery_app.AsyncResult(id_)
277+
id_ = get_async_result_id(uuid, type_)
278+
async_result = get_async_result(id_, type_)
282279

283280
href = ""
284281
info = ""
@@ -336,18 +333,18 @@ def download(uuid: str, type_: REQUEST_TYPES, lang="en") -> Response:
336333
validate_uuid(uuid)
337334
validate_type(type_)
338335

339-
id_ = db_client_flask.get_async_result_id(uuid, type_)
336+
id_ = get_async_result_id(uuid, type_)
337+
async_result = get_async_result(id_, type_)
340338

341-
# due to legacy support it is not possible to check only `type_`
342-
# (in the past every Celery result was of type `AsyncResult`)
343-
async_result = celery_app.GroupResult.restore(id_)
344-
if async_result is None:
345-
async_result = celery_app.AsyncResult(id_)
346-
if not async_result.ready() or async_result.failed():
339+
# Abort if result not ready or failed.
340+
# No nice error message here because user should first check /api/status.
341+
if isinstance(async_result, GroupResult):
342+
if not async_result.ready() or all([r.failed() for r in async_result.results]):
347343
abort(500)
348344
else:
349-
if not async_result.ready() or all([r.failed() for r in async_result.results]):
345+
if not async_result.ready() or async_result.failed():
350346
abort(500)
347+
351348
match type_:
352349
case "quality-report":
353350
mimetype = "application/pdf"
@@ -361,16 +358,19 @@ def download(uuid: str, type_: REQUEST_TYPES, lang="en") -> Response:
361358
mimetype = "application/zip"
362359
download_name = type_ + ".zip"
363360
if isinstance(async_result, GroupResult):
364-
file: BytesIO = zip_(async_result.get(propagate=False))
361+
results = async_result.get(propagate=False)
362+
raster_results = [r[:-1] for r in results]
363+
file: BytesIO = zip_(raster_results)
365364
else:
366365
# support legacy results
367366
file: BytesIO = async_result.get()
368367
case "vector-results":
369368
mimetype = "application/geo+json"
370369
download_name = type_ + ".geojson"
371370
if isinstance(async_result, GroupResult):
372-
result: list = async_result.get(propagate=False)
373-
raw = geojson.dumps(merge(result))
371+
results = async_result.get(propagate=False)
372+
vector_results = [r[-1] for r in results]
373+
raw = geojson.dumps(merge(vector_results))
374374
file: BytesIO = BytesIO(raw.encode("utf-8"))
375375
else:
376376
# support legacy results

0 commit comments

Comments
 (0)