Skip to content

Commit

Permalink
Upgrade keyvault secrets library (Azure-Samples#479)
Browse files Browse the repository at this point in the history
* Upgrade keyvault secrets library

* Add comment in test

* typo!

* Mock SecretClient (Azure-Samples#483)

Co-authored-by: Chinedum Echeta <60179183+cecheta@users.noreply.github.com>

---------

Co-authored-by: Chinedum Echeta <60179183+cecheta@users.noreply.github.com>
  • Loading branch information
ross-p-smith and cecheta authored Mar 20, 2024
1 parent 4b8c2e3 commit 2ddba25
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 65 deletions.
9 changes: 4 additions & 5 deletions code/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def prepare_body_headers_with_data(request):
],
}

chatgpt_url = f"https://{env_helper.AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{env_helper.AZURE_OPENAI_MODEL}"
chatgpt_url = f"{env_helper.AZURE_OPENAI_ENDPOINT}openai/deployments/{env_helper.AZURE_OPENAI_MODEL}"
if env_helper.is_chat_model():
chatgpt_url += "/chat/completions?api-version=2023-12-01-preview"
else:
Expand Down Expand Up @@ -166,7 +166,7 @@ def stream_with_data(body, headers, endpoint):

def conversation_with_data(request):
body, headers = prepare_body_headers_with_data(request)
endpoint = f"https://{env_helper.AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{env_helper.AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={env_helper.AZURE_OPENAI_API_VERSION}"
endpoint = f"{env_helper.AZURE_OPENAI_ENDPOINT}openai/deployments/{env_helper.AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={env_helper.AZURE_OPENAI_API_VERSION}"

if not env_helper.SHOULD_STREAM:
r = requests.post(endpoint, headers=headers, json=body)
Expand Down Expand Up @@ -202,16 +202,15 @@ def stream_without_data(response):


def conversation_without_data(request):
azure_endpoint = f"https://{env_helper.AZURE_OPENAI_RESOURCE}.openai.azure.com/"
if env_helper.AZURE_AUTH_TYPE == "rbac":
openai_client = AzureOpenAI(
azure_endpoint=azure_endpoint,
azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT,
api_version=env_helper.AZURE_OPENAI_API_VERSION,
azure_ad_token_provider=env_helper.AZURE_TOKEN_PROVIDER,
)
else:
openai_client = AzureOpenAI(
azure_endpoint=azure_endpoint,
azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT,
api_version=env_helper.AZURE_OPENAI_API_VERSION,
api_key=env_helper.AZURE_OPENAI_API_KEY,
)
Expand Down
95 changes: 59 additions & 36 deletions code/backend/batch/utilities/helpers/EnvHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class EnvHelper:
def __init__(self, **kwargs) -> None:
load_dotenv()

# Wrapper for Azure Key Vault
self.secretHelper = SecretHelper()

# Azure Search
self.AZURE_SEARCH_SERVICE = os.getenv("AZURE_SEARCH_SERVICE", "")
self.AZURE_SEARCH_INDEX = os.getenv("AZURE_SEARCH_INDEX", "")
Expand Down Expand Up @@ -76,33 +79,20 @@ def __init__(self, **kwargs) -> None:
self.AZURE_TOKEN_PROVIDER = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
self.USE_KEY_VAULT = os.getenv("USE_KEY_VAULT", "").lower() == "true"
# Initialize Azure keys based on authentication type and environment settings.
# When AZURE_AUTH_TYPE is "rbac", azure keys are None or an empty string.
# When USE_KEY_VAULT environment variable is set, keys are securely fetched from Azure Key Vault using DefaultAzureCredential.
# Otherwise, keys are obtained from environment variables.
if self.AZURE_AUTH_TYPE == "rbac":
self.AZURE_SEARCH_KEY = None
self.AZURE_OPENAI_API_KEY = ""
self.AZURE_SPEECH_KEY = None
elif self.USE_KEY_VAULT:
credential = DefaultAzureCredential()
self.secret_client = SecretClient(
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), credential
)
self.AZURE_SEARCH_KEY = self.secret_client.get_secret(
os.environ.get("AZURE_SEARCH_KEY")
).value
self.AZURE_OPENAI_API_KEY = self.secret_client.get_secret(
os.environ.get("AZURE_OPENAI_API_KEY", "")
).value # langchain expects this.
self.AZURE_SPEECH_KEY = self.secret_client.get_secret(
os.environ.get("AZURE_SPEECH_SERVICE_KEY")
).value
else:
self.AZURE_SEARCH_KEY = os.environ.get("AZURE_SEARCH_KEY")
self.AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
self.AZURE_SPEECH_KEY = os.environ.get("AZURE_SPEECH_SERVICE_KEY")
self.AZURE_SEARCH_KEY = self.secretHelper.get_secret("AZURE_SEARCH_KEY")
self.AZURE_OPENAI_API_KEY = self.secretHelper.get_secret(
"AZURE_OPENAI_API_KEY"
)
self.AZURE_SPEECH_KEY = self.secretHelper.get_secret(
"AZURE_SPEECH_SERVICE_KEY"
)

