Skip to content

Commit

Permalink
feat: Generate embeddings for images (#892)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Dougal <adamdougal@microsoft.com>
Co-authored-by: Adam Dougal <adamdougal@users.noreply.github.com>
Co-authored-by: Arpit Gaur <arpitgaur@microsoft.com>
  • Loading branch information
4 people authored May 15, 2024
1 parent 3b4c2e3 commit a96bde6
Show file tree
Hide file tree
Showing 28 changed files with 730 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
from typing import List
from urllib.parse import urljoin
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

import requests
from requests import Response

from .env_helper import EnvHelper

logger = logging.getLogger(__name__)


class AzureComputerVisionClient:

__TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default"
__VECTORIZE_IMAGE_PATH = "computervision/retrieval:vectorizeImage"
__RESPONSE_VECTOR_KEY = "vector"

def __init__(self, env_helper: EnvHelper) -> None:
self.host = env_helper.AZURE_COMPUTER_VISION_ENDPOINT
self.timeout = env_helper.AZURE_COMPUTER_VISION_TIMEOUT
self.key = env_helper.AZURE_COMPUTER_VISION_KEY
self.use_keys = env_helper.is_auth_type_keys()
self.api_version = env_helper.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_API_VERSION
self.model_version = (
env_helper.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_MODEL_VERSION
)

def vectorize_image(self, image_url: str) -> List[float]:
logger.info(f"Making call to computer vision to vectorize image: {image_url}")
response = self.__make_request(image_url)
self.__validate_response(response)

response_json = self.__get_json_body(response)
return self.__get_vectors(response_json)

def __make_request(self, image_url: str) -> Response:
try:
headers = {}
if self.use_keys:
headers["Ocp-Apim-Subscription-Key"] = self.key
else:
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), self.__TOKEN_SCOPE
)
headers["Authorization"] = "Bearer " + token_provider()

return requests.post(
url=urljoin(self.host, self.__VECTORIZE_IMAGE_PATH),
params={
"api-version": self.api_version,
"model-version": self.model_version,
},
json={"url": image_url},
headers=headers,
timeout=self.timeout,
)
except Exception as e:
raise Exception(f"Call to vectorize image failed: {image_url}") from e

def __validate_response(self, response: Response):
if response.status_code != 200:
raise Exception(
f"Call to vectorize image failed with status: {response.status_code} body: {response.text}"
)

def __get_json_body(self, response: Response) -> dict:
try:
return response.json()
except Exception as e:
raise Exception(
f"Call to vectorize image returned malformed response body: {response.text}",
) from e

def __get_vectors(self, response_json: dict) -> List[float]:
if self.__RESPONSE_VECTOR_KEY in response_json:
return response_json[self.__RESPONSE_VECTOR_KEY]
else:
raise Exception(
f"Call to vectorize image returned no vector: {response_json}"
)
28 changes: 24 additions & 4 deletions code/backend/batch/utilities/helpers/config/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

CONFIG_CONTAINER_NAME = "config"
CONFIG_FILE_NAME = "active.json"
ADVANCED_IMAGE_PROCESSING_FILE_TYPES = ["jpeg", "jpg", "png", "tiff", "bmp"]
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -54,8 +55,8 @@ def __init__(self, config: dict):
else None
)

def get_available_document_types(self):
document_types = [
def get_available_document_types(self) -> list[str]:
document_types = {
"txt",
"pdf",
"url",
Expand All @@ -65,12 +66,15 @@ def get_available_document_types(self):
"jpg",
"png",
"docx",
]
}
if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING:
document_types.extend(["tiff", "bmp"])
document_types.update(ADVANCED_IMAGE_PROCESSING_FILE_TYPES)

return sorted(document_types)

def get_advanced_image_processing_image_types(self):
return ADVANCED_IMAGE_PROCESSING_FILE_TYPES

def get_available_chunking_strategies(self):
return [c.value for c in ChunkingStrategy]

Expand Down Expand Up @@ -180,13 +184,29 @@ def get_active_config_or_default():

