diff --git a/backend/src/main/kotlin/org/loculus/backend/controller/SubmissionController.kt b/backend/src/main/kotlin/org/loculus/backend/controller/SubmissionController.kt index 979422d86..55ab97b55 100644 --- a/backend/src/main/kotlin/org/loculus/backend/controller/SubmissionController.kt +++ b/backend/src/main/kotlin/org/loculus/backend/controller/SubmissionController.kt @@ -349,6 +349,13 @@ open class SubmissionController( @ApiResponse( responseCode = "200", description = GET_ORIGINAL_METADATA_RESPONSE_DESCRIPTION, + headers = [ + Header( + name = "x-total-records", + description = "The total number of records sent in responseBody", + schema = Schema(type = "integer"), + ), + ], ) @ApiResponse( responseCode = "423", @@ -369,16 +376,29 @@ open class SubmissionController( @HiddenParam authenticatedUser: AuthenticatedUser, @RequestParam compression: CompressionFormat?, ): ResponseEntity { + val stillProcessing = submitModel.checkIfStillProcessingSubmittedData() + if (stillProcessing) { + return ResponseEntity.status(HttpStatus.LOCKED).build() + } + val headers = HttpHeaders() headers.contentType = MediaType.parseMediaType(MediaType.APPLICATION_NDJSON_VALUE) if (compression != null) { headers.add(HttpHeaders.CONTENT_ENCODING, compression.compressionName) } - val stillProcessing = submitModel.checkIfStillProcessingSubmittedData() - if (stillProcessing) { - return ResponseEntity.status(HttpStatus.LOCKED).build() - } + val totalRecords = submissionDatabaseService.countOriginalMetadata( + authenticatedUser, + organism, + groupIdsFilter?.takeIf { it.isNotEmpty() }, + statusesFilter?.takeIf { it.isNotEmpty() }, + ) + headers.add("x-total-records", totalRecords.toString()) + // TODO(https://github.com/loculus-project/loculus/issues/2778) + // There's a possibility that the totalRecords change between the count and the actual query + // this is not too bad, if the client ends up with a few more records than expected + // We just need to make sure the etag used is from before the count + // Alternatively, we could read once to file while counting and then stream the file val streamBody = streamTransactioned(compression) { submissionDatabaseService.streamOriginalMetadata( diff --git a/backend/src/main/kotlin/org/loculus/backend/service/submission/SubmissionDatabaseService.kt b/backend/src/main/kotlin/org/loculus/backend/service/submission/SubmissionDatabaseService.kt index cbbe342bc..7c2041b6d 100644 --- a/backend/src/main/kotlin/org/loculus/backend/service/submission/SubmissionDatabaseService.kt +++ b/backend/src/main/kotlin/org/loculus/backend/service/submission/SubmissionDatabaseService.kt @@ -950,13 +950,12 @@ open class SubmissionDatabaseService( ) } - fun streamOriginalMetadata( + private fun originalMetadataFilter( authenticatedUser: AuthenticatedUser, organism: Organism, groupIdsFilter: List?, statusesFilter: List?, - fields: List?, - ): Sequence { + ): Op { val organismCondition = SequenceEntriesView.organismIs(organism) val groupCondition = getGroupCondition(groupIdsFilter, authenticatedUser) val statusCondition = if (statusesFilter != null) { @@ -966,6 +965,33 @@ open class SubmissionDatabaseService( } val conditions = organismCondition and groupCondition and statusCondition + return conditions + } + + fun countOriginalMetadata( + authenticatedUser: AuthenticatedUser, + organism: Organism, + groupIdsFilter: List?, + statusesFilter: List?, + ): Long = SequenceEntriesView + .selectAll() + .where( + originalMetadataFilter( + authenticatedUser, + organism, + groupIdsFilter, + statusesFilter, + ), + ) + .count() + + fun streamOriginalMetadata( + authenticatedUser: AuthenticatedUser, + organism: Organism, + groupIdsFilter: List?, + statusesFilter: List?, + fields: List?, + ): Sequence { val originalMetadata = SequenceEntriesView.originalDataColumn .extract>("metadata") .alias("original_metadata") @@ -976,7 +1002,14 @@ open class SubmissionDatabaseService( SequenceEntriesView.accessionColumn, SequenceEntriesView.versionColumn, ) - .where(conditions) + .where( + originalMetadataFilter( + authenticatedUser, + organism, + groupIdsFilter, + statusesFilter, + ), + ) .fetchSize(streamBatchSize) .asSequence() .map { diff --git a/backend/src/test/kotlin/org/loculus/backend/controller/submission/GetOriginalMetadataEndpointTest.kt b/backend/src/test/kotlin/org/loculus/backend/controller/submission/GetOriginalMetadataEndpointTest.kt index d6b6f4ad5..e1c71e0d9 100644 --- a/backend/src/test/kotlin/org/loculus/backend/controller/submission/GetOriginalMetadataEndpointTest.kt +++ b/backend/src/test/kotlin/org/loculus/backend/controller/submission/GetOriginalMetadataEndpointTest.kt @@ -52,8 +52,10 @@ class GetOriginalMetadataEndpointTest( @Test fun `GIVEN no sequence entries in database THEN returns empty response`() { val response = submissionControllerClient.getOriginalMetadata() - val responseBody = response.expectNdjsonAndGetContent() + + response.andExpect(status().isOk) + .andExpect(header().string("x-total-records", `is`("0"))) assertThat(responseBody, `is`(emptyList())) } @@ -63,6 +65,9 @@ class GetOriginalMetadataEndpointTest( val response = submissionControllerClient.getOriginalMetadata() val responseBody = response.expectNdjsonAndGetContent() + + response.andExpect(status().isOk) + .andExpect(header().string("x-total-records", `is`(DefaultFiles.NUMBER_OF_SEQUENCES.toString()))) assertThat(responseBody.size, `is`(DefaultFiles.NUMBER_OF_SEQUENCES)) } @@ -150,6 +155,8 @@ class GetOriginalMetadataEndpointTest( groupIdsFilter = listOf(g0), statusesFilter = listOf(Status.APPROVED_FOR_RELEASE), ) + response.andExpect(status().isOk) + .andExpect(header().string("x-total-records", `is`(expectedAccessionVersions.count().toString()))) val responseBody = response.expectNdjsonAndGetContent() assertThat(responseBody, hasSize(expected.size)) diff --git a/ingest/scripts/call_loculus.py b/ingest/scripts/call_loculus.py index e4a68d191..694f41a9f 100644 --- a/ingest/scripts/call_loculus.py +++ b/ingest/scripts/call_loculus.py @@ -312,20 +312,29 @@ def get_submitted(config: Config): "statusesFilter": [], } - logger.info("Getting previously submitted sequences") + while True: + logger.info("Getting previously submitted sequences") - response = make_request(HTTPMethod.GET, url, config, params=params) + response = make_request(HTTPMethod.GET, url, config, params=params) + expected_record_count = int(response.headers["x-total-records"]) - entries: list[dict[str, Any]] = [] - try: - entries = list(jsonlines.Reader(response.iter_lines()).iter()) - except jsonlines.Error as err: - response_summary = response.text - max_error_length = 100 - if len(response_summary) > max_error_length: - response_summary = response_summary[:50] + "\n[..]\n" + response_summary[-50:] - logger.error(f"Error decoding JSON from /get-original-metadata: {response_summary}") - raise ValueError from err + entries: list[dict[str, Any]] = [] + try: + entries = list(jsonlines.Reader(response.iter_lines()).iter()) + except jsonlines.Error as err: + response_summary = response.text + max_error_length = 100 + if len(response_summary) > max_error_length: + response_summary = response_summary[:50] + "\n[..]\n" + response_summary[-50:] + logger.error(f"Error decoding JSON from /get-original-metadata: {response_summary}") + raise ValueError from err + + if len(entries) == expected_record_count: + f"Got {len(entries)} records as expected" + break + logger.error(f"Got incomplete original metadata stream: expected {len(entries)}" + f"records but got {expected_record_count}. Retrying after 60 seconds.") + sleep(60) # Initialize the dictionary to store results submitted_dict: dict[str, dict[str, str | list]] = {}