Skip to content

Commit

Permalink
chore: update type hints to standard collections (#410)
Browse files Browse the repository at this point in the history
This change is analogous to CloudSQL's change:
GoogleCloudPlatform/cloud-sql-python-connector#1183
  • Loading branch information
rhatgadkar-goog authored Jan 9, 2025
1 parent ea2a524 commit d6ec5ad
Show file tree
Hide file tree
Showing 16 changed files with 37 additions and 47 deletions.
6 changes: 3 additions & 3 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import asyncio
import logging
from types import TracebackType
from typing import Any, Dict, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import google.auth
from google.auth.credentials import with_scopes_if_required
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
user_agent: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
) -> None:
self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
Expand Down Expand Up @@ -223,7 +223,7 @@ async def __aenter__(self) -> Any:

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import asyncio
import logging
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING

import aiohttp
from cryptography import x509
Expand Down Expand Up @@ -100,7 +100,7 @@ async def _get_metadata(
region: str,
cluster: str,
name: str,
) -> Dict[str, Optional[str]]:
) -> dict[str, Optional[str]]:
"""
Fetch the metadata for a given AlloyDB instance.
Expand Down Expand Up @@ -156,7 +156,7 @@ async def _get_client_certificate(
region: str,
cluster: str,
pub_key: str,
) -> Tuple[str, List[str]]:
) -> tuple[str, list[str]]:
"""
Fetch a client certificate for the given AlloyDB cluster.
Expand All @@ -172,7 +172,7 @@ async def _get_client_certificate(
pub_key (str): PEM-encoded client public key.
Returns:
Tuple[str, list[str]]: Tuple containing the CA certificate
tuple[str, list[str]]: tuple containing the CA certificate
and certificate chain for the AlloyDB instance.
"""
headers = {
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/alloydb/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass
import logging
import ssl
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING

from aiofiles.tempfile import TemporaryDirectory

Expand All @@ -39,10 +39,10 @@ class ConnectionInfo:
"""Contains all necessary information to connect securely to the
server-side Proxy running on an AlloyDB instance."""

cert_chain: List[str]
cert_chain: list[str]
ca_cert: str
key: rsa.RSAPrivateKey
ip_addrs: Dict[str, Optional[str]]
ip_addrs: dict[str, Optional[str]]
expiration: datetime.datetime
context: Optional[ssl.SSLContext] = None

Expand Down
6 changes: 3 additions & 3 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import struct
from threading import Thread
from types import TracebackType
from typing import Any, Dict, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union

from google.auth import default
from google.auth.credentials import TokenState
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread = Thread(target=self._loop.run_forever, daemon=True)
self._thread.start()
self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
Expand Down Expand Up @@ -355,7 +355,7 @@ def __enter__(self) -> "Connector":

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from datetime import timezone
import logging
import re
from typing import Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING

from google.cloud.alloydb.connector.connection_info import ConnectionInfo
from google.cloud.alloydb.connector.exceptions import RefreshError
Expand All @@ -40,7 +40,7 @@
)


def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]:
def _parse_instance_uri(instance_uri: str) -> tuple[str, str, str, str]:
# should take form "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>"
if INSTANCE_URI_REGEX.fullmatch(instance_uri) is None:
raise ValueError(
Expand Down Expand Up @@ -69,14 +69,14 @@ class RefreshAheadCache:
instance_uri (str): The instance URI of the AlloyDB instance.
ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>
client (AlloyDBClient): Client used to make requests to AlloyDB APIs.
keys (Tuple[rsa.RSAPrivateKey, str]): Private and Public key pair.
keys (tuple[rsa.RSAPrivateKey, str]): Private and Public key pair.
"""

def __init__(
self,
instance_uri: str,
client: AlloyDBClient,
keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]],
keys: asyncio.Future[tuple[rsa.RSAPrivateKey, str]],
) -> None:
# validate and parse instance_uri
self._project, self._region, self._cluster, self._name = _parse_instance_uri(
Expand Down
8 changes: 3 additions & 5 deletions google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@

from __future__ import annotations

from typing import List, Tuple

import aiofiles
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa


async def _write_to_file(
dir_path: str, ca_cert: str, cert_chain: List[str], key: rsa.RSAPrivateKey
) -> Tuple[str, str, str]:
dir_path: str, ca_cert: str, cert_chain: list[str], key: rsa.RSAPrivateKey
) -> tuple[str, str, str]:
"""
Helper function to write the server_ca, client certificate and
private key to .pem files in a given directory.
Expand All @@ -48,7 +46,7 @@ async def _write_to_file(
return (ca_filename, cert_chain_filename, key_filename)


async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
async def generate_keys() -> tuple[rsa.RSAPrivateKey, str]:
priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
pub_key = (
priv_key.public_key()
Expand Down
4 changes: 2 additions & 2 deletions tests/system/test_asyncpg_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import os
from typing import Any, Tuple
from typing import Any

# [START alloydb_sqlalchemy_connect_async_connector]
import asyncpg
Expand All @@ -29,7 +29,7 @@ async def create_sqlalchemy_engine(
password: str,
db: str,
refresh_strategy: str = "background",
) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_asyncpg_iam_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from datetime import datetime
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_async_connector_iam_authn]
import asyncpg
Expand All @@ -26,7 +25,7 @@

async def create_sqlalchemy_engine(
inst_uri: str, user: str, db: str, refresh_strategy: str = "background"
) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_asyncpg_psc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
from typing import Tuple

import asyncpg
import pytest
Expand All @@ -28,7 +27,7 @@ async def create_sqlalchemy_engine(
user: str,
password: str,
db: str,
) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_asyncpg_public_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_async_connector_public_ip]
import asyncpg
Expand All @@ -29,7 +28,7 @@ async def create_sqlalchemy_engine(
user: str,
password: str,
db: str,
) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_pg8000_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from datetime import datetime
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_connector]
import pg8000
Expand All @@ -29,7 +28,7 @@ def create_sqlalchemy_engine(
password: str,
db: str,
refresh_strategy: str = "background",
) -> Tuple[sqlalchemy.engine.Engine, Connector]:
) -> tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_pg8000_iam_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from datetime import datetime
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_connector_iam_authn]
import pg8000
Expand All @@ -25,7 +24,7 @@

def create_sqlalchemy_engine(
inst_uri: str, user: str, db: str, refresh_strategy: str = "background"
) -> Tuple[sqlalchemy.engine.Engine, Connector]:
) -> tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_pg8000_psc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from datetime import datetime
import os
from typing import Tuple

import pg8000
import sqlalchemy
Expand All @@ -27,7 +26,7 @@ def create_sqlalchemy_engine(
user: str,
password: str,
db: str,
) -> Tuple[sqlalchemy.engine.Engine, Connector]:
) -> tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_pg8000_public_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from datetime import datetime
import os
from typing import Tuple

# [START alloydb_sqlalchemy_connect_connector_public_ip]
import pg8000
Expand All @@ -28,7 +27,7 @@ def create_sqlalchemy_engine(
user: str,
password: str,
db: str,
) -> Tuple[sqlalchemy.engine.Engine, Connector]:
) -> tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for an AlloyDB instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
connector.
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ipaddress
import ssl
import struct
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
from typing import Any, Callable, Literal, Optional

from cryptography import x509
from cryptography.hazmat.primitives import hashes
Expand Down Expand Up @@ -86,7 +86,7 @@ def token_state(

def generate_cert(
common_name: str, expires_in: int = 60, server_cert: bool = False
) -> Tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]:
) -> tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]:
"""
Generate a private key and cert object to be used in testing.
Expand All @@ -96,7 +96,7 @@ def generate_cert(
server_cert (bool): Whether it is a server certificate.
Returns:
Tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]
tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]
"""
# generate private key
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(
region: str = "test-region",
cluster: str = "test-cluster",
name: str = "test-instance",
ip_addrs: Dict = {
ip_addrs: dict = {
"PRIVATE": "127.0.0.1",
"PUBLIC": "0.0.0.0",
"PSC": "x.y.alloydb.goog",
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
# create server cert signed by root cert
self.server_cert = self.server_cert.sign(self.root_key, hashes.SHA256())

def get_pem_certs(self) -> Tuple[str, str, str]:
def get_pem_certs(self) -> tuple[str, str, str]:
"""Helper method to get all certs in pem string format."""
pem_root = self.root_cert.public_bytes(
encoding=serialization.Encoding.PEM
Expand Down Expand Up @@ -215,7 +215,7 @@ async def _get_client_certificate(
region: str,
cluster: str,
pub_key: str,
) -> Tuple[str, List[str]]:
) -> tuple[str, list[str]]:
root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs()
# encode public key to bytes
pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key(
Expand Down Expand Up @@ -365,7 +365,7 @@ def connect_info(self) -> Any:
f.set_result(self)
return f

def get_preferred_ip(self, ip_type: Any) -> Tuple[str, Any]:
def get_preferred_ip(self, ip_type: Any) -> tuple[str, Any]:
f = asyncio.Future()
f.set_result("10.0.0.1")
return f
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import asyncio
from datetime import datetime
from datetime import timedelta
from typing import Tuple

import aiohttp
from mocks import FakeAlloyDBClient
Expand Down Expand Up @@ -48,7 +47,7 @@
],
)
def test_parse_instance_uri(
instance_uri: str, expected: Tuple[str, str, str, str]
instance_uri: str, expected: tuple[str, str, str, str]
) -> None:
"""
Test that _parse_instance_uri works correctly on
Expand Down

0 comments on commit d6ec5ad

Please sign in to comment.