Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade keyvault secrets library #479

Merged
merged 6 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions code/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def prepare_body_headers_with_data(request):
],
}

chatgpt_url = f"https://{env_helper.AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{env_helper.AZURE_OPENAI_MODEL}"
chatgpt_url = f"{env_helper.AZURE_OPENAI_ENDPOINT}openai/deployments/{env_helper.AZURE_OPENAI_MODEL}"
if env_helper.is_chat_model():
chatgpt_url += "/chat/completions?api-version=2023-12-01-preview"
else:
Expand Down Expand Up @@ -166,7 +166,7 @@ def stream_with_data(body, headers, endpoint):

def conversation_with_data(request):
body, headers = prepare_body_headers_with_data(request)
endpoint = f"https://{env_helper.AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{env_helper.AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={env_helper.AZURE_OPENAI_API_VERSION}"
endpoint = f"{env_helper.AZURE_OPENAI_ENDPOINT}openai/deployments/{env_helper.AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={env_helper.AZURE_OPENAI_API_VERSION}"

if not env_helper.SHOULD_STREAM:
r = requests.post(endpoint, headers=headers, json=body)
Expand Down Expand Up @@ -202,16 +202,15 @@ def stream_without_data(response):


def conversation_without_data(request):
azure_endpoint = f"https://{env_helper.AZURE_OPENAI_RESOURCE}.openai.azure.com/"
if env_helper.AZURE_AUTH_TYPE == "rbac":
openai_client = AzureOpenAI(
azure_endpoint=azure_endpoint,
azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT,
api_version=env_helper.AZURE_OPENAI_API_VERSION,
azure_ad_token_provider=env_helper.AZURE_TOKEN_PROVIDER,
)
else:
openai_client = AzureOpenAI(
azure_endpoint=azure_endpoint,
azure_endpoint=env_helper.AZURE_OPENAI_ENDPOINT,
api_version=env_helper.AZURE_OPENAI_API_VERSION,
api_key=env_helper.AZURE_OPENAI_API_KEY,
)
Expand Down
95 changes: 59 additions & 36 deletions code/backend/batch/utilities/helpers/EnvHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class EnvHelper:
def __init__(self, **kwargs) -> None:
load_dotenv()

# Wrapper for Azure Key Vault
self.secretHelper = SecretHelper()

# Azure Search
self.AZURE_SEARCH_SERVICE = os.getenv("AZURE_SEARCH_SERVICE", "")
self.AZURE_SEARCH_INDEX = os.getenv("AZURE_SEARCH_INDEX", "")
Expand Down Expand Up @@ -76,33 +79,20 @@ def __init__(self, **kwargs) -> None:
self.AZURE_TOKEN_PROVIDER = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
self.USE_KEY_VAULT = os.getenv("USE_KEY_VAULT", "").lower() == "true"
# Initialize Azure keys based on authentication type and environment settings.
# When AZURE_AUTH_TYPE is "rbac", azure keys are None or an empty string.
# When USE_KEY_VAULT environment variable is set, keys are securely fetched from Azure Key Vault using DefaultAzureCredential.
# Otherwise, keys are obtained from environment variables.
if self.AZURE_AUTH_TYPE == "rbac":
self.AZURE_SEARCH_KEY = None
self.AZURE_OPENAI_API_KEY = ""
self.AZURE_SPEECH_KEY = None
elif self.USE_KEY_VAULT:
credential = DefaultAzureCredential()
self.secret_client = SecretClient(
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), credential
)
self.AZURE_SEARCH_KEY = self.secret_client.get_secret(
os.environ.get("AZURE_SEARCH_KEY")
).value
self.AZURE_OPENAI_API_KEY = self.secret_client.get_secret(
os.environ.get("AZURE_OPENAI_API_KEY", "")
).value # langchain expects this.
self.AZURE_SPEECH_KEY = self.secret_client.get_secret(
os.environ.get("AZURE_SPEECH_SERVICE_KEY")
).value
else:
self.AZURE_SEARCH_KEY = os.environ.get("AZURE_SEARCH_KEY")
self.AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
self.AZURE_SPEECH_KEY = os.environ.get("AZURE_SPEECH_SERVICE_KEY")
self.AZURE_SEARCH_KEY = self.secretHelper.get_secret("AZURE_SEARCH_KEY")
self.AZURE_OPENAI_API_KEY = self.secretHelper.get_secret(
"AZURE_OPENAI_API_KEY"
)
self.AZURE_SPEECH_KEY = self.secretHelper.get_secret(
"AZURE_SPEECH_SERVICE_KEY"
)

