diff --git a/.happy/terraform/modules/sfn/main.tf b/.happy/terraform/modules/sfn/main.tf index 21120f185e117..ef6c4d2516e35 100644 --- a/.happy/terraform/modules/sfn/main.tf +++ b/.happy/terraform/modules/sfn/main.tf @@ -1,4 +1,3 @@ -# Same file as https://github.com/chanzuckerberg/single-cell-infra/blob/main/.happy/terraform/modules/sfn/main.tf # This is used for environment (dev, staging, prod) deployments locals { timeout = 86400 # 24 hours @@ -154,7 +153,7 @@ resource "aws_sfn_state_machine" "state_machine" { "Validate": { "Type": "Task", "Resource": "arn:aws:states:::batch:submitJob.sync", - "Next": "CxgSeuratParallel", + "Next": "Cxg", "Parameters": { "JobDefinition.$": "$.batch.JobDefinitionName", "JobName": "validate", @@ -194,106 +193,48 @@ resource "aws_sfn_state_machine" "state_machine" { } ] }, - "CxgSeuratParallel": { - "Type": "Parallel", + "Cxg": { + "Type": "Task", "Next": "HandleSuccess", - "Branches": [ - { - "StartAt": "Cxg", - "States": { - "Cxg": { - "Type": "Task", - "End": true, - "Resource": "arn:aws:states:::batch:submitJob.sync", - "Parameters": { - "JobDefinition.$": "$.batch.JobDefinitionName", - "JobName": "cxg", - "JobQueue.$": "$.job_queue", - "ContainerOverrides": { - "Environment": [ - { - "Name": "DATASET_VERSION_ID", - "Value.$": "$.dataset_version_id" - }, - { - "Name": "STEP_NAME", - "Value": "cxg" - } - ] - } - }, - "Retry": [ { - "ErrorEquals": ["AWS.Batch.TooManyRequestsException", "Batch.BatchException", "Batch.AWSBatchException"], - "IntervalSeconds": 2, - "MaxAttempts": 7, - "BackoffRate": 5 - } ], - "Catch": [ - { - "ErrorEquals": [ - "States.ALL" - ], - "Next": "CatchCxgFailure", - "ResultPath": "$.error" - } - ], - "ResultPath": null, - "TimeoutSeconds": 360000 + "Resource": "arn:aws:states:::batch:submitJob.sync", + "Parameters": { + "JobDefinition.$": "$.batch.JobDefinitionName", + "JobName": "cxg", + "JobQueue.$": "$.job_queue", + "ContainerOverrides": { + "Environment": [ + { + "Name": "DATASET_VERSION_ID", + "Value.$": "$.dataset_version_id" }, - "CatchCxgFailure": { - "Type": "Pass", - "End": true + { + "Name": "STEP_NAME", + "Value": "cxg" } - } - }, + ] + } + }, + "Retry": [ { + "ErrorEquals": ["AWS.Batch.TooManyRequestsException", "Batch.BatchException", "Batch.AWSBatchException"], + "IntervalSeconds": 2, + "MaxAttempts": 7, + "BackoffRate": 5 + } ], + "Catch": [ { - "StartAt": "Seurat", - "States": { - "Seurat": { - "Type": "Task", - "End": true, - "Resource": "arn:aws:states:::batch:submitJob.sync", - "Parameters": { - "JobDefinition.$": "$.batch.JobDefinitionName", - "JobName": "seurat", - "JobQueue.$": "$.job_queue", - "ContainerOverrides": { - "Environment": [ - { - "Name": "DATASET_VERSION_ID", - "Value.$": "$.dataset_version_id" - }, - { - "Name": "STEP_NAME", - "Value": "seurat" - } - ] - } - }, - "Retry": [ { - "ErrorEquals": ["AWS.Batch.TooManyRequestsException", "Batch.BatchException", "Batch.AWSBatchException"], - "IntervalSeconds": 2, - "MaxAttempts": 7, - "BackoffRate": 5 - } ], - "Catch": [ - { - "ErrorEquals": [ - "States.ALL" - ], - "Next": "CatchSeuratFailure", - "ResultPath": "$.error" - } - ], - "TimeoutSeconds": ${local.timeout} - }, - "CatchSeuratFailure": { - "Type": "Pass", - "End": true - } - } + "ErrorEquals": [ + "States.ALL" + ], + "Next": "CatchCxgFailure", + "ResultPath": "$.error" } - ] + ], + "ResultPath": null, + "TimeoutSeconds": 360000 + }, + "CatchCxgFailure": { + "Type": "Pass", + "End": true }, "HandleSuccess": { "Type": "Task", @@ -301,8 +242,7 @@ resource "aws_sfn_state_machine" "state_machine" { "Resource": "${var.lambda_success_handler}", "Parameters": { "execution_id.$": "$$.Execution.Id", - "cxg_job.$": "$[0]", - "seurat_job.$": "$[1]" + "cxg_job.$": "$" }, "Retry": [ { "ErrorEquals": ["Lambda.AWSLambdaException"], @@ -351,7 +291,7 @@ resource "aws_sfn_state_machine" "state_machine" { "Type": "Task", "Next": "CheckForErrors", "Parameters": { - "JobDefinition.$": "$[0].batch.JobDefinitionName" + "JobDefinition.$": "$.batch.JobDefinitionName" }, "Resource": "arn:aws:states:::aws-sdk:batch:deregisterJobDefinition", "Retry": [ { @@ -403,42 +343,6 @@ resource "aws_sfn_state_machine" "state_machine" { EOF } -resource "aws_sfn_state_machine" "state_machine_seurat" { - name = "dp-${var.deployment_stage}-${var.custom_stack_name}-seurat-sfn" - role_arn = var.role_arn - - definition = < None: - super().__init__() - self.business_logic = business_logic - self.uri_provider = uri_provider - self.s3_provider = s3_provider - - def process(self, dataset_version_id: DatasetVersionId, artifact_bucket: str, datasets_bucket: str): - """ - 1. Download the labeled dataset from the artifact bucket - 2. Convert it to Seurat format - 3. Upload the Seurat file to the artifact bucket - :param dataset_version_id: - :param artifact_bucket: - :param datasets_bucket: - :return: - """ - - # If the validator previously marked the dataset as rds_status.SKIPPED, do not start the Seurat processing - dataset = self.business_logic.get_dataset_version(dataset_version_id) - - if dataset is None: - raise Exception("Dataset not found") # TODO: maybe improve - - if dataset.status.rds_status == DatasetConversionStatus.SKIPPED: - self.logger.info("Skipping Seurat conversion") - return - - # Download h5ad locally - labeled_h5ad_filename = "local.h5ad" - key_prefix = self.get_key_prefix(dataset_version_id.id) - object_key = f"{key_prefix}/{labeled_h5ad_filename}" - self.download_from_s3(artifact_bucket, object_key, labeled_h5ad_filename) - - # Convert the citation from h5ad to RDS - adata = anndata.read_h5ad(labeled_h5ad_filename) - if "citation" in adata.uns: - adata.uns["citation"] = rds_citation_from_h5ad(adata.uns["citation"]) - - # enforce for canonical - logger.info("enforce canonical format in X") - enforce_canonical_format(adata) - if adata.raw: - logger.info("enforce canonical format in raw.X") - enforce_canonical_format(adata.raw) - - adata.write_h5ad(labeled_h5ad_filename) - - # Use Seurat to convert to RDS - seurat_filename = self.convert_file( - self.make_seurat, - labeled_h5ad_filename, - "Failed to convert dataset to Seurat format.", - dataset_version_id, - DatasetStatusKey.RDS, - ) - - self.create_artifact( - seurat_filename, - DatasetArtifactType.RDS, - key_prefix, - dataset_version_id, - artifact_bucket, - DatasetStatusKey.RDS, - datasets_bucket=datasets_bucket, - ) - - @logit - def make_seurat(self, local_filename, *args, **kwargs): - """ - Create a Seurat rds file from the AnnData file. - """ - try: - completed_process = subprocess.run( - ["Rscript", "-e", "\"installed.packages()[, c('Package', 'Version')]\""], capture_output=True - ) - logger.debug({"stdout": completed_process.stdout, "args": completed_process.args}) - - subprocess.run( - [ - "Rscript", - os.path.join(os.path.abspath(os.path.dirname(__file__)), "make_seurat.R"), - local_filename, - ], - capture_output=True, - check=True, - ) - except subprocess.CalledProcessError as ex: - msg = f"Seurat conversion failed: {ex.output} {ex.stderr}" - self.logger.exception(msg) - raise RuntimeError(msg) from ex - - return local_filename.replace(".h5ad", ".rds") diff --git a/backend/layers/processing/process_validate.py b/backend/layers/processing/process_validate.py index dd3987470e156..e372f34e7e460 100644 --- a/backend/layers/processing/process_validate.py +++ b/backend/layers/processing/process_validate.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional import numpy import scanpy @@ -34,7 +34,7 @@ class ProcessValidate(ProcessingLogic): 3. Save and upload a labeled copy of the original artifact (local.h5ad) 5. Persist the dataset metadata on the database 6. Determine if a Seurat conversion is possible (it is not if the X matrix has more than 2**32-1 nonzero values) - If this step completes successfully, ProcessCxg and ProcessSeurat will start in parallel. + If this step completes successfully, ProcessCxg will start in parallel. If this step fails, the handle_failures lambda will be invoked. """ @@ -56,13 +56,13 @@ def __init__( @logit def validate_h5ad_file_and_add_labels( self, collection_version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, local_filename: str - ) -> Tuple[str, bool]: + ) -> str: """ Validates and labels the specified dataset file and updates the processing status in the database :param dataset_version_id: version ID of the dataset to update :param collection_version_id: version ID of the collection dataset is being uploaded to :param local_filename: file name of the dataset to validate and label - :return: file name of labeled dataset, boolean indicating if seurat conversion is possible + :return: file name of labeled dataset """ # TODO: use a provider here @@ -72,9 +72,7 @@ def validate_h5ad_file_and_add_labels( output_filename = CorporaConstants.LABELED_H5AD_ARTIFACT_FILENAME try: - is_valid, errors, can_convert_to_seurat = self.schema_validator.validate_and_save_labels( - local_filename, output_filename - ) + is_valid, errors, _ = self.schema_validator.validate_and_save_labels(local_filename, output_filename) except Exception as e: self.logger.exception("validation failed") raise ValidationFailed([str(e)]) from None @@ -89,7 +87,7 @@ def validate_h5ad_file_and_add_labels( self.update_processing_status( dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.VALID ) - return output_filename, can_convert_to_seurat + return output_filename def populate_dataset_citation( self, collection_version_id: CollectionVersionId, dataset_version_id: DatasetVersionId, adata_path: str @@ -239,16 +237,15 @@ def process( self.download_from_s3(artifact_bucket, object_key, original_h5ad_artifact_file_name) # Validate and label the dataset - file_with_labels, can_convert_to_seurat = self.validate_h5ad_file_and_add_labels( + file_with_labels = self.validate_h5ad_file_and_add_labels( collection_version_id, dataset_version_id, original_h5ad_artifact_file_name ) # Process metadata metadata = self.extract_metadata(file_with_labels) self.business_logic.set_dataset_metadata(dataset_version_id, metadata) - if not can_convert_to_seurat: - self.update_processing_status(dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.SKIPPED) - self.logger.info(f"Skipping Seurat conversion for dataset {dataset_version_id}") + # Skip seurat conversion + self.update_processing_status(dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.SKIPPED) # Upload the labeled dataset to the artifact bucket self.create_artifact( diff --git a/backend/layers/processing/schema_migration.py b/backend/layers/processing/schema_migration.py index 8742b5382efa3..2ac25b0c1f17e 100644 --- a/backend/layers/processing/schema_migration.py +++ b/backend/layers/processing/schema_migration.py @@ -15,11 +15,7 @@ CollectionVersion, CollectionVersionId, DatasetArtifactType, - DatasetConversionStatus, DatasetProcessingStatus, - DatasetStatusKey, - DatasetUploadStatus, - DatasetValidationStatus, DatasetVersionId, ) from backend.layers.processing import logger @@ -232,25 +228,6 @@ def log_errors_and_cleanup(self, collection_version_id: str) -> list: key_prefix = self.get_key_prefix(previous_dataset_version_id) object_keys_to_delete.append(f"{key_prefix}/migrated.h5ad") if dataset.status.processing_status != DatasetProcessingStatus.SUCCESS: - # If only rds failure, set rds status to skipped + processing status to successful and do not rollback - if ( - dataset.status.rds_status == DatasetConversionStatus.FAILED - and dataset.status.upload_status == DatasetUploadStatus.UPLOADED - and dataset.status.validation_status == DatasetValidationStatus.VALID - and dataset.status.cxg_status == DatasetConversionStatus.UPLOADED - and dataset.status.h5ad_status == DatasetConversionStatus.UPLOADED - ): - self.business_logic.update_dataset_version_status( - dataset.version_id, - DatasetStatusKey.RDS, - DatasetConversionStatus.SKIPPED, - ) - self.business_logic.update_dataset_version_status( - dataset.version_id, - DatasetStatusKey.PROCESSING, - DatasetProcessingStatus.SUCCESS, - ) - continue error = { "message": dataset.status.validation_message, "dataset_status": dataset.status.to_dict(), diff --git a/backend/layers/processing/upload_success/app.py b/backend/layers/processing/upload_success/app.py index 6896fd5cfa4a1..7948f6b2fd578 100644 --- a/backend/layers/processing/upload_success/app.py +++ b/backend/layers/processing/upload_success/app.py @@ -19,13 +19,11 @@ def success_handler(events: dict, context) -> None: :param context: Lambda's context object :return: """ - cxg_job, seurat_job = events["cxg_job"], events["seurat_job"] - cxg_job["execution_id"], seurat_job["execution_id"] = events["execution_id"], events["execution_id"] + cxg_job = events["cxg_job"] + cxg_job["execution_id"] = events["execution_id"] if cxg_job.get("error"): handle_failure(cxg_job, context) - elif seurat_job.get("error"): - handle_failure(seurat_job, context, delete_artifacts=False) else: business_logic.update_dataset_version_status( DatasetVersionId(cxg_job["dataset_version_id"]), diff --git a/scripts/cxg_admin.py b/scripts/cxg_admin.py index 073a0a98d993b..37ba3b70222e8 100755 --- a/scripts/cxg_admin.py +++ b/scripts/cxg_admin.py @@ -238,19 +238,7 @@ def refresh_preprint_doi(ctx): updates.refresh_preprint_doi(ctx) -# Commands to reprocess dataset artifacts (seurat or cxg) - - -@cli.command() -@click.argument("dataset_id") -@click.pass_context -def reprocess_seurat(ctx: click.Context, dataset_id: str) -> None: - """ - Reconverts the specified dataset to Seurat format in place. - :param ctx: command context - :param dataset_id: ID of dataset to reconvert to Seurat format - """ - reprocess_datafile.reprocess_seurat(ctx, dataset_id) +# Commands to reprocess cxg dataset artifacts @cli.command() diff --git a/scripts/cxg_admin_scripts/reprocess_datafile.py b/scripts/cxg_admin_scripts/reprocess_datafile.py index 8c0aed6039ef7..efbecf7f36269 100644 --- a/scripts/cxg_admin_scripts/reprocess_datafile.py +++ b/scripts/cxg_admin_scripts/reprocess_datafile.py @@ -1,12 +1,8 @@ -import json import logging import os import sys -from time import time import boto3 -import click -from click import Context pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "...")) # noqa sys.path.insert(0, pkg_root) # noqa @@ -33,38 +29,3 @@ def get_happy_stack_name(deployment) -> str: def cxg_remaster(ctx): """Cxg remaster v2""" pass - - -def reprocess_seurat(ctx: Context, dataset_id: str) -> None: - """ - Reconverts the specified dataset to Seurat format in place. - :param ctx: command context - :param dataset_id: ID of dataset to reconvert to Seurat format - """ - - deployment = ctx.obj["deployment"] - - click.confirm( - f"Are you sure you want to run this script? " - f"It will reconvert and replace the dataset {dataset_id} to Seurat in the {deployment} environment.", - abort=True, - ) - - aws_account_id = get_aws_account_id() - deployment = ctx.obj["deployment"] - happy_stack_name = get_happy_stack_name(deployment) - - payload = {"dataset_id": dataset_id} - - client = boto3.client("stepfunctions") - response = client.start_execution( - stateMachineArn=f"arn:aws:states:us-west-2:{aws_account_id}:stateMachine:dp-{happy_stack_name}-seurat-sfn", - name=f"{dataset_id}-{int(time())}", - input=json.dumps(payload), - ) - - click.echo( - f"Step function executing: " - f"https://us-west-2.console.aws.amazon.com/states/home?region=us-west-2#/executions/details/" - f"{response['executionArn']}" - ) diff --git a/tests/memory/processing/test_process_seurat.py b/tests/memory/processing/test_process_seurat.py deleted file mode 100644 index 0535a38d06096..0000000000000 --- a/tests/memory/processing/test_process_seurat.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -This script is used to test the ProcessCxg class. -""" - -import shutil -import tempfile - -from backend.layers.common.entities import DatasetVersionId -from backend.layers.processing.process_seurat import ProcessSeurat -from tests.unit.backend.fixtures.environment_setup import fixture_file_path - -if __name__ == "__main__": - file_name = "labeled_visium.h5ad" - dataset_version_id = DatasetVersionId("test_dataset_id") - with tempfile.TemporaryDirectory() as tmpdirname: - temp_file = "/".join([tmpdirname, file_name]) - shutil.copy(fixture_file_path(file_name), temp_file) - process = ProcessSeurat(None, None, None) - process.make_seurat(temp_file, dataset_version_id) diff --git a/tests/unit/backend/layers/utils/test_aws.py b/tests/unit/backend/layers/utils/test_aws.py index 7d5db8d295b59..d2dda6d418c0b 100644 --- a/tests/unit/backend/layers/utils/test_aws.py +++ b/tests/unit/backend/layers/utils/test_aws.py @@ -15,12 +15,10 @@ def setUp(self) -> None: super().setUp() self.tmp_dir = tempfile.mkdtemp() self.h5ad_filename = pathlib.Path(self.tmp_dir, "test.h5ad") - self.seurat_filename = pathlib.Path(self.tmp_dir, "test.rds") self.cxg_filename = pathlib.Path(self.tmp_dir, "test.cxg") self.h5ad_filename.touch() self.cxg_filename.touch() - self.seurat_filename.touch() # Mock S3 service if we don't have a mock api already running if os.getenv("BOTO_ENDPOINT_URL"): diff --git a/tests/unit/processing/test_dataset_metadata_update.py b/tests/unit/processing/test_dataset_metadata_update.py index cf791205c5323..a86b2b5cf0a81 100644 --- a/tests/unit/processing/test_dataset_metadata_update.py +++ b/tests/unit/processing/test_dataset_metadata_update.py @@ -1,7 +1,5 @@ import json -import os import tempfile -from shutil import copy2 from unittest.mock import Mock, patch import pytest @@ -27,12 +25,10 @@ from backend.layers.processing.exceptions import ProcessingFailed from backend.layers.processing.utils.cxg_generation_utils import convert_dictionary_to_cxg_group from backend.layers.thirdparty.s3_provider_mock import MockS3Provider -from tests.unit.backend.fixtures.environment_setup import fixture_file_path from tests.unit.backend.layers.common.base_test import DatasetArtifactUpdate, DatasetStatusUpdate from tests.unit.processing.base_processing_test import BaseProcessingTest base = importr("base") -seurat = importr("SeuratObject") def mock_process(target, args=()): @@ -86,7 +82,6 @@ def test_update_metadata(self, mock_worker_factory, *args): # skip raw_h5ad update since no updated fields are expected fields in raw H5AD mock_worker.update_raw_h5ad.assert_not_called() mock_worker.update_h5ad.assert_called_once() - mock_worker.update_rds.assert_called_once() mock_worker.update_cxg.assert_called_once() # check that collection version maps to dataset version with updated metadata @@ -99,6 +94,9 @@ def test_update_metadata(self, mock_worker_factory, *args): assert new_dataset_version.status.upload_status == DatasetUploadStatus.UPLOADED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.SUCCESS + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + assert self.updater.s3_provider.uri_exists(f"s3://artifact_bucket/{new_dataset_version_id}/raw.h5ad") @patch("backend.common.utils.dl_sources.uri.downloader") @@ -135,7 +133,6 @@ def test_update_metadata__rds_skipped(self, mock_worker_factory, *args): mock_worker.update_raw_h5ad.assert_not_called() mock_worker.update_h5ad.assert_called_once() - mock_worker.update_rds.assert_not_called() mock_worker.update_cxg.assert_called_once() # check that collection version maps to dataset version with updated metadata @@ -148,6 +145,9 @@ def test_update_metadata__rds_skipped(self, mock_worker_factory, *args): assert new_dataset_version.status.upload_status == DatasetUploadStatus.UPLOADED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.SUCCESS + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + assert self.updater.s3_provider.uri_exists(f"s3://artifact_bucket/{new_dataset_version_id}/raw.h5ad") @patch("backend.common.utils.dl_sources.uri.downloader") @@ -179,13 +179,15 @@ def test_update_metadata__raw_h5ad_updated(self, mock_worker_factory, *args): mock_worker.update_raw_h5ad.assert_called_once() mock_worker.update_h5ad.assert_called_once() - mock_worker.update_rds.assert_called_once() mock_worker.update_cxg.assert_called_once() # check that collection version maps to dataset version with updated metadata collection_version = self.business_logic.get_collection_version(collection_version_id) new_dataset_version = collection_version.datasets[0] + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + assert new_dataset_version.status.processing_status == DatasetProcessingStatus.SUCCESS def test_update_metadata__current_dataset_version_bad_processing_status(self, *args): @@ -284,40 +286,8 @@ def test_update_metadata__missing_labeled_h5ad(self, *args): assert new_dataset_version.status.h5ad_status == DatasetConversionStatus.FAILED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.FAILURE - @patch("backend.common.utils.dl_sources.uri.downloader") - @patch("scanpy.read_h5ad") - @patch("backend.layers.processing.dataset_metadata_update.S3Provider", Mock(side_effect=MockS3Provider)) - @patch("backend.layers.processing.dataset_metadata_update.DatabaseProvider", Mock(side_effect=DatabaseProviderMock)) - @patch("backend.layers.processing.dataset_metadata_update.DatasetMetadataUpdater") - def test_update_metadata__missing_rds(self, *args): - current_dataset_version = self.generate_dataset( - artifacts=[ - DatasetArtifactUpdate(DatasetArtifactType.RAW_H5AD, "s3://fake.bucket/raw.h5ad"), - DatasetArtifactUpdate(DatasetArtifactType.H5AD, "s3://fake.bucket/local.h5ad"), - DatasetArtifactUpdate(DatasetArtifactType.CXG, "s3://fake.bucket/local.cxg"), - ], - statuses=[ - DatasetStatusUpdate(status_key=DatasetStatusKey.PROCESSING, status=DatasetProcessingStatus.SUCCESS), - DatasetStatusUpdate(status_key=DatasetStatusKey.RDS, status=DatasetConversionStatus.CONVERTED), - ], - ) - collection_version_id = CollectionVersionId(current_dataset_version.collection_version_id) - current_dataset_version_id = DatasetVersionId(current_dataset_version.dataset_version_id) - new_dataset_version_id, _ = self.business_logic.ingest_dataset( - collection_version_id=collection_version_id, - url=None, - file_size=0, - current_dataset_version_id=current_dataset_version_id, - start_step_function=False, - ) - - with pytest.raises(ProcessingFailed): - self.updater.update_metadata(current_dataset_version_id, new_dataset_version_id, None) - - new_dataset_version = self.business_logic.get_dataset_version(new_dataset_version_id) - - assert new_dataset_version.status.rds_status == DatasetConversionStatus.FAILED - assert new_dataset_version.status.processing_status == DatasetProcessingStatus.FAILURE + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED @patch("backend.common.utils.dl_sources.uri.downloader") @patch("scanpy.read_h5ad") @@ -354,6 +324,9 @@ def test_update_metadata__missing_cxg(self, *args): assert new_dataset_version.status.cxg_status == DatasetConversionStatus.FAILED assert new_dataset_version.status.processing_status == DatasetProcessingStatus.FAILURE + # RDS should be skipped + assert new_dataset_version.status.rds_status == DatasetConversionStatus.SKIPPED + @patch("backend.common.utils.dl_sources.uri.downloader") @patch("scanpy.read_h5ad") @patch("backend.layers.processing.dataset_metadata_update.DatasetMetadataUpdater") @@ -599,44 +572,6 @@ def test_update_cxg__with_spatial_deepzoom_assets(self): f"s3://{self.updater.spatial_deep_zoom_dir}/{new_dataset_version_id.id}" ) - @patch("backend.layers.processing.dataset_metadata_update.os.remove") - def test_update_rds(self, *args): - with tempfile.TemporaryDirectory() as tempdir: - temp_path = os.path.join(tempdir, "test.rds") - copy2(fixture_file_path("test.rds"), temp_path) - self.updater.download_from_source_uri = Mock(return_value=temp_path) - - collection_version = self.generate_unpublished_collection(add_datasets=1) - current_dataset_version = collection_version.datasets[0] - new_dataset_version_id, _ = self.business_logic.ingest_dataset( - collection_version_id=collection_version.version_id, - url=None, - file_size=0, - current_dataset_version_id=current_dataset_version.version_id, - start_step_function=False, - ) - key_prefix = new_dataset_version_id.id - metadata_update_dict = DatasetArtifactMetadataUpdate(title="New Dataset Title") - - self.updater.update_rds(None, key_prefix, new_dataset_version_id, metadata_update_dict) - - # check Seurat object metadata is updated - seurat_object = base.readRDS(temp_path) - assert seurat.Misc(object=seurat_object, slot="title")[0] == "New Dataset Title" - # schema_version should stay the same as base fixture after update of other metadata - assert seurat.Misc(object=seurat_object, slot="schema_version")[0] == "3.1.0" - - # check new artifacts are uploaded in expected uris - assert self.updater.s3_provider.uri_exists(f"s3://artifact_bucket/{new_dataset_version_id.id}/test.rds") - assert self.updater.s3_provider.uri_exists(f"s3://datasets_bucket/{new_dataset_version_id.id}.rds") - - # check artifacts + status updated in DB - new_dataset_version = self.business_logic.get_dataset_version(new_dataset_version_id) - artifacts = [(artifact.uri, artifact.type) for artifact in new_dataset_version.artifacts] - assert (f"s3://artifact_bucket/{new_dataset_version_id.id}/test.rds", DatasetArtifactType.RDS) in artifacts - - assert new_dataset_version.status.rds_status == DatasetConversionStatus.CONVERTED - class TestValidArtifactStatuses(BaseProcessingTest): def setUp(self): diff --git a/tests/unit/processing/test_processing.py b/tests/unit/processing/test_processing.py index d733a8680713f..b49b79c0cbc4d 100644 --- a/tests/unit/processing/test_processing.py +++ b/tests/unit/processing/test_processing.py @@ -8,38 +8,10 @@ ) from backend.layers.processing.process import ProcessMain from backend.layers.processing.process_cxg import ProcessCxg -from backend.layers.processing.process_seurat import ProcessSeurat from tests.unit.processing.base_processing_test import BaseProcessingTest class ProcessingTest(BaseProcessingTest): - @patch("anndata.read_h5ad") - @patch("backend.layers.processing.process_seurat.ProcessSeurat.make_seurat") - def test_process_seurat_success(self, mock_seurat, mock_anndata_read_h5ad): - collection = self.generate_unpublished_collection() - dataset_version_id, dataset_id = self.business_logic.ingest_dataset( - collection.version_id, "nothing", None, None - ) - - mock_anndata = MagicMock(uns=dict(), n_obs=1000, n_vars=1000) - mock_anndata_read_h5ad.return_value = mock_anndata - - mock_seurat.return_value = "local.rds" - ps = ProcessSeurat(self.business_logic, self.uri_provider, self.s3_provider) - ps.process(dataset_version_id, "fake_bucket_name", "fake_datasets_bucket") - - status = self.business_logic.get_dataset_status(dataset_version_id) - self.assertEqual(status.rds_status, DatasetConversionStatus.UPLOADED) - - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds")) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.rds")) - - artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) - self.assertEqual(1, len(artifacts)) - artifact = artifacts[0] - artifact.type = "RDS" - artifact.uri = f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds" - def test_process_cxg_success(self): collection = self.generate_unpublished_collection() dataset_version_id, dataset_id = self.business_logic.ingest_dataset( @@ -96,12 +68,8 @@ def test_reprocess_cxg_success(self): @patch("scanpy.read_h5ad") @patch("anndata.read_h5ad") @patch("backend.layers.processing.process_validate.ProcessValidate.extract_metadata") - @patch("backend.layers.processing.process_seurat.ProcessSeurat.make_seurat") @patch("backend.layers.processing.process_cxg.ProcessCxg.make_cxg") - def test_process_all( - self, mock_cxg, mock_seurat, mock_h5ad, mock_anndata_read_h5ad, mock_scanpy_read_h5ad, mock_sfn_provider - ): - mock_seurat.return_value = "local.rds" + def test_process_all(self, mock_cxg, mock_h5ad, mock_anndata_read_h5ad, mock_scanpy_read_h5ad, mock_sfn_provider): mock_cxg.return_value = "local.cxg" # Mock anndata object @@ -116,7 +84,7 @@ def test_process_all( ) pm = ProcessMain(self.business_logic, self.uri_provider, self.s3_provider, self.schema_validator) - for step_name in ["download", "validate", "cxg", "seurat"]: + for step_name in ["download", "validate", "cxg"]: assert pm.process( collection.version_id, dataset_version_id, @@ -130,13 +98,13 @@ def test_process_all( self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/raw.h5ad")) self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.h5ad")) self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.h5ad")) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds")) - self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.rds")) + self.assertFalse(self.s3_provider.uri_exists(f"s3://fake_bucket_name/{dataset_version_id.id}/local.rds")) + self.assertFalse(self.s3_provider.uri_exists(f"s3://fake_datasets_bucket/{dataset_version_id.id}.rds")) self.assertTrue(self.s3_provider.uri_exists(f"s3://fake_cxg_bucket/{dataset_version_id.id}.cxg/")) status = self.business_logic.get_dataset_status(dataset_version_id) self.assertEqual(status.cxg_status, DatasetConversionStatus.UPLOADED) - self.assertEqual(status.rds_status, DatasetConversionStatus.UPLOADED) + self.assertEqual(status.rds_status, DatasetConversionStatus.SKIPPED) self.assertEqual(status.h5ad_status, DatasetConversionStatus.UPLOADED) self.assertEqual(status.validation_status, DatasetValidationStatus.VALID) self.assertEqual(status.upload_status, DatasetUploadStatus.UPLOADED) @@ -144,4 +112,4 @@ def test_process_all( # TODO: DatasetProcessingStatus.SUCCESS is set by a lambda that also needs to be modified. It should belong here artifacts = list(self.business_logic.get_dataset_artifacts(dataset_version_id)) - self.assertEqual(4, len(artifacts)) + self.assertEqual(3, len(artifacts))