Skip to content

Commit

Permalink
fix: handle none connectionstr
Browse files Browse the repository at this point in the history
  • Loading branch information
lorr1 committed Jul 26, 2023
1 parent d941019 commit 7b7c0c9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
26 changes: 14 additions & 12 deletions manifest/clients/azureopenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,34 @@ def connect(
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key, self.host = None, None
if connection_str:
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
if self.api_key is None:
raise ValueError(
"AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
"(e.g. https://openai-azure-service.openai.azure.com/)."
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.host = self.host.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAI_ENGINES:
Expand Down
26 changes: 14 additions & 12 deletions manifest/clients/azureopenai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,34 @@ def connect(
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key, self.host = None, None
if connection_str:
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
if self.api_key is None:
raise ValueError(
"AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
"(e.g. https://openai-azure-service.openai.azure.com/)."
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.host = self.host.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
Expand Down

0 comments on commit 7b7c0c9

Please sign in to comment.