Skip to content

Commit

Permalink
Update to include the generative service. (#105)
Browse files Browse the repository at this point in the history
* Start code for generative service.
This is working with v1beta.

Change-Id: I96d2d4f773db8e4be621b40697016d1acbe24903

* Embedding content functionality for v1beta.

Change-Id: I1bf2379b33a607b0bb2cf77066073166a6fe9f95

* Add py3.9 support, Fix roles, UsageMetadata, and more.

Change-Id: I5fa91f191eb5eebb0fd160325b102bfb48ef8e27

* Add async suppoort. Fix some types. Add count_tokens.

Change-Id: Ic03d9caa9996843c0ac1438f52c8b10a08fd6563

* Docstrings

Change-Id: Ia2adca04a2f65bfd4ea40eec0cc18c4c98703b9f

* Add missing async type to export list.

Change-Id: I22e0aba1a59997fef6263105413b5626f1cc5a51

* Add async tests, kwargs, format.

Change-Id: Ia2c3260efe58f48183458e997ba560cc07f4b442

* docs

Change-Id: I45d2f405fb058f25fc99c30c79221f6461ebe945

* debug tests

* Add GenerationConfig at the top level

* test

* replace __init__.py

* remove -e

* drop notebook tests for now

* Update version.

* format + pytype

* Fix pytype.

* Fix tests + pytype

---------

Co-authored-by: Shilpa Kancharla <snkancharla@google.com>
Co-authored-by: Mark McDonald <macd@google.com>
  • Loading branch information
3 people authored Dec 12, 2023
1 parent 5b0b406 commit 0988543
Show file tree
Hide file tree
Showing 25 changed files with 3,008 additions and 142 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/test_pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
- name: Run tests
run: |
python --version
pip install -q -e .[dev]
python -m unittest discover --pattern '*test*.py'
pip install .[dev]
python -m unittest
test3_10:
name: Test Py3.10
runs-on: ubuntu-latest
Expand All @@ -36,8 +36,8 @@ jobs:
- name: Run tests
run: |
python --version
pip install -q -e .[dev]
python -m unittest discover --pattern '*test*.py'
pip install -q .[dev]
python -m unittest
test3_9:
name: Test Py3.9
runs-on: ubuntu-latest
Expand All @@ -49,8 +49,8 @@ jobs:
- name: Run tests
run: |
python --version
pip install -q -e .[dev]
python -m unittest discover --pattern '*test*.py'
pip install .[dev]
python -m unittest
pytype3_10:
name: pytype 3.10
runs-on: ubuntu-latest
Expand All @@ -62,7 +62,7 @@ jobs:
- name: Run pytype
run: |
python --version
pip install -q -e .[dev]
pip install .[dev]
pip install -q gspread ipython
pytype
format:
Expand All @@ -76,7 +76,7 @@ jobs:
- name: Check format
run: |
python --version
pip install -q -e .
pip install -q .
pip install -q black
black . --check
16 changes: 12 additions & 4 deletions google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
Use the `palm.chat` function to have a discussion with a model:
```
response = palm.chat(messages=["Hello."])
print(response.last) # 'Hello! What can I help you with?'
response.reply("Can you tell me a joke?")
chat = palm.chat(messages=["Hello."])
print(chat.last) # 'Hello! What can I help you with?'
chat = chat.reply("Can you tell me a joke?")
print(chat.last) # 'Why did the chicken cross the road?'
```
## Models
Expand All @@ -68,13 +69,20 @@
"""
from __future__ import annotations

from google.generativeai import types
from google.generativeai import version

from google.generativeai import types
from google.generativeai.types import GenerationConfig


from google.generativeai.discuss import chat
from google.generativeai.discuss import chat_async
from google.generativeai.discuss import count_message_tokens

from google.generativeai.embedding import embed_content

from google.generativeai.generative_models import GenerativeModel

from google.generativeai.text import generate_text
from google.generativeai.text import generate_embeddings
from google.generativeai.text import count_text_tokens
Expand Down
112 changes: 54 additions & 58 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
Expand All @@ -27,7 +13,12 @@
from google.api_core import gapic_v1
from google.api_core import operations_v1

from google.generativeai import version
try:
from google.generativeai import version

__version__ = version.__version__
except ImportError:
__version__ = "0.0.0"

USER_AGENT = "genai-py"

Expand All @@ -36,11 +27,10 @@
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
default_metadata: Sequence[tuple[str, str]] = ()

discuss_client: glm.DiscussServiceClient | None = None
discuss_async_client: glm.DiscussServiceAsyncClient | None = None
model_client: glm.ModelServiceClient | None = None
text_client: glm.TextServiceClient | None = None
operations_client = None
clients: dict[str, Any] = dataclasses.field(default_factory=dict)

def configure(
self,
Expand All @@ -54,7 +44,7 @@ def configure(
# We could accept a dict since all the `Transport` classes take the same args,
# but that seems rare. Users that need it can just switch to the low level API.
transport: str | None = None,
client_options: client_options_lib.ClientOptions | dict | None = None,
client_options: client_options_lib.ClientOptions | dict[str, Any] | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
default_metadata: Sequence[tuple[str, str]] = (),
) -> None:
Expand Down Expand Up @@ -93,7 +83,7 @@ def configure(

client_options.api_key = api_key

user_agent = f"{USER_AGENT}/{version.__version__}"
user_agent = f"{USER_AGENT}/{__version__}"
if client_info:
# Be respectful of any existing agent setting.
if client_info.user_agent:
Expand All @@ -114,12 +104,16 @@ def configure(

self.client_config = client_config
self.default_metadata = default_metadata
self.discuss_client = None
self.text_client = None
self.model_client = None
self.operations_client = None

def make_client(self, cls):
self.clients = {}

def make_client(self, name):
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")

# Attempt to configure using defaults.
if not self.client_config:
configure()
Expand Down Expand Up @@ -157,35 +151,25 @@ def call(*args, metadata=(), **kwargs):

return client

def get_default_discuss_client(self) -> glm.DiscussServiceClient:
if self.discuss_client is None:
self.discuss_client = self.make_client(glm.DiscussServiceClient)
return self.discuss_client

def get_default_text_client(self) -> glm.TextServiceClient:
if self.text_client is None:
self.text_client = self.make_client(glm.TextServiceClient)
return self.text_client

def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient:
if self.discuss_async_client is None:
self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient)
return self.discuss_async_client
def get_default_client(self, name):
name = name.lower()
if name == "operations":
return self.get_default_operations_client()

def get_default_model_client(self) -> glm.ModelServiceClient:
if self.model_client is None:
self.model_client = self.make_client(glm.ModelServiceClient)
return self.model_client
client = self.clients.get(name)
if client is None:
client = self.make_client(name)
self.clients[name] = client
return client

def get_default_operations_client(self) -> operations_v1.OperationsClient:
if self.operations_client is None:
self.model_client = get_default_model_client()
self.operations_client = self.model_client._transport.operations_client

return self.operations_client

client = self.clients.get("operations", None)
if client is None:
model_client = self.get_default_client("Model")
client = model_client._transport.operations_client
self.clients["operations"] = client

_client_manager = _ClientManager()
return client


def configure(
Expand Down Expand Up @@ -230,21 +214,33 @@ def configure(
)


_client_manager = _ClientManager()
_client_manager.configure()


def get_default_discuss_client() -> glm.DiscussServiceClient:
return _client_manager.get_default_discuss_client()
return _client_manager.get_default_client("discuss")


def get_default_text_client() -> glm.TextServiceClient:
return _client_manager.get_default_text_client()
def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
return _client_manager.get_default_client("discuss_async")


def get_default_operations_client() -> operations_v1.OperationsClient:
return _client_manager.get_default_operations_client()
def get_default_generative_client() -> glm.GenerativeServiceClient:
return _client_manager.get_default_client("generative")


def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
return _client_manager.get_default_discuss_async_client()
def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient:
return _client_manager.get_default_client("generative_async")


def get_default_text_client() -> glm.TextServiceClient:
return _client_manager.get_default_client("text")


def get_default_operations_client() -> operations_v1.OperationsClient:
return _client_manager.get_default_client("operations")


def get_default_model_client() -> glm.ModelServiceAsyncClient:
return _client_manager.get_default_model_client()
return _client_manager.get_default_client("model")
20 changes: 5 additions & 15 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,6 @@ def _make_generate_message_request(
)


def set_doc(doc):
"""A decorator to set the docstring of a function."""

def inner(f):
f.__doc__ = doc
return f

return inner


DEFAULT_DISCUSS_MODEL = "models/chat-bison-001"


Expand Down Expand Up @@ -411,7 +401,7 @@ def chat(
return _generate_response(client=client, request=request)


@set_doc(chat.__doc__)
@string_utils.set_doc(chat.__doc__)
async def chat_async(
*,
model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
Expand Down Expand Up @@ -447,7 +437,7 @@ async def chat_async(


@string_utils.prettyprint
@set_doc(discuss_types.ChatResponse.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.__doc__)
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
class ChatResponse(discuss_types.ChatResponse):
_client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)
Expand All @@ -457,7 +447,7 @@ def __init__(self, **kwargs):
setattr(self, key, value)

@property
@set_doc(discuss_types.ChatResponse.last.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.last.__doc__)
def last(self) -> str | None:
if self.messages[-1]:
return self.messages[-1]["content"]
Expand All @@ -470,7 +460,7 @@ def last(self, message: discuss_types.MessageOptions):
message = type(message).to_dict(message)
self.messages[-1] = message

@set_doc(discuss_types.ChatResponse.reply.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResponse:
if isinstance(self._client, glm.DiscussServiceAsyncClient):
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
Expand All @@ -489,7 +479,7 @@ def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResp
request = _make_generate_message_request(**request)
return _generate_response(request=request, client=self._client)

@set_doc(discuss_types.ChatResponse.reply.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
async def reply_async(
self, message: discuss_types.MessageOptions
) -> discuss_types.ChatResponse:
Expand Down
Loading

1 comment on commit 0988543

@armandRobled
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.