Skip to content

Commit

Permalink
Merge pull request #59 from BatsResearch/google-gemini-integration
Browse files Browse the repository at this point in the history
Integrating Google Gemini APIs
  • Loading branch information
dotpyu authored Dec 28, 2023
2 parents 40551fe + 3f64a7c commit 46376c5
Show file tree
Hide file tree
Showing 51 changed files with 565 additions and 102 deletions.
11 changes: 8 additions & 3 deletions alfred/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
"anthropic",
"cohere",
"ai21",
"google",
"torch",
"dummy",
], f"Invalid model type: {self.model_type}"
Expand Down Expand Up @@ -186,6 +187,10 @@ def __init__(
from ..fm.ai21 import AI21Model

self.model = AI21Model(self.model, **kwargs)
elif self.model_type == "google":
from ..fm.google import GoogleModel

self.model = GoogleModel(self.model, **kwargs)
elif self.model_type == "dummy":
from ..fm.dummy import DummyModel

Expand Down Expand Up @@ -422,12 +427,12 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any):
:param log_save_path: The file to save the chat logs.
:type log_save_path: Optional[str]
"""
if self.model_type in ["openai", "anthropic"]:
if self.model_type in ["openai", "anthropic", "google"]:
self.model.chat(log_save_path=log_save_path, **kwargs)
else:
logger.error(
"Chat APIs are only supported for Anthropic and OpenAI models."
"Chat APIs are only supported for Anthropic, Google Gemini and OpenAI models."
)
raise NotImplementedError(
"Currently Chat are only supported for Anthropic and OpenAI models."
"Currently Chat are only supported for Anthropic, Google Gemini and OpenAI models."
)
42 changes: 23 additions & 19 deletions alfred/fm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .model import APIAccessFoundationModel
from .response import CompletionResponse
from .utils import colorize_str
from .utils import colorize_str, type_print

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,12 +38,12 @@ class AnthropicModel(APIAccessFoundationModel):
"""

def _anthropic_query(
self,
query: Union[str, List],
temperature: float = 0.0,
max_tokens: int = 3,
model: str = "claude-instant-1",
**kwargs: Any,
self,
query: Union[str, List],
temperature: float = 0.0,
max_tokens: int = 3,
model: str = "claude-instant-1",
**kwargs: Any,
) -> str:
"""
Run a single query through the foundation model
Expand Down Expand Up @@ -85,7 +85,7 @@ def _anthropic_query(
return response["completion"]

def __init__(
self, model_string: str = "claude-instant-1", api_key: Optional[str] = None
self, model_string: str = "claude-instant-1", api_key: Optional[str] = None
):
"""
Initialize the Anthropic API wrapper.
Expand All @@ -100,7 +100,7 @@ def __init__(
:type api_key: Optional[str]
"""
assert (
model_string in ANTHROPIC_MODELS
model_string in ANTHROPIC_MODELS
), f"Model {model_string} not found. Please choose from {ANTHROPIC_MODELS}"

if "ANTHROPIC_API_KEY" in os.environ:
Expand All @@ -124,9 +124,9 @@ def __init__(
super().__init__(model_string, {"api_key": api_key})

def _generate_batch(
self,
batch_instance: List[str],
**kwargs,
self,
batch_instance: List[str],
**kwargs,
) -> List[CompletionResponse]:
"""
Generate completions for a batch of prompts using the anthropic API.
Expand Down Expand Up @@ -161,9 +161,13 @@ def _feedback(feedback: str, no_newline=False, override=False):
if override:
print("\r", end="")
print(
colorize_str("Chat AI: ", "GREEN") + feedback,
end="\n" if not no_newline else "",
colorize_str("Chat AI: ", "GREEN"),
end="",
)
type_print(feedback)
print("",
end="\n" if not no_newline else "",
)

model = kwargs.get("model", self.model_string)
c_title = colorize_str("Alfred's Anthropic Chat", "BLUE")
Expand Down Expand Up @@ -201,11 +205,11 @@ def _feedback(feedback: str, no_newline=False, override=False):
message_log.append({"role": "user", "content": query})
response = []
for resp in self._anthropic_query(
query,
chat=True,
model=model,
temperature=temperature,
max_tokens=max_tokens,
query,
chat=True,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
if resp["stop_reason"] in ["stop", "stop_sequence"]:
break
Expand Down
Loading

0 comments on commit 46376c5

Please sign in to comment.