# Set env for Azure OpenAI
self.AZURE_OPENAI_ENDPOINT = os.environ.get(
Expand All @@ -126,22 +116,16 @@ def __init__(self, **kwargs) -> None:
)
# Azure Blob Storage
self.AZURE_BLOB_ACCOUNT_NAME = os.getenv("AZURE_BLOB_ACCOUNT_NAME", "")
self.AZURE_BLOB_ACCOUNT_KEY = (
self.secret_client.get_secret(os.getenv("AZURE_BLOB_ACCOUNT_KEY", "")).value
if self.USE_KEY_VAULT
else os.getenv("AZURE_BLOB_ACCOUNT_KEY", "")
self.AZURE_BLOB_ACCOUNT_KEY = self.secretHelper.get_secret(
"AZURE_BLOB_ACCOUNT_KEY"
)
self.AZURE_BLOB_CONTAINER_NAME = os.getenv("AZURE_BLOB_CONTAINER_NAME", "")
# Azure Form Recognizer
self.AZURE_FORM_RECOGNIZER_ENDPOINT = os.getenv(
"AZURE_FORM_RECOGNIZER_ENDPOINT", ""
)
self.AZURE_FORM_RECOGNIZER_KEY = (
self.secret_client.get_secret(
os.getenv("AZURE_FORM_RECOGNIZER_KEY", "")
).value
if self.USE_KEY_VAULT
else os.getenv("AZURE_FORM_RECOGNIZER_KEY", "")
self.AZURE_FORM_RECOGNIZER_KEY = self.secretHelper.get_secret(
"AZURE_FORM_RECOGNIZER_KEY"
)
# Azure App Insights
self.APPINSIGHTS_CONNECTION_STRING = os.getenv(
Expand All @@ -156,12 +140,8 @@ def __init__(self, **kwargs) -> None:
and "api.cognitive.microsoft.com" not in self.AZURE_CONTENT_SAFETY_ENDPOINT
):
self.AZURE_CONTENT_SAFETY_ENDPOINT = self.AZURE_FORM_RECOGNIZER_ENDPOINT
self.AZURE_CONTENT_SAFETY_KEY = (
self.secret_client.get_secret(
os.getenv("AZURE_CONTENT_SAFETY_KEY", "")
).value
if self.USE_KEY_VAULT
else os.getenv("AZURE_CONTENT_SAFETY_KEY", "")
self.AZURE_CONTENT_SAFETY_KEY = self.secretHelper.get_secret(
"AZURE_CONTENT_SAFETY_KEY"
)
# Orchestration Settings
self.ORCHESTRATION_STRATEGY = os.getenv(
Expand Down Expand Up @@ -189,3 +169,46 @@ def check_env():
for attr, value in EnvHelper().__dict__.items():
if value == "":
logging.warning(f"{attr} is not set in the environment variables.")


class SecretHelper:
def __init__(self) -> None:
"""
Initializes an instance of the SecretHelper class.
The constructor sets the USE_KEY_VAULT attribute based on the value of the USE_KEY_VAULT environment variable.
If USE_KEY_VAULT is set to "true" (case-insensitive), it initializes a SecretClient object using the
AZURE_KEY_VAULT_ENDPOINT environment variable and the DefaultAzureCredential.
Args:
None
Returns:
None
"""
self.USE_KEY_VAULT = os.getenv("USE_KEY_VAULT", "").lower() == "true"
self.secret_client = None
if self.USE_KEY_VAULT:
self.secret_client = SecretClient(
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), DefaultAzureCredential()
)

def get_secret(self, secret_name: str) -> str:
"""
Retrieves the value of a secret from the environment variables or Azure Key Vault.
Args:
secret_name (str): The name of the secret or "".
Returns:
str: The value of the secret.
Raises:
None
"""
return (
self.secret_client.get_secret(os.getenv(secret_name, "")).value
if self.USE_KEY_VAULT
else os.getenv(secret_name, "")
)
2 changes: 1 addition & 1 deletion code/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_returns_correct_config(self):

assert response.status_code == 200
assert response.json == {
"azureSpeechKey": None,
"azureSpeechKey": "",
"azureSpeechRegion": None,
"AZURE_OPENAI_ENDPOINT": "https://.openai.azure.com/",
}
Expand Down
6 changes: 3 additions & 3 deletions code/tests/utilities/test_LangChainAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from backend.batch.utilities.common.Answer import Answer


class Testing_LangChainAgent(LangChainAgent):
class LangChainAgentNoInit(LangChainAgent):
def __init__(self) -> None:
self.content_safety_checker = MagicMock()
self.question_answer_tool = MagicMock()
Expand All @@ -14,7 +14,7 @@ def __init__(self) -> None:
def test_run_tool_returns_answer_json():
# Given
user_message = "Hello"
agent = Testing_LangChainAgent()
agent = LangChainAgentNoInit()
answer = Answer(
question=user_message,
answer="Hello, how can I help you?",
Expand All @@ -40,7 +40,7 @@ def test_run_tool_returns_answer_json():
def test_run_text_processing_tool_returns_answer_json():
# Given
user_message = "Hello"
agent = Testing_LangChainAgent()
agent = LangChainAgentNoInit()
answer = Answer(
question=user_message,
answer="Hello, how can I help you?",
Expand Down
48 changes: 48 additions & 0 deletions code/tests/utilities/test_SecretHelper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest.mock import MagicMock, patch
from pytest import MonkeyPatch
from backend.batch.utilities.helpers.EnvHelper import SecretHelper


def test_get_secret_returns_value_from_environment_variables(monkeypatch: MonkeyPatch):
# given
secret_name = "MY_SECRET"
expected_value = "my_secret_value"
monkeypatch.setenv(secret_name, expected_value)
secret_helper = SecretHelper()

# when
actual_value = secret_helper.get_secret(secret_name)

# then
assert actual_value == expected_value


@patch("backend.batch.utilities.helpers.EnvHelper.SecretClient")
def test_get_secret_returns_value_from_secret_client_when_use_key_vault_is_true(
secret_client: MagicMock, monkeypatch: MonkeyPatch
):
# given
secret_name = "MY_SECRET"
expected_value = "my_secret_value"
monkeypatch.setenv("USE_KEY_VAULT", "true")
secret_client.return_value.get_secret.return_value.value = expected_value
secret_helper = SecretHelper()

# when
actual_value = secret_helper.get_secret(secret_name)

# then
assert actual_value == expected_value


def test_get_secret_returns_empty_string_when_secret_name_is_empty():
# given
secret_name = ""
expected_value = ""
secret_helper = SecretHelper()

# when
actual_value = secret_helper.get_secret(secret_name)

# then
assert actual_value == expected_value
36 changes: 18 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ azure-ai-formrecognizer = "3.3.2"
azure-storage-blob = "12.19.1"
azure-identity = "1.15.0"
flask = "3.0.2"
openai = "1.14.1"
openai = "1.14.2"
langchain = "0.1.12"
langchain-community = "0.0.28"
langchain-openai = "0.0.7"
Expand All @@ -33,7 +33,7 @@ azure-search-documents = "11.4.0"
opencensus-ext-azure = "1.1.13"
azure-ai-contentsafety = "1.0.0"
python-docx = "1.1.0"
azure-keyvault-secrets = "4.4.*"
azure-keyvault-secrets = "4.8.0"
pandas = "2.2.1"

[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit 2ddba25

Please sign in to comment.