-
Notifications
You must be signed in to change notification settings - Fork 469
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Generate embeddings for images (#892)
Co-authored-by: Adam Dougal <adamdougal@microsoft.com> Co-authored-by: Adam Dougal <adamdougal@users.noreply.github.com> Co-authored-by: Arpit Gaur <arpitgaur@microsoft.com>
- Loading branch information
1 parent
3b4c2e3
commit a96bde6
Showing
28 changed files
with
730 additions
and
80 deletions.
There are no files selected for viewing
82 changes: 82 additions & 0 deletions
82
code/backend/batch/utilities/helpers/azure_computer_vision_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import logging | ||
from typing import List | ||
from urllib.parse import urljoin | ||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider | ||
|
||
import requests | ||
from requests import Response | ||
|
||
from .env_helper import EnvHelper | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AzureComputerVisionClient: | ||
|
||
__TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default" | ||
__VECTORIZE_IMAGE_PATH = "computervision/retrieval:vectorizeImage" | ||
__RESPONSE_VECTOR_KEY = "vector" | ||
|
||
def __init__(self, env_helper: EnvHelper) -> None: | ||
self.host = env_helper.AZURE_COMPUTER_VISION_ENDPOINT | ||
self.timeout = env_helper.AZURE_COMPUTER_VISION_TIMEOUT | ||
self.key = env_helper.AZURE_COMPUTER_VISION_KEY | ||
self.use_keys = env_helper.is_auth_type_keys() | ||
self.api_version = env_helper.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_API_VERSION | ||
self.model_version = ( | ||
env_helper.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_MODEL_VERSION | ||
) | ||
|
||
def vectorize_image(self, image_url: str) -> List[float]: | ||
logger.info(f"Making call to computer vision to vectorize image: {image_url}") | ||
response = self.__make_request(image_url) | ||
self.__validate_response(response) | ||
|
||
response_json = self.__get_json_body(response) | ||
return self.__get_vectors(response_json) | ||
|
||
def __make_request(self, image_url: str) -> Response: | ||
try: | ||
headers = {} | ||
if self.use_keys: | ||
headers["Ocp-Apim-Subscription-Key"] = self.key | ||
else: | ||
token_provider = get_bearer_token_provider( | ||
DefaultAzureCredential(), self.__TOKEN_SCOPE | ||
) | ||
headers["Authorization"] = "Bearer " + token_provider() | ||
|
||
return requests.post( | ||
url=urljoin(self.host, self.__VECTORIZE_IMAGE_PATH), | ||
params={ | ||
"api-version": self.api_version, | ||
"model-version": self.model_version, | ||
}, | ||
json={"url": image_url}, | ||
headers=headers, | ||
timeout=self.timeout, | ||
) | ||
except Exception as e: | ||
raise Exception(f"Call to vectorize image failed: {image_url}") from e | ||
|
||
def __validate_response(self, response: Response): | ||
if response.status_code != 200: | ||
raise Exception( | ||
f"Call to vectorize image failed with status: {response.status_code} body: {response.text}" | ||
) | ||
|
||
def __get_json_body(self, response: Response) -> dict: | ||
try: | ||
return response.json() | ||
except Exception as e: | ||
raise Exception( | ||
f"Call to vectorize image returned malformed response body: {response.text}", | ||
) from e | ||
|
||
def __get_vectors(self, response_json: dict) -> List[float]: | ||
if self.__RESPONSE_VECTOR_KEY in response_json: | ||
return response_json[self.__RESPONSE_VECTOR_KEY] | ||
else: | ||
raise Exception( | ||
f"Call to vectorize image returned no vector: {response_json}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import ssl | ||
|
||
import pytest | ||
import trustme | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def ca(): | ||
""" | ||
This fixture is required to run the http mock server with SSL. | ||
https://pytest-httpserver.readthedocs.io/en/latest/howto.html#running-an-https-server | ||
""" | ||
return trustme.CA() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def httpserver_ssl_context(ca): | ||
""" | ||
This fixture is required to run the http mock server with SSL. | ||
https://pytest-httpserver.readthedocs.io/en/latest/howto.html#running-an-https-server | ||
""" | ||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | ||
localhost_cert = ca.issue_cert("localhost") | ||
localhost_cert.configure_cert(context) | ||
return context | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def httpclient_ssl_context(ca): | ||
""" | ||
This fixture is required to run the http mock server with SSL. | ||
https://pytest-httpserver.readthedocs.io/en/latest/howto.html#running-an-https-server | ||
""" | ||
with ca.cert_pem.tempfile() as ca_temp_path: | ||
return ssl.create_default_context(cafile=ca_temp_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
AZURE_STORAGE_CONFIG_CONTAINER_NAME = "config" | ||
AZURE_STORAGE_CONFIG_FILE_NAME = "active.json" | ||
|
||
COMPUTER_VISION_VECTORIZE_IMAGE_PATH = "/computervision/retrieval:vectorizeImage" | ||
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD = "POST" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.