Skip to content

Commit

Permalink
Add test for backend functions (Azure-Samples#644)
Browse files Browse the repository at this point in the history
* Add test for batch processing function

* Remove comment

* Try to remove sys path modifications

* Fix imports

* Fix sys path append statements by making paths absolute

* Remove extra sys.path.append that isnt necessary

* Apply suggestions from code review

Co-authored-by: Chinedum Echeta <60179183+cecheta@users.noreply.github.com>

* Rename test file, add a test for no url in request

* Make sure ConfigHelper etc are mocked out

* Mock out config helper

* Improve sys.path.append

* Mock EnvHelper as well to speed up test execution time

* Utilise get_user_function() method to enable tests

* re-add env sample

* Init log level from os environ

---------

Co-authored-by: Chinedum Echeta <60179183+cecheta@users.noreply.github.com>
  • Loading branch information
tanya-borisova and cecheta authored Apr 15, 2024
1 parent 857a441 commit ade0392
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 25 deletions.
1 change: 0 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,3 @@ AZURE_SPEECH_SERVICE_REGION=
AZURE_AUTH_TYPE=keys
USE_KEY_VAULT=true
AZURE_KEY_VAULT_ENDPOINT=
LOGLEVEL=INFO
9 changes: 3 additions & 6 deletions code/backend/batch/AddURLEmbeddings.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import os
import logging
import traceback
import azure.functions as func
import sys

from utilities.helpers.EnvHelper import EnvHelper
from utilities.helpers.DocumentProcessorHelper import DocumentProcessor
from utilities.helpers.ConfigHelper import ConfigHelper

sys.path.append("..")

bp_add_url_embeddings = func.Blueprint()
env_helper: EnvHelper = EnvHelper()

logger = logging.getLogger(__name__)
logger.setLevel(env_helper.LOGLEVEL)
logger.setLevel(level=os.environ.get("LOGLEVEL", "INFO").upper())


@bp_add_url_embeddings.route(route="AddURLEmbeddings")
def add_url_embeddings(req: func.HttpRequest) -> func.HttpResponse:
logger.info("Python HTTP trigger function processed a request.")

# Get Url from request
url = req.params.get("url")
if not url:
Expand Down
14 changes: 8 additions & 6 deletions code/backend/batch/BatchPushResults.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import os
import logging
import json
import azure.functions as func
from urllib.parse import urlparse
import sys

from utilities.helpers.EnvHelper import EnvHelper
from utilities.helpers.AzureBlobStorageHelper import AzureBlobStorageClient
from utilities.helpers.DocumentProcessorHelper import DocumentProcessor
from utilities.helpers.ConfigHelper import ConfigHelper

sys.path.append("..")

bp_batch_push_results = func.Blueprint()
env_helper: EnvHelper = EnvHelper()

logger = logging.getLogger(__name__)
logger.setLevel(env_helper.LOGLEVEL)
logger.setLevel(level=os.environ.get("LOGLEVEL", "INFO").upper())


def _get_file_name_from_message(msg: func.QueueMessage) -> str:
Expand All @@ -32,10 +28,15 @@ def _get_file_name_from_message(msg: func.QueueMessage) -> str:
arg_name="msg", queue_name="doc-processing", connection="AzureWebJobsStorage"
)
def batch_push_results(msg: func.QueueMessage) -> None:
do_batch_push_results(msg)


def do_batch_push_results(msg: func.QueueMessage) -> None:
logger.info(
"Python queue trigger function processed a queue item: %s",
msg.get_body().decode("utf-8"),
)

document_processor = DocumentProcessor()
blob_client = AzureBlobStorageClient()
# Get the file name from the message
Expand All @@ -44,6 +45,7 @@ def batch_push_results(msg: func.QueueMessage) -> None:
file_sas = blob_client.get_blob_sas(file_name)
# Get file extension's processors
file_extension = file_name.split(".")[-1]

