diff --git a/src/backups.py b/src/backups.py index 5fb9a20533..eb75df42b1 100644 --- a/src/backups.py +++ b/src/backups.py @@ -11,7 +11,6 @@ import time from datetime import datetime, timezone from io import BytesIO -from typing import Dict, List, Optional, Tuple import boto3 as boto3 import botocore @@ -89,7 +88,7 @@ def _tls_ca_chain_filename(self) -> str: return f"{self.charm._storage_path}/pgbackrest-tls-ca-chain.crt" return "" - def _are_backup_settings_ok(self) -> Tuple[bool, Optional[str]]: + def _are_backup_settings_ok(self) -> tuple[bool, str | None]: """Validates whether backup settings are OK.""" if self.model.get_relation(self.relation_name) is None: return ( @@ -120,7 +119,7 @@ def _can_initialise_stanza(self) -> bool: ) ) - def _can_unit_perform_backup(self) -> Tuple[bool, Optional[str]]: + def _can_unit_perform_backup(self) -> tuple[bool, str | None]: """Validates whether this unit can perform a backup.""" if self.charm.is_blocked: return False, "Unit is in a blocking state" @@ -152,7 +151,7 @@ def _can_unit_perform_backup(self) -> Tuple[bool, Optional[str]]: return self._are_backup_settings_ok() - def can_use_s3_repository(self) -> Tuple[bool, Optional[str]]: + def can_use_s3_repository(self) -> tuple[bool, str | None]: """Returns whether the charm was configured to use another cluster repository.""" # Check model uuid s3_parameters, _ = self._retrieve_s3_parameters() @@ -196,7 +195,7 @@ def can_use_s3_repository(self) -> Tuple[bool, Optional[str]]: return False, ANOTHER_CLUSTER_REPOSITORY_ERROR_MESSAGE return True, None - def _construct_endpoint(self, s3_parameters: Dict) -> str: + def _construct_endpoint(self, s3_parameters: dict) -> str: """Construct the S3 service endpoint using the region. This is needed when the provided endpoint is from AWS, and it doesn't contain the region. @@ -278,8 +277,11 @@ def _change_connectivity_to_database(self, connectivity: bool) -> None: self.charm.update_config(is_creating_backup=True) def _execute_command( - self, command: List[str], timeout: Optional[float] = None, stream: bool = False - ) -> Tuple[Optional[str], Optional[str]]: + self, + command: list[str], + timeout: float | None = None, + stream: bool = False, + ) -> tuple[str | None, str | None]: """Execute a command in the workload container.""" try: logger.debug("Running command %s", " ".join(command)) @@ -504,7 +506,7 @@ def _parse_psql_timestamp(self, timestamp: str) -> datetime: dt = dt.astimezone(tz=timezone.utc) return dt.replace(tzinfo=None) - def _parse_backup_id(self, label) -> Tuple[str, str]: + def _parse_backup_id(self, label) -> tuple[str, str]: """Parse backup ID as a timestamp and its type.""" if label[-1] == "F": timestamp = label @@ -1221,7 +1223,7 @@ def _restart_database(self) -> None: self.charm.update_config() self.container.start(self.charm._postgresql_service) - def _retrieve_s3_parameters(self) -> Tuple[Dict, List[str]]: + def _retrieve_s3_parameters(self) -> tuple[dict, list[str]]: """Retrieve S3 parameters from the S3 integrator relation.""" s3_parameters = self.s3_client.get_s3_connection_info() required_parameters = [ diff --git a/src/charm.py b/src/charm.py index 6266d1d91f..d44b786e7c 100755 --- a/src/charm.py +++ b/src/charm.py @@ -13,7 +13,7 @@ import sys import time from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple, get_args +from typing import Literal, get_args # First platform-specific import, will fail on wrong architecture try: @@ -254,7 +254,7 @@ def __init__(self, *args): ) @property - def tracing_endpoint(self) -> Optional[str]: + def tracing_endpoint(self) -> str | None: """Otlp http endpoint for charm instrumentation.""" if self.tracing.is_ready(): return self.tracing.get_endpoint(TRACING_PROTOCOL) @@ -267,7 +267,7 @@ def _pebble_log_forwarding_supported(self) -> bool: juju_version = JujuVersion.from_environ() return juju_version > JujuVersion(version="3.3") - def _generate_metrics_jobs(self, enable_tls: bool) -> Dict: + def _generate_metrics_jobs(self, enable_tls: bool) -> dict: """Generate spec for Prometheus scraping.""" return [ {"static_configs": [{"targets": [f"*:{METRICS_PORT}"]}]}, @@ -287,7 +287,7 @@ def app_units(self) -> set[Unit]: return {self.unit, *self._peers.units} @property - def app_peer_data(self) -> Dict: + def app_peer_data(self) -> dict: """Application peer relation data object.""" relation = self.model.get_relation(PEER) if relation is None: @@ -296,7 +296,7 @@ def app_peer_data(self) -> Dict: return relation.data[self.app] @property - def unit_peer_data(self) -> Dict: + def unit_peer_data(self) -> dict: """Unit peer relation data object.""" relation = self.model.get_relation(PEER) if relation is None: @@ -304,7 +304,7 @@ def unit_peer_data(self) -> Dict: return relation.data[self.unit] - def _peer_data(self, scope: Scopes) -> Dict: + def _peer_data(self, scope: Scopes) -> dict: """Return corresponding databag for app/unit.""" relation = self.model.get_relation(PEER) if relation is None: @@ -333,7 +333,7 @@ def _translate_field_to_secret_key(self, key: str) -> str: new_key = key.replace("_", "-") return new_key.strip("-") - def get_secret(self, scope: Scopes, key: str) -> Optional[str]: + def get_secret(self, scope: Scopes, key: str) -> str | None: """Get secret from the secret storage.""" if scope not in get_args(Scopes): raise RuntimeError("Unknown secret scope.") @@ -348,7 +348,7 @@ def get_secret(self, scope: Scopes, key: str) -> Optional[str]: return self.peer_relation_data(scope).get_secret(peers.id, secret_key) - def set_secret(self, scope: Scopes, key: str, value: Optional[str]) -> Optional[str]: + def set_secret(self, scope: Scopes, key: str, value: str | None) -> str | None: """Set secret from the secret storage.""" if scope not in get_args(Scopes): raise RuntimeError("Unknown secret scope.") @@ -426,14 +426,14 @@ def get_hostname_by_unit(self, unit_name: str) -> str: unit_id = unit_name.split("/")[1] return f"{self.app.name}-{unit_id}.{self.app.name}-endpoints" - def _get_endpoints_to_remove(self) -> List[str]: + def _get_endpoints_to_remove(self) -> list[str]: """List the endpoints that were part of the cluster but departed.""" old = self._endpoints current = [self._get_hostname_from_unit(member) for member in self._hosts] endpoints_to_remove = list(set(old) - set(current)) return endpoints_to_remove - def get_unit_ip(self, unit: Unit) -> Optional[str]: + def get_unit_ip(self, unit: Unit) -> str | None: """Get the IP address of a specific unit.""" # Check if host is current host. if unit == self.unit: @@ -683,7 +683,7 @@ def _on_config_changed(self, event) -> None: ) return - def enable_disable_extensions(self, database: Optional[str] = None) -> None: + def enable_disable_extensions(self, database: str | None = None) -> None: """Enable/disable PostgreSQL extensions set through config options. Args: @@ -1593,7 +1593,7 @@ def _endpoint(self) -> str: return self._get_hostname_from_unit(self._unit_name_to_pod_name(self.unit.name)) @property - def _endpoints(self) -> List[str]: + def _endpoints(self) -> list[str]: """Cluster members hostnames.""" if self._peers: return json.loads(self._peers.data[self.app].get("endpoints", "[]")) @@ -1602,7 +1602,7 @@ def _endpoints(self) -> List[str]: return [self._endpoint] @property - def peer_members_endpoints(self) -> List[str]: + def peer_members_endpoints(self) -> list[str]: """Fetch current list of peer members endpoints. Returns: @@ -1621,14 +1621,14 @@ def _add_to_endpoints(self, endpoint) -> None: """Add one endpoint to the members list.""" self._update_endpoints(endpoint_to_add=endpoint) - def _remove_from_endpoints(self, endpoints: List[str]) -> None: + def _remove_from_endpoints(self, endpoints: list[str]) -> None: """Remove endpoints from the members list.""" self._update_endpoints(endpoints_to_remove=endpoints) def _update_endpoints( self, - endpoint_to_add: Optional[str] = None, - endpoints_to_remove: Optional[List[str]] = None, + endpoint_to_add: str | None = None, + endpoints_to_remove: list[str] | None = None, ) -> None: """Update members IPs.""" # Allow leader to reset which members are part of the cluster. @@ -1643,7 +1643,7 @@ def _update_endpoints( endpoints.remove(endpoint) self._peers.data[self.app]["endpoints"] = json.dumps(endpoints) - def _generate_metrics_service(self) -> Dict: + def _generate_metrics_service(self) -> dict: """Generate the metrics service definition.""" return { "override": "replace", @@ -2012,7 +2012,7 @@ def _get_node_name_for_pod(self) -> str: ) return pod.spec.nodeName - def get_resources_limits(self, container_name: str) -> Dict: + def get_resources_limits(self, container_name: str) -> dict: """Return resources limits for a given container. Args: @@ -2040,7 +2040,7 @@ def get_node_cpu_cores(self) -> int: node = client.get(Node, name=self._get_node_name_for_pod(), namespace=self._namespace) return any_cpu_to_cores(node.status.allocatable["cpu"]) - def get_available_resources(self) -> Tuple[int, int]: + def get_available_resources(self) -> tuple[int, int]: """Get available CPU cores and memory (in bytes) for the container.""" cpu_cores = self.get_node_cpu_cores() allocable_memory = self.get_node_allocable_memory() @@ -2073,7 +2073,7 @@ def on_deployed_without_trust(self) -> None: ) @property - def client_relations(self) -> List[Relation]: + def client_relations(self) -> list[Relation]: """Return the list of established client relations.""" relations = [] for relation_name in ["database", "db", "db-admin"]: @@ -2150,7 +2150,7 @@ def restore_patroni_on_failure_condition(self) -> None: else: logger.warning("not restoring patroni on-failure condition as it's not overridden") - def is_pitr_failed(self, container: Container) -> Tuple[bool, bool]: + def is_pitr_failed(self, container: Container) -> tuple[bool, bool]: """Check if Patroni service failed to bootstrap cluster during point-in-time-recovery. Typically, this means that database service failed to reach point-in-time-recovery target or has been @@ -2158,7 +2158,7 @@ def is_pitr_failed(self, container: Container) -> Tuple[bool, bool]: it belongs to previous action. Executes only on current unit. Returns: - Tuple[bool, bool]: + tuple[bool, bool]: - Is patroni service failed to bootstrap cluster. - Is it new fail, that wasn't observed previously. """ @@ -2228,7 +2228,7 @@ def log_pitr_last_transaction_time(self) -> None: else: logger.error("Can't tell last completed transaction time") - def get_plugins(self) -> List[str]: + def get_plugins(self) -> list[str]: """Return a list of installed plugins.""" plugins = [ "_".join(plugin.split("_")[1:-1]) diff --git a/src/config.py b/src/config.py index b5f41ec5b4..82e479b2f5 100644 --- a/src/config.py +++ b/src/config.py @@ -5,7 +5,6 @@ """Structured configuration for the PostgreSQL charm.""" import logging -from typing import Optional from charms.data_platform_libs.v0.data_models import BaseConfigModel from pydantic import validator @@ -16,24 +15,24 @@ class CharmConfig(BaseConfigModel): """Manager for the structured configuration.""" - durability_synchronous_commit: Optional[str] - instance_default_text_search_config: Optional[str] - instance_password_encryption: Optional[str] - logging_log_connections: Optional[bool] - logging_log_disconnections: Optional[bool] - logging_log_lock_waits: Optional[bool] - logging_log_min_duration_statement: Optional[int] - memory_maintenance_work_mem: Optional[int] - memory_max_prepared_transactions: Optional[int] - memory_shared_buffers: Optional[int] - memory_temp_buffers: Optional[int] - memory_work_mem: Optional[int] - optimizer_constraint_exclusion: Optional[str] - optimizer_default_statistics_target: Optional[int] - optimizer_from_collapse_limit: Optional[int] - optimizer_join_collapse_limit: Optional[int] + durability_synchronous_commit: str | None + instance_default_text_search_config: str | None + instance_password_encryption: str | None + logging_log_connections: bool | None + logging_log_disconnections: bool | None + logging_log_lock_waits: bool | None + logging_log_min_duration_statement: int | None + memory_maintenance_work_mem: int | None + memory_max_prepared_transactions: int | None + memory_shared_buffers: int | None + memory_temp_buffers: int | None + memory_work_mem: int | None + optimizer_constraint_exclusion: str | None + optimizer_default_statistics_target: int | None + optimizer_from_collapse_limit: int | None + optimizer_join_collapse_limit: int | None profile: str - profile_limit_memory: Optional[int] + profile_limit_memory: int | None plugin_audit_enable: bool plugin_citext_enable: bool plugin_debversion_enable: bool @@ -86,20 +85,20 @@ class CharmConfig(BaseConfigModel): plugin_postgis_raster_enable: bool plugin_vector_enable: bool plugin_timescaledb_enable: bool - request_date_style: Optional[str] - request_standard_conforming_strings: Optional[bool] - request_time_zone: Optional[str] - response_bytea_output: Optional[str] - response_lc_monetary: Optional[str] - response_lc_numeric: Optional[str] - response_lc_time: Optional[str] - vacuum_autovacuum_analyze_scale_factor: Optional[float] - vacuum_autovacuum_analyze_threshold: Optional[int] - vacuum_autovacuum_freeze_max_age: Optional[int] - vacuum_autovacuum_vacuum_cost_delay: Optional[float] - vacuum_autovacuum_vacuum_scale_factor: Optional[float] - vacuum_vacuum_freeze_table_age: Optional[int] - experimental_max_connections: Optional[int] + request_date_style: str | None + request_standard_conforming_strings: bool | None + request_time_zone: str | None + response_bytea_output: str | None + response_lc_monetary: str | None + response_lc_numeric: str | None + response_lc_time: str | None + vacuum_autovacuum_analyze_scale_factor: float | None + vacuum_autovacuum_analyze_threshold: int | None + vacuum_autovacuum_freeze_max_age: int | None + vacuum_autovacuum_vacuum_cost_delay: float | None + vacuum_autovacuum_vacuum_scale_factor: float | None + vacuum_vacuum_freeze_table_age: int | None + experimental_max_connections: int | None @classmethod def keys(cls) -> list[str]: @@ -113,7 +112,7 @@ def plugin_keys(cls) -> filter: @validator("durability_synchronous_commit") @classmethod - def durability_synchronous_commit_values(cls, value: str) -> Optional[str]: + def durability_synchronous_commit_values(cls, value: str) -> str | None: """Check durability_synchronous_commit config option is one of `on`, `remote_apply` or `remote_write`.""" if value not in ["on", "remote_apply", "remote_write"]: raise ValueError("Value not one of 'on', 'remote_apply' or 'remote_write'") @@ -122,7 +121,7 @@ def durability_synchronous_commit_values(cls, value: str) -> Optional[str]: @validator("instance_password_encryption") @classmethod - def instance_password_encryption_values(cls, value: str) -> Optional[str]: + def instance_password_encryption_values(cls, value: str) -> str | None: """Check instance_password_encryption config option is one of `md5` or `scram-sha-256`.""" if value not in ["md5", "scram-sha-256"]: raise ValueError("Value not one of 'md5' or 'scram-sha-256'") @@ -131,7 +130,7 @@ def instance_password_encryption_values(cls, value: str) -> Optional[str]: @validator("logging_log_min_duration_statement") @classmethod - def logging_log_min_duration_statement_values(cls, value: int) -> Optional[int]: + def logging_log_min_duration_statement_values(cls, value: int) -> int | None: """Check logging_log_min_duration_statement config option is between -1 and 2147483647.""" if value < -1 or value > 2147483647: raise ValueError("Value is not between -1 and 2147483647") @@ -140,7 +139,7 @@ def logging_log_min_duration_statement_values(cls, value: int) -> Optional[int]: @validator("memory_maintenance_work_mem") @classmethod - def memory_maintenance_work_mem_values(cls, value: int) -> Optional[int]: + def memory_maintenance_work_mem_values(cls, value: int) -> int | None: """Check memory_maintenance_work_mem config option is between 1024 and 2147483647.""" if value < 1024 or value > 2147483647: raise ValueError("Value is not between 1024 and 2147483647") @@ -149,7 +148,7 @@ def memory_maintenance_work_mem_values(cls, value: int) -> Optional[int]: @validator("memory_max_prepared_transactions") @classmethod - def memory_max_prepared_transactions_values(cls, value: int) -> Optional[int]: + def memory_max_prepared_transactions_values(cls, value: int) -> int | None: """Check memory_max_prepared_transactions config option is between 0 and 262143.""" if value < 0 or value > 262143: raise ValueError("Value is not between 0 and 262143") @@ -158,7 +157,7 @@ def memory_max_prepared_transactions_values(cls, value: int) -> Optional[int]: @validator("memory_shared_buffers") @classmethod - def memory_shared_buffers_values(cls, value: int) -> Optional[int]: + def memory_shared_buffers_values(cls, value: int) -> int | None: """Check memory_shared_buffers config option is greater or equal than 16.""" if value < 16 or value > 1073741823: raise ValueError("Shared buffers config option should be at least 16") @@ -167,7 +166,7 @@ def memory_shared_buffers_values(cls, value: int) -> Optional[int]: @validator("memory_temp_buffers") @classmethod - def memory_temp_buffers_values(cls, value: int) -> Optional[int]: + def memory_temp_buffers_values(cls, value: int) -> int | None: """Check memory_temp_buffers config option is between 100 and 1073741823.""" if value < 100 or value > 1073741823: raise ValueError("Value is not between 100 and 1073741823") @@ -176,7 +175,7 @@ def memory_temp_buffers_values(cls, value: int) -> Optional[int]: @validator("memory_work_mem") @classmethod - def memory_work_mem_values(cls, value: int) -> Optional[int]: + def memory_work_mem_values(cls, value: int) -> int | None: """Check memory_work_mem config option is between 64 and 2147483647.""" if value < 64 or value > 2147483647: raise ValueError("Value is not between 64 and 2147483647") @@ -185,7 +184,7 @@ def memory_work_mem_values(cls, value: int) -> Optional[int]: @validator("optimizer_constraint_exclusion") @classmethod - def optimizer_constraint_exclusion_values(cls, value: str) -> Optional[str]: + def optimizer_constraint_exclusion_values(cls, value: str) -> str | None: """Check optimizer_constraint_exclusion config option is one of `on`, `off` or `partition`.""" if value not in ["on", "off", "partition"]: raise ValueError("Value not one of 'on', 'off' or 'partition'") @@ -194,7 +193,7 @@ def optimizer_constraint_exclusion_values(cls, value: str) -> Optional[str]: @validator("optimizer_default_statistics_target") @classmethod - def optimizer_default_statistics_target_values(cls, value: int) -> Optional[int]: + def optimizer_default_statistics_target_values(cls, value: int) -> int | None: """Check optimizer_default_statistics_target config option is between 1 and 10000.""" if value < 1 or value > 10000: raise ValueError("Value is not between 1 and 10000") @@ -203,7 +202,7 @@ def optimizer_default_statistics_target_values(cls, value: int) -> Optional[int] @validator("optimizer_from_collapse_limit", "optimizer_join_collapse_limit") @classmethod - def optimizer_collapse_limit_values(cls, value: int) -> Optional[int]: + def optimizer_collapse_limit_values(cls, value: int) -> int | None: """Check optimizer collapse_limit config option is between 1 and 2147483647.""" if value < 1 or value > 2147483647: raise ValueError("Value is not between 1 and 2147483647") @@ -212,7 +211,7 @@ def optimizer_collapse_limit_values(cls, value: int) -> Optional[int]: @validator("profile") @classmethod - def profile_values(cls, value: str) -> Optional[str]: + def profile_values(cls, value: str) -> str | None: """Check profile config option is one of `testing` or `production`.""" if value not in ["testing", "production"]: raise ValueError("Value not one of 'testing' or 'production'") @@ -221,7 +220,7 @@ def profile_values(cls, value: str) -> Optional[str]: @validator("profile_limit_memory") @classmethod - def profile_limit_memory_validator(cls, value: int) -> Optional[int]: + def profile_limit_memory_validator(cls, value: int) -> int | None: """Check profile limit memory.""" if value < 128: raise ValueError("PostgreSQL Charm requires at least 128MB") @@ -232,7 +231,7 @@ def profile_limit_memory_validator(cls, value: int) -> Optional[int]: @validator("response_bytea_output") @classmethod - def response_bytea_output_values(cls, value: str) -> Optional[str]: + def response_bytea_output_values(cls, value: str) -> str | None: """Check response_bytea_output config option is one of `escape` or `hex`.""" if value not in ["escape", "hex"]: raise ValueError("Value not one of 'escape' or 'hex'") @@ -241,7 +240,7 @@ def response_bytea_output_values(cls, value: str) -> Optional[str]: @validator("vacuum_autovacuum_analyze_scale_factor", "vacuum_autovacuum_vacuum_scale_factor") @classmethod - def vacuum_autovacuum_vacuum_scale_factor_values(cls, value: float) -> Optional[float]: + def vacuum_autovacuum_vacuum_scale_factor_values(cls, value: float) -> float | None: """Check autovacuum scale_factor config option is between 0 and 100.""" if value < 0 or value > 100: raise ValueError("Value is not between 0 and 100") @@ -250,7 +249,7 @@ def vacuum_autovacuum_vacuum_scale_factor_values(cls, value: float) -> Optional[ @validator("vacuum_autovacuum_analyze_threshold") @classmethod - def vacuum_autovacuum_analyze_threshold_values(cls, value: int) -> Optional[int]: + def vacuum_autovacuum_analyze_threshold_values(cls, value: int) -> int | None: """Check vacuum_autovacuum_analyze_threshold config option is between 0 and 2147483647.""" if value < 0 or value > 2147483647: raise ValueError("Value is not between 0 and 2147483647") @@ -259,7 +258,7 @@ def vacuum_autovacuum_analyze_threshold_values(cls, value: int) -> Optional[int] @validator("vacuum_autovacuum_freeze_max_age") @classmethod - def vacuum_autovacuum_freeze_max_age_values(cls, value: int) -> Optional[int]: + def vacuum_autovacuum_freeze_max_age_values(cls, value: int) -> int | None: """Check vacuum_autovacuum_freeze_max_age config option is between 100000 and 2000000000.""" if value < 100000 or value > 2000000000: raise ValueError("Value is not between 100000 and 2000000000") @@ -268,7 +267,7 @@ def vacuum_autovacuum_freeze_max_age_values(cls, value: int) -> Optional[int]: @validator("vacuum_autovacuum_vacuum_cost_delay") @classmethod - def vacuum_autovacuum_vacuum_cost_delay_values(cls, value: float) -> Optional[float]: + def vacuum_autovacuum_vacuum_cost_delay_values(cls, value: float) -> float | None: """Check vacuum_autovacuum_vacuum_cost_delay config option is between -1 and 100.""" if value < -1 or value > 100: raise ValueError("Value is not between -1 and 100") @@ -277,7 +276,7 @@ def vacuum_autovacuum_vacuum_cost_delay_values(cls, value: float) -> Optional[fl @validator("vacuum_vacuum_freeze_table_age") @classmethod - def vacuum_vacuum_freeze_table_age_values(cls, value: int) -> Optional[int]: + def vacuum_vacuum_freeze_table_age_values(cls, value: int) -> int | None: """Check vacuum_vacuum_freeze_table_age config option is between 0 and 2000000000.""" if value < 0 or value > 2000000000: raise ValueError("Value is not between 0 and 2000000000") diff --git a/src/patroni.py b/src/patroni.py index d219b6e593..148c77f865 100644 --- a/src/patroni.py +++ b/src/patroni.py @@ -7,7 +7,7 @@ import logging import os import pwd -from typing import Any, Dict, List, Optional +from typing import Any import requests import yaml @@ -60,7 +60,7 @@ def __init__( self, charm, endpoint: str, - endpoints: List[str], + endpoints: list[str], primary_endpoint: str, namespace: str, storage_path: str, @@ -97,7 +97,7 @@ def _patroni_url(self) -> str: return f"{'https' if self._tls_enabled else 'http'}://{self._endpoint}:8008" @property - def rock_postgresql_version(self) -> Optional[str]: + def rock_postgresql_version(self) -> str | None: """Version of Postgresql installed in the Rock image.""" container = self._charm.unit.get_container("postgresql") if not container.can_connect(): @@ -107,7 +107,7 @@ def rock_postgresql_version(self) -> Optional[str]: return yaml.safe_load(snap_meta)["version"] def _get_alternative_patroni_url( - self, attempt: AttemptManager, alternative_endpoints: Optional[List[str]] = None + self, attempt: AttemptManager, alternative_endpoints: list[str] | None = None ) -> str: """Get an alternative REST API URL from another member each time. @@ -127,7 +127,7 @@ def _get_alternative_patroni_url( return url def get_primary( - self, unit_name_pattern=False, alternative_endpoints: Optional[List[str]] = None + self, unit_name_pattern=False, alternative_endpoints: list[str] | None = None ) -> str: """Get primary instance. @@ -157,7 +157,7 @@ def get_primary( def get_standby_leader( self, unit_name_pattern=False, check_whether_is_running: bool = False - ) -> Optional[str]: + ) -> str | None: """Get standby leader instance. Args: @@ -190,7 +190,7 @@ def get_standby_leader( break return standby_leader - def get_sync_standby_names(self) -> List[str]: + def get_sync_standby_names(self) -> list[str]: """Get the list of sync standby unit names.""" sync_standbys = [] # Request info from cluster endpoint (which returns all members of the cluster). @@ -400,7 +400,7 @@ def is_database_running(self) -> bool: return any(process for process in postgresql_processes if process.split()[7] != "T") @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) - def bulk_update_parameters_controller_by_patroni(self, parameters: Dict[str, Any]) -> None: + def bulk_update_parameters_controller_by_patroni(self, parameters: dict[str, Any]) -> None: """Update the value of a parameter controller by Patroni. For more information, check https://patroni.readthedocs.io/en/latest/patroni_configuration.html#postgresql-parameters-controlled-by-patroni. @@ -473,14 +473,14 @@ def render_patroni_yml_file( is_creating_backup: bool = False, enable_tls: bool = False, is_no_sync_member: bool = False, - stanza: Optional[str] = None, - restore_stanza: Optional[str] = None, + stanza: str | None = None, + restore_stanza: str | None = None, disable_pgbackrest_archiving: bool = False, - backup_id: Optional[str] = None, - pitr_target: Optional[str] = None, - restore_timeline: Optional[str] = None, + backup_id: str | None = None, + pitr_target: str | None = None, + restore_timeline: str | None = None, restore_to_latest: bool = False, - parameters: Optional[dict[str, str]] = None, + parameters: dict[str, str] | None = None, ) -> None: """Render the Patroni configuration file. @@ -578,7 +578,7 @@ def restart_postgresql(self) -> None: timeout=PATRONI_TIMEOUT, ) - def switchover(self, candidate: Optional[str] = None) -> None: + def switchover(self, candidate: str | None = None) -> None: """Trigger a switchover.""" # Try to trigger the switchover. if candidate is not None: diff --git a/src/relations/async_replication.py b/src/relations/async_replication.py index 29b6d2ee25..838b1bba20 100644 --- a/src/relations/async_replication.py +++ b/src/relations/async_replication.py @@ -18,7 +18,6 @@ import json import logging from datetime import datetime -from typing import List, Optional, Tuple from lightkube import ApiError, Client from lightkube.resources.core_v1 import Endpoints, Service @@ -217,7 +216,7 @@ def _get_highest_promoted_cluster_counter_value(self) -> str: promoted_cluster_counter = relation_promoted_cluster_counter return promoted_cluster_counter - def get_primary_cluster(self) -> Optional[Application]: + def get_primary_cluster(self) -> Application | None: """Return the primary cluster.""" primary_cluster = None promoted_cluster_counter = "0" @@ -238,7 +237,7 @@ def get_primary_cluster(self) -> Optional[Application]: primary_cluster = app return primary_cluster - def get_primary_cluster_endpoint(self) -> Optional[str]: + def get_primary_cluster_endpoint(self) -> str | None: """Return the primary cluster endpoint.""" primary_cluster = self.get_primary_cluster() if primary_cluster is None or self.charm.app == primary_cluster: @@ -249,7 +248,7 @@ def get_primary_cluster_endpoint(self) -> Optional[str]: return None return json.loads(primary_cluster_data).get("endpoint") - def get_all_primary_cluster_endpoints(self) -> List[str]: + def get_all_primary_cluster_endpoints(self) -> list[str]: """Return all the primary cluster endpoints.""" relation = self._relation primary_cluster = self.get_primary_cluster() @@ -291,7 +290,7 @@ def _get_secret(self) -> Secret: return self.charm.model.app.add_secret(content=shared_content, label=SECRET_LABEL) - def get_standby_endpoints(self) -> List[str]: + def get_standby_endpoints(self) -> list[str]: """Return the standby endpoints.""" relation = self._relation primary_cluster = self.get_primary_cluster() @@ -309,7 +308,7 @@ def get_standby_endpoints(self) -> List[str]: if relation.data[unit].get("unit-address") is not None ] - def get_system_identifier(self) -> Tuple[Optional[str], Optional[str]]: + def get_system_identifier(self) -> tuple[str | None, str | None]: """Returns the PostgreSQL system identifier from this instance.""" try: system_identifier, error = self.container.exec( @@ -753,8 +752,8 @@ def _update_internal_secret(self) -> bool: def _update_primary_cluster_data( self, - promoted_cluster_counter: Optional[int] = None, - system_identifier: Optional[str] = None, + promoted_cluster_counter: int | None = None, + system_identifier: str | None = None, ) -> None: """Update the primary cluster data.""" async_relation = self._relation diff --git a/src/relations/db.py b/src/relations/db.py index 233b8d45fe..dee16696aa 100644 --- a/src/relations/db.py +++ b/src/relations/db.py @@ -4,7 +4,7 @@ """Postgres db and db-admin relation hooks & helpers.""" import logging -from typing import Iterable, List, Set, Tuple +from typing import Iterable from charms.postgresql_k8s.v0.postgresql import ( PostgreSQLCreateDatabaseError, @@ -119,7 +119,7 @@ def _check_multiple_endpoints(self) -> bool: return True return False - def _get_extensions(self, relation: Relation) -> Tuple[List, Set]: + def _get_extensions(self, relation: Relation) -> tuple[list, set]: """Returns the list of required and disabled extensions.""" requested_extensions = relation.data.get(relation.app, {}).get("extensions", "").split(",") for unit in relation.units: diff --git a/tests/integration/ha_tests/helpers.py b/tests/integration/ha_tests/helpers.py index 1b641b5dc2..a67bde151a 100644 --- a/tests/integration/ha_tests/helpers.py +++ b/tests/integration/ha_tests/helpers.py @@ -12,7 +12,6 @@ import zipfile from datetime import datetime from pathlib import Path -from typing import Dict, Optional, Set, Tuple, Union import kubernetes as kubernetes import psycopg2 @@ -101,7 +100,7 @@ async def are_all_db_processes_down(ops_test: OpsTest, process: str, signal: str return True -def get_patroni_cluster(unit_ip: str) -> Dict[str, str]: +def get_patroni_cluster(unit_ip: str) -> dict[str, str]: for attempt in Retrying(stop=stop_after_delay(30), wait=wait_fixed(3)): with attempt: resp = requests.get(f"http://{unit_ip}:8008/cluster") @@ -109,7 +108,7 @@ def get_patroni_cluster(unit_ip: str) -> Dict[str, str]: async def change_patroni_setting( - ops_test: OpsTest, setting: str, value: Union[int, str], password: str, tls: bool = False + ops_test: OpsTest, setting: str, value: str | int, password: str, tls: bool = False ) -> None: """Change the value of one of the Patroni settings. @@ -194,7 +193,7 @@ async def is_cluster_updated(ops_test: OpsTest, primary_name: str) -> None: ) -def get_member_lag(cluster: Dict, member_name: str) -> int: +def get_member_lag(cluster: dict, member_name: str) -> int: """Return the lag of a specific member.""" for member in cluster["members"]: if member["name"] == member_name.replace("/", "-"): @@ -315,7 +314,7 @@ def copy_file_into_pod( async def count_writes( ops_test: OpsTest, down_unit: str | None = None, extra_model: Model = None -) -> Tuple[Dict[str, int], Dict[str, int]]: +) -> tuple[dict[str, int], dict[str, int]]: """Count the number of writes in the database.""" app = await app_name(ops_test) password = await get_password(ops_test, database_app_name=app, down_unit=down_unit) @@ -435,7 +434,7 @@ def get_host_ip(host: str) -> str: return member_ips -async def get_patroni_setting(ops_test: OpsTest, setting: str, tls: bool = False) -> Optional[int]: +async def get_patroni_setting(ops_test: OpsTest, setting: str, tls: bool = False) -> int | None: """Get the value of one of the integer Patroni settings. Args: @@ -476,7 +475,7 @@ async def get_instances_roles(ops_test: OpsTest): return labels -async def get_postgresql_parameter(ops_test: OpsTest, parameter_name: str) -> Optional[int]: +async def get_postgresql_parameter(ops_test: OpsTest, parameter_name: str) -> int | None: """Get the value of a PostgreSQL parameter from Patroni API. Args: @@ -543,7 +542,7 @@ async def get_sync_standby(model: Model, application_name: str) -> str: async def inject_dependency_fault( - ops_test: OpsTest, application_name: str, charm_file: Union[str, Path] + ops_test: OpsTest, application_name: str, charm_file: str | Path ) -> None: """Inject a dependency fault into the PostgreSQL charm.""" # Query running dependency to overwrite with incompatible version. @@ -615,7 +614,7 @@ async def is_replica(ops_test: OpsTest, unit_name: str) -> bool: return False -async def list_wal_files(ops_test: OpsTest, app: str) -> Set: +async def list_wal_files(ops_test: OpsTest, app: str) -> set: """Returns the list of WAL segment files in each unit.""" units = [unit.name for unit in ops_test.model.applications[app].units] command = "ls -1 /var/lib/postgresql/data/pgdata/pg_wal/" diff --git a/tests/integration/ha_tests/test_async_replication.py b/tests/integration/ha_tests/test_async_replication.py index df1eaf3b52..df04ee61fb 100644 --- a/tests/integration/ha_tests/test_async_replication.py +++ b/tests/integration/ha_tests/test_async_replication.py @@ -5,7 +5,6 @@ import logging import subprocess from asyncio import gather -from typing import Optional import psycopg2 import pytest as pytest @@ -46,9 +45,7 @@ @contextlib.asynccontextmanager -async def fast_forward( - model: Model, fast_interval: str = "10s", slow_interval: Optional[str] = None -): +async def fast_forward(model: Model, fast_interval: str = "10s", slow_interval: str | None = None): """Adaptation of OpsTest.fast_forward to work with different models.""" update_interval_key = "update-status-hook-interval" interval_after = ( diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index 3d791fb880..9fb6a2160f 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -8,7 +8,6 @@ from multiprocessing import ProcessError from pathlib import Path from subprocess import check_call -from typing import List, Optional import botocore import psycopg2 @@ -53,7 +52,7 @@ async def app_name( ops_test: OpsTest, application_name: str = "postgresql-k8s", model: Model = None -) -> Optional[str]: +) -> str | None: """Returns the name of the cluster running PostgreSQL. This is important since not all deployments of the PostgreSQL charm have the application name @@ -119,8 +118,8 @@ async def build_and_deploy( async def check_database_users_existence( ops_test: OpsTest, - users_that_should_exist: List[str], - users_that_should_not_exist: List[str], + users_that_should_exist: list[str], + users_that_should_not_exist: list[str], admin: bool = False, database_app_name: str = DATABASE_APP_NAME, ) -> None: @@ -239,7 +238,7 @@ def construct_endpoint(endpoint: str, region: str) -> str: return endpoint -def convert_records_to_dict(records: List[tuple]) -> dict: +def convert_records_to_dict(records: list[tuple]) -> dict: """Converts psycopg2 records list to a dict.""" records_dict = {} for record in records: @@ -331,7 +330,7 @@ async def execute_query_on_unit( password: str, query: str, database: str = "postgres", - sslmode: Optional[str] = None, + sslmode: str | None = None, ): """Execute given PostgreSQL query on a unit. @@ -355,7 +354,7 @@ async def execute_query_on_unit( return output -def get_cluster_members(endpoint: str) -> List[str]: +def get_cluster_members(endpoint: str) -> list[str]: """List of current Patroni cluster members. Args: @@ -368,7 +367,7 @@ def get_cluster_members(endpoint: str) -> List[str]: return [member["name"] for member in r.json()["members"]] -def get_application_units(ops_test: OpsTest, application_name: str) -> List[str]: +def get_application_units(ops_test: OpsTest, application_name: str) -> list[str]: """List the unit names of an application. Args: @@ -434,7 +433,7 @@ def get_expected_k8s_resources(application: str) -> set: } -async def get_leader_unit(ops_test: OpsTest, app: str, model: Model = None) -> Optional[Unit]: +async def get_leader_unit(ops_test: OpsTest, app: str, model: Model = None) -> Unit | None: leader_unit = None if model is None: model = ops_test.model @@ -458,8 +457,8 @@ async def get_password( ops_test: OpsTest, username: str = "operator", database_app_name: str = DATABASE_APP_NAME, - down_unit: Optional[str] = None, - unit_name: Optional[str] = None, + down_unit: str | None = None, + unit_name: str | None = None, ): """Retrieve a user password using the action.""" for unit in ops_test.model.applications[database_app_name].units: @@ -476,7 +475,7 @@ async def get_password( wait=wait_exponential(multiplier=1, min=2, max=30), ) async def get_primary( - ops_test: OpsTest, database_app_name: str = DATABASE_APP_NAME, down_unit: Optional[str] = None + ops_test: OpsTest, database_app_name: str = DATABASE_APP_NAME, down_unit: str | None = None ) -> str: """Get the primary unit. @@ -512,7 +511,7 @@ async def get_unit_address(ops_test: OpsTest, unit_name: str) -> str: return status["applications"][unit_name.split("/")[0]].units[unit_name]["address"] -def get_unit_by_index(app: str, units: list, index: int) -> Optional[Unit]: +def get_unit_by_index(app: str, units: list, index: int) -> Unit | None: """Get unit by index. Args: @@ -725,7 +724,7 @@ async def scale_application( async def set_password( - ops_test: OpsTest, unit_name: str, username: str = "operator", password: Optional[str] = None + ops_test: OpsTest, unit_name: str, username: str = "operator", password: str | None = None ): """Set a user password using the action.""" unit = ops_test.model.units.get(unit_name) @@ -738,7 +737,7 @@ async def set_password( async def switchover( - ops_test: OpsTest, current_primary: str, password: str, candidate: Optional[str] = None + ops_test: OpsTest, current_primary: str, password: str, candidate: str | None = None ) -> None: """Trigger a switchover. diff --git a/tests/integration/new_relations/helpers.py b/tests/integration/new_relations/helpers.py index bae62263f9..fafe5e86dc 100644 --- a/tests/integration/new_relations/helpers.py +++ b/tests/integration/new_relations/helpers.py @@ -2,7 +2,6 @@ # Copyright 2022 Canonical Ltd. # See LICENSE file for licensing details. import json -from typing import Dict, Optional import yaml from lightkube import AsyncClient @@ -11,7 +10,7 @@ from tenacity import RetryError, Retrying, stop_after_attempt, wait_exponential -async def get_juju_secret(ops_test: OpsTest, secret_uri: str) -> Dict[str, str]: +async def get_juju_secret(ops_test: OpsTest, secret_uri: str) -> dict[str, str]: """Retrieve juju secret.""" secret_unique_id = secret_uri.split("/")[-1] complete_command = f"show-secret {secret_uri} --reveal --format=json" @@ -24,10 +23,10 @@ async def build_connection_string( application_name: str, relation_name: str, *, - relation_id: Optional[str] = None, - relation_alias: Optional[str] = None, + relation_id: str | None = None, + relation_alias: str | None = None, read_only_endpoint: bool = False, - database: Optional[str] = None, + database: str | None = None, ) -> str: """Build a PostgreSQL connection string. @@ -130,7 +129,7 @@ async def check_relation_data_existence( async def get_alias_from_relation_data( ops_test: OpsTest, unit_name: str, related_unit_name: str -) -> Optional[str]: +) -> str | None: """Get the alias that the unit assigned to the related unit application/cluster. Args: @@ -171,9 +170,9 @@ async def get_application_relation_data( application_name: str, relation_name: str, key: str, - relation_id: Optional[str] = None, - relation_alias: Optional[str] = None, -) -> Optional[str]: + relation_id: str | None = None, + relation_alias: str | None = None, +) -> str | None: """Get relation data for an application. Args: diff --git a/tests/integration/test_backups.py b/tests/integration/test_backups.py index 90748df022..bf2a5ea469 100644 --- a/tests/integration/test_backups.py +++ b/tests/integration/test_backups.py @@ -3,7 +3,6 @@ # See LICENSE file for licensing details. import logging import uuid -from typing import Dict, Tuple import boto3 import pytest as pytest @@ -97,7 +96,7 @@ async def cloud_configs(ops_test: OpsTest, github_secrets) -> None: @pytest.mark.group("AWS") @pytest.mark.abort_on_fail -async def test_backup_aws(ops_test: OpsTest, cloud_configs: Tuple[Dict, Dict]) -> None: +async def test_backup_aws(ops_test: OpsTest, cloud_configs: tuple[dict, dict]) -> None: """Build and deploy two units of PostgreSQL in AWS and then test the backup and restore actions.""" config = cloud_configs[0][AWS] credentials = cloud_configs[1][AWS] @@ -191,7 +190,7 @@ async def test_backup_aws(ops_test: OpsTest, cloud_configs: Tuple[Dict, Dict]) - @pytest.mark.group("GCP") @pytest.mark.abort_on_fail -async def test_backup_gcp(ops_test: OpsTest, cloud_configs: Tuple[Dict, Dict]) -> None: +async def test_backup_gcp(ops_test: OpsTest, cloud_configs: tuple[dict, dict]) -> None: """Build and deploy two units of PostgreSQL in GCP and then test the backup and restore actions.""" config = cloud_configs[0][GCP] credentials = cloud_configs[1][GCP] @@ -314,7 +313,7 @@ async def test_restore_on_new_cluster(ops_test: OpsTest, github_secrets) -> None @pytest.mark.group("GCP") async def test_invalid_config_and_recovery_after_fixing_it( - ops_test: OpsTest, cloud_configs: Tuple[Dict, Dict] + ops_test: OpsTest, cloud_configs: tuple[dict, dict] ) -> None: """Test that the charm can handle invalid and valid backup configurations.""" database_app_name = f"new-{DATABASE_APP_NAME}" diff --git a/tests/integration/test_backups_pitr.py b/tests/integration/test_backups_pitr.py index b58187e4df..2ea0d66d86 100644 --- a/tests/integration/test_backups_pitr.py +++ b/tests/integration/test_backups_pitr.py @@ -3,7 +3,6 @@ # See LICENSE file for licensing details. import logging import uuid -from typing import Dict, Tuple import boto3 import pytest as pytest @@ -382,7 +381,7 @@ async def pitr_backup_operations( @pytest.mark.group("AWS") @pytest.mark.abort_on_fail -async def test_pitr_backup_aws(ops_test: OpsTest, cloud_configs: Tuple[Dict, Dict]) -> None: +async def test_pitr_backup_aws(ops_test: OpsTest, cloud_configs: tuple[dict, dict]) -> None: """Build and deploy two units of PostgreSQL in AWS and then test PITR backup and restore actions.""" config = cloud_configs[0][AWS] credentials = cloud_configs[1][AWS] @@ -402,7 +401,7 @@ async def test_pitr_backup_aws(ops_test: OpsTest, cloud_configs: Tuple[Dict, Dic @pytest.mark.group("GCP") @pytest.mark.abort_on_fail -async def test_pitr_backup_gcp(ops_test: OpsTest, cloud_configs: Tuple[Dict, Dict]) -> None: +async def test_pitr_backup_gcp(ops_test: OpsTest, cloud_configs: tuple[dict, dict]) -> None: """Build and deploy two units of PostgreSQL in GCP and then test PITR backup and restore actions.""" config = cloud_configs[0][GCP] credentials = cloud_configs[1][GCP]