-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add VertexAI provider #11
Changes from all commits
8470f1c
cc568b5
de4f9a2
0485f6a
11732c4
e2171c1
ea8744c
27f98f3
380b117
2077a67
d53aa80
8efc81e
1cde9cb
cd44390
8ec2ff7
c7dad78
66c0e74
087ff28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,12 @@ | ||
# A base class for API providers, defining the interface and common properties for subclasses. | ||
|
||
|
||
class BaseAPIProvider: | ||
MODELS = {} | ||
|
||
def __init__(self): | ||
self.api_key: str | ||
|
||
class BaseAPIProvider(object): | ||
def set_model(self, model_name: str): | ||
raise NotImplementedError("Subclass must implement set_model") | ||
|
||
def get_models(self) -> dict: | ||
raise NotImplementedError("Subclass must implement get_model") | ||
raise NotImplementedError("Subclass must implement get_models") | ||
|
||
def generate_response(prompt: str) -> str: | ||
def generate_response(self, prompt: str, system_content: str) -> str: | ||
raise NotImplementedError("Subclass must implement generate_response") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import logging | ||
import os | ||
|
||
import google.api_core.exceptions | ||
import vertexai.generative_models | ||
|
||
from .base_provider import BaseAPIProvider | ||
|
||
logging.basicConfig(level=logging.ERROR) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VertexAPI(BaseAPIProvider): | ||
VERTEX_AI_PROVIDER = "VertexAI" | ||
MODELS = { | ||
"gemini-1.5-flash-001": { | ||
"name": "Gemini 1.5 Flash 001", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
"gemini-1.5-flash-002": { | ||
"name": "Gemini 1.5 Flash 002", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
"gemini-1.5-pro-002": { | ||
"name": "Gemini 1.5 Pro 002", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
"gemini-1.5-pro-001": { | ||
"name": "Gemini 1.5 Pro 001", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
"gemini-1.0-pro-002": { | ||
"name": "Gemini 1.0 Pro 002", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
"gemini-1.0-pro-001": { | ||
"name": "Gemini 1.0 Pro 001", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": False, | ||
}, | ||
"gemini-flash-experimental": { | ||
"name": "Gemini Flash Experimental", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
"gemini-pro-experimental": { | ||
"name": "Gemini Pro Experimental", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
"gemini-experimental": { | ||
"name": "Gemini Experimental", | ||
"provider": VERTEX_AI_PROVIDER, | ||
"max_tokens": 8192, | ||
"system_instruction_supported": True, | ||
}, | ||
} | ||
|
||
def __init__(self): | ||
self.enabled = bool(os.environ.get("VERTEX_AI_PROJECT_ID", "")) | ||
if self.enabled: | ||
vertexai.init( | ||
project=os.environ.get("VERTEX_AI_PROJECT_ID"), | ||
location=os.environ.get("VERTEX_AI_LOCATION"), | ||
) | ||
|
||
def set_model(self, model_name: str): | ||
if model_name not in self.MODELS.keys(): | ||
raise ValueError("Invalid model") | ||
self.current_model = model_name | ||
|
||
def get_models(self) -> dict: | ||
if self.enabled: | ||
return self.MODELS | ||
else: | ||
return {} | ||
|
||
def generate_response(self, prompt: str, system_content: str) -> str: | ||
system_instruction = None | ||
if self.MODELS[self.current_model]["system_instruction_supported"]: | ||
system_instruction = system_content | ||
else: | ||
prompt = system_content + "\n" + prompt | ||
Comment on lines
+92
to
+96
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated this to remove Argument of type "str" cannot be assigned to parameter "tool_config" of type "ToolConfig | None" in function "__init__"
Type "str" is not assignable to type "ToolConfig | None"
"str" is not assignable to "ToolConfig"
"str" is not assignable to "None" I found that these changes continue to work fine for both cases, and I'm hoping the set variables before model setup helps make future updates clear 🔭 |
||
|
||
try: | ||
self.client = vertexai.generative_models.GenerativeModel( | ||
model_name=self.current_model, | ||
generation_config={ | ||
"max_output_tokens": self.MODELS[self.current_model]["max_tokens"], | ||
}, | ||
system_instruction=system_instruction, | ||
) | ||
response = self.client.generate_content( | ||
contents=prompt, | ||
) | ||
return "".join(part.text for part in response.candidates[0].content.parts) | ||
|
||
except google.api_core.exceptions.Unauthorized as e: | ||
logger.error(f"Client is not Authorized. {e.reason}, {e.message}") | ||
raise e | ||
except google.api_core.exceptions.Forbidden as e: | ||
logger.error(f"Client Forbidden. {e.reason}, {e.message}") | ||
raise e | ||
except google.api_core.exceptions.TooManyRequests as e: | ||
logger.error(f"Too many requests. {e.reason}, {e.message}") | ||
raise e | ||
except google.api_core.exceptions.ClientError as e: | ||
logger.error(f"Client error: {e.reason}, {e.message}") | ||
raise e | ||
except google.api_core.exceptions.ServerError as e: | ||
logger.error(f"Server error: {e.reason}, {e.message}") | ||
raise e | ||
except google.api_core.exceptions.GoogleAPICallError as e: | ||
logger.error(f"Error: {e.reason}, {e.message}") | ||
raise e | ||
except google.api_core.exceptions.GoogleAPIError as e: | ||
logger.error(f"Unknown error. {e}") | ||
raise e |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ black==24.8.0 | |
slack-cli-hooks==0.0.2 | ||
openai==1.47.1 | ||
anthropic==0.34.2 | ||
google-cloud-aiplatform==1.67.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a super nice approach and I'm glad it can be known within the models of this class 🙏