diff --git a/api/app/settings/common.py b/api/app/settings/common.py index 60a43f09a5df..cb84d5f2627b 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -28,6 +28,7 @@ from task_processor.task_run_method import TaskRunMethod # type: ignore[import-untyped] from app.routers import ReplicaReadStrategy +from app.utils import get_numbered_env_vars_with_prefix env = Env() @@ -177,11 +178,15 @@ ), } REPLICA_DATABASE_URLS_DELIMITER = env("REPLICA_DATABASE_URLS_DELIMITER", ",") - REPLICA_DATABASE_URLS = env.list( - "REPLICA_DATABASE_URLS", - subcast=str, - default=[], - delimiter=REPLICA_DATABASE_URLS_DELIMITER, + REPLICA_DATABASE_URLS = ( + env.list( + "REPLICA_DATABASE_URLS", + subcast=str, + default=[], + delimiter=REPLICA_DATABASE_URLS_DELIMITER, + ) + if not os.getenv("REPLICA_DATABASE_URL_0") + else get_numbered_env_vars_with_prefix("REPLICA_DATABASE_URL_") ) NUM_DB_REPLICAS = len(REPLICA_DATABASE_URLS) @@ -190,11 +195,15 @@ CROSS_REGION_REPLICA_DATABASE_URLS_DELIMITER = env( "CROSS_REGION_REPLICA_DATABASE_URLS_DELIMITER", "," ) - CROSS_REGION_REPLICA_DATABASE_URLS: list[str] = env.list( - "CROSS_REGION_REPLICA_DATABASE_URLS", - subcast=str, - default=[], - delimiter=CROSS_REGION_REPLICA_DATABASE_URLS_DELIMITER, + CROSS_REGION_REPLICA_DATABASE_URLS = ( + env.list( + "CROSS_REGION_REPLICA_DATABASE_URLS", + subcast=str, + default=[], + delimiter=CROSS_REGION_REPLICA_DATABASE_URLS_DELIMITER, + ) + if not os.getenv("CROSS_REGION_REPLICA_DATABASE_URL_0") + else get_numbered_env_vars_with_prefix("CROSS_REGION_REPLICA_DATABASE_URL_") ) NUM_CROSS_REGION_DB_REPLICAS = len(CROSS_REGION_REPLICA_DATABASE_URLS) diff --git a/api/app/utils.py b/api/app/utils.py index 7be80731b8a5..044adea5a6f2 100644 --- a/api/app/utils.py +++ b/api/app/utils.py @@ -1,6 +1,24 @@ +import os + import shortuuid def create_hash() -> str: """Helper function to create a short hash""" return shortuuid.uuid() + + +def get_numbered_env_vars_with_prefix(prefix: str) -> list[str]: + """ + Returns a list containing the values of all environment variables whose names have a given prefix followed by an + integer, starting from 0, until no more variables with that prefix are found. + """ + db_urls = [] + i = 0 + while True: + db_url = os.getenv(f"{prefix}{i}") + if not db_url: + break + db_urls.append(db_url) + i += 1 + return db_urls diff --git a/api/tests/unit/app/test_unit_app_utils.py b/api/tests/unit/app/test_unit_app_utils.py new file mode 100644 index 000000000000..4fb613a918c7 --- /dev/null +++ b/api/tests/unit/app/test_unit_app_utils.py @@ -0,0 +1,16 @@ +import pytest + +from app.utils import get_numbered_env_vars_with_prefix + + +def test_get_numbered_env_vars_with_prefix(monkeypatch: pytest.MonkeyPatch) -> None: + # Given + monkeypatch.setenv("DB_URL_0", "0") + monkeypatch.setenv("DB_URL_1", "1") + monkeypatch.setenv("DB_URL_3", "3") + + # When + env_vars = get_numbered_env_vars_with_prefix("DB_URL_") + + # Then + assert env_vars == ["0", "1"]