Skip to content

Commit

Permalink
Merge branch 'main' into arpit/feature/implement-chat-endpoint-switch
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurarpit committed May 16, 2024
2 parents 5b4221e + f27b68e commit 2eecd7b
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 27 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
"python.testing.cwd": "${workspaceFolder}/code",
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"pylint.path" : [ "${interpreter}", "-m", "pylint" ]
}
42 changes: 29 additions & 13 deletions code/backend/batch/batch_push_results.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import os
import logging
import json
import azure.functions as func
from urllib.parse import urlparse
import azure.functions as func

from utilities.helpers.azure_blob_storage_client import AzureBlobStorageClient
from utilities.helpers.env_helper import EnvHelper
from utilities.helpers.embedders.embedder_factory import EmbedderFactory
from utilities.search.search import Search

bp_batch_push_results = func.Blueprint()
logger = logging.getLogger(__name__)
logger.setLevel(level=os.environ.get("LOGLEVEL", "INFO").upper())


def _get_file_name_from_message(msg: func.QueueMessage) -> str:
message_body = json.loads(msg.get_body().decode("utf-8"))
def _get_file_name_from_message(message_body) -> str:
return message_body.get(
"filename",
"/".join(
Expand All @@ -27,21 +27,37 @@ def _get_file_name_from_message(msg: func.QueueMessage) -> str:
arg_name="msg", queue_name="doc-processing", connection="AzureWebJobsStorage"
)
def batch_push_results(msg: func.QueueMessage) -> None:
do_batch_push_results(msg)
message_body = json.loads(msg.get_body().decode("utf-8"))
logger.debug("Process Document Event queue function triggered: %s", message_body)

event_type = message_body.get("eventType", "")
# We handle "" in this scenario for backwards compatibility
# This function is primarily triggered by an Event Grid queue message from the blob storage
# However, it can also be triggered using a legacy schema from BatchStartProcessing
if event_type in ("", "Microsoft.Storage.BlobCreated"):
_process_document_created_event(message_body)

elif event_type == "Microsoft.Storage.BlobDeleted":
_process_document_deleted_event(message_body)

else:
raise NotImplementedError(f"Unknown event type received: {event_type}")


def do_batch_push_results(msg: func.QueueMessage) -> None:
def _process_document_created_event(message_body) -> None:
env_helper: EnvHelper = EnvHelper()
logger.info(
"Python queue trigger function processed a queue item: %s",
msg.get_body().decode("utf-8"),
)

blob_client = AzureBlobStorageClient()
# Get the file name from the message
file_name = _get_file_name_from_message(msg)
# Generate the SAS URL for the file
file_name = _get_file_name_from_message(message_body)
file_sas = blob_client.get_blob_sas(file_name)
# Process the file

embedder = EmbedderFactory.create(env_helper)
embedder.embed_file(file_sas, file_name)


def _process_document_deleted_event(message_body) -> None:
env_helper: EnvHelper = EnvHelper()
search_handler = Search.get_search_handler(env_helper)

blob_url = message_body.get("data", {}).get("url", "")
search_handler.delete_by_source(f"{blob_url}_SAS_TOKEN_PLACEHOLDER_")
3 changes: 2 additions & 1 deletion code/backend/batch/function_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
logging.captureWarnings(True)
# Raising the azure log level to WARN as it is too verbose - https://github.com/Azure/azure-sdk-for-python/issues/9422
logging.getLogger("azure").setLevel(os.environ.get("LOGLEVEL_AZURE", "WARN").upper())
configure_azure_monitor()
if os.getenv("APPLICATIONINSIGHTS_ENABLED", "false").lower() == "true":
configure_azure_monitor()

app = func.FunctionApp(
http_auth_level=func.AuthLevel.FUNCTION
Expand Down
26 changes: 24 additions & 2 deletions code/backend/batch/utilities/search/search_handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,35 @@ def get_files(self):
pass

@abstractmethod
def output_results(self, results, id_field):
def output_results(self, results):
pass

@abstractmethod
def delete_files(self, files, id_field):
def delete_files(self, files):
pass

@abstractmethod
def query_search(self, question) -> list[SourceDocument]:
pass

def delete_by_source(self, source) -> None:
if source is None:
return

documents = self._get_documents_by_source(source)
if documents is None:
return

files_to_delete = self.output_results(documents)
self.delete_files(files_to_delete)

def _get_documents_by_source(self, source):
if source is None:
return None

return self.search_client.search(
"*",
select="id, title",
include_total_count=True,
filter=f"source eq '{source}'",
)
98 changes: 87 additions & 11 deletions code/tests/test_batch_push_results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import sys
import os
import pytest
Expand All @@ -15,17 +16,22 @@

@pytest.fixture(autouse=True)
def get_processor_handler_mock():
with patch("backend.batch.batch_push_results.EmbedderFactory.create") as mock:
processor_handler = mock.return_value
yield processor_handler
with patch(
"backend.batch.batch_push_results.EmbedderFactory.create"
) as mock_create_embedder, patch(
"backend.batch.batch_push_results.Search.get_search_handler"
) as mock_get_search_handler:
processor_handler_create = mock_create_embedder.return_value
processor_handler_get_search_handler = mock_get_search_handler.return_value
yield processor_handler_create, processor_handler_get_search_handler


def test_get_file_name_from_message():
mock_queue_message = QueueMessage(
body='{"message": "test message", "filename": "test_filename.md"}'
)

file_name = _get_file_name_from_message(mock_queue_message)
message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
file_name = _get_file_name_from_message(message_body)

assert file_name == "test_filename.md"

Expand All @@ -34,25 +40,95 @@ def test_get_file_name_from_message_no_filename():
mock_queue_message = QueueMessage(
body='{"data": { "url": "test/test/test_filename.md"} }'
)

file_name = _get_file_name_from_message(mock_queue_message)
message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
file_name = _get_file_name_from_message(message_body)

assert file_name == "test_filename.md"


def test_batch_push_results_with_unhandled_event_type():
mock_queue_message = QueueMessage(
body='{"eventType": "Microsoft.Storage.BlobUpdated"}'
)

with pytest.raises(NotImplementedError):
batch_push_results.build().get_user_function()(mock_queue_message)


@patch("backend.batch.batch_push_results._process_document_created_event")
def test_batch_push_results_with_blob_created_event(
mock_process_document_created_event,
):
mock_queue_message = QueueMessage(
body='{"eventType": "Microsoft.Storage.BlobCreated", "filename": "test/test/test_filename.md"}'
)

batch_push_results.build().get_user_function()(mock_queue_message)

expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
mock_process_document_created_event.assert_called_once_with(expected_message_body)


@patch("backend.batch.batch_push_results._process_document_created_event")
def test_batch_push_results_with_no_event(mock_process_document_created_event):
mock_queue_message = QueueMessage(
body='{"data": { "url": "test/test/test_filename.md"} }'
)

batch_push_results.build().get_user_function()(mock_queue_message)

expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
mock_process_document_created_event.assert_called_once_with(expected_message_body)


@patch("backend.batch.batch_push_results._process_document_deleted_event")
def test_batch_push_results_with_blob_deleted_event(
mock_process_document_deleted_event,
):
mock_queue_message = QueueMessage(
body='{"eventType": "Microsoft.Storage.BlobDeleted", "filename": "test/test/test_filename.md"}'
)

batch_push_results.build().get_user_function()(mock_queue_message)

expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
mock_process_document_deleted_event.assert_called_once_with(expected_message_body)


@patch("backend.batch.batch_push_results.EnvHelper")
@patch("backend.batch.batch_push_results.AzureBlobStorageClient")
def test_batch_push_results(
mock_azure_blob_storage_client, mock_env_helper, get_processor_handler_mock
def test_batch_push_results_with_blob_created_event_uses_embedder(
mock_azure_blob_storage_client,
mock_env_helper,
get_processor_handler_mock,
):
mock_create_embedder, mock_get_search_handler = get_processor_handler_mock

mock_queue_message = QueueMessage(
body='{"message": "test message", "filename": "test/test/test_filename.md"}'
body='{"eventType": "Microsoft.Storage.BlobCreated", "filename": "test/test/test_filename.md"}'
)

mock_blob_client_instance = mock_azure_blob_storage_client.return_value
mock_blob_client_instance.get_blob_sas.return_value = "test_blob_sas"

batch_push_results.build().get_user_function()(mock_queue_message)
get_processor_handler_mock.embed_file.assert_called_once_with(
mock_create_embedder.embed_file.assert_called_once_with(
"test_blob_sas", "test/test/test_filename.md"
)


@patch("backend.batch.batch_push_results.EnvHelper")
def test_batch_push_results_with_blob_deleted_event_uses_search_to_delete_with_sas_appended(
mock_env_helper,
get_processor_handler_mock,
):
mock_create_embedder, mock_get_search_handler = get_processor_handler_mock

mock_queue_message = QueueMessage(
body='{"eventType": "Microsoft.Storage.BlobDeleted", "data": { "url": "https://test.test/test/test_filename.pdf"}}'
)

batch_push_results.build().get_user_function()(mock_queue_message)
mock_get_search_handler.delete_by_source.assert_called_once_with(
"https://test.test/test/test_filename.pdf_SAS_TOKEN_PLACEHOLDER_"
)

0 comments on commit 2eecd7b

Please sign in to comment.