processors = list(
filter(
lambda x: x.document_type.lower() == file_extension.lower(),
Expand Down
9 changes: 3 additions & 6 deletions code/backend/batch/BatchStartProcessing.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import os
import logging
import json
import azure.functions as func
import sys
from utilities.helpers.EnvHelper import EnvHelper

from utilities.helpers.AzureBlobStorageHelper import (
AzureBlobStorageClient,
create_queue_client,
)

sys.path.append("..")
bp_batch_start_processing = func.Blueprint()
env_helper: EnvHelper = EnvHelper()

logger = logging.getLogger(__name__)
logger.setLevel(env_helper.LOGLEVEL)
logger.setLevel(level=os.environ.get("LOGLEVEL", "INFO").upper())


@bp_batch_start_processing.route(route="BatchStartProcessing")
Expand Down
15 changes: 9 additions & 6 deletions code/backend/batch/GetConversationResponse.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
import os
import azure.functions as func
import logging
import json
import sys

from utilities.helpers.EnvHelper import EnvHelper
from utilities.helpers.OrchestratorHelper import Orchestrator
from utilities.helpers.ConfigHelper import ConfigHelper

sys.path.append("..")

bp_get_conversation_response = func.Blueprint()
env_helper: EnvHelper = EnvHelper()

logger = logging.getLogger(__name__)
logger.setLevel(env_helper.LOGLEVEL)
logger.setLevel(level=os.environ.get("LOGLEVEL", "INFO").upper())


@bp_get_conversation_response.route(route="GetConversationResponse")
def get_conversation_response(req: func.HttpRequest) -> func.HttpResponse:
return do_get_conversation_response(req)


def do_get_conversation_response(req: func.HttpRequest) -> func.HttpResponse:
logger.info("Python HTTP trigger function processed a request.")

message_orchestrator = Orchestrator()
env_helper: EnvHelper = EnvHelper()

try:
req_body = req.get_json()
Expand All @@ -38,7 +42,6 @@ def get_conversation_response(req: func.HttpRequest) -> func.HttpResponse:
user_assistant_messages[i + 1]["content"],
)
)
from utilities.helpers.ConfigHelper import ConfigHelper