@staticmethod
def save_config_as_active(config):
ConfigHelper.validate_config(config)
blob_client = AzureBlobStorageClient(container_name=CONFIG_CONTAINER_NAME)
blob_client = blob_client.upload_file(
json.dumps(config, indent=2),
CONFIG_FILE_NAME,
content_type="application/json",
)

@staticmethod
def validate_config(config: dict):
for document_processor in config.get("document_processors"):
document_type = document_processor.get("document_type")
unsupported_advanced_image_processing_file_type = (
document_type not in ADVANCED_IMAGE_PROCESSING_FILE_TYPES
)
if (
document_processor.get("use_advanced_image_processing")
and unsupported_advanced_image_processing_file_type
):
raise Exception(
f"Advanced image processing has been enabled for document type {document_type}, but only {ADVANCED_IMAGE_PROCESSING_FILE_TYPES} file types are supported."
)

@staticmethod
def get_default_config():
if ConfigHelper._default_config is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def create(env_helper: EnvHelper):
if env_helper.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION:
return IntegratedVectorizationEmbedder(env_helper)
else:
return PushEmbedder(AzureBlobStorageClient())
return PushEmbedder(AzureBlobStorageClient(), env_helper)
35 changes: 26 additions & 9 deletions code/backend/batch/utilities/helpers/embedders/push_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import List

from ...helpers.llm_helper import LLMHelper
from ...helpers.env_helper import EnvHelper
from ..azure_computer_vision_client import AzureComputerVisionClient

from ..azure_blob_storage_client import AzureBlobStorageClient

Expand All @@ -19,30 +21,48 @@


class PushEmbedder(EmbedderBase):
def __init__(self, blob_client: AzureBlobStorageClient):
def __init__(self, blob_client: AzureBlobStorageClient, env_helper: EnvHelper):
self.llm_helper = LLMHelper()
self.azure_search_helper = AzureSearchHelper()
self.azure_computer_vision_client = AzureComputerVisionClient(env_helper)
self.document_loading = DocumentLoading()
self.document_chunking = DocumentChunking()
self.blob_client = blob_client
config = ConfigHelper.get_active_config_or_default()
self.config = ConfigHelper.get_active_config_or_default()
self.embedding_configs = {}
for processor in config.document_processors:
for processor in self.config.document_processors:
ext = processor.document_type.lower()
self.embedding_configs[ext] = processor

def embed_file(self, source_url: str, file_name: str):
file_extension = file_name.split(".")[-1]
embedding_config = self.embedding_configs.get(file_extension)
self.__embed(source_url=source_url, embedding_config=embedding_config)
self.__embed(
source_url=source_url,
file_extension=file_extension,
embedding_config=embedding_config,
)
if file_extension != "url":
self.blob_client.upsert_blob_metadata(
file_name, {"embeddings_added": "true"}
)

def __embed(self, source_url: str, embedding_config: EmbeddingConfig):
def __embed(
self, source_url: str, file_extension: str, embedding_config: EmbeddingConfig
):
documents_to_upload: List[SourceDocument] = []
if not embedding_config.use_advanced_image_processing:
if (
embedding_config.use_advanced_image_processing
and file_extension
in self.config.get_advanced_image_processing_image_types()
):
logger.warning("Advanced image processing is not supported yet")
image_vectors = self.azure_computer_vision_client.vectorize_image(
source_url
)
logger.info("Image vectors: " + str(image_vectors))
# Coming soon, storing the image embeddings in Azure Search
else:
documents: List[SourceDocument] = self.document_loading.load(
source_url, embedding_config.loading
)
Expand All @@ -59,9 +79,6 @@ def __embed(self, source_url: str, embedding_config: EmbeddingConfig):
if not all([r.succeeded for r in response]):
raise Exception(response)

else:
logger.warning("Advanced image processing is not supported yet")

