Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VertexAI provider #11

Merged
merged 18 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Before getting started, make sure you have a development workspace where you hav

#### Prerequisites
* To use the OpenAI and Anthropic models, you must have an account with sufficient credits.
* To use the Vertex models, you must have [a Google Cloud Provider project](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal#expandable-1) with sufficient credits.

#### Create a Slack App
1. Open [https://api.slack.com/apps/new](https://api.slack.com/apps/new) and choose "From an app manifest"
Expand All @@ -36,9 +37,22 @@ Before you can run the app, you'll need to store some environment variables.
export SLACK_BOT_TOKEN=<your-bot-token>
export SLACK_APP_TOKEN=<your-app-token>
export OPENAI_API_KEY=<your-api-key>
export ANTHROPIC_API_KEY=<your-api-key>
export ANTHROPIC_API_KEY=<your-api-key>
```

##### Google Cloud Vertex AI Setup

To use Google Cloud Vertex AI, [follow this quick start](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal#expandable-1) to create a project for sending requests to the Gemini API, then gather [Application Default Credentials](https://cloud.google.com/docs/authentication/provide-credentials-adc) with the strategy to match your development environment.

Once your project and credentials are configured, export environment variables to select from Gemini models:

```zsh
export VERTEX_AI_PROJECT_ID=<your-project-id>
export VERTEX_AI_LOCATION=<location-to-deploy-model>
```

The project location can be located under the **Region** on the [Vertex AI](https://console.cloud.google.com/vertex-ai) dashboard, as well as more details about available Gemini models.

### Setup Your Local Project
```zsh
# Clone this project onto your machine
Expand Down Expand Up @@ -89,10 +103,10 @@ Every incoming request is routed to a "listener". Inside this directory, we grou

<a name="byo-llm"></a>
#### `ai/providers`
This module contains classes for communicating with different API providers, such as [Anthropic](https://www.anthropic.com/) and [OpenAI](https://openai.com/). To add your own LLM, create a new class for it using the `base_api.py` as an example, then update `get_available_apis.py` and `handle_response.py` to include and utilize your new class for API communication.
This module contains classes for communicating with different API providers, such as [Anthropic](https://www.anthropic.com/), [OpenAI](https://openai.com/), and [Vertex AI](cloud.google.com/vertex-ai). To add your own LLM, create a new class for it using the `base_api.py` as an example, then update `ai/providers/__init__.py` to include and utilize your new class for API communication.

* `__init__.py`:
This file contains utility functions for handling responses from the provider APIs and retreiving available providers.
This file contains utility functions for handling responses from the provider APIs and retrieving available providers.

### `/state_store`

Expand Down Expand Up @@ -128,4 +142,4 @@ Navigate to **OAuth & Permissions** in your app configuration and click **Add a

```
https://3cb89939.ngrok.io/slack/oauth_redirect
```
```
23 changes: 16 additions & 7 deletions ai/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Optional

from state_store.get_user_state import get_user_state

from ..ai_constants import DEFAULT_SYSTEM_CONTENT
from .anthropic import AnthropicAPI
from .openai import OpenAI_API
from ..ai_constants import DEFAULT_SYSTEM_CONTENT
from state_store.get_user_state import get_user_state
from typing import Optional, List
from .vertexai import VertexAPI

"""
New AI providers must be added below.
Expand All @@ -21,14 +24,20 @@


def get_available_providers():
return {**AnthropicAPI().get_models(), **OpenAI_API().get_models()}
return {
**AnthropicAPI().get_models(),
**OpenAI_API().get_models(),
**VertexAPI().get_models(),
}


def _get_provider(provider_name: str):
if provider_name.lower() == "openai":
return OpenAI_API()
elif provider_name.lower() == "anthropic":
if provider_name.lower() == "anthropic":
return AnthropicAPI()
elif provider_name.lower() == "openai":
return OpenAI_API()
elif provider_name.lower() == "vertexai":
return VertexAPI()
else:
raise ValueError(f"Unknown provider: {provider_name}")

Expand Down
11 changes: 3 additions & 8 deletions ai/providers/base_provider.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
# A base class for API providers, defining the interface and common properties for subclasses.


class BaseAPIProvider:
MODELS = {}

def __init__(self):
self.api_key: str

class BaseAPIProvider(object):
def set_model(self, model_name: str):
raise NotImplementedError("Subclass must implement set_model")

def get_models(self) -> dict:
raise NotImplementedError("Subclass must implement get_model")
raise NotImplementedError("Subclass must implement get_models")

def generate_response(prompt: str) -> str:
def generate_response(self, prompt: str, system_content: str) -> str:
raise NotImplementedError("Subclass must implement generate_response")
131 changes: 131 additions & 0 deletions ai/providers/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging
import os

import google.api_core.exceptions
import vertexai.generative_models

from .base_provider import BaseAPIProvider

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)


class VertexAPI(BaseAPIProvider):
VERTEX_AI_PROVIDER = "VertexAI"
MODELS = {
"gemini-1.5-flash-001": {
"name": "Gemini 1.5 Flash 001",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
"gemini-1.5-flash-002": {
"name": "Gemini 1.5 Flash 002",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
"gemini-1.5-pro-002": {
"name": "Gemini 1.5 Pro 002",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
"gemini-1.5-pro-001": {
"name": "Gemini 1.5 Pro 001",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
"gemini-1.0-pro-002": {
"name": "Gemini 1.0 Pro 002",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
"gemini-1.0-pro-001": {
"name": "Gemini 1.0 Pro 001",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a super nice approach and I'm glad it can be known within the models of this class 🙏

},
"gemini-flash-experimental": {
"name": "Gemini Flash Experimental",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
"gemini-pro-experimental": {
"name": "Gemini Pro Experimental",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
"gemini-experimental": {
"name": "Gemini Experimental",
"provider": VERTEX_AI_PROVIDER,
"max_tokens": 8192,
"system_instruction_supported": True,
},
}

def __init__(self):
self.enabled = bool(os.environ.get("VERTEX_AI_PROJECT_ID", ""))
if self.enabled:
vertexai.init(
project=os.environ.get("VERTEX_AI_PROJECT_ID"),
location=os.environ.get("VERTEX_AI_LOCATION"),
)

def set_model(self, model_name: str):
if model_name not in self.MODELS.keys():
raise ValueError("Invalid model")
self.current_model = model_name

def get_models(self) -> dict:
if self.enabled:
return self.MODELS
else:
return {}

def generate_response(self, prompt: str, system_content: str) -> str:
system_instruction = None
if self.MODELS[self.current_model]["system_instruction_supported"]:
system_instruction = system_content
else:
prompt = system_content + "\n" + prompt
Comment on lines +92 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated this to remove **kwargs since it wasn't so clear which arguments were being updated before generation and I was finding this error with pyright:

Argument of type "str" cannot be assigned to parameter "tool_config" of type "ToolConfig | None" in function "__init__"
Type "str" is not assignable to type "ToolConfig | None"
"str" is not assignable to "ToolConfig"
"str" is not assignable to "None"

I found that these changes continue to work fine for both cases, and I'm hoping the set variables before model setup helps make future updates clear 🔭


try:
self.client = vertexai.generative_models.GenerativeModel(
model_name=self.current_model,
generation_config={
"max_output_tokens": self.MODELS[self.current_model]["max_tokens"],
},
system_instruction=system_instruction,
)
response = self.client.generate_content(
contents=prompt,
)
return "".join(part.text for part in response.candidates[0].content.parts)

except google.api_core.exceptions.Unauthorized as e:
logger.error(f"Client is not Authorized. {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.Forbidden as e:
logger.error(f"Client Forbidden. {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.TooManyRequests as e:
logger.error(f"Too many requests. {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.ClientError as e:
logger.error(f"Client error: {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.ServerError as e:
logger.error(f"Server error: {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.GoogleAPICallError as e:
logger.error(f"Error: {e.reason}, {e.message}")
raise e
except google.api_core.exceptions.GoogleAPIError as e:
logger.error(f"Unknown error. {e}")
raise e
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ black==24.8.0
slack-cli-hooks==0.0.2
openai==1.47.1
anthropic==0.34.2
google-cloud-aiplatform==1.67.1