Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: replace aiohttp.ClientSession with AlloyDBAdminAsyncClient #416

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.cloud.alloydb.connector.instance import RefreshAheadCache
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
from google.cloud.alloydb.connector.utils import generate_keys
import traceback

if TYPE_CHECKING:
from google.auth.credentials import Credentials
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
alloydb_api_endpoint: str = "alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
Expand Down Expand Up @@ -181,6 +182,7 @@ async def connect(
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
print(f"RISHABH DEBUG: exception = {traceback.print_exc()}")
# with an error from AlloyDB API call or IP type, invalidate the
# cache and re-raise the error
await self._remove_cached(instance_uri)
Expand Down
102 changes: 35 additions & 67 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
import logging
from typing import Optional, TYPE_CHECKING

import aiohttp
from cryptography import x509
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud import alloydb_v1beta
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
from google.cloud.alloydb.connector.version import __version__ as version
from google.protobuf import duration_pb2

if TYPE_CHECKING:
from google.auth.credentials import Credentials
Expand Down Expand Up @@ -55,7 +58,7 @@ def __init__(
alloydb_api_endpoint: str,
quota_project: Optional[str],
credentials: Credentials,
client: Optional[aiohttp.ClientSession] = None,
client: Optional[alloydb_v1beta.AlloyDBAdminAsyncClient] = None,
driver: Optional[str] = None,
user_agent: Optional[str] = None,
) -> None:
Expand All @@ -72,27 +75,27 @@ def __init__(
A credentials object created from the google-auth Python library.
Must have the AlloyDB Admin scopes. For more info check out
https://google-auth.readthedocs.io/en/latest/.
client (aiohttp.ClientSession): Async client used to make requests to
AlloyDB APIs.
client (alloydb_v1beta.AlloyDBAdminAsyncClient): Async client used to
make requests to AlloyDB APIs.
Optional, defaults to None and creates new client.
driver (str): Database driver to be used by the client.
"""
user_agent = _format_user_agent(driver, user_agent)
headers = {
"x-goog-api-client": user_agent,
"User-Agent": user_agent,
"Content-Type": "application/json",
}
if quota_project:
headers["x-goog-user-project"] = quota_project

self._client = client if client else aiohttp.ClientSession(headers=headers)
self._client = client if client else alloydb_v1beta.AlloyDBAdminAsyncClient(
credentials=credentials,
client_options=ClientOptions(
api_endpoint=alloydb_api_endpoint,
quota_project_id=quota_project,
),
client_info=ClientInfo(
user_agent=user_agent,
),
)
self._credentials = credentials
self._alloydb_api_endpoint = alloydb_api_endpoint
# asyncpg does not currently support using metadata exchange
# only use metadata exchange for pg8000 driver
self._use_metadata = True if driver == "pg8000" else False
self._user_agent = user_agent

async def _get_metadata(
self,
Expand All @@ -118,35 +121,19 @@ async def _get_metadata(
Returns:
dict: IP addresses of the AlloyDB instance.
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}
parent = f"projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}"

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}/connectionInfo"

resp = await self._client.get(url, headers=headers)
# try to get response json for better error message
try:
resp_dict = await resp.json()
if resp.status >= 400:
# if detailed error message is in json response, use as error message
message = resp_dict.get("error", {}).get("message")
if message:
resp.reason = message
# skip, raise_for_status will catch all errors in finally block
except Exception:
pass
finally:
resp.raise_for_status()
req = alloydb_v1beta.GetConnectionInfoRequest(parent=parent)
resp = await self._client.get_connection_info(request=req)

# Remove trailing period from PSC DNS name.
psc_dns = resp_dict.get("pscDnsName")
psc_dns = resp.psc_dns_name
if psc_dns:
psc_dns = psc_dns.rstrip(".")

return {
"PRIVATE": resp_dict.get("ipAddress"),
"PUBLIC": resp_dict.get("publicIpAddress"),
"PRIVATE": resp.ip_address,
"PUBLIC": resp.public_ip_address,
"PSC": psc_dns,
}

Expand Down Expand Up @@ -175,34 +162,17 @@ async def _get_client_certificate(
tuple[str, list[str]]: tuple containing the CA certificate
and certificate chain for the AlloyDB instance.
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"

data = {
"publicKey": pub_key,
"certDuration": "3600s",
"useMetadataExchange": self._use_metadata,
}

resp = await self._client.post(url, headers=headers, json=data)
# try to get response json for better error message
try:
resp_dict = await resp.json()
if resp.status >= 400:
# if detailed error message is in json response, use as error message
message = resp_dict.get("error", {}).get("message")
if message:
resp.reason = message
# skip, raise_for_status will catch all errors in finally block
except Exception:
pass
finally:
resp.raise_for_status()

return (resp_dict["caCert"], resp_dict["pemCertificateChain"])
parent = f"projects/{project}/locations/{region}/clusters/{cluster}"
dur = duration_pb2.Duration()
dur.seconds = 3600
req = alloydb_v1beta.GenerateClientCertificateRequest(
parent=parent,
cert_duration=dur,
public_key=pub_key,
use_metadata_exchange=self._use_metadata,
)
resp = await self._client.generate_client_certificate(request=req)
return (resp.ca_cert, resp.pem_certificate_chain)

async def get_connection_info(
self,
Expand Down Expand Up @@ -270,6 +240,4 @@ async def get_connection_info(

async def close(self) -> None:
"""Close AlloyDBClient gracefully."""
logger.debug("Waiting for connector's http client to close")
await self._client.close()
logger.debug("Closed connector's http client")
logger.debug("Closed AlloyDBClient")
4 changes: 2 additions & 2 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Connector:
billing purposes.
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
Expand All @@ -77,7 +77,7 @@ def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
alloydb_api_endpoint: str = "alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def cover(session):
def default(session, path):
# Install all test dependencies, then install this package in-place.
session.install("-r", "requirements-test.txt")
session.install("-e", ".")
session.install(".")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting the following error if using session.install("-e", "."):

ImportError while loading conftest '/usr/local/google/home/rhatgadkar/alloydb-python-connector/tests/unit/conftest.py'.
tests/unit/conftest.py:22: in <module>
    from mocks import FakeAlloyDBClient
tests/unit/mocks.py:33: in <module>
    from google.cloud.alloydb.connector.connection_info import ConnectionInfo
E   ModuleNotFoundError: No module named 'google.cloud.alloydb.connector'

Maybe there’s a conflict with the google-cloud-alloydb-connector package when installing the google-cloud-alloydb package. But when changing to session.install("."), this error doesn’t occur anymore.

session.install("-r", "requirements.txt")
# Run pytest with coverage.
session.run(
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"requests",
"google-auth",
"protobuf",
"google-cloud-alloydb",
"google-api-core",
]

package_root = os.path.abspath(os.path.dirname(__file__))
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud import alloydb_v1beta
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

Expand Down Expand Up @@ -378,3 +379,34 @@ async def force_refresh(self) -> None:

async def close(self) -> None:
self._close_called = True


class FakeAlloyDBAdminAsyncClient:
async def get_connection_info(self, request: alloydb_v1beta.GetConnectionInfoRequest) -> alloydb_v1beta.types.resources.ConnectionInfo:
ci = alloydb_v1beta.types.resources.ConnectionInfo()
ci.ip_address = "10.0.0.1"
ci.public_ip_address = "127.0.0.1"
ci.instance_uid = "123456789"
ci.psc_dns_name = "x.y.alloydb.goog"

parent = request.parent
instance = parent.split("/")[-1]
if instance == "test-instance":
ci.public_ip_address = ""
ci.psc_dns_name = ""
return ci
elif instance == "public-instance":
ci.psc_dns_name = ""
return ci
else:
ci.ip_address = ""
ci.public_ip_address = ""
return ci

async def generate_client_certificate(self, request: alloydb_v1beta.GenerateClientCertificateRequest) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
ccr.ca_cert = "This is the CA cert"
ccr.pem_certificate_chain.append("This is the client cert")
ccr.pem_certificate_chain.append("This is the intermediate cert")
ccr.pem_certificate_chain.append("This is the root cert")
return ccr
6 changes: 3 additions & 3 deletions tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
import asyncio
from typing import Union

from aiohttp import ClientResponseError
from mock import patch
from mocks import FakeAlloyDBClient
from mocks import FakeConnectionInfo
from mocks import FakeCredentials
import pytest

from google.api_core.exceptions import RetryError
from google.cloud.alloydb.connector import AsyncConnector
from google.cloud.alloydb.connector import IPTypes
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
from google.cloud.alloydb.connector.instance import RefreshAheadCache

ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com"
ALLOYDB_API_ENDPOINT = "alloydb.googleapis.com"


@pytest.mark.asyncio
Expand Down Expand Up @@ -309,7 +309,7 @@ async def test_Connector_remove_cached_bad_instance(
"""
instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance"
async with AsyncConnector(credentials=credentials) as connector:
with pytest.raises(ClientResponseError):
with pytest.raises(RetryError):
await connector.connect(instance_uri, "asyncpg")
assert instance_uri not in connector._cache

Expand Down
Loading
Loading