Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend,prepro): use etag to reduce database load #2768

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ open class SubmissionController(
),
],
)
@ApiResponse(responseCode = "304", description = "Not Modified")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets maybe add a test for this - we can basically just copy the test we had for get-released-data?

@ApiResponse(responseCode = "422", description = EXTRACT_UNPROCESSED_DATA_ERROR_RESPONSE)
@PostMapping("/extract-unprocessed-data", produces = [MediaType.APPLICATION_NDJSON_VALUE])
fun extractUnprocessedData(
Expand All @@ -143,6 +144,7 @@ open class SubmissionController(
message = "You can extract at max $MAX_EXTRACTED_SEQUENCE_ENTRIES sequence entries at once.",
corneliusroemer marked this conversation as resolved.
Show resolved Hide resolved
) numberOfSequenceEntries: Int,
@RequestParam pipelineVersion: Long,
@RequestHeader(value = HttpHeaders.IF_NONE_MATCH, required = false) ifNoneMatch: String?,
): ResponseEntity<StreamingResponseBody> {
val currentProcessingPipelineVersion = submissionDatabaseService.getCurrentProcessingPipelineVersion()
if (pipelineVersion < currentProcessingPipelineVersion) {
Expand All @@ -152,8 +154,12 @@ open class SubmissionController(
)
}

val lastDatabaseWriteETag = releasedDataModel.getLastDatabaseWriteETag()
if (ifNoneMatch == lastDatabaseWriteETag) return ResponseEntity.status(HttpStatus.NOT_MODIFIED).build()

val headers = HttpHeaders()
headers.contentType = MediaType.parseMediaType(MediaType.APPLICATION_NDJSON_VALUE)
headers.eTag = lastDatabaseWriteETag
val streamBody = streamTransactioned {
submissionDatabaseService.streamUnprocessedSubmissions(numberOfSequenceEntries, organism, pipelineVersion)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ and roll back the whole transaction.
const val GET_RELEASED_DATA_DESCRIPTION = """
Get released data as a stream of NDJSON.
This returns all accession versions that have the status 'APPROVED_FOR_RELEASE'.
Optionally submit the etag received in previous request with If-None-Match
to only retrieve all released data if the database has changed since last request.
"""

const val GET_RELEASED_DATA_RESPONSE_DESCRIPTION = """
Expand Down
5 changes: 2 additions & 3 deletions kubernetes/loculus/silo_import_job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ download_data() {
http_status_code=$(curl -o "$new_input_data_path" --fail-with-body "$released_data_endpoint" -H "If-None-Match: $last_etag" -D "$new_input_header_path" -w "%{http_code}")
exit_code=$?
set -e

echo "Release data request returned with http status code: $http_status_code"
if [ "$http_status_code" -eq 304 ]; then
echo "State in Loculus backend has not changed: HTTP 304 Not Modified."
Expand All @@ -109,8 +108,8 @@ download_data() {
expected_record_count=$(grep -i '^x-total-records:' "$new_input_header_path" | awk '{print $2}' | tr -d '[:space:]')
echo "Response should contain a total of : $expected_record_count records"

# jq validates each individual json object, to catch truncated lines
true_record_count=$(zstd -d -c "$new_input_data_path" | jq -c . | wc -l | tr -d '[:space:]')
# jq validates each individual json object, to catch truncated lines
true_record_count=$(zstd -d -c "$new_input_data_path" | jq -c . | wc -l | tr -d '[:space:]')
echo "Response contained a total of : $true_record_count records"

if [ "$true_record_count" -ne "$expected_record_count" ]; then
Expand Down
58 changes: 32 additions & 26 deletions preprocessing/dummy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,28 @@ class Sequence:
)


def fetch_unprocessed_sequences(n: int) -> List[Sequence]:
def fetch_unprocessed_sequences(etag: str | None, n: int) -> tuple[str | None, List[Sequence]]:
url = backendHost + "/extract-unprocessed-data"
params = {"numberOfSequenceEntries": n, "pipelineVersion": pipeline_version}
headers = {"Authorization": "Bearer " + get_jwt()}
headers = {
"Authorization": "Bearer " + get_jwt(),
**({"If-None-Match": etag} if etag else {}),
}
response = requests.post(url, data=params, headers=headers)
if not response.ok:
if response.status_code == 422:
logging.debug("{}. Sleeping for a while.".format(response.text))
match response.status_code:
case 200:
return response.headers.get("ETag"), parse_ndjson(response.text)
case 304:
return etag, []
case 422:
logging.debug(f"{response.text}. Sleeping for a while.")
time.sleep(60 * 10)
return []
raise Exception(
"Fetching unprocessed data failed. Status code: {}".format(response.status_code),
response.text,
)
return parse_ndjson(response.text)
return None, []
case _:
raise Exception(
f"Fetching unprocessed data failed. Status code: {response.status_code}",
response.text,
)


def parse_ndjson(ndjson_data: str) -> List[Sequence]:
Expand Down Expand Up @@ -181,7 +188,7 @@ def submit_processed_sequences(processed: List[Sequence]):
response = requests.post(url, data=ndjson_string, headers=headers)
if not response.ok:
raise Exception(
"Submitting processed data failed. Status code: {}".format(response.status_code),
f"Submitting processed data failed. Status code: {response.status_code}",
response.text,
)

Expand All @@ -196,45 +203,44 @@ def get_jwt():
}
response = requests.post(url, data=data)
if not response.ok:
raise Exception(
"Fetching JWT failed. Status code: {}".format(response.status_code), response.text
)
raise Exception(f"Fetching JWT failed. Status code: {response.status_code}", response.text)
return response.json()["access_token"]


def main():
total_processed = 0
locally_processed = 0
etag = None
last_force_refresh = time.time()

if watch_mode:
logging.debug("Started in watch mode - waiting 10 seconds before fetching data.")
time.sleep(10)

if args.maxSequences and args.maxSequences < 100:
sequences_to_fetch = args.maxSequences
else:
sequences_to_fetch = 100
sequences_to_fetch = args.maxSequences if args.maxSequences and args.maxSequences < 100 else 100

while True:
unprocessed = fetch_unprocessed_sequences(sequences_to_fetch)
if last_force_refresh + 3600 < time.time():
etag = None
last_force_refresh = time.time()

etag, unprocessed = fetch_unprocessed_sequences(etag, sequences_to_fetch)
if len(unprocessed) == 0:
if watch_mode:
logging.debug(
"Processed {} sequences. Sleeping for 10 seconds.".format(locally_processed)
)
logging.debug(f"Processed {locally_processed} sequences. Sleeping for 10 seconds.")
time.sleep(2)
locally_processed = 0
continue
else:
break
break
etag = None
processed = process(unprocessed)
submit_processed_sequences(processed)
total_processed += len(processed)
locally_processed += len(processed)

if args.maxSequences and total_processed >= args.maxSequences:
break
logging.debug("Total processed sequences: {}".format(total_processed))
logging.debug(f"Total processed sequences: {total_processed}")


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion preprocessing/nextclade/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ name: loculus-nextclade
channels:
- conda-forge
- bioconda
- nodefaults
dependencies:
- python=3.12
- biopython=1.83
- dpath=2.1
- nextclade=3.5
- nextclade=3.8
- pip=24.0
- PyYAML=6.0
- pyjwt=2.8
Expand Down
58 changes: 46 additions & 12 deletions preprocessing/nextclade/src/loculus_preprocessing/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .config import Config
from .datatypes import (
ProcessedEntry,
UnprocessedData,
UnprocessedEntry,
)


Expand Down Expand Up @@ -66,24 +68,56 @@ def get_jwt(config: Config) -> str:
raise Exception(error_msg)


def fetch_unprocessed_sequences(n: int, config: Config) -> str:
def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedEntry]:
entries = []
for json_str in ndjson_data.split("\n"):
if len(json_str) == 0:
continue
# Loculus currently cannot handle non-breaking spaces.
json_str_processed = json_str.replace("\N{NO-BREAK SPACE}", " ")
json_object = json.loads(json_str_processed)
unprocessed_data = UnprocessedData(
submitter=json_object["submitter"],
metadata=json_object["data"]["metadata"],
unalignedNucleotideSequences=json_object["data"]["unalignedNucleotideSequences"],
)
entry = UnprocessedEntry(
accessionVersion=f"{json_object['accession']}.{
json_object['version']}",
data=unprocessed_data,
)
entries.append(entry)
return entries


def fetch_unprocessed_sequences(
etag: str | None, config: Config
) -> tuple[str | None, Sequence[UnprocessedEntry] | None]:
n = config.batch_size
url = config.backend_host.rstrip("/") + "/extract-unprocessed-data"
logging.debug(f"Fetching {n} unprocessed sequences from {url}")
params = {"numberOfSequenceEntries": n, "pipelineVersion": config.pipeline_version}
headers = {"Authorization": "Bearer " + get_jwt(config)}
headers = {
"Authorization": "Bearer " + get_jwt(config),
**({"If-None-Match": etag} if etag else {}),
}
logging.debug(f"Requesting data with ETag: {etag}")
response = requests.post(url, data=params, headers=headers, timeout=10)
if not response.ok:
if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY:
match response.status_code:
case HTTPStatus.NOT_MODIFIED:
return etag, None
case HTTPStatus.OK:
return response.headers["ETag"], parse_ndjson(response.text)
case HTTPStatus.UNPROCESSABLE_ENTITY:
logging.debug(f"{response.text}.\nSleeping for a while.")
time.sleep(60 * 1)
return ""
msg = f"Fetching unprocessed data failed. Status code: {
response.status_code}"
raise Exception(
msg,
response.text,
)
return response.text
return None, None
case _:
msg = f"Fetching unprocessed data failed. Status code: {response.status_code}"
raise Exception(
msg,
response.text,
)


def submit_processed_sequences(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ruff: noqa: N815
from dataclasses import dataclass, field
from enum import StrEnum, unique
from typing import List, Tuple, Any
from typing import Any

AccessionVersion = str
GeneName = str
Expand Down Expand Up @@ -37,7 +37,7 @@ def __hash__(self):

@dataclass(frozen=True)
class ProcessingAnnotation:
source: Tuple[AnnotationSource, ...]
source: tuple[AnnotationSource, ...]
message: str

def __post_init__(self):
Expand Down
44 changes: 17 additions & 27 deletions preprocessing/nextclade/src/loculus_preprocessing/prepro.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,6 @@
# Functions related to reading and writing files


def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedEntry]:
entries = []
for json_str in ndjson_data.split("\n"):
if len(json_str) == 0:
continue
# Loculus currently cannot handle non-breaking spaces.
json_str_processed = json_str.replace("\N{NO-BREAK SPACE}", " ")
json_object = json.loads(json_str_processed)
unprocessed_data = UnprocessedData(
submitter=json_object["submitter"],
metadata=json_object["data"]["metadata"],
unalignedNucleotideSequences=json_object["data"]["unalignedNucleotideSequences"],
)
entry = UnprocessedEntry(
accessionVersion=f"{json_object['accession']}.{
json_object['version']}",
data=unprocessed_data,
)
entries.append(entry)
return entries


def parse_nextclade_tsv(
amino_acid_insertions: defaultdict[
AccessionVersion, defaultdict[GeneName, list[AminoAcidInsertion]]
Expand Down Expand Up @@ -725,17 +703,29 @@ def run(config: Config) -> None:
if config.nextclade_dataset_name:
download_nextclade_dataset(dataset_dir, config)
total_processed = 0
etag = None
last_force_refresh = time.time()
while True:
logging.debug("Fetching unprocessed sequences")
unprocessed = parse_ndjson(fetch_unprocessed_sequences(config.batch_size, config))
if len(unprocessed) == 0:
# Reset etag every hour just in case
if last_force_refresh + 3600 < time.time():
etag = None
last_force_refresh = time.time()
etag, unprocessed = fetch_unprocessed_sequences(etag, config)
if not unprocessed:
# sleep 1 sec and try again
logging.debug("No unprocessed sequences found. Sleeping for 1 second.")
time.sleep(1)
continue
# Process the sequences, get result as dictionary
processed = process_all(unprocessed, dataset_dir, config)
# Submit the result
# Don't use etag if we just got data, preprocessing only asks for 100 sequences to process at a time, so there might be more
etag = None
try:
processed = process_all(unprocessed, dataset_dir, config)
except Exception as e:
logging.exception(
f"Processing failed. Traceback : {e}. Unprocessed data: {unprocessed}"
)
continue
try:
submit_processed_sequences(processed, dataset_dir, config)
except RuntimeError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,9 @@ def identity(
warnings.append(
ProcessingAnnotation(
source=[
AnnotationSource(name=output_field, type=AnnotationSourceType.METADATA)
AnnotationSource(
name=output_field, type=AnnotationSourceType.METADATA
)
],
message=f"Invalid boolean value: {input_datum}. Defaulting to null.",
)
Expand Down
Loading