Skip to content

Commit

Permalink
[MISC] Migrate typing syntax to Python 3.10+ (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
sinclert-canonical authored Jan 13, 2025
1 parent 68fb053 commit da22c0f
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 152 deletions.
20 changes: 11 additions & 9 deletions src/backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import time
from datetime import datetime, timezone
from io import BytesIO
from typing import Dict, List, Optional, Tuple

import boto3 as boto3
import botocore
Expand Down Expand Up @@ -89,7 +88,7 @@ def _tls_ca_chain_filename(self) -> str:
return f"{self.charm._storage_path}/pgbackrest-tls-ca-chain.crt"
return ""

def _are_backup_settings_ok(self) -> Tuple[bool, Optional[str]]:
def _are_backup_settings_ok(self) -> tuple[bool, str | None]:
"""Validates whether backup settings are OK."""
if self.model.get_relation(self.relation_name) is None:
return (
Expand Down Expand Up @@ -120,7 +119,7 @@ def _can_initialise_stanza(self) -> bool:
)
)

def _can_unit_perform_backup(self) -> Tuple[bool, Optional[str]]:
def _can_unit_perform_backup(self) -> tuple[bool, str | None]:
"""Validates whether this unit can perform a backup."""
if self.charm.is_blocked:
return False, "Unit is in a blocking state"
Expand Down Expand Up @@ -152,7 +151,7 @@ def _can_unit_perform_backup(self) -> Tuple[bool, Optional[str]]:

return self._are_backup_settings_ok()

def can_use_s3_repository(self) -> Tuple[bool, Optional[str]]:
def can_use_s3_repository(self) -> tuple[bool, str | None]:
"""Returns whether the charm was configured to use another cluster repository."""
# Check model uuid
s3_parameters, _ = self._retrieve_s3_parameters()
Expand Down Expand Up @@ -196,7 +195,7 @@ def can_use_s3_repository(self) -> Tuple[bool, Optional[str]]:
return False, ANOTHER_CLUSTER_REPOSITORY_ERROR_MESSAGE
return True, None

