Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rhatgadkar-goog committed Jan 29, 2025
1 parent 51638a2 commit 44d084b
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 147 deletions.
2 changes: 1 addition & 1 deletion google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
alloydb_api_endpoint: str = "alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
Expand Down
36 changes: 3 additions & 33 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
A credentials object created from the google-auth Python library.
Must have the AlloyDB Admin scopes. For more info check out
https://google-auth.readthedocs.io/en/latest/.
client (alloydb_v1.AlloyDBAdminAsyncClient): Async client used to
client (alloydb_v1beta.AlloyDBAdminAsyncClient): Async client used to
make requests to AlloyDB APIs.
Optional, defaults to None and creates new client.
driver (str): Database driver to be used by the client.
Expand All @@ -85,19 +85,17 @@ def __init__(
self._client = client if client else alloydb_v1beta.AlloyDBAdminAsyncClient(
credentials=credentials,
client_options=ClientOptions(
api_endpoint="alloydb.googleapis.com",
api_endpoint=alloydb_api_endpoint,
quota_project_id=quota_project,
),
client_info=ClientInfo(
user_agent=user_agent,
),
)
self._credentials = credentials
self._alloydb_api_endpoint = alloydb_api_endpoint
# asyncpg does not currently support using metadata exchange
# only use metadata exchange for pg8000 driver
self._use_metadata = True if driver == "pg8000" else False
self._user_agent = user_agent

async def _get_metadata(
self,
Expand Down Expand Up @@ -127,19 +125,6 @@ async def _get_metadata(

req = alloydb_v1beta.GetConnectionInfoRequest(parent=parent)
resp = await self._client.get_connection_info(request=req)
# # try to get response json for better error message
# try:
# resp_dict = await resp.json()
# if resp.status >= 400:
# # if detailed error message is in json response, use as error message
# message = resp_dict.get("error", {}).get("message")
# if message:
# resp.reason = message
# # skip, raise_for_status will catch all errors in finally block
# except Exception:
# pass
# finally:
# resp.raise_for_status()

# Remove trailing period from PSC DNS name.
psc_dns = resp.psc_dns_name
Expand Down Expand Up @@ -187,20 +172,6 @@ async def _get_client_certificate(
use_metadata_exchange=self._use_metadata,
)
resp = await self._client.generate_client_certificate(request=req)
# # try to get response json for better error message
# try:
# resp_dict = await resp.json()
# if resp.status >= 400:
# # if detailed error message is in json response, use as error message
# message = resp_dict.get("error", {}).get("message")
# if message:
# resp.reason = message
# # skip, raise_for_status will catch all errors in finally block
# except Exception:
# pass
# finally:
# resp.raise_for_status()

return (resp.ca_cert, resp.pem_certificate_chain)

async def get_connection_info(
Expand Down Expand Up @@ -269,5 +240,4 @@ async def get_connection_info(

async def close(self) -> None:
"""Close AlloyDBClient gracefully."""
logger.debug("Waiting for connector's http client to close")
logger.debug("Closed connector's http client")
logger.debug("Closed AlloyDBClient")
4 changes: 2 additions & 2 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Connector:
billing purposes.
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
Expand All @@ -77,7 +77,7 @@ def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
alloydb_api_endpoint: str = "alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud import alloydb_v1beta
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

Expand Down Expand Up @@ -378,3 +379,34 @@ async def force_refresh(self) -> None:

async def close(self) -> None:
self._close_called = True


class FakeAlloyDBAdminAsyncClient:
async def get_connection_info(self, request: alloydb_v1beta.GetConnectionInfoRequest) -> alloydb_v1beta.types.resources.ConnectionInfo:
ci = alloydb_v1beta.types.resources.ConnectionInfo()
ci.ip_address = "10.0.0.1"
ci.public_ip_address = "127.0.0.1"
ci.instance_uid = "123456789"
ci.psc_dns_name = "x.y.alloydb.goog"

parent = request.parent
instance = parent.split("/")[-1]
if instance == "test-instance":
ci.public_ip_address = ""
ci.psc_dns_name = ""
return ci
elif instance == "public-instance":
ci.psc_dns_name = ""
return ci
else:
ci.ip_address = ""
ci.public_ip_address = ""
return ci

async def generate_client_certificate(self, request: alloydb_v1beta.GenerateClientCertificateRequest) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
ccr.ca_cert = "This is the CA cert"
ccr.pem_certificate_chain.append("This is the client cert")
ccr.pem_certificate_chain.append("This is the intermediate cert")
ccr.pem_certificate_chain.append("This is the root cert")
return ccr
2 changes: 1 addition & 1 deletion tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
from google.cloud.alloydb.connector.instance import RefreshAheadCache

ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com"
ALLOYDB_API_ENDPOINT = "alloydb.googleapis.com"


@pytest.mark.asyncio
Expand Down
131 changes: 24 additions & 107 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,73 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Any, Optional

from aiohttp import ClientResponseError
from aiohttp import web
from aioresponses import aioresponses
from mocks import FakeCredentials
from mocks import FakeAlloyDBAdminAsyncClient, FakeCredentials
import pytest

from google.api_core.exceptions import RetryError
from google.cloud import alloydb_v1beta
from google.cloud.alloydb.connector.client import AlloyDBClient
from google.cloud.alloydb.connector.utils import generate_keys
from google.cloud.alloydb.connector.version import __version__ as version


async def connectionInfo(request: Any) -> alloydb_v1beta.types.resources.ConnectionInfo:
ci = alloydb_v1beta.types.resources.ConnectionInfo()
ci.ip_address = "10.0.0.1"
ci.instance_uid = "123456789"
return ci


async def connectionInfoPublicIP(request: Any) -> alloydb_v1beta.types.resources.ConnectionInfo:
ci = alloydb_v1beta.types.resources.ConnectionInfo()
ci.ip_address = "10.0.0.1"
ci.public_ip_address = "127.0.0.1"
ci.instance_uid = "123456789"
return ci


async def connectionInfoPsc(request: Any) -> alloydb_v1beta.types.resources.ConnectionInfo:
ci = alloydb_v1beta.types.resources.ConnectionInfo()
ci.psc_dns_name = "x.y.alloydb.goog"
ci.instance_uid = "123456789"
return ci


async def generateClientCertificate(request: Any) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
ccr.ca_cert = "This is the CA cert"
ccr.pem_certificate_chain.append("This is the client cert")
ccr.pem_certificate_chain.append("This is the intermediate cert")
ccr.pem_certificate_chain.append("This is the root cert")
return ccr


class MockAlloyDBAdminAsyncClient:
async def get_connection_info(self, request: alloydb_v1beta.GetConnectionInfoRequest) -> alloydb_v1beta.types.resources.ConnectionInfo:
parent = request.parent
instance = parent.split("/")[-1]
if instance == "test-instance":
return connectionInfo(request)
elif instance == "public-instance":
return connectionInfoPublicIP(request)
else:
return connectionInfoPsc(request)

async def generate_client_certificate(self, request: alloydb_v1beta.GenerateClientCertificateRequest) -> web.Response:
return generateClientCertificate(request)


@pytest.mark.asyncio
async def test__get_metadata(credentials: FakeCredentials) -> None:
"""
Test _get_metadata returns successfully.
"""
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
ip_addrs = await test_client._get_metadata(
"test-project",
"test-region",
Expand All @@ -99,7 +52,7 @@ async def test__get_metadata_with_public_ip(
"""
Test _get_metadata returns successfully with Public IP.
"""
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
ip_addrs = await test_client._get_metadata(
"test-project",
"test-region",
Expand All @@ -120,7 +73,7 @@ async def test__get_metadata_with_psc(
"""
Test _get_metadata returns successfully with PSC DNS name.
"""
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
ip_addrs = await test_client._get_metadata(
"test-project",
"test-region",
Expand All @@ -140,34 +93,14 @@ async def test__get_metadata_error(
"""
Test that AlloyDB API error messages are raised for _get_metadata.
"""
# mock AlloyDB API calls with exceptions
client = AlloyDBClient(
alloydb_api_endpoint="https://alloydb.googleapis.com",
alloydb_api_endpoint="alloydb.googleapis.com",
quota_project=None,
credentials=credentials,
)
get_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance/connectionInfo"
resp_body = {
"error": {
"code": 403,
"message": "AlloyDB API has not been used in project 123456789 before or it is disabled",
}
}
with aioresponses() as mocked:
mocked.get(
get_url,
status=403,
payload=resp_body,
repeat=True,
)
with pytest.raises(ClientResponseError) as exc_info:
await client._get_metadata(
"my-project", "my-region", "my-cluster", "my-instance"
)
assert exc_info.value.status == 403
assert (
exc_info.value.message
== "AlloyDB API has not been used in project 123456789 before or it is disabled"
with pytest.raises(RetryError) as exc_info:
await client._get_metadata(
"my-project", "my-region", "my-cluster", "my-instance"
)
await client.close()

Expand All @@ -179,7 +112,7 @@ async def test__get_client_certificate(
"""
Test _get_client_certificate returns successfully.
"""
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
keys = await generate_keys()
certs = await test_client._get_client_certificate(
"test-project", "test-region", "test-cluster", keys[1]
Expand All @@ -197,32 +130,16 @@ async def test__get_client_certificate_error(
"""
Test that AlloyDB API error messages are raised for _get_client_certificate.
"""
# mock AlloyDB API calls with exceptions
client = AlloyDBClient(
alloydb_api_endpoint="https://alloydb.googleapis.com",
alloydb_api_endpoint="alloydb.googleapis.com",
quota_project=None,
credentials=credentials,
)
post_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster:generateClientCertificate"
resp_body = {
"error": {
"code": 404,
"message": "The AlloyDB instance does not exist.",
}
}
with aioresponses() as mocked:
mocked.post(
post_url,
status=404,
payload=resp_body,
repeat=True,
with pytest.raises(RetryError) as exc_info:
await client._get_client_certificate(
"my-project", "my-region", "my-cluster", ""
)
with pytest.raises(ClientResponseError) as exc_info:
await client._get_client_certificate(
"my-project", "my-region", "my-cluster", ""
)
assert exc_info.value.status == 404
assert exc_info.value.message == "The AlloyDB instance does not exist."
print(exc_info)
await client.close()


Expand All @@ -234,10 +151,11 @@ async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None:
"""
client = AlloyDBClient("www.test-endpoint.com", "my-quota-project", credentials)
# verify base endpoint is set
assert client._alloydb_api_endpoint == "www.test-endpoint.com"
assert client._client.api_endpoint == "www.test-endpoint.com"
# verify proper headers are set
assert client._client.headers["User-Agent"] == f"alloydb-python-connector/{version}"
assert client._client.headers["x-goog-user-project"] == "my-quota-project"
got_user_agent = client._client.transport._wrapped_methods[client._client.transport.list_clusters]._metadata[0][1]
assert got_user_agent.startswith(f"alloydb-python-connector/{version}")
assert client._client._client._client_options.quota_project_id == "my-quota-project"
# close client
await client.close()

Expand All @@ -255,10 +173,8 @@ async def test_AlloyDBClient_init_custom_user_agent(
credentials,
user_agent="custom-agent/v1.0.0 other-agent/v2.0.0",
)
assert (
client._client.headers["User-Agent"]
== f"alloydb-python-connector/{version} custom-agent/v1.0.0 other-agent/v2.0.0"
)
got_user_agent = client._client.transport._wrapped_methods[client._client.transport.list_clusters]._metadata[0][1]
assert got_user_agent.startswith(f"alloydb-python-connector/{version} custom-agent/v1.0.0 other-agent/v2.0.0")
await client.close()


Expand All @@ -277,10 +193,11 @@ async def test_AlloyDBClient_user_agent(
client = AlloyDBClient(
"www.test-endpoint.com", "my-quota-project", credentials, driver=driver
)
got_user_agent = client._client.transport._wrapped_methods[client._client.transport.list_clusters]._metadata[0][1]
if driver is None:
assert client._user_agent == f"alloydb-python-connector/{version}"
assert got_user_agent.startswith(f"alloydb-python-connector/{version}")
else:
assert client._user_agent == f"alloydb-python-connector/{version}+{driver}"
assert got_user_agent.startswith(f"alloydb-python-connector/{version}+{driver}")
# close client
await client.close()

Expand Down
Loading

0 comments on commit 44d084b

Please sign in to comment.