diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
index 0cd98b0e6..7a51a5c31 100644
--- a/.devcontainer/Dockerfile
+++ b/.devcontainer/Dockerfile
@@ -1,6 +1,5 @@
-FROM --platform=linux/amd64 mcr.microsoft.com/devcontainers/python:1-3.11-bullseye
-# We need to force the container to be amd so that it works on a Mac. Without this the functions extension doesn't install.
+FROM mcr.microsoft.com/devcontainers/python:3.11
# install git
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
- && apt-get -y install --no-install-recommends git libgtk2.0-0 libgtk-3-0 libgbm-dev libnotify-dev libnss3 libxss1 libasound2 libxtst6 xauth xvfb
\ No newline at end of file
+ && apt-get -y install --no-install-recommends git libgtk2.0-0 libgtk-3-0 libgbm-dev libnotify-dev libnss3 libxss1 libasound2 libxtst6 xauth xvfb
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 1512f6b06..f1ac51150 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -7,9 +7,7 @@
"ghcr.io/devcontainers/features/azure-cli:1": {},
"ghcr.io/devcontainers/features/docker-outside-of-docker:1": {},
"ghcr.io/devcontainers/features/node:1": {},
- "ghcr.io/jlaundry/devcontainer-features/azure-functions-core-tools:1": {
- "version": "4.0.5530"
- },
+ "ghcr.io/jlaundry/devcontainer-features/azure-functions-core-tools:1": {},
"ghcr.io/azure/azure-dev/azd:latest": {},
"ghcr.io/rchaganti/vsc-devcontainer-features/azurebicep:1.0.5": {}
},
@@ -28,6 +26,7 @@
"ms-python.python",
"ms-python.black-formatter",
"ms-python.vscode-pylance",
+ "ms-python.pylint",
"ms-toolsai.jupyter",
"ms-vscode.vscode-node-azure-pack",
"TeamsDevApp.ms-teams-vscode-extension",
diff --git a/.vscode/settings.json b/.vscode/settings.json
index bb92a125e..4efc47061 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -21,4 +21,5 @@
"python.testing.cwd": "${workspaceFolder}/code",
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
+ "pylint.path" : [ "${interpreter}", "-m", "pylint" ]
}
diff --git a/code/app.py b/code/app.py
index ba3043c12..8df53875a 100644
--- a/code/app.py
+++ b/code/app.py
@@ -1,3 +1,7 @@
+"""
+This module contains the entry point for the application.
+"""
+
import os
import logging
from azure.monitor.opentelemetry import configure_azure_monitor
@@ -5,7 +9,8 @@
logging.captureWarnings(True)
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO").upper())
-# Raising the azure log level to WARN as it is too verbose - https://github.com/Azure/azure-sdk-for-python/issues/9422
+# Raising the azure log level to WARN as it is too verbose -
+# https://github.com/Azure/azure-sdk-for-python/issues/9422
logging.getLogger("azure").setLevel(os.environ.get("LOGLEVEL_AZURE", "WARN").upper())
# We cannot use EnvHelper here as Application Insights should be configured first
# for instrumentation to work correctly
@@ -13,6 +18,7 @@
configure_azure_monitor()
HTTPXClientInstrumentor().instrument() # httpx is used by openai
+# pylint: disable=wrong-import-position
from create_app import create_app # noqa: E402
app = create_app()
diff --git a/code/backend/Admin.py b/code/backend/Admin.py
index 3b566aae0..efa444865 100644
--- a/code/backend/Admin.py
+++ b/code/backend/Admin.py
@@ -1,14 +1,19 @@
-import streamlit as st
+"""
+This module contains the code for the Admin app of the Chat with your data Solution Accelerator.
+"""
+
import os
import logging
import sys
+import streamlit as st
from azure.monitor.opentelemetry import configure_azure_monitor
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
logging.captureWarnings(True)
logging.basicConfig(level=os.getenv("LOGLEVEL", "INFO").upper())
-# Raising the azure log level to WARN as it is too verbose - https://github.com/Azure/azure-sdk-for-python/issues/9422
+# Raising the azure log level to WARN as it is too verbose
+# https://github.com/Azure/azure-sdk-for-python/issues/9422
logging.getLogger("azure").setLevel(os.environ.get("LOGLEVEL_AZURE", "WARN").upper())
# We cannot use EnvHelper here as Application Insights needs to be configured first
# for instrumentation to work correctly
@@ -26,14 +31,14 @@
menu_items=None,
)
-mod_page_style = """
+MOD_PAGE_STYLE = """
"""
-st.markdown(mod_page_style, unsafe_allow_html=True)
+st.markdown(MOD_PAGE_STYLE, unsafe_allow_html=True)
col1, col2, col3 = st.columns([1, 2, 1])
diff --git a/code/backend/batch/batch_push_results.py b/code/backend/batch/batch_push_results.py
index 18859f573..4058b96ef 100644
--- a/code/backend/batch/batch_push_results.py
+++ b/code/backend/batch/batch_push_results.py
@@ -1,20 +1,20 @@
import os
import logging
import json
-import azure.functions as func
from urllib.parse import urlparse
+import azure.functions as func
from utilities.helpers.azure_blob_storage_client import AzureBlobStorageClient
from utilities.helpers.env_helper import EnvHelper
from utilities.helpers.embedders.embedder_factory import EmbedderFactory
+from utilities.search.search import Search
bp_batch_push_results = func.Blueprint()
logger = logging.getLogger(__name__)
logger.setLevel(level=os.environ.get("LOGLEVEL", "INFO").upper())
-def _get_file_name_from_message(msg: func.QueueMessage) -> str:
- message_body = json.loads(msg.get_body().decode("utf-8"))
+def _get_file_name_from_message(message_body) -> str:
return message_body.get(
"filename",
"/".join(
@@ -27,21 +27,37 @@ def _get_file_name_from_message(msg: func.QueueMessage) -> str:
arg_name="msg", queue_name="doc-processing", connection="AzureWebJobsStorage"
)
def batch_push_results(msg: func.QueueMessage) -> None:
- do_batch_push_results(msg)
+ message_body = json.loads(msg.get_body().decode("utf-8"))
+ logger.debug("Process Document Event queue function triggered: %s", message_body)
+
+ event_type = message_body.get("eventType", "")
+ # We handle "" in this scenario for backwards compatibility
+ # This function is primarily triggered by an Event Grid queue message from the blob storage
+ # However, it can also be triggered using a legacy schema from BatchStartProcessing
+ if event_type in ("", "Microsoft.Storage.BlobCreated"):
+ _process_document_created_event(message_body)
+
+ elif event_type == "Microsoft.Storage.BlobDeleted":
+ _process_document_deleted_event(message_body)
+
+ else:
+ raise NotImplementedError(f"Unknown event type received: {event_type}")
-def do_batch_push_results(msg: func.QueueMessage) -> None:
+def _process_document_created_event(message_body) -> None:
env_helper: EnvHelper = EnvHelper()
- logger.info(
- "Python queue trigger function processed a queue item: %s",
- msg.get_body().decode("utf-8"),
- )
blob_client = AzureBlobStorageClient()
- # Get the file name from the message
- file_name = _get_file_name_from_message(msg)
- # Generate the SAS URL for the file
+ file_name = _get_file_name_from_message(message_body)
file_sas = blob_client.get_blob_sas(file_name)
- # Process the file
+
embedder = EmbedderFactory.create(env_helper)
embedder.embed_file(file_sas, file_name)
+
+
+def _process_document_deleted_event(message_body) -> None:
+ env_helper: EnvHelper = EnvHelper()
+ search_handler = Search.get_search_handler(env_helper)
+
+ blob_url = message_body.get("data", {}).get("url", "")
+ search_handler.delete_by_source(f"{blob_url}_SAS_TOKEN_PLACEHOLDER_")
diff --git a/code/backend/batch/function_app.py b/code/backend/batch/function_app.py
index 5eca04a1e..b2756c751 100644
--- a/code/backend/batch/function_app.py
+++ b/code/backend/batch/function_app.py
@@ -10,7 +10,8 @@
logging.captureWarnings(True)
# Raising the azure log level to WARN as it is too verbose - https://github.com/Azure/azure-sdk-for-python/issues/9422
logging.getLogger("azure").setLevel(os.environ.get("LOGLEVEL_AZURE", "WARN").upper())
-configure_azure_monitor()
+if os.getenv("APPLICATIONINSIGHTS_ENABLED", "false").lower() == "true":
+ configure_azure_monitor()
app = func.FunctionApp(
http_auth_level=func.AuthLevel.FUNCTION
diff --git a/code/backend/batch/utilities/helpers/embedders/push_embedder.py b/code/backend/batch/utilities/helpers/embedders/push_embedder.py
index 7ab2ac29d..58ba2f682 100644
--- a/code/backend/batch/utilities/helpers/embedders/push_embedder.py
+++ b/code/backend/batch/utilities/helpers/embedders/push_embedder.py
@@ -24,6 +24,7 @@
class PushEmbedder(EmbedderBase):
def __init__(self, blob_client: AzureBlobStorageClient, env_helper: EnvHelper):
+ self.env_helper = env_helper
self.llm_helper = LLMHelper()
self.azure_search_helper = AzureSearchHelper()
self.azure_computer_vision_client = AzureComputerVisionClient(env_helper)
@@ -59,13 +60,15 @@ def __embed(
in self.config.get_advanced_image_processing_image_types()
):
logger.warning("Advanced image processing is not supported yet")
- image_vectors = self.azure_computer_vision_client.vectorize_image(
- source_url
- )
- logger.info("Image vectors: " + str(image_vectors))
+ caption = self.__generate_image_caption(source_url)
+ caption_vector = self.llm_helper.generate_embeddings(caption)
+
+ image_vector = self.azure_computer_vision_client.vectorize_image(source_url)
documents_to_upload.append(
- self.__create_image_document(source_url, image_vectors)
+ self.__create_image_document(
+ source_url, image_vector, caption, caption_vector
+ )
)
else:
documents: List[SourceDocument] = self.document_loading.load(
@@ -85,6 +88,32 @@ def __embed(
logger.error("Failed to upload documents to search index")
raise Exception(response)
+ def __generate_image_caption(self, source_url):
+ model = self.env_helper.AZURE_OPENAI_VISION_MODEL
+ caption_system_message = """You are an assistant that generates rich descriptions of images.
+You need to be accurate in the information you extract and detailed in the descriptons you generate.
+Do not abbreviate anything and do not shorten sentances. Explain the image completely.
+If you are provided with an image of a flow chart, describe the flow chart in detail.
+If the image is mostly text, use OCR to extract the text as it is displayed in the image."""
+
+ messages = [
+ {"role": "system", "content": caption_system_message},
+ {
+ "role": "user",
+ "content": [
+ {
+ "text": "Describe this image in detail. Limit the response to 500 words.",
+ "type": "text",
+ },
+ {"image_url": source_url, "type": "image_url"},
+ ],
+ },
+ ]
+
+ response = self.llm_helper.get_chat_completion(messages, model)
+ caption = response.choices[0].message.content
+ return caption
+
def __convert_to_search_document(self, document: SourceDocument):
embedded_content = self.llm_helper.generate_embeddings(document.content)
metadata = {
@@ -111,7 +140,13 @@ def __generate_document_id(self, source_url: str) -> str:
hash_key = hashlib.sha1(f"{source_url}_1".encode("utf-8")).hexdigest()
return f"doc_{hash_key}"
- def __create_image_document(self, source_url: str, image_vectors: List[float]):
+ def __create_image_document(
+ self,
+ source_url: str,
+ image_vector: List[float],
+ content: str,
+ content_vector: List[float],
+ ):
parsed_url = urlparse(source_url)
file_url = parsed_url.scheme + "://" + parsed_url.netloc + parsed_url.path
@@ -127,9 +162,9 @@ def __create_image_document(self, source_url: str, image_vectors: List[float]):
return {
"id": document_id,
- "content": "",
- "content_vector": [],
- "image_vector": image_vectors,
+ "content": content,
+ "content_vector": content_vector,
+ "image_vector": image_vector,
"metadata": json.dumps(
{
"id": document_id,
diff --git a/code/backend/batch/utilities/helpers/env_helper.py b/code/backend/batch/utilities/helpers/env_helper.py
index 138ecd890..b860a79c2 100644
--- a/code/backend/batch/utilities/helpers/env_helper.py
+++ b/code/backend/batch/utilities/helpers/env_helper.py
@@ -86,6 +86,7 @@ def __load_config(self, **kwargs) -> None:
self.AZURE_OPENAI_MODEL_NAME = os.getenv(
"AZURE_OPENAI_MODEL_NAME", "gpt-35-turbo"
)
+ self.AZURE_OPENAI_VISION_MODEL = os.getenv("AZURE_OPENAI_VISION_MODEL", "gpt-4")
self.AZURE_OPENAI_TEMPERATURE = os.getenv("AZURE_OPENAI_TEMPERATURE", "0")
self.AZURE_OPENAI_TOP_P = os.getenv("AZURE_OPENAI_TOP_P", "1.0")
self.AZURE_OPENAI_MAX_TOKENS = os.getenv("AZURE_OPENAI_MAX_TOKENS", "1000")
diff --git a/code/backend/batch/utilities/helpers/llm_helper.py b/code/backend/batch/utilities/helpers/llm_helper.py
index 8c6084033..bbbe83e52 100644
--- a/code/backend/batch/utilities/helpers/llm_helper.py
+++ b/code/backend/batch/utilities/helpers/llm_helper.py
@@ -117,9 +117,9 @@ def get_chat_completion_with_functions(
function_call=function_call,
)
- def get_chat_completion(self, messages: list[dict]):
+ def get_chat_completion(self, messages: list[dict], model: str | None = None):
return self.openai_client.chat.completions.create(
- model=self.llm_model,
+ model=model or self.llm_model,
messages=messages,
)
diff --git a/code/backend/batch/utilities/search/search_handler_base.py b/code/backend/batch/utilities/search/search_handler_base.py
index 4a937d7b9..5e3443e5c 100644
--- a/code/backend/batch/utilities/search/search_handler_base.py
+++ b/code/backend/batch/utilities/search/search_handler_base.py
@@ -36,13 +36,35 @@ def get_files(self):
pass
@abstractmethod
- def output_results(self, results, id_field):
+ def output_results(self, results):
pass
@abstractmethod
- def delete_files(self, files, id_field):
+ def delete_files(self, files):
pass
@abstractmethod
def query_search(self, question) -> list[SourceDocument]:
pass
+
+ def delete_by_source(self, source) -> None:
+ if source is None:
+ return
+
+ documents = self._get_documents_by_source(source)
+ if documents is None:
+ return
+
+ files_to_delete = self.output_results(documents)
+ self.delete_files(files_to_delete)
+
+ def _get_documents_by_source(self, source):
+ if source is None:
+ return None
+
+ return self.search_client.search(
+ "*",
+ select="id, title",
+ include_total_count=True,
+ filter=f"source eq '{source}'",
+ )
diff --git a/code/create_app.py b/code/create_app.py
index ed593ffc2..26fe70ee3 100644
--- a/code/create_app.py
+++ b/code/create_app.py
@@ -1,14 +1,18 @@
+"""
+This module creates a Flask app that serves the web interface for the chatbot.
+"""
+
+import functools
import json
import logging
+import mimetypes
from os import path
+import sys
import requests
from openai import AzureOpenAI, Stream
from openai.types.chat import ChatCompletionChunk
-import mimetypes
from flask import Flask, Response, request, Request, jsonify
from dotenv import load_dotenv
-import sys
-import functools
from backend.batch.utilities.helpers.env_helper import EnvHelper
from backend.batch.utilities.helpers.orchestrator_helper import Orchestrator
from backend.batch.utilities.helpers.config.config_helper import ConfigHelper
@@ -19,6 +23,7 @@
def stream_with_data(response: Stream[ChatCompletionChunk]):
+ """This function streams the response from Azure OpenAI with data."""
response_obj = {
"id": "",
"model": "",
@@ -69,7 +74,8 @@ def stream_with_data(response: Stream[ChatCompletionChunk]):
yield json.dumps(response_obj, ensure_ascii=False) + "\n"
-def conversation_with_data(request: Request, env_helper: EnvHelper):
+def conversation_with_data(conversation: Request, env_helper: EnvHelper):
+ """This function streams the response from Azure OpenAI with data."""
if env_helper.is_auth_type_keys():
openai_client = AzureOpenAI(
azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT,
@@ -83,9 +89,10 @@ def conversation_with_data(request: Request, env_helper: EnvHelper):
azure_ad_token_provider=env_helper.AZURE_TOKEN_PROVIDER,
)
- messages = request.json["messages"]
+ messages = conversation.json["messages"]
- # Azure OpenAI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means deployment name.
+ # Azure OpenAI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means
+ # deployment name.
response = openai_client.chat.completions.create(
model=env_helper.AZURE_OPENAI_MODEL,
messages=messages,
@@ -180,45 +187,51 @@ def conversation_with_data(request: Request, env_helper: EnvHelper):
}
return response_obj
- else:
- return Response(
- stream_with_data(response),
- mimetype="application/json-lines",
- )
+
+ return Response(
+ stream_with_data(response),
+ mimetype="application/json-lines",
+ )
def stream_without_data(response: Stream[ChatCompletionChunk]):
- responseText = ""
+ """This function streams the response from Azure OpenAI without data."""
+ response_text = ""
for line in response:
if not line.choices:
continue
- deltaText = line.choices[0].delta.content
+ delta_text = line.choices[0].delta.content
- if deltaText is None:
+ if delta_text is None:
return
- responseText += deltaText
+ response_text += delta_text
response_obj = {
"id": line.id,
"model": line.model,
"created": line.created,
"object": line.object,
- "choices": [{"messages": [{"role": "assistant", "content": responseText}]}],
+ "choices": [
+ {"messages": [{"role": "assistant", "content": response_text}]}
+ ],
}
yield json.dumps(response_obj, ensure_ascii=False) + "\n"
def get_message_orchestrator():
+ """This function gets the message orchestrator."""
return Orchestrator()
def get_orchestrator_config():
+ """This function gets the orchestrator configuration."""
return ConfigHelper.get_active_config_or_default().orchestrator
-def conversation_without_data(request: Request, env_helper: EnvHelper):
+def conversation_without_data(conversation: Request, env_helper: EnvHelper):
+ """This function streams the response from Azure OpenAI without data."""
if env_helper.AZURE_AUTH_TYPE == "rbac":
openai_client = AzureOpenAI(
azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT,
@@ -232,13 +245,14 @@ def conversation_without_data(request: Request, env_helper: EnvHelper):
api_key=env_helper.AZURE_OPENAI_API_KEY,
)
- request_messages = request.json["messages"]
+ request_messages = conversation.json["messages"]
messages = [{"role": "system", "content": env_helper.AZURE_OPENAI_SYSTEM_MESSAGE}]
for message in request_messages:
messages.append({"role": message["role"], "content": message["content"]})
- # Azure Open AI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means deployment name.
+ # Azure Open AI takes the deployment name as the model name, "AZURE_OPENAI_MODEL" means
+ # deployment name.
response = openai_client.chat.completions.create(
model=env_helper.AZURE_OPENAI_MODEL,
messages=messages,
@@ -271,10 +285,8 @@ def conversation_without_data(request: Request, env_helper: EnvHelper):
],
}
return jsonify(response_obj), 200
- else:
- return Response(
- stream_without_data(response), mimetype="application/json-lines"
- )
+
+ return Response(stream_without_data(response), mimetype="application/json-lines")
@functools.cache
@@ -296,6 +308,7 @@ def get_speech_key(env_helper: EnvHelper):
def create_app():
+ """This function creates the Flask app."""
# Fixing MIME types for static files under Windows
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
@@ -313,8 +326,8 @@ def create_app():
@app.route("/", defaults={"path": "index.html"})
@app.route("/")
- def static_file(path):
- return app.send_static_file(path)
+ def static_file(file_path):
+ return app.send_static_file(file_path)
@app.route("/api/health", methods=["GET"])
def health():
@@ -328,9 +341,9 @@ def conversation_azure_byod():
else:
return conversation_without_data(request, env_helper)
except Exception as e:
- errorMessage = str(e)
+ error_message = str(e)
logger.exception(
- f"Exception in /api/conversation/azure_byod | {errorMessage}"
+ "Exception in /api/conversation/azure_byod | %s", error_message
)
return (
jsonify(
@@ -373,8 +386,10 @@ async def conversation_custom():
return jsonify(response_obj), 200
except Exception as e:
- errorMessage = str(e)
- logger.exception(f"Exception in /api/conversation/custom | {errorMessage}")
+ error_message = str(e)
+ logger.exception(
+ "Exception in /api/conversation/custom | %s", error_message
+ )
return (
jsonify(
{
@@ -386,6 +401,7 @@ async def conversation_custom():
@app.route("/api/speech", methods=["GET"])
def speech_config():
+ """Get the speech config for Azure Speech."""
try:
speech_key = env_helper.AZURE_SPEECH_KEY or get_speech_key(env_helper)
@@ -394,6 +410,7 @@ def speech_config():
headers={
"Ocp-Apim-Subscription-Key": speech_key,
},
+ timeout=5,
)
if response.status_code == 200:
@@ -403,10 +420,10 @@ def speech_config():
"languages": env_helper.AZURE_SPEECH_RECOGNIZER_LANGUAGES,
}
- logger.error(f"Failed to get speech config: {response.text}")
+ logger.error("Failed to get speech config: %s", response.text)
return {"error": "Failed to get speech config"}, response.status_code
except Exception as e:
- logger.exception(f"Exception in /api/speech | {str(e)}")
+ logger.exception("Exception in /api/speech | %s", str(e))
return {"error": "Failed to get speech config"}, 500
diff --git a/code/tests/functional/app_config.py b/code/tests/functional/app_config.py
index b1c841c14..18837a4da 100644
--- a/code/tests/functional/app_config.py
+++ b/code/tests/functional/app_config.py
@@ -28,6 +28,7 @@ class AppConfig:
"AZURE_OPENAI_MAX_TOKENS": "1000",
"AZURE_OPENAI_MODEL": "some-openai-model",
"AZURE_OPENAI_MODEL_NAME": "some-openai-model-name",
+ "AZURE_OPENAI_VISION_MODEL": "some-openai-vision-model",
"AZURE_OPENAI_RESOURCE": "some-openai-resource",
"AZURE_OPENAI_STREAM": "True",
"AZURE_OPENAI_STOP_SEQUENCE": "",
diff --git a/code/tests/functional/conftest.py b/code/tests/functional/conftest.py
index 6e5e6408f..8f76a14e4 100644
--- a/code/tests/functional/conftest.py
+++ b/code/tests/functional/conftest.py
@@ -162,6 +162,33 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
}
)
+ httpserver.expect_request(
+ f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions",
+ method="POST",
+ ).respond_with_json(
+ {
+ "id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9",
+ "object": "chat.completion",
+ "created": 1679072642,
+ "model": app_config.get("AZURE_OPENAI_VISION_MODEL"),
+ "usage": {
+ "prompt_tokens": 58,
+ "completion_tokens": 68,
+ "total_tokens": 126,
+ },
+ "choices": [
+ {
+ "message": {
+ "role": "assistant",
+ "content": "This is a caption for the image",
+ },
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ }
+ )
+
httpserver.expect_request(
f"/indexes('{app_config.get('AZURE_SEARCH_CONVERSATIONS_LOG_INDEX')}')/docs/search.index",
method="POST",
diff --git a/code/tests/functional/tests/functions/test_advanced_image_processing.py b/code/tests/functional/tests/functions/test_advanced_image_processing.py
index 300ec4a7e..fd41d6a3d 100644
--- a/code/tests/functional/tests/functions/test_advanced_image_processing.py
+++ b/code/tests/functional/tests/functions/test_advanced_image_processing.py
@@ -103,6 +103,9 @@ def test_image_passed_to_computer_vision_to_generate_image_embeddings(
RequestMatcher(
path=COMPUTER_VISION_VECTORIZE_IMAGE_PATH,
method=COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD,
+ json={
+ "url": ANY,
+ },
query_string="api-version=2024-02-01&model-version=2023-04-15",
headers={
"Content-Type": "application/json",
@@ -115,7 +118,87 @@ def test_image_passed_to_computer_vision_to_generate_image_embeddings(
)[0]
assert request.get_json()["url"].startswith(
- f"{app_config.get('AZURE_COMPUTER_VISION_ENDPOINT')}{app_config.get('AZURE_BLOB_CONTAINER_NAME')}/{FILE_NAME}"
+ f"{app_config.get('AZURE_STORAGE_ACCOUNT_ENDPOINT')}{app_config.get('AZURE_BLOB_CONTAINER_NAME')}/{FILE_NAME}"
+ )
+
+
+def test_image_passed_to_llm_to_generate_caption(
+ message: QueueMessage, httpserver: HTTPServer, app_config: AppConfig
+):
+ # when
+ batch_push_results.build().get_user_function()(message)
+
+ # then
+ request = verify_request_made(
+ mock_httpserver=httpserver,
+ request_matcher=RequestMatcher(
+ path=f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions",
+ method="POST",
+ json={
+ "messages": [
+ {
+ "role": "system",
+ "content": """You are an assistant that generates rich descriptions of images.
+You need to be accurate in the information you extract and detailed in the descriptons you generate.
+Do not abbreviate anything and do not shorten sentances. Explain the image completely.
+If you are provided with an image of a flow chart, describe the flow chart in detail.
+If the image is mostly text, use OCR to extract the text as it is displayed in the image.""",
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "text": "Describe this image in detail. Limit the response to 500 words.",
+ "type": "text",
+ },
+ {"image_url": ANY, "type": "image_url"},
+ ],
+ },
+ ],
+ "model": app_config.get("AZURE_OPENAI_VISION_MODEL"),
+ },
+ headers={
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {app_config.get('AZURE_OPENAI_API_KEY')}",
+ "Api-Key": app_config.get("AZURE_OPENAI_API_KEY"),
+ },
+ query_string="api-version=2024-02-01",
+ times=1,
+ ),
+ )[0]
+
+ assert request.get_json()["messages"][1]["content"][1]["image_url"].startswith(
+ f"{app_config.get('AZURE_STORAGE_ACCOUNT_ENDPOINT')}{app_config.get('AZURE_BLOB_CONTAINER_NAME')}/{FILE_NAME}"
+ )
+
+
+def test_embeddings_generated_for_caption(
+ message: QueueMessage, httpserver: HTTPServer, app_config: AppConfig
+):
+ # when
+ batch_push_results.build().get_user_function()(message)
+
+ # then
+ verify_request_made(
+ mock_httpserver=httpserver,
+ request_matcher=RequestMatcher(
+ path=f"/openai/deployments/{app_config.get('AZURE_OPENAI_EMBEDDING_MODEL')}/embeddings",
+ method="POST",
+ json={
+ "input": ["This is a caption for the image"],
+ "model": app_config.get("AZURE_OPENAI_EMBEDDING_MODEL"),
+ "encoding_format": "base64",
+ },
+ headers={
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {app_config.get('AZURE_OPENAI_API_KEY')}",
+ "Api-Key": app_config.get("AZURE_OPENAI_API_KEY"),
+ },
+ query_string="api-version=2024-02-01",
+ times=1,
+ ),
)
@@ -343,8 +426,11 @@ def test_makes_correct_call_to_store_documents_in_search_index(
"value": [
{
"id": expected_id,
- "content": "",
- "content_vector": [],
+ "content": "This is a caption for the image",
+ "content_vector": [
+ 0.018990106880664825,
+ -0.0073809814639389515,
+ ],
"image_vector": [1.0, 2.0, 3.0],
"metadata": json.dumps(
{
diff --git a/code/tests/test_app.py b/code/tests/test_app.py
index d7ddf8e6e..858579258 100644
--- a/code/tests/test_app.py
+++ b/code/tests/test_app.py
@@ -1,7 +1,11 @@
+"""
+This module tests the entry point for the application.
+"""
+
import os
-from flask.testing import FlaskClient
-import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
+import pytest
+from flask.testing import FlaskClient
from create_app import create_app
AZURE_SPEECH_KEY = "mock-speech-key"
@@ -34,11 +38,13 @@
@pytest.fixture
def client():
+ """Create a test client for the app."""
return create_app().test_client()
@pytest.fixture(autouse=True)
def env_helper_mock():
+ """Mock the environment variables for the tests."""
with patch("create_app.EnvHelper") as mock:
env_helper = mock.return_value
@@ -82,6 +88,7 @@ class TestSpeechToken:
def test_returns_speech_token_using_keys(
self, requests: MagicMock, client: FlaskClient
):
+ """Test that the speech token is returned correctly when using keys."""
# given
mock_response: MagicMock = requests.post.return_value
mock_response.text = "speech-token"
@@ -103,6 +110,7 @@ def test_returns_speech_token_using_keys(
headers={
"Ocp-Apim-Subscription-Key": AZURE_SPEECH_KEY,
},
+ timeout=5,
)
@patch("create_app.CognitiveServicesManagementClient")
@@ -114,6 +122,7 @@ def test_returns_speech_token_using_rbac(
env_helper_mock: MagicMock,
client: FlaskClient,
):
+ """Test that the speech token is returned correctly when using RBAC."""
# given
env_helper_mock.AZURE_SPEECH_KEY = None
@@ -144,12 +153,14 @@ def test_returns_speech_token_using_rbac(
headers={
"Ocp-Apim-Subscription-Key": "mock-key1",
},
+ timeout=5,
)
@patch("create_app.requests")
def test_error_when_cannot_retrieve_speech_token(
self, requests: MagicMock, client: FlaskClient
):
+ """Test that an error is returned when the speech token cannot be retrieved."""
# given
mock_response: MagicMock = requests.post.return_value
mock_response.text = "error"
@@ -166,6 +177,7 @@ def test_error_when_cannot_retrieve_speech_token(
def test_error_when_unexpected_error_occurs(
self, requests: MagicMock, client: FlaskClient
):
+ """Test that an error is returned when an unexpected error occurs."""
# given
requests.post.side_effect = Exception("An error occurred")
@@ -177,7 +189,10 @@ def test_error_when_unexpected_error_occurs(
class TestConfig:
+ """Test the config endpoint."""
+
def test_health(self, client):
+ """Test that the health endpoint returns OK."""
response = client.get("/api/health")
assert response.status_code == 200
@@ -185,7 +200,10 @@ def test_health(self, client):
class TestConversationCustom:
+ """Test the custom conversation endpoint."""
+
def setup_method(self):
+ """Set up the test data."""
self.orchestrator_config = {"strategy": "langchain"}
self.messages = [
{
@@ -209,13 +227,14 @@ def setup_method(self):
@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
- def test_converstation_custom_returns_correct_response(
+ def test_conversation_custom_returns_correct_response(
self,
get_active_config_or_default_mock,
get_message_orchestrator_mock,
env_helper_mock,
client,
):
+ """Test that the custom conversation endpoint returns the correct response."""
# given
get_active_config_or_default_mock.return_value.orchestrator.return_value = (
self.orchestrator_config
@@ -246,12 +265,13 @@ def test_converstation_custom_returns_correct_response(
@patch("create_app.get_message_orchestrator")
@patch("create_app.get_orchestrator_config")
- def test_converstation_custom_calls_message_orchestrator_correctly(
+ def test_conversation_custom_calls_message_orchestrator_correctly(
self,
get_orchestrator_config_mock,
get_message_orchestrator_mock,
client,
):
+ """Test that the custom conversation endpoint calls the message orchestrator correctly."""
# given
get_orchestrator_config_mock.return_value = self.orchestrator_config
@@ -277,9 +297,10 @@ def test_converstation_custom_calls_message_orchestrator_correctly(
)
@patch("create_app.get_orchestrator_config")
- def test_converstation_custom_returns_error_resonse_on_exception(
+ def test_conversaation_custom_returns_error_response_on_exception(
self, get_orchestrator_config_mock, client
):
+ """Test that an error response is returned when an exception occurs."""
# given
get_orchestrator_config_mock.side_effect = Exception("An error occurred")
@@ -298,7 +319,7 @@ def test_converstation_custom_returns_error_resonse_on_exception(
@patch("create_app.get_message_orchestrator")
@patch("create_app.get_orchestrator_config")
- def test_converstation_custom_allows_multiple_messages_from_user(
+ def test_conversation_custom_allows_multiple_messages_from_user(
self, get_orchestrator_config_mock, get_message_orchestrator_mock, client
):
"""This can happen if there was an error getting a response from the assistant for the previous user message."""
@@ -342,6 +363,7 @@ def test_converstation_custom_allows_multiple_messages_from_user(
class TestConversationAzureByod:
def setup_method(self):
+ """Set up the test data."""
self.body = {
"conversation_id": "123",
"messages": [
@@ -438,9 +460,10 @@ def setup_method(self):
]
@patch("create_app.AzureOpenAI")
- def test_converstation_azure_byod_returns_correct_response_when_streaming_with_data_keys(
+ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_data_keys(
self, azure_openai_mock: MagicMock, client: FlaskClient
):
+ """Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
openai_client_mock = azure_openai_mock.return_value
openai_client_mock.chat.completions.create.return_value = (
@@ -515,12 +538,13 @@ def test_converstation_azure_byod_returns_correct_response_when_streaming_with_d
)
@patch("create_app.AzureOpenAI")
- def test_converstation_azure_byod_returns_correct_response_when_streaming_with_data_rbac(
+ def test_conversation_azure_byod_returns_correct_response_when_streaming_with_data_rbac(
self,
azure_openai_mock: MagicMock,
env_helper_mock: MagicMock,
client: FlaskClient,
):
+ """Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.is_auth_type_keys.return_value = False
openai_client_mock = azure_openai_mock.return_value
@@ -563,12 +587,13 @@ def test_converstation_azure_byod_returns_correct_response_when_streaming_with_d
}
@patch("create_app.AzureOpenAI")
- def test_converstation_azure_byod_returns_correct_response_when_not_streaming_with_data(
+ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_with_data(
self,
azure_openai_mock: MagicMock,
env_helper_mock: MagicMock,
client: FlaskClient,
):
+ """Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.SHOULD_STREAM = False
@@ -608,9 +633,10 @@ def test_converstation_azure_byod_returns_correct_response_when_not_streaming_wi
}
@patch("create_app.conversation_with_data")
- def test_converstation_azure_byod_returns_500_when_exception_occurs(
+ def test_conversation_azure_byod_returns_500_when_exception_occurs(
self, conversation_with_data_mock, client
):
+ """Test that an error response is returned when an exception occurs."""
# given
conversation_with_data_mock.side_effect = Exception("Test exception")
@@ -628,9 +654,10 @@ def test_converstation_azure_byod_returns_500_when_exception_occurs(
}
@patch("create_app.AzureOpenAI")
- def test_converstation_azure_byod_returns_correct_response_when_not_streaming_without_data_keys(
+ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_without_data_keys(
self, azure_openai_mock, env_helper_mock, client
):
+ """Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.should_use_data.return_value = False
env_helper_mock.SHOULD_STREAM = False
@@ -691,9 +718,10 @@ def test_converstation_azure_byod_returns_correct_response_when_not_streaming_wi
)
@patch("create_app.AzureOpenAI")
- def test_converstation_azure_byod_returns_correct_response_when_not_streaming_without_data_rbac(
+ def test_conversation_azure_byod_returns_correct_response_when_not_streaming_without_data_rbac(
self, azure_openai_mock, env_helper_mock, client
):
+ """Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.should_use_data.return_value = False
env_helper_mock.SHOULD_STREAM = False
@@ -756,9 +784,10 @@ def test_converstation_azure_byod_returns_correct_response_when_not_streaming_wi
)
@patch("create_app.AzureOpenAI")
- def test_converstation_azure_byod_returns_correct_response_when_streaming_without_data(
+ def test_conversation_azure_byod_returns_correct_response_when_streaming_without_data(
self, azure_openai_mock, env_helper_mock, client
):
+ """Test that the Azure BYOD conversation endpoint returns the correct response."""
# given
env_helper_mock.should_use_data.return_value = False
@@ -792,9 +821,10 @@ def test_converstation_azure_byod_returns_correct_response_when_streaming_withou
)
@patch("create_app.AzureOpenAI")
- def test_converstation_azure_byod_uses_semantic_config(
+ def test_conversation_azure_byod_uses_semantic_config(
self, azure_openai_mock: MagicMock, client: FlaskClient
):
+ """Test that the Azure BYOD conversation endpoint uses the semantic configuration."""
# given
openai_client_mock = azure_openai_mock.return_value
openai_client_mock.chat.completions.create.return_value = (
diff --git a/code/tests/test_batch_push_results.py b/code/tests/test_batch_push_results.py
index b7c39c267..5350d901d 100644
--- a/code/tests/test_batch_push_results.py
+++ b/code/tests/test_batch_push_results.py
@@ -1,3 +1,4 @@
+import json
import sys
import os
import pytest
@@ -15,17 +16,22 @@
@pytest.fixture(autouse=True)
def get_processor_handler_mock():
- with patch("backend.batch.batch_push_results.EmbedderFactory.create") as mock:
- processor_handler = mock.return_value
- yield processor_handler
+ with patch(
+ "backend.batch.batch_push_results.EmbedderFactory.create"
+ ) as mock_create_embedder, patch(
+ "backend.batch.batch_push_results.Search.get_search_handler"
+ ) as mock_get_search_handler:
+ processor_handler_create = mock_create_embedder.return_value
+ processor_handler_get_search_handler = mock_get_search_handler.return_value
+ yield processor_handler_create, processor_handler_get_search_handler
def test_get_file_name_from_message():
mock_queue_message = QueueMessage(
body='{"message": "test message", "filename": "test_filename.md"}'
)
-
- file_name = _get_file_name_from_message(mock_queue_message)
+ message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
+ file_name = _get_file_name_from_message(message_body)
assert file_name == "test_filename.md"
@@ -34,25 +40,95 @@ def test_get_file_name_from_message_no_filename():
mock_queue_message = QueueMessage(
body='{"data": { "url": "test/test/test_filename.md"} }'
)
-
- file_name = _get_file_name_from_message(mock_queue_message)
+ message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
+ file_name = _get_file_name_from_message(message_body)
assert file_name == "test_filename.md"
+def test_batch_push_results_with_unhandled_event_type():
+ mock_queue_message = QueueMessage(
+ body='{"eventType": "Microsoft.Storage.BlobUpdated"}'
+ )
+
+ with pytest.raises(NotImplementedError):
+ batch_push_results.build().get_user_function()(mock_queue_message)
+
+
+@patch("backend.batch.batch_push_results._process_document_created_event")
+def test_batch_push_results_with_blob_created_event(
+ mock_process_document_created_event,
+):
+ mock_queue_message = QueueMessage(
+ body='{"eventType": "Microsoft.Storage.BlobCreated", "filename": "test/test/test_filename.md"}'
+ )
+
+ batch_push_results.build().get_user_function()(mock_queue_message)
+
+ expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
+ mock_process_document_created_event.assert_called_once_with(expected_message_body)
+
+
+@patch("backend.batch.batch_push_results._process_document_created_event")
+def test_batch_push_results_with_no_event(mock_process_document_created_event):
+ mock_queue_message = QueueMessage(
+ body='{"data": { "url": "test/test/test_filename.md"} }'
+ )
+
+ batch_push_results.build().get_user_function()(mock_queue_message)
+
+ expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
+ mock_process_document_created_event.assert_called_once_with(expected_message_body)
+
+
+@patch("backend.batch.batch_push_results._process_document_deleted_event")
+def test_batch_push_results_with_blob_deleted_event(
+ mock_process_document_deleted_event,
+):
+ mock_queue_message = QueueMessage(
+ body='{"eventType": "Microsoft.Storage.BlobDeleted", "filename": "test/test/test_filename.md"}'
+ )
+
+ batch_push_results.build().get_user_function()(mock_queue_message)
+
+ expected_message_body = json.loads(mock_queue_message.get_body().decode("utf-8"))
+ mock_process_document_deleted_event.assert_called_once_with(expected_message_body)
+
+
@patch("backend.batch.batch_push_results.EnvHelper")
@patch("backend.batch.batch_push_results.AzureBlobStorageClient")
-def test_batch_push_results(
- mock_azure_blob_storage_client, mock_env_helper, get_processor_handler_mock
+def test_batch_push_results_with_blob_created_event_uses_embedder(
+ mock_azure_blob_storage_client,
+ mock_env_helper,
+ get_processor_handler_mock,
):
+ mock_create_embedder, mock_get_search_handler = get_processor_handler_mock
+
mock_queue_message = QueueMessage(
- body='{"message": "test message", "filename": "test/test/test_filename.md"}'
+ body='{"eventType": "Microsoft.Storage.BlobCreated", "filename": "test/test/test_filename.md"}'
)
mock_blob_client_instance = mock_azure_blob_storage_client.return_value
mock_blob_client_instance.get_blob_sas.return_value = "test_blob_sas"
batch_push_results.build().get_user_function()(mock_queue_message)
- get_processor_handler_mock.embed_file.assert_called_once_with(
+ mock_create_embedder.embed_file.assert_called_once_with(
"test_blob_sas", "test/test/test_filename.md"
)
+
+
+@patch("backend.batch.batch_push_results.EnvHelper")
+def test_batch_push_results_with_blob_deleted_event_uses_search_to_delete_with_sas_appended(
+ mock_env_helper,
+ get_processor_handler_mock,
+):
+ mock_create_embedder, mock_get_search_handler = get_processor_handler_mock
+
+ mock_queue_message = QueueMessage(
+ body='{"eventType": "Microsoft.Storage.BlobDeleted", "data": { "url": "https://test.test/test/test_filename.pdf"}}'
+ )
+
+ batch_push_results.build().get_user_function()(mock_queue_message)
+ mock_get_search_handler.delete_by_source.assert_called_once_with(
+ "https://test.test/test/test_filename.pdf_SAS_TOKEN_PLACEHOLDER_"
+ )
diff --git a/code/tests/utilities/helpers/test_push_embedder.py b/code/tests/utilities/helpers/test_push_embedder.py
index 48f5a7b0a..fa5434067 100644
--- a/code/tests/utilities/helpers/test_push_embedder.py
+++ b/code/tests/utilities/helpers/test_push_embedder.py
@@ -22,8 +22,13 @@ def llm_helper_mock():
llm_helper.get_embedding_model.return_value.embed_query.return_value = [
0
] * 1536
+ mock_completion = llm_helper.get_chat_completion.return_value
+ choice = MagicMock()
+ choice.message.content = "This is a caption for an image"
+ mock_completion.choices = [choice]
+
llm_helper.generate_embeddings.return_value = [123]
- yield mock
+ yield llm_helper
@pytest.fixture(autouse=True)
@@ -129,7 +134,46 @@ def test_embed_file_advanced_image_processing_vectorizes_image(
)
+def test_embed_file_advanced_image_processing_uses_vision_model_for_captioning(
+ llm_helper_mock,
+):
+ # given
+ env_helper_mock = MagicMock()
+ env_helper_mock.AZURE_OPENAI_VISION_MODEL = "gpt-4"
+ push_embedder = PushEmbedder(MagicMock(), env_helper_mock)
+ source_url = "http://localhost:8080/some-file-name.jpg"
+
+ # when
+ push_embedder.embed_file(source_url, "some-file-name.jpg")
+
+ # then
+ llm_helper_mock.get_chat_completion.assert_called_once_with(
+ [
+ {
+ "role": "system",
+ "content": """You are an assistant that generates rich descriptions of images.
+You need to be accurate in the information you extract and detailed in the descriptons you generate.
+Do not abbreviate anything and do not shorten sentances. Explain the image completely.
+If you are provided with an image of a flow chart, describe the flow chart in detail.
+If the image is mostly text, use OCR to extract the text as it is displayed in the image.""",
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "text": "Describe this image in detail. Limit the response to 500 words.",
+ "type": "text",
+ },
+ {"image_url": source_url, "type": "image_url"},
+ ],
+ },
+ ],
+ env_helper_mock.AZURE_OPENAI_VISION_MODEL,
+ )
+
+
def test_embed_file_advanced_image_processing_stores_embeddings_in_search_index(
+ llm_helper_mock,
azure_computer_vision_mock,
azure_search_helper_mock: MagicMock,
):
@@ -153,12 +197,16 @@ def test_embed_file_advanced_image_processing_stores_embeddings_in_search_index(
hash_key = hashlib.sha1(f"{host_path}_1".encode("utf-8")).hexdigest()
expected_id = f"doc_{hash_key}"
+ llm_helper_mock.generate_embeddings.assert_called_once_with(
+ "This is a caption for an image"
+ )
+
azure_search_helper_mock.return_value.get_search_client.return_value.upload_documents.assert_called_once_with(
[
{
"id": expected_id,
- "content": "",
- "content_vector": [],
+ "content": "This is a caption for an image",
+ "content_vector": [123],
"image_vector": image_embeddings,
"metadata": json.dumps(
{
@@ -265,7 +313,7 @@ def test_embed_file_generates_embeddings_for_documents(llm_helper_mock):
)
# then
- llm_helper_mock.return_value.generate_embeddings.assert_has_calls(
+ llm_helper_mock.generate_embeddings.assert_has_calls(
[call("some content"), call("some other content")]
)
@@ -291,7 +339,7 @@ def test_embed_file_stores_documents_in_search_index(
{
"id": expected_chunked_documents[0].id,
"content": expected_chunked_documents[0].content,
- "content_vector": llm_helper_mock.return_value.generate_embeddings.return_value,
+ "content_vector": llm_helper_mock.generate_embeddings.return_value,
"metadata": json.dumps(
{
"id": expected_chunked_documents[0].id,
@@ -311,7 +359,7 @@ def test_embed_file_stores_documents_in_search_index(
{
"id": expected_chunked_documents[1].id,
"content": expected_chunked_documents[1].content,
- "content_vector": llm_helper_mock.return_value.generate_embeddings.return_value,
+ "content_vector": llm_helper_mock.generate_embeddings.return_value,
"metadata": json.dumps(
{
"id": expected_chunked_documents[1].id,
@@ -338,10 +386,8 @@ def test_embed_file_raises_exception_on_failure(
# given
push_embedder = PushEmbedder(MagicMock(), MagicMock())
- successful_indexing_result = MagicMock()
- successful_indexing_result.succeeded = True
- failed_indexing_result = MagicMock()
- failed_indexing_result.succeeded = False
+ successful_indexing_result = MagicMock(succeeded=True)
+ failed_indexing_result = MagicMock(succeeded=False)
azure_search_helper_mock.return_value.get_search_client.return_value.upload_documents.return_value = [
successful_indexing_result,
failed_indexing_result,
diff --git a/infra/main.bicep b/infra/main.bicep
index a0c3b6597..461fb0816 100644
--- a/infra/main.bicep
+++ b/infra/main.bicep
@@ -108,7 +108,7 @@ param azureOpenAIModelCapacity int = 30
param useAdvancedImageProcessing bool = false
@description('Azure OpenAI Vision Model Deployment Name')
-param azureOpenAIVisionModel string = 'gpt-4-vision'
+param azureOpenAIVisionModel string = 'gpt-4'
@description('Azure OpenAI Vision Model Name')
param azureOpenAIVisionModelName string = 'gpt-4'
diff --git a/infra/main.bicepparam b/infra/main.bicepparam
index 2aaec96f4..e19c2656e 100644
--- a/infra/main.bicepparam
+++ b/infra/main.bicepparam
@@ -25,7 +25,7 @@ param azureOpenAIModelName = readEnvironmentVariable('AZURE_OPENAI_MODEL_NAME',
param azureOpenAIModelVersion = readEnvironmentVariable('AZURE_OPENAI_MODEL_VERSION', '0613')
param azureOpenAIModelCapacity = int(readEnvironmentVariable('AZURE_OPENAI_MODEL_CAPACITY', '30'))
param useAdvancedImageProcessing = bool(readEnvironmentVariable('USE_ADVANCED_IMAGE_PROCESSING', 'false'))
-param azureOpenAIVisionModel = readEnvironmentVariable('AZURE_OPENAI_VISION_MODEL', 'gpt-4-vision')
+param azureOpenAIVisionModel = readEnvironmentVariable('AZURE_OPENAI_VISION_MODEL', 'gpt-4')
param azureOpenAIVisionModelName = readEnvironmentVariable('AZURE_OPENAI_VISION_MODEL_NAME', 'gpt-4')
param azureOpenAIVisionModelVersion = readEnvironmentVariable('AZURE_OPENAI_VISION_MODEL_VERSION', 'vision-preview')
param azureOpenAIVisionModelCapacity = int(readEnvironmentVariable('AZURE_OPENAI_VISION_MODEL_CAPACITY', '10'))
diff --git a/infra/main.json b/infra/main.json
index fec8cca5a..9404b20cc 100644
--- a/infra/main.json
+++ b/infra/main.json
@@ -5,7 +5,7 @@
"_generator": {
"name": "bicep",
"version": "0.27.1.19265",
- "templateHash": "13373198886203455254"
+ "templateHash": "9021391279672164541"
}
},
"parameters": {
@@ -224,7 +224,7 @@
},
"azureOpenAIVisionModel": {
"type": "string",
- "defaultValue": "gpt-4-vision",
+ "defaultValue": "gpt-4",
"metadata": {
"description": "Azure OpenAI Vision Model Deployment Name"
}