def _construct_endpoint(self, s3_parameters: Dict) -> str:
def _construct_endpoint(self, s3_parameters: dict) -> str:
"""Construct the S3 service endpoint using the region.
This is needed when the provided endpoint is from AWS, and it doesn't contain the region.
Expand Down Expand Up @@ -278,8 +277,11 @@ def _change_connectivity_to_database(self, connectivity: bool) -> None:
self.charm.update_config(is_creating_backup=True)

def _execute_command(
self, command: List[str], timeout: Optional[float] = None, stream: bool = False
) -> Tuple[Optional[str], Optional[str]]:
self,
command: list[str],
timeout: float | None = None,
stream: bool = False,
) -> tuple[str | None, str | None]:
"""Execute a command in the workload container."""
try:
logger.debug("Running command %s", " ".join(command))
Expand Down Expand Up @@ -504,7 +506,7 @@ def _parse_psql_timestamp(self, timestamp: str) -> datetime:
dt = dt.astimezone(tz=timezone.utc)
return dt.replace(tzinfo=None)

def _parse_backup_id(self, label) -> Tuple[str, str]:
def _parse_backup_id(self, label) -> tuple[str, str]:
"""Parse backup ID as a timestamp and its type."""
if label[-1] == "F":
timestamp = label
Expand Down Expand Up @@ -1221,7 +1223,7 @@ def _restart_database(self) -> None:
self.charm.update_config()
self.container.start(self.charm._postgresql_service)

def _retrieve_s3_parameters(self) -> Tuple[Dict, List[str]]:
def _retrieve_s3_parameters(self) -> tuple[dict, list[str]]:
"""Retrieve S3 parameters from the S3 integrator relation."""
s3_parameters = self.s3_client.get_s3_connection_info()
required_parameters = [
Expand Down
46 changes: 23 additions & 23 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
import time
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, get_args
from typing import Literal, get_args

# First platform-specific import, will fail on wrong architecture
try:
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(self, *args):
)

@property
def tracing_endpoint(self) -> Optional[str]:
def tracing_endpoint(self) -> str | None:
"""Otlp http endpoint for charm instrumentation."""
if self.tracing.is_ready():
return self.tracing.get_endpoint(TRACING_PROTOCOL)
Expand All @@ -267,7 +267,7 @@ def _pebble_log_forwarding_supported(self) -> bool:
juju_version = JujuVersion.from_environ()
return juju_version > JujuVersion(version="3.3")

def _generate_metrics_jobs(self, enable_tls: bool) -> Dict:
def _generate_metrics_jobs(self, enable_tls: bool) -> dict:
"""Generate spec for Prometheus scraping."""
return [
{"static_configs": [{"targets": [f"*:{METRICS_PORT}"]}]},
Expand All @@ -287,7 +287,7 @@ def app_units(self) -> set[Unit]:
return {self.unit, *self._peers.units}

@property
def app_peer_data(self) -> Dict:
def app_peer_data(self) -> dict:
"""Application peer relation data object."""
relation = self.model.get_relation(PEER)
if relation is None:
Expand All @@ -296,15 +296,15 @@ def app_peer_data(self) -> Dict:
return relation.data[self.app]

@property
def unit_peer_data(self) -> Dict:
def unit_peer_data(self) -> dict:
"""Unit peer relation data object."""
relation = self.model.get_relation(PEER)
if relation is None:
return {}

return relation.data[self.unit]

def _peer_data(self, scope: Scopes) -> Dict:
def _peer_data(self, scope: Scopes) -> dict:
"""Return corresponding databag for app/unit."""
relation = self.model.get_relation(PEER)
if relation is None:
Expand Down Expand Up @@ -333,7 +333,7 @@ def _translate_field_to_secret_key(self, key: str) -> str:
new_key = key.replace("_", "-")
return new_key.strip("-")

def get_secret(self, scope: Scopes, key: str) -> Optional[str]:
def get_secret(self, scope: Scopes, key: str) -> str | None:
"""Get secret from the secret storage."""
if scope not in get_args(Scopes):
raise RuntimeError("Unknown secret scope.")
Expand All @@ -348,7 +348,7 @@ def get_secret(self, scope: Scopes, key: str) -> Optional[str]:

return self.peer_relation_data(scope).get_secret(peers.id, secret_key)

def set_secret(self, scope: Scopes, key: str, value: Optional[str]) -> Optional[str]:
def set_secret(self, scope: Scopes, key: str, value: str | None) -> str | None:
"""Set secret from the secret storage."""
if scope not in get_args(Scopes):
raise RuntimeError("Unknown secret scope.")
Expand Down Expand Up @@ -426,14 +426,14 @@ def get_hostname_by_unit(self, unit_name: str) -> str:
unit_id = unit_name.split("/")[1]
return f"{self.app.name}-{unit_id}.{self.app.name}-endpoints"

def _get_endpoints_to_remove(self) -> List[str]:
def _get_endpoints_to_remove(self) -> list[str]:
"""List the endpoints that were part of the cluster but departed."""
old = self._endpoints
current = [self._get_hostname_from_unit(member) for member in self._hosts]
endpoints_to_remove = list(set(old) - set(current))
return endpoints_to_remove

def get_unit_ip(self, unit: Unit) -> Optional[str]:
def get_unit_ip(self, unit: Unit) -> str | None:
"""Get the IP address of a specific unit."""
# Check if host is current host.
if unit == self.unit:
Expand Down Expand Up @@ -683,7 +683,7 @@ def _on_config_changed(self, event) -> None:
)
return

def enable_disable_extensions(self, database: Optional[str] = None) -> None:
def enable_disable_extensions(self, database: str | None = None) -> None:
"""Enable/disable PostgreSQL extensions set through config options.
Args:
Expand Down Expand Up @@ -1593,7 +1593,7 @@ def _endpoint(self) -> str:
return self._get_hostname_from_unit(self._unit_name_to_pod_name(self.unit.name))

