Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VertexAI provider #11

Merged
merged 18 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ Before you can run the app, you'll need to store some environment variables.
export SLACK_BOT_TOKEN=<your-bot-token>
export SLACK_APP_TOKEN=<your-app-token>
export OPENAI_API_KEY=<your-api-key>
export ANTHROPIC_API_KEY=<your-api-key>
export ANTHROPIC_API_KEY=<your-api-key>
# For vertex, follow the quickstart to set up a project to run models: https://cloud.google.com/python/docs/reference/aiplatform/latest/index.html#quick-start
# If this is deployed into google cloud (app engine, cloud run etc) then this step is not needed.
export VERTEX_AI_ENABLED=true
gcloud auth application-default login
calvingiles marked this conversation as resolved.
Show resolved Hide resolved
```

### Setup Your Local Project
Expand Down Expand Up @@ -128,4 +132,4 @@ Navigate to **OAuth & Permissions** in your app configuration and click **Add a

```
https://3cb89939.ngrok.io/slack/oauth_redirect
```
```
9 changes: 8 additions & 1 deletion ai/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .anthropic import AnthropicAPI
from .openai import OpenAI_API
from .vertexai import VertexAPI
from ..ai_constants import DEFAULT_SYSTEM_CONTENT
from state_store.get_user_state import get_user_state
from typing import Optional, List
Expand All @@ -21,14 +22,20 @@


def get_available_providers():
return {**AnthropicAPI().get_models(), **OpenAI_API().get_models()}
return {
**AnthropicAPI().get_models(),
**OpenAI_API().get_models(),
**VertexAPI().get_models(),
}


def _get_provider(provider_name: str):
if provider_name.lower() == "openai":
return OpenAI_API()
elif provider_name.lower() == "anthropic":
return AnthropicAPI()
elif provider_name.lower() == "vertexai":
return VertexAPI()
else:
raise ValueError(f"Unknown provider: {provider_name}")

Expand Down
9 changes: 2 additions & 7 deletions ai/providers/base_provider.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
# A base class for API providers, defining the interface and common properties for subclasses.


class BaseAPIProvider:
MODELS = {}

def __init__(self):
self.api_key: str

class BaseAPIProvider(object):
def set_model(self, model_name: str):
raise NotImplementedError("Subclass must implement set_model")

def get_models(self) -> dict:
raise NotImplementedError("Subclass must implement get_model")

def generate_response(prompt: str) -> str:
def generate_response(self, prompt: str, system_content: str) -> str:
raise NotImplementedError("Subclass must implement generate_response")
111 changes: 111 additions & 0 deletions ai/providers/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import google.api_core.exceptions

from .base_provider import BaseAPIProvider
import os
import logging
import vertexai.generative_models

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)


class VertexAPI(BaseAPIProvider):
VERTEX_AI_PROVIDER = "VertexAI"
MODELS = {
"gemini-1.5-flash-001": {
"name": "Gemini 1.5 Flash 001",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
"gemini-1.5-flash-002": {
"name": "Gemini 1.5 Flash 002",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
"gemini-1.5-pro-002": {
"name": "Gemini 1.5 Pro 002",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
"gemini-1.5-pro-001": {
"name": "Gemini 1.5 Pro 001",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
"gemini-1.0-pro-002": {
"name": "Gemini 1.0 Pro 002",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
"gemini-1.0-pro-001": {
"name": "Gemini 1.0 Pro 001",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
calvingiles marked this conversation as resolved.
Show resolved Hide resolved
"gemini-flash-experimental": {
"name": "Gemini Flash Experimental",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
"gemini-pro-experimental": {
"name": "Gemini Pro Experimental",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
"gemini-experimental": {
"name": "Gemini Pro Experimental",
calvingiles marked this conversation as resolved.
Show resolved Hide resolved
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
},
}

def __init__(self):
self.enabled = os.environ.get("VERTEX_AI_ENABLED", "").lower() in ["1", "true", "t", "yes", "y"]
calvingiles marked this conversation as resolved.
Show resolved Hide resolved
vertexai.init(project=os.environ.get("GCP_PROJECT"), location=os.environ.get("GCP_LOCATION"))

def set_model(self, model_name: str):
if model_name not in self.MODELS.keys():
raise ValueError("Invalid model")
self.current_model = model_name

def get_models(self) -> dict:
if self.enabled:
return self.MODELS
else:
return {}

def generate_response(self, prompt: str, system_content: str) -> str:
try:
self.client = vertexai.generative_models.GenerativeModel(
model_name=self.current_model,
system_instruction=system_content,
generation_config={
"max_output_tokens": self.MODELS[self.current_model]["max_tokens"],
},
)
response = self.client.generate_content(
contents=prompt,
)
return "".join(part.text for part in response.candidates[0].content.parts)

except google.api_core.exceptions.Unauthorized as e:
logger.error(f"Client is not Authorized. {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.Forbidden as e:
logger.error(f"Client Forbidden. {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.TooManyRequests as e:
logger.error(f"A 429 status code was received. {e.reason}, {e.message}")
calvingiles marked this conversation as resolved.
Show resolved Hide resolved
raise e
except google.api_core.exceptions.ClientError as e:
logger.error(f"Client error: {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.ServerError as e:
logger.error(f"Server error: {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.GoogleAPICallError as e:
logger.error(f"Error: {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.GoogleAPIError as e:
logger.error(f"Unknown error.")
calvingiles marked this conversation as resolved.
Show resolved Hide resolved
raise e
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ black==24.8.0
slack-cli-hooks==0.0.2
openai==1.37.1
anthropic==0.32.0
google-cloud-aiplatform==1.67.1