diff --git a/.vscode/settings.json b/.vscode/settings.json index bb92a125e..4efc47061 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,4 +21,5 @@ "python.testing.cwd": "${workspaceFolder}/code", "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, + "pylint.path" : [ "${interpreter}", "-m", "pylint" ] } diff --git a/code/backend/batch/batch_push_results.py b/code/backend/batch/batch_push_results.py index 18859f573..4058b96ef 100644 --- a/code/backend/batch/batch_push_results.py +++ b/code/backend/batch/batch_push_results.py @@ -1,20 +1,20 @@ import os import logging import json -import azure.functions as func from urllib.parse import urlparse +import azure.functions as func from utilities.helpers.azure_blob_storage_client import AzureBlobStorageClient from utilities.helpers.env_helper import EnvHelper from utilities.helpers.embedders.embedder_factory import EmbedderFactory +from utilities.search.search import Search bp_batch_push_results = func.Blueprint() logger = logging.getLogger(__name__) logger.setLevel(level=os.environ.get("LOGLEVEL", "INFO").upper()) -def _get_file_name_from_message(msg: func.QueueMessage) -> str: - message_body = json.loads(msg.get_body().decode("utf-8")) +def _get_file_name_from_message(message_body) -> str: return message_body.get( "filename", "/".join( @@ -27,21 +27,37 @@ def _get_file_name_from_message(msg: func.QueueMessage) -> str: arg_name="msg", queue_name="doc-processing", connection="AzureWebJobsStorage" ) def batch_push_results(msg: func.QueueMessage) -> None: - do_batch_push_results(msg) + message_body = json.loads(msg.get_body().decode("utf-8")) + logger.debug("Process Document Event queue function triggered: %s", message_body) + + event_type = message_body.get("eventType", "") + # We handle "" in this scenario for backwards compatibility + # This function is primarily triggered by an Event Grid queue message from the blob storage + # However, it can also be triggered using a legacy schema from BatchStartProcessing + if event_type in ("", "Microsoft.Storage.BlobCreated"): + _process_document_created_event(message_body) + + elif event_type == "Microsoft.Storage.BlobDeleted": + _process_document_deleted_event(message_body) + + else: + raise NotImplementedError(f"Unknown event type received: {event_type}") -def do_batch_push_results(msg: func.QueueMessage) -> None: +def _process_document_created_event(message_body) -> None: env_helper: EnvHelper = EnvHelper() - logger.info( - "Python queue trigger function processed a queue item: %s", - msg.get_body().decode("utf-8"), - ) blob_client = AzureBlobStorageClient() - # Get the file name from the message - file_name = _get_file_name_from_message(msg) - # Generate the SAS URL for the file + file_name = _get_file_name_from_message(message_body) file_sas = blob_client.get_blob_sas(file_name) - # Process the file + embedder = EmbedderFactory.create(env_helper) embedder.embed_file(file_sas, file_name) + + +def _process_document_deleted_event(message_body) -> None: + env_helper: EnvHelper = EnvHelper() + search_handler = Search.get_search_handler(env_helper) + + blob_url = message_body.get("data", {}).get("url", "") + search_handler.delete_by_source(f"{blob_url}_SAS_TOKEN_PLACEHOLDER_") diff --git a/code/backend/batch/function_app.py b/code/backend/batch/function_app.py index 5eca04a1e..b2756c751 100644 --- a/code/backend/batch/function_app.py +++ b/code/backend/batch/function_app.py @@ -10,7 +10,8 @@ logging.captureWarnings(True) # Raising the azure log level to WARN as it is too verbose - https://github.com/Azure/azure-sdk-for-python/issues/9422 logging.getLogger("azure").setLevel(os.environ.get("LOGLEVEL_AZURE", "WARN").upper()) -configure_azure_monitor() +if os.getenv("APPLICATIONINSIGHTS_ENABLED", "false").lower() == "true": + configure_azure_monitor() app = func.FunctionApp( http_auth_level=func.AuthLevel.FUNCTION diff --git a/code/backend/batch/utilities/search/search_handler_base.py b/code/backend/batch/utilities/search/search_handler_base.py index 4a937d7b9..5e3443e5c 100644 --- a/code/backend/batch/utilities/search/search_handler_base.py +++ b/code/backend/batch/utilities/search/search_handler_base.py @@ -36,13 +36,35 @@ def get_files(self): pass @abstractmethod - def output_results(self, results, id_field): + def output_results(self, results): pass @abstractmethod - def delete_files(self, files, id_field): + def delete_files(self, files): pass @abstractmethod def query_search(self, question) -> list[SourceDocument]: pass + + def delete_by_source(self, source) -> None: + if source is None: + return + + documents = self._get_documents_by_source(source) + if documents is None: + return + + files_to_delete = self.output_results(documents) + self.delete_files(files_to_delete) + + def _get_documents_by_source(self, source): + if source is None: + return None + + return self.search_client.search( + "*", + select="id, title", + include_total_count=True, + filter=f"source eq '{source}'", + ) diff --git a/code/tests/test_batch_push_results.py b/code/tests/test_batch_push_results.py index b7c39c267..5350d901d 100644 --- a/code/tests/test_batch_push_results.py +++ b/code/tests/test_batch_push_results.py @@ -1,3 +1,4 @@ +import json import sys import os import pytest @@ -15,17 +16,22 @@ @pytest.fixture(autouse=True) def get_processor_handler_mock(): - with patch("backend.batch.batch_push_results.EmbedderFactory.create") as mock: - processor_handler = mock.return_value - yield processor_handler + with patch( + "backend.batch.batch_push_results.EmbedderFactory.create" + ) as mock_create_embedder, patch( + "backend.batch.batch_push_results.Search.get_search_handler" + ) as mock_get_search_handler: + processor_handler_create = mock_create_embedder.return_value + processor_handler_get_search_handler = mock_get_search_handler.return_value + yield processor_handler_create, processor_handler_get_search_handler def test_get_file_name_from_message(): mock_queue_message = QueueMessage( body='{"message": "test message", "filename": "test_filename.md"}' ) - - file_name = _get_file_name_from_message(mock_queue_message) + message_body = json.loads(mock_queue_message.get_body().decode("utf-8")) + file_name = _get_file_name_from_message(message_body) assert file_name == "test_filename.md" @@ -34,25 +40,95 @@ def test_get_file_name_from_message_no_filename(): mock_queue_message = QueueMessage( body='{"data": { "url": "test/test/test_filename.md"} }' ) - - file_name = _get_file_name_from_message(mock_queue_message) + message_body = json.loads(mock_queue_message.get_body().decode("utf-8")) + file_name = _get_file_name_from_message(message_body) assert file_name == "test_filename.md" +def test_batch_push_results_with_unhandled_event_type(): + mock_queue_message = QueueMessage( + body='{"eventType": "Microsoft.Storage.BlobUpdated"}' + ) + + with pytest.raises(NotImplementedError): + batch_push_results.build().get_user_function()(mock_queue_message) + + +@patch("backend.batch.batch_push_results._process_document_created_event") +def test_batch_push_results_with_blob_created_event( + mock_process_document_created_event, +): + mock_queue_message = QueueMessage( + body='{"eventType": "Microsoft.Storage.BlobCreated", "filename": "test/test/test_filename.md"}' + ) + + batch_push_results.build().get_user_function()(mock_queue_message) + + expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8")) + mock_process_document_created_event.assert_called_once_with(expected_message_body) + + +@patch("backend.batch.batch_push_results._process_document_created_event") +def test_batch_push_results_with_no_event(mock_process_document_created_event): + mock_queue_message = QueueMessage( + body='{"data": { "url": "test/test/test_filename.md"} }' + ) + + batch_push_results.build().get_user_function()(mock_queue_message) + + expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8")) + mock_process_document_created_event.assert_called_once_with(expected_message_body) + + +@patch("backend.batch.batch_push_results._process_document_deleted_event") +def test_batch_push_results_with_blob_deleted_event( + mock_process_document_deleted_event, +): + mock_queue_message = QueueMessage( + body='{"eventType": "Microsoft.Storage.BlobDeleted", "filename": "test/test/test_filename.md"}' + ) + + batch_push_results.build().get_user_function()(mock_queue_message) + + expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8")) + mock_process_document_deleted_event.assert_called_once_with(expected_message_body) + + @patch("backend.batch.batch_push_results.EnvHelper") @patch("backend.batch.batch_push_results.AzureBlobStorageClient") -def test_batch_push_results( - mock_azure_blob_storage_client, mock_env_helper, get_processor_handler_mock +def test_batch_push_results_with_blob_created_event_uses_embedder( + mock_azure_blob_storage_client, + mock_env_helper, + get_processor_handler_mock, ): + mock_create_embedder, mock_get_search_handler = get_processor_handler_mock + mock_queue_message = QueueMessage( - body='{"message": "test message", "filename": "test/test/test_filename.md"}' + body='{"eventType": "Microsoft.Storage.BlobCreated", "filename": "test/test/test_filename.md"}' ) mock_blob_client_instance = mock_azure_blob_storage_client.return_value mock_blob_client_instance.get_blob_sas.return_value = "test_blob_sas" batch_push_results.build().get_user_function()(mock_queue_message) - get_processor_handler_mock.embed_file.assert_called_once_with( + mock_create_embedder.embed_file.assert_called_once_with( "test_blob_sas", "test/test/test_filename.md" ) + + +@patch("backend.batch.batch_push_results.EnvHelper") +def test_batch_push_results_with_blob_deleted_event_uses_search_to_delete_with_sas_appended( + mock_env_helper, + get_processor_handler_mock, +): + mock_create_embedder, mock_get_search_handler = get_processor_handler_mock + + mock_queue_message = QueueMessage( + body='{"eventType": "Microsoft.Storage.BlobDeleted", "data": { "url": "https://test.test/test/test_filename.pdf"}}' + ) + + batch_push_results.build().get_user_function()(mock_queue_message) + mock_get_search_handler.delete_by_source.assert_called_once_with( + "https://test.test/test/test_filename.pdf_SAS_TOKEN_PLACEHOLDER_" + )