@property
def _endpoints(self) -> List[str]:
def _endpoints(self) -> list[str]:
"""Cluster members hostnames."""
if self._peers:
return json.loads(self._peers.data[self.app].get("endpoints", "[]"))
Expand All @@ -1602,7 +1602,7 @@ def _endpoints(self) -> List[str]:
return [self._endpoint]

@property
def peer_members_endpoints(self) -> List[str]:
def peer_members_endpoints(self) -> list[str]:
"""Fetch current list of peer members endpoints.
Returns:
Expand All @@ -1621,14 +1621,14 @@ def _add_to_endpoints(self, endpoint) -> None:
"""Add one endpoint to the members list."""
self._update_endpoints(endpoint_to_add=endpoint)

def _remove_from_endpoints(self, endpoints: List[str]) -> None:
def _remove_from_endpoints(self, endpoints: list[str]) -> None:
"""Remove endpoints from the members list."""
self._update_endpoints(endpoints_to_remove=endpoints)

def _update_endpoints(
self,
endpoint_to_add: Optional[str] = None,
endpoints_to_remove: Optional[List[str]] = None,
endpoint_to_add: str | None = None,
endpoints_to_remove: list[str] | None = None,
) -> None:
"""Update members IPs."""
# Allow leader to reset which members are part of the cluster.
Expand All @@ -1643,7 +1643,7 @@ def _update_endpoints(
endpoints.remove(endpoint)
self._peers.data[self.app]["endpoints"] = json.dumps(endpoints)

def _generate_metrics_service(self) -> Dict:
def _generate_metrics_service(self) -> dict:
"""Generate the metrics service definition."""
return {
"override": "replace",
Expand Down Expand Up @@ -2012,7 +2012,7 @@ def _get_node_name_for_pod(self) -> str:
)
return pod.spec.nodeName

def get_resources_limits(self, container_name: str) -> Dict:
def get_resources_limits(self, container_name: str) -> dict:
"""Return resources limits for a given container.
Args:
Expand Down Expand Up @@ -2040,7 +2040,7 @@ def get_node_cpu_cores(self) -> int:
node = client.get(Node, name=self._get_node_name_for_pod(), namespace=self._namespace)
return any_cpu_to_cores(node.status.allocatable["cpu"])

def get_available_resources(self) -> Tuple[int, int]:
def get_available_resources(self) -> tuple[int, int]:
"""Get available CPU cores and memory (in bytes) for the container."""
cpu_cores = self.get_node_cpu_cores()
allocable_memory = self.get_node_allocable_memory()
Expand Down Expand Up @@ -2073,7 +2073,7 @@ def on_deployed_without_trust(self) -> None:
)

@property
def client_relations(self) -> List[Relation]:
def client_relations(self) -> list[Relation]:
"""Return the list of established client relations."""
relations = []
for relation_name in ["database", "db", "db-admin"]:
Expand Down Expand Up @@ -2150,15 +2150,15 @@ def restore_patroni_on_failure_condition(self) -> None:
else:
logger.warning("not restoring patroni on-failure condition as it's not overridden")

def is_pitr_failed(self, container: Container) -> Tuple[bool, bool]:
def is_pitr_failed(self, container: Container) -> tuple[bool, bool]:
"""Check if Patroni service failed to bootstrap cluster during point-in-time-recovery.
Typically, this means that database service failed to reach point-in-time-recovery target or has been
supplied with bad PITR parameter. Also, remembers last state and can provide info is it new event, or
it belongs to previous action. Executes only on current unit.
Returns:
Tuple[bool, bool]:
tuple[bool, bool]:
- Is patroni service failed to bootstrap cluster.
- Is it new fail, that wasn't observed previously.
"""
Expand Down Expand Up @@ -2228,7 +2228,7 @@ def log_pitr_last_transaction_time(self) -> None:
else:
logger.error("Can't tell last completed transaction time")

def get_plugins(self) -> List[str]:
def get_plugins(self) -> list[str]:
"""Return a list of installed plugins."""
plugins = [
"_".join(plugin.split("_")[1:-1])
Expand Down
Loading

0 comments on commit da22c0f

Please sign in to comment.