Skip to content

Commit

Permalink
Add support for Azure for vectordb embedding + LLM Model Shield
Browse files Browse the repository at this point in the history
  • Loading branch information
cmpxchg16 committed Mar 31, 2024
1 parent a3ad618 commit dc6fe89
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from uuid import UUID

import httpx
from openai import OpenAI
from openai import OpenAI ,AzureOpenAI
from pydantic import Json

from vibraniumdome_shields.settings_loader import settings
Expand All @@ -27,12 +27,20 @@ class CaptainsShield(VibraniumShield):

def __init__(self, openai_api_key):
super().__init__(self._shield_name)
if not openai_api_key:
raise ValueError("LLMShield missed openai_api_key")
self._openai_client = OpenAI(
api_key=openai_api_key,
max_retries=3,
timeout=httpx.Timeout(60.0, read=10.0, write=10.0, connect=2.0))
if os.getenv("OPENAI_API_TYPE") == "azure":
self._openai_client = AzureOpenAI(
api_version=os.environ.get("AZURE_OPENAI_VERSION"),
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"),
api_key=os.environ.get("AZURE_OPENAI_KEY"),
max_retries=3,
timeout=httpx.Timeout(60.0, read=10.0, write=10.0, connect=2.0))
else:
if not openai_api_key:
raise ValueError("LLMShield missed openai_api_key")
self._openai_client = OpenAI(
api_key=openai_api_key,
max_retries=3,
timeout=httpx.Timeout(60.0, read=10.0, write=10.0, connect=2.0))

@captains_shield_seconds_histogram.time()
def deflect(self, llm_interaction: LLMInteraction, shield_policy_config: dict, scan_id: UUID, policy: dict) -> List[ShieldDeflectionResult]:
Expand All @@ -49,7 +57,7 @@ def deflect(self, llm_interaction: LLMInteraction, shield_policy_config: dict, s
}

if os.getenv("OPENAI_API_TYPE") == "azure":
params["engine"] = os.getenv("OPENAI_API_DEPLOYMENT")
params["model"] = os.environ.get("AZURE_OPENAI_DEPLOYMENT")
else:
params["model"] = shield_policy_config.get("model", settings.get("openai.openai_model", "gpt-3.5-turbo"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from datasets import load_dataset
from langchain.docstore.document import Document
from langchain.embeddings import OpenAIEmbeddings

# from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from vibraniumdome_shields.utils import load_vibranium_home, uuid4_str

Expand All @@ -22,10 +21,18 @@ def __init__(self, vector_db_dir, index_name, embedding_model_name):
self.vector_store_dir = os.path.join(vector_db_dir, "vibranium-vector-store")
self._index_name = index_name
self._vector_store = FAISS
self._embeddings = OpenAIEmbeddings(
chunk_size=16 if os.getenv("OPENAI_API_TYPE") == "azure" else 1000,
model=embedding_model_name
) # 1000 is the default also in OpenAIEmbeddings, and 16 in Azure limit

# Note: requires those env vars
####################################################################################
# os.environ["AZURE_OPENAI_API_KEY"] = "..."
# os.environ["AZURE_OPENAI_ENDPOINT"] = "https://<your-endpoint>.openai.azure.com/"

if os.getenv("OPENAI_API_TYPE") == "azure":
self._embeddings = AzureOpenAIEmbeddings(
azure_deployment=os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
openai_api_version=os.environ.get("AZURE_OPENAI_EMBEDDING_VERSION"),)
else:
self._embeddings = OpenAIEmbeddings(model=embedding_model_name)

if os.path.exists(self.vector_store_file_path):
self._vector_store = self._vector_store.load_local(folder_path=self.vector_store_dir, embeddings=self._embeddings, index_name=self._index_name)
Expand Down

0 comments on commit dc6fe89

Please sign in to comment.