Skip to content

Commit

Permalink
Add support for Azure OpenAi
Browse files Browse the repository at this point in the history
  • Loading branch information
liukidar committed Dec 4, 2024
1 parent 7e65ad1 commit 9645f52
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions fast_graphrag/_llm/_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import asyncio
from dataclasses import dataclass, field
from itertools import chain
from typing import Any, List, Optional, Tuple, Type, cast
from typing import Any, List, Literal, Optional, Tuple, Type, cast

import instructor
import numpy as np
from openai import APIConnectionError, AsyncOpenAI, RateLimitError
from openai import APIConnectionError, AsyncAzureOpenAI, AsyncOpenAI, RateLimitError
from pydantic import BaseModel
from tenacity import (
AsyncRetrying,
Expand All @@ -32,12 +32,21 @@ class OpenAILLMService(BaseLLMService):

model: Optional[str] = field(default="gpt-4o-mini")
mode: instructor.Mode = field(default=instructor.Mode.JSON)
client: Literal["openai", "azure"] = field(default="openai")

def __post_init__(self):
if self.client == "azure":
assert self.base_url is not None, "Azure OpenAI requires a base url."
self.llm_async_client = instructor.from_openai(
AsyncAzureOpenAI(base_url=self.base_url, api_key=self.api_key, timeout=TIMEOUT_SECONDS), mode=self.mode
)
elif self.client == "openai":
self.llm_async_client = instructor.from_openai(
AsyncOpenAI(base_url=self.base_url, api_key=self.api_key, timeout=TIMEOUT_SECONDS), mode=self.mode
)
else:
raise ValueError("Invalid client type. Must be 'openai' or 'azure'")
logger.debug("Initialized OpenAILLMService with patched OpenAI client.")
self.llm_async_client: instructor.AsyncInstructor = instructor.from_openai(
AsyncOpenAI(base_url=self.base_url, api_key=self.api_key, timeout=TIMEOUT_SECONDS), mode=self.mode
)

@throttle_async_func_call(max_concurrent=256, stagger_time=0.001, waiting_time=0.001)
async def send_message(
Expand Down Expand Up @@ -113,11 +122,16 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
embedding_dim: int = field(default=1536)
max_elements_per_request: int = field(default=32)
model: Optional[str] = field(default="text-embedding-3-small")
client: Literal["openai", "azure"] = field(default="openai")

def __post_init__(self):
self.embedding_async_client: AsyncOpenAI = AsyncOpenAI(
base_url=self.base_url, api_key=self.api_key, timeout=TIMEOUT_SECONDS
)
if self.client == "azure":
assert self.base_url is not None, "Azure OpenAI requires a base url."
self.embedding_async_client = AsyncAzureOpenAI(base_url=self.base_url, api_key=self.api_key)
elif self.client == "openai":
self.embedding_async_client = AsyncOpenAI(base_url=self.base_url, api_key=self.api_key)
else:
raise ValueError("Invalid client type. Must be 'openai' or 'azure'")
logger.debug("Initialized OpenAIEmbeddingService with OpenAI client.")

async def encode(self, texts: list[str], model: Optional[str] = None) -> np.ndarray[Any, np.dtype[np.float32]]:
Expand Down

0 comments on commit 9645f52

Please sign in to comment.