diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index 69ebecbb55..df693ca8f0 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -26,7 +26,7 @@ DEFAULT_GCP_NODE_GROUPS, node_groups_to_dict, ) -from _nebari.stages.kubernetes_ingress import LetsEncryptCertificate +from _nebari.stages.kubernetes_ingress import CertificateEnum from _nebari.stages.kubernetes_keycloak import AuthenticationEnum from _nebari.stages.terraform_state import TerraformStateEnum from _nebari.utils import get_latest_kubernetes_version, random_secure_string @@ -194,7 +194,8 @@ def render_config( config["theme"]["jupyterhub"]["hub_subtitle"] = WELCOME_HEADER_TEXT if ssl_cert_email: - config["certificate"] = LetsEncryptCertificate(acme_email=ssl_cert_email) + config["certificate"] = {"type": CertificateEnum.letsencrypt.value} + config["certificate"]["acme_email"] = ssl_cert_email # validate configuration and convert to model from nebari.plugins import nebari_plugin_manager diff --git a/src/_nebari/keycloak.py b/src/_nebari/keycloak.py index 0aee3dc8f2..ea8815940d 100644 --- a/src/_nebari/keycloak.py +++ b/src/_nebari/keycloak.py @@ -7,7 +7,7 @@ import requests import rich -from _nebari.stages.kubernetes_ingress import SelfSignedCertificate +from _nebari.stages.kubernetes_ingress import CertificateEnum from nebari import schema logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ def get_keycloak_admin_from_config(config: schema.Main): "KEYCLOAK_ADMIN_PASSWORD", config.security.keycloak.initial_root_password ) - should_verify_tls = not isinstance(config.certificate, SelfSignedCertificate) + should_verify_tls = config.certificate.type != CertificateEnum.selfsigned try: keycloak_admin = keycloak.KeycloakAdmin( diff --git a/src/_nebari/stages/kubernetes_ingress/__init__.py b/src/_nebari/stages/kubernetes_ingress/__init__.py index efe4502feb..628d383830 100644 --- a/src/_nebari/stages/kubernetes_ingress/__init__.py +++ b/src/_nebari/stages/kubernetes_ingress/__init__.py @@ -1,12 +1,9 @@ -from __future__ import annotations - +import enum import logging import socket import sys import time -from typing import Any, Dict, List, Literal, Optional, Type, Union - -from pydantic import Field +from typing import Any, Dict, List, Optional, Type from _nebari import constants from _nebari.provider.dns.cloudflare import update_record @@ -115,31 +112,25 @@ def _attempt_dns_lookup( sys.exit(1) -class SelfSignedCertificate(schema.Base): - type: Literal["self-signed"] = Field("self-signed", validate_default=True) - - -class LetsEncryptCertificate(schema.Base): - type: Literal["lets-encrypt"] = Field("lets-encrypt", validate_default=True) - acme_email: str - acme_server: str = "https://acme-v02.api.letsencrypt.org/directory" +@schema.yaml_object(schema.yaml) +class CertificateEnum(str, enum.Enum): + letsencrypt = "lets-encrypt" + selfsigned = "self-signed" + existing = "existing" + disabled = "disabled" + @classmethod + def to_yaml(cls, representer, node): + return representer.represent_str(node.value) -class ExistingCertificate(schema.Base): - type: Literal["existing"] = Field("existing", validate_default=True) - secret_name: str - -class DisabledCertificate(schema.Base): - type: Literal["disabled"] = Field("disabled", validate_default=True) - - -Certificate = Union[ - SelfSignedCertificate, - LetsEncryptCertificate, - ExistingCertificate, - DisabledCertificate, -] +class Certificate(schema.Base): + type: CertificateEnum = CertificateEnum.selfsigned + # existing + secret_name: Optional[str] = None + # lets-encrypt + acme_email: Optional[str] = None + acme_server: str = "https://acme-v02.api.letsencrypt.org/directory" class DnsProvider(schema.Base): @@ -153,7 +144,7 @@ class Ingress(schema.Base): class InputSchema(schema.Base): domain: Optional[str] = None - certificate: Certificate = SelfSignedCertificate() + certificate: Certificate = Certificate() ingress: Ingress = Ingress() dns: DnsProvider = DnsProvider() diff --git a/tests/tests_unit/test_cli.py b/tests/tests_unit/test_cli.py index 4a091f3bb3..d8a4e423b9 100644 --- a/tests/tests_unit/test_cli.py +++ b/tests/tests_unit/test_cli.py @@ -53,7 +53,7 @@ def test_nebari_init(tmp_path, namespace, auth_provider, ci_provider, ssl_cert_e assert config.namespace == namespace assert config.security.authentication.type.lower() == auth_provider assert config.ci_cd.type == ci_provider - assert getattr(config.certificate, "acme_email", None) == ssl_cert_email + assert config.certificate.acme_email == ssl_cert_email @pytest.mark.parametrize(