def _convert_to_search_document(self, document: SourceDocument):
embedded_content = self.llm_helper.generate_embeddings(document.content)
metadata = {
Expand Down
19 changes: 19 additions & 0 deletions code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,26 @@ def __load_config(self, **kwargs) -> None:
self.USE_ADVANCED_IMAGE_PROCESSING = self.get_env_var_bool(
"USE_ADVANCED_IMAGE_PROCESSING", "False"
)
self.AZURE_COMPUTER_VISION_ENDPOINT = os.getenv(
"AZURE_COMPUTER_VISION_ENDPOINT"
)
self.AZURE_COMPUTER_VISION_TIMEOUT = self.get_env_var_float(
"AZURE_COMPUTER_VISION_TIMEOUT", 30
)
self.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_API_VERSION = os.getenv(
"AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_API_VERSION", "2024-02-01"
)
self.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_MODEL_VERSION = os.getenv(
"AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_MODEL_VERSION", "2023-04-15"
)

# Initialize Azure keys based on authentication type and environment settings.
# When AZURE_AUTH_TYPE is "rbac", azure keys are None or an empty string.
if self.AZURE_AUTH_TYPE == "rbac":
self.AZURE_SEARCH_KEY = None
self.AZURE_OPENAI_API_KEY = ""
self.AZURE_SPEECH_KEY = None
self.AZURE_COMPUTER_VISION_KEY = None
else:
self.AZURE_SEARCH_KEY = self.secretHelper.get_secret("AZURE_SEARCH_KEY")
self.AZURE_OPENAI_API_KEY = self.secretHelper.get_secret(
Expand All @@ -126,6 +139,9 @@ def __load_config(self, **kwargs) -> None:
self.AZURE_SPEECH_KEY = self.secretHelper.get_secret(
"AZURE_SPEECH_SERVICE_KEY"
)
self.AZURE_COMPUTER_VISION_KEY = self.secretHelper.get_secret(
"AZURE_COMPUTER_VISION_KEY"
)

# Set env for Azure OpenAI
self.AZURE_OPENAI_ENDPOINT = os.environ.get(
Expand Down Expand Up @@ -221,6 +237,9 @@ def get_env_var_bool(self, var_name: str, default: str = "True") -> bool:
def get_env_var_array(self, var_name: str, default: str = ""):
return os.getenv(var_name, default).split(",")

def get_env_var_float(self, var_name: str, default: int):
return float(os.getenv(var_name, default))

def is_auth_type_keys(self):
return self.AZURE_AUTH_TYPE == "keys"

Expand Down
35 changes: 35 additions & 0 deletions code/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import ssl

import pytest
import trustme


@pytest.fixture(scope="session")
def ca():
"""
This fixture is required to run the http mock server with SSL.
https://pytest-httpserver.readthedocs.io/en/latest/howto.html#running-an-https-server
"""
return trustme.CA()


@pytest.fixture(scope="session")
def httpserver_ssl_context(ca):
"""
This fixture is required to run the http mock server with SSL.
https://pytest-httpserver.readthedocs.io/en/latest/howto.html#running-an-https-server
"""
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
localhost_cert = ca.issue_cert("localhost")
localhost_cert.configure_cert(context)
return context


@pytest.fixture(scope="session")
def httpclient_ssl_context(ca):
"""
This fixture is required to run the http mock server with SSL.
https://pytest-httpserver.readthedocs.io/en/latest/howto.html#running-an-https-server
"""
with ca.cert_pem.tempfile() as ca_temp_path:
return ssl.create_default_context(cafile=ca_temp_path)
5 changes: 5 additions & 0 deletions code/tests/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
AZURE_STORAGE_CONFIG_CONTAINER_NAME = "config"
AZURE_STORAGE_CONFIG_FILE_NAME = "active.json"

COMPUTER_VISION_VECTORIZE_IMAGE_PATH = "/computervision/retrieval:vectorizeImage"
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD = "POST"
1 change: 1 addition & 0 deletions code/tests/functional/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class AppConfig:
),
"AZURE_BLOB_ACCOUNT_NAME": "some-blob-account-name",
"AZURE_BLOB_CONTAINER_NAME": "some-blob-container-name",
"AZURE_COMPUTER_VISION_KEY": "some-computer-vision-key",
"AZURE_CONTENT_SAFETY_ENDPOINT": "some-content-safety-endpoint",
"AZURE_CONTENT_SAFETY_KEY": "some-content-safety-key",
"AZURE_FORM_RECOGNIZER_ENDPOINT": "some-form-recognizer-endpoint",
Expand Down
Loading

0 comments on commit a96bde6

Please sign in to comment.