Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: replace aiohttp.ClientSession with AlloyDBAdminAsyncClient #416

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.cloud.alloydb.connector.instance import RefreshAheadCache
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
from google.cloud.alloydb.connector.utils import generate_keys
import traceback

if TYPE_CHECKING:
from google.auth.credentials import Credentials
Expand Down Expand Up @@ -181,6 +182,7 @@ async def connect(
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
print(f"RISHABH DEBUG: exception = {traceback.print_exc()}")
# with an error from AlloyDB API call or IP type, invalidate the
# cache and re-raise the error
await self._remove_cached(instance_uri)
Expand Down
128 changes: 64 additions & 64 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
import logging
from typing import Optional, TYPE_CHECKING

import aiohttp
from cryptography import x509
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud import alloydb_v1beta
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
from google.cloud.alloydb.connector.version import __version__ as version
from google.protobuf import duration_pb2

if TYPE_CHECKING:
from google.auth.credentials import Credentials
Expand Down Expand Up @@ -55,7 +58,7 @@ def __init__(
alloydb_api_endpoint: str,
quota_project: Optional[str],
credentials: Credentials,
client: Optional[aiohttp.ClientSession] = None,
client: Optional[alloydb_v1beta.AlloyDBAdminAsyncClient] = None,
driver: Optional[str] = None,
user_agent: Optional[str] = None,
) -> None:
Expand All @@ -72,21 +75,23 @@ def __init__(
A credentials object created from the google-auth Python library.
Must have the AlloyDB Admin scopes. For more info check out
https://google-auth.readthedocs.io/en/latest/.
client (aiohttp.ClientSession): Async client used to make requests to
AlloyDB APIs.
client (alloydb_v1.AlloyDBAdminAsyncClient): Async client used to
make requests to AlloyDB APIs.
Optional, defaults to None and creates new client.
driver (str): Database driver to be used by the client.
"""
user_agent = _format_user_agent(driver, user_agent)
headers = {
"x-goog-api-client": user_agent,
"User-Agent": user_agent,
"Content-Type": "application/json",
}
if quota_project:
headers["x-goog-user-project"] = quota_project

self._client = client if client else aiohttp.ClientSession(headers=headers)
self._client = client if client else alloydb_v1beta.AlloyDBAdminAsyncClient(
credentials=credentials,
client_options=ClientOptions(
api_endpoint=alloydb_api_endpoint,
quota_project_id=quota_project,
),
client_info=ClientInfo(
user_agent=user_agent,
),
)
self._credentials = credentials
self._alloydb_api_endpoint = alloydb_api_endpoint
# asyncpg does not currently support using metadata exchange
Expand Down Expand Up @@ -118,35 +123,33 @@ async def _get_metadata(
Returns:
dict: IP addresses of the AlloyDB instance.
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}/connectionInfo"

resp = await self._client.get(url, headers=headers)
# try to get response json for better error message
try:
resp_dict = await resp.json()
if resp.status >= 400:
# if detailed error message is in json response, use as error message
message = resp_dict.get("error", {}).get("message")
if message:
resp.reason = message
# skip, raise_for_status will catch all errors in finally block
except Exception:
pass
finally:
resp.raise_for_status()
parent = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}"

req = alloydb_v1beta.GetConnectionInfoRequest(parent=parent)
resp = await self._client.get_connection_info(request=req)
resp = await resp
# # try to get response json for better error message
# try:
# resp_dict = await resp.json()
# if resp.status >= 400:
# # if detailed error message is in json response, use as error message
# message = resp_dict.get("error", {}).get("message")
# if message:
# resp.reason = message
# # skip, raise_for_status will catch all errors in finally block
# except Exception:
# pass
# finally:
# resp.raise_for_status()

# Remove trailing period from PSC DNS name.
psc_dns = resp_dict.get("pscDnsName")
psc_dns = resp.psc_dns_name
if psc_dns:
psc_dns = psc_dns.rstrip(".")

return {
"PRIVATE": resp_dict.get("ipAddress"),
"PUBLIC": resp_dict.get("publicIpAddress"),
"PRIVATE": resp.ip_address,
"PUBLIC": resp.public_ip_address,
"PSC": psc_dns,
}

Expand Down Expand Up @@ -175,34 +178,32 @@ async def _get_client_certificate(
tuple[str, list[str]]: tuple containing the CA certificate
and certificate chain for the AlloyDB instance.
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"

data = {
"publicKey": pub_key,
"certDuration": "3600s",
"useMetadataExchange": self._use_metadata,
}

resp = await self._client.post(url, headers=headers, json=data)
# try to get response json for better error message
try:
resp_dict = await resp.json()
if resp.status >= 400:
# if detailed error message is in json response, use as error message
message = resp_dict.get("error", {}).get("message")
if message:
resp.reason = message
# skip, raise_for_status will catch all errors in finally block
except Exception:
pass
finally:
resp.raise_for_status()

return (resp_dict["caCert"], resp_dict["pemCertificateChain"])
parent = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}"
dur = duration_pb2.Duration()
dur.seconds = 3600
req = alloydb_v1beta.GenerateClientCertificateRequest(
parent=parent,
cert_duration=dur,
public_key=pub_key,
use_metadata_exchange=self._use_metadata,
)
resp = await self._client.generate_client_certificate(request=req)
resp = await resp
# # try to get response json for better error message
# try:
# resp_dict = await resp.json()
# if resp.status >= 400:
# # if detailed error message is in json response, use as error message
# message = resp_dict.get("error", {}).get("message")
# if message:
# resp.reason = message
# # skip, raise_for_status will catch all errors in finally block
# except Exception:
# pass
# finally:
# resp.raise_for_status()

return (resp.ca_cert, resp.pem_certificate_chain)

async def get_connection_info(
self,
Expand Down Expand Up @@ -271,5 +272,4 @@ async def get_connection_info(
async def close(self) -> None:
"""Close AlloyDBClient gracefully."""
logger.debug("Waiting for connector's http client to close")
await self._client.close()
logger.debug("Closed connector's http client")
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def cover(session):
def default(session, path):
# Install all test dependencies, then install this package in-place.
session.install("-r", "requirements-test.txt")
session.install("-e", ".")
session.install(".")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting the following error if using session.install("-e", "."):

ImportError while loading conftest '/usr/local/google/home/rhatgadkar/alloydb-python-connector/tests/unit/conftest.py'.
tests/unit/conftest.py:22: in <module>
    from mocks import FakeAlloyDBClient
tests/unit/mocks.py:33: in <module>
    from google.cloud.alloydb.connector.connection_info import ConnectionInfo
E   ModuleNotFoundError: No module named 'google.cloud.alloydb.connector'

Maybe there’s a conflict with the google-cloud-alloydb-connector package when installing the google-cloud-alloydb package. But when changing to session.install("."), this error doesn’t occur anymore.

session.install("-r", "requirements.txt")
# Run pytest with coverage.
session.run(
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"requests",
"google-auth",
"protobuf",
"google-cloud-alloydb",
"google-api-core",
]

package_root = os.path.abspath(os.path.dirname(__file__))
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import asyncio
from typing import Union

from aiohttp import ClientResponseError
from mock import patch
from mocks import FakeAlloyDBClient
from mocks import FakeConnectionInfo
from mocks import FakeCredentials
import pytest

from google.api_core.exceptions import RetryError
from google.cloud.alloydb.connector import AsyncConnector
from google.cloud.alloydb.connector import IPTypes
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
Expand Down Expand Up @@ -309,7 +309,7 @@ async def test_Connector_remove_cached_bad_instance(
"""
instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance"
async with AsyncConnector(credentials=credentials) as connector:
with pytest.raises(ClientResponseError):
with pytest.raises(RetryError):
await connector.connect(instance_uri, "asyncpg")
assert instance_uri not in connector._cache

Expand Down
Loading
Loading