# Set env for Azure OpenAI
self.AZURE_OPENAI_ENDPOINT = os.environ.get(
Expand All @@ -126,22 +116,16 @@ def __init__(self, **kwargs) -> None:
)
# Azure Blob Storage
self.AZURE_BLOB_ACCOUNT_NAME = os.getenv("AZURE_BLOB_ACCOUNT_NAME", "")
self.AZURE_BLOB_ACCOUNT_KEY = (
self.secret_client.get_secret(os.getenv("AZURE_BLOB_ACCOUNT_KEY", "")).value
if self.USE_KEY_VAULT
else os.getenv("AZURE_BLOB_ACCOUNT_KEY", "")
self.AZURE_BLOB_ACCOUNT_KEY = self.secretHelper.get_secret(
"AZURE_BLOB_ACCOUNT_KEY"
)
self.AZURE_BLOB_CONTAINER_NAME = os.getenv("AZURE_BLOB_CONTAINER_NAME", "")
# Azure Form Recognizer
self.AZURE_FORM_RECOGNIZER_ENDPOINT = os.getenv(
"AZURE_FORM_RECOGNIZER_ENDPOINT", ""
)
self.AZURE_FORM_RECOGNIZER_KEY = (
self.secret_client.get_secret(
os.getenv("AZURE_FORM_RECOGNIZER_KEY", "")
).value
if self.USE_KEY_VAULT
else os.getenv("AZURE_FORM_RECOGNIZER_KEY", "")
self.AZURE_FORM_RECOGNIZER_KEY = self.secretHelper.get_secret(
"AZURE_FORM_RECOGNIZER_KEY"
)
# Azure App Insights
self.APPINSIGHTS_CONNECTION_STRING = os.getenv(
Expand All @@ -156,12 +140,8 @@ def __init__(self, **kwargs) -> None:
and "api.cognitive.microsoft.com" not in self.AZURE_CONTENT_SAFETY_ENDPOINT
):
self.AZURE_CONTENT_SAFETY_ENDPOINT = self.AZURE_FORM_RECOGNIZER_ENDPOINT
self.AZURE_CONTENT_SAFETY_KEY = (
self.secret_client.get_secret(
os.getenv("AZURE_CONTENT_SAFETY_KEY", "")
).value
if self.USE_KEY_VAULT
else os.getenv("AZURE_CONTENT_SAFETY_KEY", "")
self.AZURE_CONTENT_SAFETY_KEY = self.secretHelper.get_secret(
"AZURE_CONTENT_SAFETY_KEY"
)
# Orchestration Settings
self.ORCHESTRATION_STRATEGY = os.getenv(
Expand Down Expand Up @@ -189,3 +169,46 @@ def check_env():
for attr, value in EnvHelper().__dict__.items():
if value == "":
logging.warning(f"{attr} is not set in the environment variables.")


class SecretHelper:
ross-p-smith marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self) -> None:
"""
Initializes an instance of the SecretHelper class.

The constructor sets the USE_KEY_VAULT attribute based on the value of the USE_KEY_VAULT environment variable.
If USE_KEY_VAULT is set to "true" (case-insensitive), it initializes a SecretClient object using the
AZURE_KEY_VAULT_ENDPOINT environment variable and the DefaultAzureCredential.

Args:
None

Returns:
None
"""
self.USE_KEY_VAULT = os.getenv("USE_KEY_VAULT", "").lower() == "true"
self.secret_client = None
if self.USE_KEY_VAULT:
self.secret_client = SecretClient(
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), DefaultAzureCredential()
)

def get_secret(self, secret_name: str) -> str:
"""
Retrieves the value of a secret from the environment variables or Azure Key Vault.

Args:
secret_name (str): The name of the secret or "".

Returns:
str: The value of the secret.

Raises:
None

"""
return (
self.secret_client.get_secret(os.getenv(secret_name, "")).value
if self.USE_KEY_VAULT
else os.getenv(secret_name, "")
)
2 changes: 1 addition & 1 deletion code/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_returns_correct_config(self):

assert response.status_code == 200
assert response.json == {
"azureSpeechKey": None,
"azureSpeechKey": "",
"azureSpeechRegion": None,
"AZURE_OPENAI_ENDPOINT": "https://.openai.azure.com/",
}
Expand Down
6 changes: 3 additions & 3 deletions code/tests/utilities/test_LangChainAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from backend.batch.utilities.common.Answer import Answer


class Testing_LangChainAgent(LangChainAgent):
class LangChainAgentNoInit(LangChainAgent):
def __init__(self) -> None:
self.content_safety_checker = MagicMock()
self.question_answer_tool = MagicMock()
Expand All @@ -14,7 +14,7 @@ def __init__(self) -> None:
def test_run_tool_returns_answer_json():
# Given
user_message = "Hello"
agent = Testing_LangChainAgent()
agent = LangChainAgentNoInit()
answer = Answer(
question=user_message,
answer="Hello, how can I help you?",
Expand All @@ -40,7 +40,7 @@ def test_run_tool_returns_answer_json():
def test_run_text_processing_tool_returns_answer_json():
# Given
user_message = "Hello"
agent = Testing_LangChainAgent()
agent = LangChainAgentNoInit()
answer = Answer(
question=user_message,
answer="Hello, how can I help you?",
Expand Down
48 changes: 48 additions & 0 deletions code/tests/utilities/test_SecretHelper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest.mock import MagicMock, patch
from pytest import MonkeyPatch
from backend.batch.utilities.helpers.EnvHelper import SecretHelper


def test_get_secret_returns_value_from_environment_variables(monkeypatch: MonkeyPatch):
# given
secret_name = "MY_SECRET"
expected_value = "my_secret_value"
monkeypatch.setenv(secret_name, expected_value)
secret_helper = SecretHelper()

# when
actual_value = secret_helper.get_secret(secret_name)

# then
assert actual_value == expected_value


@patch("backend.batch.utilities.helpers.EnvHelper.SecretClient")
def test_get_secret_returns_value_from_secret_client_when_use_key_vault_is_true(
secret_client: MagicMock, monkeypatch: MonkeyPatch
):
# given
secret_name = "MY_SECRET"
expected_value = "my_secret_value"
monkeypatch.setenv("USE_KEY_VAULT", "true")
secret_client.return_value.get_secret.return_value.value = expected_value
secret_helper = SecretHelper()

# when
actual_value = secret_helper.get_secret(secret_name)

# then
assert actual_value == expected_value


def test_get_secret_returns_empty_string_when_secret_name_is_empty():
# given
secret_name = ""
expected_value = ""
secret_helper = SecretHelper()

# when
actual_value = secret_helper.get_secret(secret_name)

# then
assert actual_value == expected_value
36 changes: 18 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ azure-ai-formrecognizer = "3.3.2"
azure-storage-blob = "12.19.1"
azure-identity = "1.15.0"
flask = "3.0.2"
openai = "1.14.1"
openai = "1.14.2"
langchain = "0.1.12"
langchain-community = "0.0.28"
langchain-openai = "0.0.7"
Expand All @@ -33,7 +33,7 @@ azure-search-documents = "11.4.0"
opencensus-ext-azure = "1.1.13"
azure-ai-contentsafety = "1.0.0"
python-docx = "1.1.0"
azure-keyvault-secrets = "4.4.*"
azure-keyvault-secrets = "4.8.0"
pandas = "2.2.1"

[tool.poetry.group.dev.dependencies]
Expand Down