messages = message_orchestrator.handle_message(
user_message=user_message,
Expand Down
Empty file added code/backend/batch/__init__.py
Empty file.
55 changes: 55 additions & 0 deletions code/tests/test_AddURLEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import sys
import os
from unittest.mock import patch
import azure.functions as func


sys.path.append(os.path.join(os.path.dirname(sys.path[0]), "backend", "batch"))

from backend.batch.AddURLEmbeddings import add_url_embeddings # noqa: E402


@patch("backend.batch.AddURLEmbeddings.ConfigHelper")
@patch("backend.batch.AddURLEmbeddings.DocumentProcessor")
def test_add_url_embeddings_when_url_set_in_body(_, __):
fake_request = func.HttpRequest(
method="POST",
url="",
body=b'{"url": "https://example.com"}',
headers={"Content-Type": "application/json"},
)

response = add_url_embeddings.build().get_user_function()(fake_request)

assert response.status_code == 200


@patch("backend.batch.AddURLEmbeddings.ConfigHelper")
@patch("backend.batch.AddURLEmbeddings.DocumentProcessor")
def test_add_url_embeddings_when_url_set_in_param(_, __):
fake_request = func.HttpRequest(
method="POST",
url="",
body=b"",
headers={"Content-Type": "application/json"},
params={"url": "https://example.com"},
)

response = add_url_embeddings.build().get_user_function()(fake_request)

assert response.status_code == 200


@patch("backend.batch.AddURLEmbeddings.ConfigHelper")
@patch("backend.batch.AddURLEmbeddings.DocumentProcessor")
def test_add_url_embeddings_returns_400_when_url_not_set(_, __):
fake_request = func.HttpRequest(
method="POST",
url="",
body=b"",
params={},
)

response = add_url_embeddings.build().get_user_function()(fake_request)

assert response.status_code == 400
68 changes: 68 additions & 0 deletions code/tests/test_BatchPushResults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import sys
import os
from unittest.mock import patch, Mock
from azure.functions import QueueMessage


sys.path.append(os.path.join(os.path.dirname(sys.path[0]), "backend", "batch"))

from backend.batch.BatchPushResults import ( # noqa: E402
batch_push_results,
_get_file_name_from_message,
)


def test_get_file_name_from_message():
mock_queue_message = QueueMessage(
body='{"message": "test message", "filename": "test_filename.md"}'
)

file_name = _get_file_name_from_message(mock_queue_message)

assert file_name == "test_filename.md"


def test_get_file_name_from_message_no_filename():
mock_queue_message = QueueMessage(
body='{"data": { "url": "test/test/test_filename.md"} }'
)

file_name = _get_file_name_from_message(mock_queue_message)

assert file_name == "test_filename.md"


@patch("backend.batch.BatchPushResults.ConfigHelper")
@patch("backend.batch.BatchPushResults.AzureBlobStorageClient")
@patch("backend.batch.BatchPushResults.DocumentProcessor")
def test_batch_push_results(
mock_document_processor,
mock_azure_blob_storage_client,
mock_config_helper,
):
mock_queue_message = QueueMessage(
body='{"message": "test message", "filename": "test/test/test_filename.md"}'
)

mock_blob_client_instance = mock_azure_blob_storage_client.return_value
mock_blob_client_instance.get_blob_sas.return_value = "test_blob_sas"

mock_document_processor_instance = mock_document_processor.return_value

md_processor = Mock()
md_processor.document_type.lower.return_value = "md"
txt_processor = Mock()
txt_processor.document_type.lower.return_value = "txt"
mock_processors = [md_processor, txt_processor]
mock_config_helper.get_active_config_or_default.return_value.document_processors = (
mock_processors
)

batch_push_results.build().get_user_function()(mock_queue_message)

mock_document_processor_instance.process.assert_called_once_with(
source_url="test_blob_sas", processors=[md_processor]
)
mock_blob_client_instance.upsert_blob_metadata.assert_called_once_with(
"test/test/test_filename.md", {"embeddings_added": "true"}
)
60 changes: 60 additions & 0 deletions code/tests/test_BatchStartProcessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import sys
import os
from unittest.mock import patch, Mock

sys.path.append(os.path.join(os.path.dirname(sys.path[0]), "backend", "batch"))

from backend.batch.BatchStartProcessing import batch_start_processing # noqa: E402


@patch("backend.batch.BatchStartProcessing.create_queue_client")
@patch("backend.batch.BatchStartProcessing.AzureBlobStorageClient")
def test_batch_start_processing_processes_all(
mock_blob_storage_client, mock_create_queue_client
):
mock_http_request = Mock()
mock_http_request.params = dict()
mock_http_request.params["process_all"] = "true"

mock_queue_client = Mock()
mock_create_queue_client.return_value = mock_queue_client

mock_blob_storage_client.return_value.get_all_files.return_value = [
{"filename": "file_name_one", "embeddings_added": False}
]

response = batch_start_processing.build().get_user_function()(mock_http_request)

assert response.status_code == 200

mock_queue_client.send_message.assert_called_once_with(
b'{"filename": "file_name_one"}',
)


@patch("backend.batch.BatchStartProcessing.create_queue_client")
@patch("backend.batch.BatchStartProcessing.AzureBlobStorageClient")
def test_batch_start_processing_filters_filter_no_embeddings(
mock_blob_storage_client, mock_create_queue_client
):
mock_http_request = Mock()
mock_http_request.params = dict()
mock_http_request.params["process_all"] = "false"

mock_queue_client = Mock()
mock_create_queue_client.return_value = mock_queue_client

mock_blob_storage_client.return_value.get_all_files.return_value = [
{
"filename": "file_name_one",
"embeddings_added": True, # will get filtered out
},
{"filename": "file_name_two", "embeddings_added": False},
]
response = batch_start_processing.build().get_user_function()(mock_http_request)

assert response.status_code == 200

mock_queue_client.send_message.assert_called_once_with(
b'{"filename": "file_name_two"}',
)
49 changes: 49 additions & 0 deletions code/tests/test_GetConversationResponse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import os
from unittest.mock import patch, Mock, ANY
import json

sys.path.append(os.path.join(os.path.dirname(sys.path[0]), "backend", "batch"))

from backend.batch.GetConversationResponse import ( # noqa: E402
get_conversation_response,
)


@patch("backend.batch.GetConversationResponse.ConfigHelper")
@patch("backend.batch.GetConversationResponse.Orchestrator")
def test_get_conversation_response(mock_create_message_orchestrator, _):
mock_http_request = Mock()
request_json = {
"messages": [
{"content": "Do I have meetings today?", "role": "user"},
{"content": "It is sunny today", "role": "assistant"},
{"content": "What is the weather like today?", "role": "user"},
],
"conversation_id": "13245",
}
mock_http_request.get_json.return_value = request_json

mock_message_orchestrator = Mock()
mock_message_orchestrator.handle_message.return_value = [
"You don't have any meetings today"
]

mock_create_message_orchestrator.return_value = mock_message_orchestrator

response = get_conversation_response.build().get_user_function()(mock_http_request)

assert response.status_code == 200

mock_message_orchestrator.handle_message.assert_called_once_with(
user_message="What is the weather like today?",
chat_history=[("Do I have meetings today?", "It is sunny today")],
conversation_id="13245",
orchestrator=ANY,
)

response_json = json.loads(response.get_body())
assert response_json["id"] == "response.id"
assert response_json["choices"] == [
{"messages": ["You don't have any meetings today"]}
]

0 comments on commit ade0392

Please sign in to comment.