diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 559f17bd53..5361b0db5b 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -43,10 +43,35 @@ class ExistingInputVars(schema.Base): kube_context: str -class DigitalOceanNodeGroup(schema.Base): +class NodeGroup(schema.Base): instance: str - min_nodes: int - max_nodes: int + min_nodes: Annotated[int, Field(ge=0)] = 0 + max_nodes: Annotated[int, Field(ge=1)] = 1 + taints: Optional[List[schema.Taint]] = [] + + @field_validator("taints", mode="before") + def validate_taint_strings(cls, value: list[Any]): + TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)") + return_value = [] + for taint in value: + if not isinstance(taint, str): + return_value.append(taint) + else: + match = TAINT_STR_REGEX.match(taint) + if not match: + raise ValueError(f"Invalid taint string: {taint}") + key, value, effect = match.groups() + parsed_taint = schema.Taint(key=key, value=value, effect=effect) + return_value.append(parsed_taint) + + return return_value + + +DEFAULT_GENERAL_TAINTS = [] +DEFAULT_USER_TAINTS = [schema.Taint(key="dedicated", value="user", effect="NoSchedule")] +DEFAULT_WORKER_TAINTS = [ + schema.Taint(key="dedicated", value="worker", effect="NoSchedule") +] class DigitalOceanInputVars(schema.Base): @@ -55,7 +80,7 @@ class DigitalOceanInputVars(schema.Base): region: str tags: List[str] kubernetes_version: str - node_groups: Dict[str, DigitalOceanNodeGroup] + node_groups: Dict[str, "DigitalOceanNodeGroup"] kubeconfig_filename: str = get_kubeconfig_filename() @@ -64,10 +89,26 @@ class GCPNodeGroupInputVars(schema.Base): instance_type: str min_size: int max_size: int + node_taints: List[dict] labels: Dict[str, str] preemptible: bool guest_accelerators: List["GCPGuestAccelerator"] + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [ + dict( + key=taint.key, + value=taint.value, + effect={ + schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE", + schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE", + schema.TaintEffectEnum.NoExecute: "NO_EXECUTE", + }[taint.effect], + ) + for taint in value + ] + class GCPPrivateClusterConfig(schema.Base): enable_private_nodes: bool @@ -109,6 +150,11 @@ class AzureNodeGroupInputVars(schema.Base): instance: str min_nodes: int max_nodes: int + node_taints: list[str] + + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [f"{taint.key}={taint.value}:{taint.effect.value}" for taint in value] class AzureInputVars(schema.Base): @@ -150,6 +196,22 @@ class AWSNodeGroupInputVars(schema.Base): permissions_boundary: Optional[str] = None ami_type: Optional[AWSAmiTypes] = None launch_template: Optional[AWSNodeLaunchTemplate] = None + node_taints: list[dict] + + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [ + dict( + key=taint.key, + value=taint.value, + effect={ + schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE", + schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE", + schema.TaintEffectEnum.NoExecute: "NO_EXECUTE", + }[taint.effect], + ) + for taint in value + ] def construct_aws_ami_type(gpu_enabled: bool, launch_template: AWSNodeLaunchTemplate): @@ -163,6 +225,21 @@ def construct_aws_ami_type(gpu_enabled: bool, launch_template: AWSNodeLaunchTemp return "AL2_x86_64" + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [ + dict( + key=taint.key, + value=taint.value, + effect={ + schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE", + schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE", + schema.TaintEffectEnum.NoExecute: "NO_EXECUTE", + }[taint.effect], + ) + for taint in value + ] + class AWSInputVars(schema.Base): name: str @@ -253,16 +330,14 @@ class KeyValueDict(schema.Base): value: str -class DigitalOceanNodeGroup(schema.Base): +class DigitalOceanNodeGroup(NodeGroup): """Representation of a node group with Digital Ocean - Kubernetes limits: https://docs.digitalocean.com/products/kubernetes/details/limits/ - Available instance types: https://slugs.do-api.dev/ """ - instance: str min_nodes: Annotated[int, Field(ge=1)] = 1 - max_nodes: Annotated[int, Field(ge=1)] = 1 DEFAULT_DO_NODE_GROUPS = { @@ -341,19 +416,31 @@ class GCPGuestAccelerator(schema.Base): count: Annotated[int, Field(ge=1)] = 1 -class GCPNodeGroup(schema.Base): - instance: str - min_nodes: Annotated[int, Field(ge=0)] = 0 - max_nodes: Annotated[int, Field(ge=1)] = 1 +class GCPNodeGroup(NodeGroup): preemptible: bool = False labels: Dict[str, str] = {} guest_accelerators: List[GCPGuestAccelerator] = [] DEFAULT_GCP_NODE_GROUPS = { - "general": GCPNodeGroup(instance="e2-standard-8", min_nodes=1, max_nodes=1), - "user": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5), - "worker": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5), + "general": GCPNodeGroup( + instance="e2-standard-8", + min_nodes=1, + max_nodes=1, + taints=DEFAULT_GENERAL_TAINTS, + ), + "user": GCPNodeGroup( + instance="e2-standard-4", + min_nodes=0, + max_nodes=5, + taints=DEFAULT_USER_TAINTS, + ), + "worker": GCPNodeGroup( + instance="e2-standard-4", + min_nodes=0, + max_nodes=5, + taints=DEFAULT_WORKER_TAINTS, + ), } @@ -390,16 +477,26 @@ def _check_input(cls, data: Any) -> Any: return data -class AzureNodeGroup(schema.Base): - instance: str - min_nodes: int - max_nodes: int +class AzureNodeGroup(NodeGroup): + pass DEFAULT_AZURE_NODE_GROUPS = { - "general": AzureNodeGroup(instance="Standard_D8_v3", min_nodes=1, max_nodes=1), - "user": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), - "worker": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), + "general": AzureNodeGroup( + instance="Standard_D8_v3", + min_nodes=1, + max_nodes=1, + taints=DEFAULT_GENERAL_TAINTS, + ), + "user": AzureNodeGroup( + instance="Standard_D4_v3", min_nodes=0, max_nodes=5, taints=DEFAULT_USER_TAINTS + ), + "worker": AzureNodeGroup( + instance="Standard_D4_v3", + min_nodes=0, + max_nodes=5, + taints=DEFAULT_WORKER_TAINTS, + ), } @@ -461,10 +558,7 @@ def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]: return value if value is None else azure_cloud.validate_tags(value) -class AWSNodeGroup(schema.Base): - instance: str - min_nodes: int = 0 - max_nodes: int +class AWSNodeGroup(NodeGroup): gpu: bool = False single_subnet: bool = False permissions_boundary: Optional[str] = None @@ -472,12 +566,22 @@ class AWSNodeGroup(schema.Base): DEFAULT_AWS_NODE_GROUPS = { - "general": AWSNodeGroup(instance="m5.2xlarge", min_nodes=1, max_nodes=1), + "general": AWSNodeGroup( + instance="m5.2xlarge", min_nodes=1, max_nodes=1, taints=DEFAULT_GENERAL_TAINTS + ), "user": AWSNodeGroup( - instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False + instance="m5.xlarge", + min_nodes=0, + max_nodes=5, + single_subnet=False, + taints=DEFAULT_USER_TAINTS, ), "worker": AWSNodeGroup( - instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False + instance="m5.xlarge", + min_nodes=0, + max_nodes=5, + single_subnet=False, + taints=DEFAULT_WORKER_TAINTS, ), } @@ -576,17 +680,8 @@ class ExistingProvider(schema.Base): schema.ProviderEnum.do: DigitalOceanProvider, } -provider_enum_name_map: Dict[schema.ProviderEnum, str] = { - schema.ProviderEnum.local: "local", - schema.ProviderEnum.existing: "existing", - schema.ProviderEnum.gcp: "google_cloud_platform", - schema.ProviderEnum.aws: "amazon_web_services", - schema.ProviderEnum.azure: "azure", - schema.ProviderEnum.do: "digital_ocean", -} - provider_name_abbreviation_map: Dict[str, str] = { - value: key.value for key, value in provider_enum_name_map.items() + value: key.value for key, value in schema.provider_enum_name_map.items() } provider_enum_default_node_groups_map: Dict[schema.ProviderEnum, Any] = { @@ -778,6 +873,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): instance_type=node_group.instance, min_size=node_group.min_nodes, max_size=node_group.max_nodes, + node_taints=node_group.taints, preemptible=node_group.preemptible, guest_accelerators=node_group.guest_accelerators, ) @@ -809,6 +905,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): instance=node_group.instance, min_nodes=node_group.min_nodes, max_nodes=node_group.max_nodes, + node_taints=node_group.taints, ) for name, node_group in self.config.azure.node_groups.items() }, @@ -850,6 +947,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): single_subnet=node_group.single_subnet, permissions_boundary=node_group.permissions_boundary, launch_template=node_group.launch_template, + node_taints=node_group.taints, ami_type=construct_aws_ami_type( gpu_enabled=node_group.gpu, launch_template=node_group.launch_template, diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf index 5b66201f83..b217cfecdb 100644 --- a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf @@ -86,6 +86,15 @@ resource "aws_eks_node_group" "main" { max_size = var.node_groups[count.index].max_size } + dynamic "taint" { + for_each = var.node_groups[count.index].node_taints + content { + key = taint.value.key + value = taint.value.value + effect = taint.value.effect + } + } + # Only set launch_template if its node_group counterpart parameter is not null dynamic "launch_template" { for_each = var.node_groups[count.index].launch_template != null ? [0] : [] diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf index 4d38d10a19..703aaba52c 100644 --- a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf @@ -53,6 +53,11 @@ variable "node_groups" { single_subnet = bool launch_template = map(any) ami_type = string + node_taints = list(object({ + key = string + value = string + effect = string + })) })) } diff --git a/src/_nebari/stages/infrastructure/template/aws/variables.tf b/src/_nebari/stages/infrastructure/template/aws/variables.tf index a3f37b9eb9..2621686d4b 100644 --- a/src/_nebari/stages/infrastructure/template/aws/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/variables.tf @@ -40,6 +40,11 @@ variable "node_groups" { single_subnet = bool launch_template = map(any) ami_type = string + node_taints = list(object({ + key = string + value = string + effect = string + })) })) } diff --git a/src/_nebari/stages/infrastructure/template/azure/main.tf b/src/_nebari/stages/infrastructure/template/azure/main.tf index 2d6e2e2afa..0ddff5f583 100644 --- a/src/_nebari/stages/infrastructure/template/azure/main.tf +++ b/src/_nebari/stages/infrastructure/template/azure/main.tf @@ -38,6 +38,7 @@ module "kubernetes" { instance_type = config.instance min_size = config.min_nodes max_size = config.max_nodes + node_taints = config.node_taints } ] vnet_subnet_id = var.vnet_subnet_id diff --git a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf index f093f048c6..a054147759 100644 --- a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf +++ b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf @@ -36,6 +36,7 @@ resource "azurerm_kubernetes_cluster" "main" { min_count = var.node_groups[0].min_size max_count = var.node_groups[0].max_size max_pods = var.max_pods + # TODO: It's not possible to add node_taints to the default node pool. See https://github.com/hashicorp/terraform-provider-azurerm/issues/9183 for more info orchestrator_version = var.kubernetes_version node_labels = { @@ -81,4 +82,5 @@ resource "azurerm_kubernetes_cluster_node_pool" "node_group" { orchestrator_version = var.kubernetes_version tags = var.tags vnet_subnet_id = var.vnet_subnet_id + node_taints = each.value.node_taints } diff --git a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf index b93a9fae2d..01851b842c 100644 --- a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf @@ -29,10 +29,16 @@ variable "environment" { type = string } - variable "node_groups" { description = "Node pools to add to Azure Kubernetes Cluster" - type = list(map(any)) + type = list(object({ + name = string + auto_scale = bool + instance_type = string + min_size = number + max_size = number + node_taints = list(string) + })) } variable "vnet_subnet_id" { diff --git a/src/_nebari/stages/infrastructure/template/azure/variables.tf b/src/_nebari/stages/infrastructure/template/azure/variables.tf index dcef2c97cb..5eeb32c02a 100644 --- a/src/_nebari/stages/infrastructure/template/azure/variables.tf +++ b/src/_nebari/stages/infrastructure/template/azure/variables.tf @@ -21,9 +21,10 @@ variable "kubernetes_version" { variable "node_groups" { description = "Azure node groups" type = map(object({ - instance = string - min_nodes = number - max_nodes = number + instance = string + min_nodes = number + max_nodes = number + node_taints = list(string) })) } diff --git a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf index 57e8d9fc88..182168fada 100644 --- a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf +++ b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf @@ -93,6 +93,15 @@ resource "google_container_node_pool" "main" { oauth_scopes = local.node_group_oauth_scopes + dynamic "taint" { + for_each = local.merged_node_groups[count.index].node_taints + content { + key = taint.value.key + value = taint.value.value + effect = taint.value.effect + } + } + metadata = { disable-legacy-endpoints = "true" } @@ -108,9 +117,4 @@ resource "google_container_node_pool" "main" { tags = var.tags } - lifecycle { - ignore_changes = [ - node_config[0].taint - ] - } } diff --git a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf index 2ee2d78ed5..236a0b9017 100644 --- a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf @@ -50,6 +50,7 @@ variable "node_groups" { min_size = 1 max_size = 1 labels = {} + node_taints = [] }, { name = "user" @@ -57,6 +58,7 @@ variable "node_groups" { min_size = 0 max_size = 2 labels = {} + node_taints = [] }, { name = "worker" @@ -64,6 +66,7 @@ variable "node_groups" { min_size = 0 max_size = 5 labels = {} + node_taints = [] } ] } diff --git a/src/_nebari/stages/kubernetes_services/__init__.py b/src/_nebari/stages/kubernetes_services/__init__.py index fdc413bd40..41025b7737 100644 --- a/src/_nebari/stages/kubernetes_services/__init__.py +++ b/src/_nebari/stages/kubernetes_services/__init__.py @@ -453,6 +453,28 @@ def handle_units(cls, value: Optional[str]) -> float: return byte_unit_conversion(value, "GiB") +class TolerationOperatorEnum(str, enum.Enum): + Equal = "Equal" + Exists = "Exists" + + @classmethod + def to_yaml(cls, representer, node): + return representer.represent_str(node.value) + + +class Toleration(schema.Taint): + operator: TolerationOperatorEnum = TolerationOperatorEnum.Equal + + @classmethod + def from_taint( + cls, taint: schema.Taint, operator: None | TolerationOperatorEnum = None + ): + kwargs = {} + if operator: + kwargs["operator"] = operator + return cls(**taint.model_dump(), **kwargs) + + class JupyterhubInputVars(schema.Base): jupyterhub_theme: Dict[str, Any] = Field(alias="jupyterhub-theme") jupyterlab_image: ImageNameTag = Field(alias="jupyterlab-image") @@ -478,6 +500,9 @@ class JupyterhubInputVars(schema.Base): cloud_provider: str = Field(alias="cloud-provider") jupyterlab_preferred_dir: Optional[str] = Field(alias="jupyterlab-preferred-dir") shared_fs_type: SharedFsEnum + user_taint_tolerations: Optional[List[Toleration]] = Field( + alias="node-taint-tolerations" + ) @field_validator("jupyterhub_shared_storage", mode="before") @classmethod @@ -490,6 +515,9 @@ class DaskGatewayInputVars(schema.Base): dask_gateway_profiles: Dict[str, Any] = Field(alias="dask-gateway-profiles") cloud_provider: str = Field(alias="cloud-provider") forwardauth_middleware_name: str = _forwardauth_middleware_name + worker_taint_tolerations: Optional[list[Toleration]] = Field( + alias="worker-taint-tolerations" + ) class MonitoringInputVars(schema.Base): @@ -592,6 +620,27 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ): jupyterhub_theme.update({"version": f"v{self.config.nebari_version}"}) + def _node_taint_tolerations(node_group_name: str) -> List[Toleration]: + tolerations = [] + provider = getattr( + self.config, schema.provider_enum_name_map[self.config.provider] + ) + if not ( + hasattr(provider, "node_groups") + and provider.node_groups.get(node_group_name, {}) + and hasattr(provider.node_groups[node_group_name], "taints") + ): + return tolerations + tolerations = [ + Toleration.from_taint(taint) + for taint in getattr( + self.config, schema.provider_enum_name_map[self.config.provider] + ) + .node_groups[node_group_name] + .taints + ] + return tolerations + kubernetes_services_vars = KubernetesServicesInputVars( name=self.config.project_name, environment=self.config.namespace, @@ -646,6 +695,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): jupyterlab_default_settings=self.config.jupyterlab.default_settings, jupyterlab_gallery_settings=self.config.jupyterlab.gallery_settings, jupyterlab_preferred_dir=self.config.jupyterlab.preferred_dir, + user_taint_tolerations=_node_taint_tolerations(node_group_name="user"), shared_fs_type=( # efs is equivalent to nfs in these modules SharedFsEnum.nfs @@ -660,6 +710,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ), dask_gateway_profiles=self.config.profiles.model_dump()["dask_worker"], cloud_provider=cloud_provider, + worker_taint_tolerations=_node_taint_tolerations(node_group_name="worker"), ) monitoring_vars = MonitoringInputVars( diff --git a/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf b/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf index a47acee8fa..997a4ab294 100644 --- a/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf +++ b/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf @@ -11,6 +11,16 @@ variable "dask-gateway-profiles" { description = "Dask Gateway profiles to expose to user" } +variable "worker-taint-tolerations" { + description = "Tolerations for the worker node taints needed by Dask Scheduler/Worker pods" + type = list(object({ + key = string + operator = string + value = string + effect = string + })) +} + # =================== RESOURCES ===================== module "dask-gateway" { source = "./modules/kubernetes/services/dask-gateway" @@ -43,6 +53,15 @@ module "dask-gateway" { forwardauth_middleware_name = var.forwardauth_middleware_name + cluster = { + scheduler_extra_pod_config = { + tolerations = var.worker-taint-tolerations + } + worker_extra_pod_config = { + tolerations = var.worker-taint-tolerations + } + } + depends_on = [ module.kubernetes-nfs-server, module.rook-ceph diff --git a/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf b/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf index 121cff4b22..8759f13f43 100644 --- a/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf +++ b/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf @@ -85,6 +85,16 @@ variable "idle-culler-settings" { type = any } +variable "node-taint-tolerations" { + description = "Node taint toleration" + type = list(object({ + key = string + operator = string + value = string + effect = string + })) +} + variable "shared_fs_type" { type = string description = "Use NFS or Ceph" @@ -180,6 +190,7 @@ module "jupyterhub" { conda-store-service-name = module.kubernetes-conda-store-server.service_name conda-store-jhub-apps-token = module.kubernetes-conda-store-server.service-tokens.jhub-apps jhub-apps-enabled = var.jhub-apps-enabled + node-taint-tolerations = var.node-taint-tolerations jhub-apps-overrides = var.jhub-apps-overrides extra-mounts = { diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py index c58e3aa90d..427b8734a7 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py @@ -15,7 +15,6 @@ def dask_gateway_config(path="/var/lib/dask-gateway/config.json"): config = dask_gateway_config() - c.DaskGateway.log_level = config["gateway"]["loglevel"] # Configure addresses @@ -26,6 +25,8 @@ def dask_gateway_config(path="/var/lib/dask-gateway/config.json"): c.KubeBackend.gateway_instance = config["gateway_service_name"] # ========= Dask Cluster Default Configuration ========= +# These settings are overridden by c.Backend.cluster_option if key e.g. image, scheduler_extra_pod_config, etc. is present + c.KubeClusterConfig.image = ( f"{config['cluster-image']['name']}:{config['cluster-image']['tag']}" ) @@ -40,6 +41,7 @@ def dask_gateway_config(path="/var/lib/dask-gateway/config.json"): c.KubeClusterConfig.scheduler_extra_container_config = config["cluster"][ "scheduler_extra_container_config" ] + c.KubeClusterConfig.scheduler_extra_pod_config = config["cluster"][ "scheduler_extra_pod_config" ] @@ -227,18 +229,24 @@ def base_username_mount(username, uid=1000, gid=100): } -def worker_profile(options, user): - namespace, name = options.conda_environment.split("/") +def options_handler(options, user): + namespace, environment_name = options.conda_environment.split("/") return functools.reduce( deep_merge, [ + # ordering is higher to lower precedence + {}, base_node_group(options), - base_conda_store_mounts(namespace, name), + base_conda_store_mounts(namespace, environment_name), base_username_mount(user.name), config["profiles"][options.profile], {"environment": {**options.environment_vars}}, + # merge with default values + { + k: config["cluster"][k] + for k in ("worker_extra_pod_config", "scheduler_extra_pod_config") + }, ], - {}, ) @@ -279,7 +287,7 @@ def user_options(user): return Options( *args, - handler=worker_profile, + handler=options_handler, ) @@ -288,7 +296,7 @@ def user_options(user): # ============== utils ============ def deep_merge(d1, d2): - """Deep merge two dictionaries. + """Deep merge two dictionaries. Left argument takes precedence. >>> value_1 = { 'a': [1, 2], 'b': {'c': 1, 'z': [5, 6]}, diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf index 121405a322..0b3fbcab35 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf @@ -130,23 +130,23 @@ variable "cluster" { description = "dask gateway cluster defaults" type = object({ # scheduler configuration - scheduler_cores = number - scheduler_cores_limit = number - scheduler_memory = string - scheduler_memory_limit = string - scheduler_extra_container_config = any - scheduler_extra_pod_config = any + scheduler_cores = optional(number, 1) + scheduler_cores_limit = optional(number, 1) + scheduler_memory = optional(string, "2 G") + scheduler_memory_limit = optional(string, "2 G") + scheduler_extra_container_config = optional(any, {}) + scheduler_extra_pod_config = optional(any, {}) # worker configuration - worker_cores = number - worker_cores_limit = number - worker_memory = string - worker_memory_limit = string - worker_extra_container_config = any - worker_extra_pod_config = any + worker_cores = optional(number, 1) + worker_cores_limit = optional(number, 1) + worker_memory = optional(string, "2 G") + worker_memory_limit = optional(string, "2 G") + worker_extra_container_config = optional(any, {}) + worker_extra_pod_config = optional(any, {}) # additional fields - idle_timeout = number - image_pull_policy = string - environment = map(string) + idle_timeout = optional(number, 1800) # 30 minutes + image_pull_policy = optional(string, "IfNotPresent") + environment = optional(map(string), {}) }) default = { # scheduler configuration diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py index b298ae5ae1..83d8444ac1 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py @@ -243,6 +243,25 @@ def base_profile_extra_mounts(): } +def node_taint_tolerations(): + tolerations = z2jh.get_config("custom.node-taint-tolerations") + + if not tolerations: + return {} + + return { + "tolerations": [ + { + "key": taint["key"], + "operator": taint["operator"], + "value": taint["value"], + "effect": taint["effect"], + } + for taint in tolerations + ] + } + + def configure_user_provisioned_repositories(username): # Define paths and configurations pvc_home_mount_path = f"home/{username}" @@ -523,6 +542,7 @@ def render_profile( configure_user(username, groups), configure_user_provisioned_repositories(username), profile_kubespawner_override, + node_taint_tolerations(), ], {}, ) diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf index a36090f41c..a3694cf883 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf @@ -79,6 +79,7 @@ resource "helm_release" "jupyterhub" { jhub-apps-enabled = var.jhub-apps-enabled jhub-apps-overrides = var.jhub-apps-overrides initial-repositories = var.initial-repositories + node-taint-tolerations = var.node-taint-tolerations skel-mount = { name = kubernetes_config_map.etc-skel.metadata.0.name namespace = kubernetes_config_map.etc-skel.metadata.0.namespace diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf index f395e08487..cc0c935872 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf @@ -219,3 +219,13 @@ variable "initial-repositories" { type = string default = "[]" } + +variable "node-taint-tolerations" { + description = "Node taint toleration" + type = list(object({ + key = string + operator = string + value = string + effect = string + })) +} diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf index 8180d46fb8..3868de9cbf 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf @@ -96,6 +96,22 @@ resource "helm_release" "grafana-promtail" { values = concat([ file("${path.module}/values_promtail.yaml"), jsonencode({ + tolerations = [ + { + key = "node-role.kubernetes.io/master" + operator = "Exists" + effect = "NoSchedule" + }, + { + key = "node-role.kubernetes.io/control-plane" + operator = "Exists" + effect = "NoSchedule" + }, + { + operator = "Exists" + effect = "NoSchedule" + }, + ] }) ], var.grafana-promtail-overrides) diff --git a/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf b/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf index c40b6fae33..96cf6131e4 100644 --- a/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf +++ b/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf @@ -41,6 +41,13 @@ resource "helm_release" "rook-ceph" { }, csi = { enableRbdDriver = false, # necessary to provision block storage, but saves some cpu and memory if not needed + provisionerReplicas : 1, # default is 2 on different nodes + pluginTolerations = [ + { + operator = "Exists" + effect = "NoSchedule" + } + ], }, }) ], diff --git a/src/_nebari/upgrade.py b/src/_nebari/upgrade.py index fb95fc391f..19c8654965 100644 --- a/src/_nebari/upgrade.py +++ b/src/_nebari/upgrade.py @@ -25,10 +25,6 @@ from _nebari.config import backup_configuration from _nebari.keycloak import get_keycloak_admin -from _nebari.stages.infrastructure import ( - provider_enum_default_node_groups_map, - provider_enum_name_map, -) from _nebari.utils import ( get_k8s_version_prefix, get_provider_config_block_name, @@ -36,7 +32,7 @@ yaml, ) from _nebari.version import __version__, rounded_ver_parse -from nebari.schema import ProviderEnum, is_version_accepted +from nebari.schema import ProviderEnum, is_version_accepted, provider_enum_name_map logger = logging.getLogger(__name__) @@ -954,7 +950,7 @@ def _version_specific_upgrade( provider_full_name, {} ): try: - default_node_groups = provider_enum_default_node_groups_map[ + default_node_groups = schema.provider_enum_default_node_groups_map[ provider ] continue_ = Prompt.ask( diff --git a/src/nebari/schema.py b/src/nebari/schema.py index 6a809842d7..2f844ba6f9 100644 --- a/src/nebari/schema.py +++ b/src/nebari/schema.py @@ -105,3 +105,30 @@ def is_version_accepted(v): for deployment with the current Nebari package. """ return Main.is_version_accepted(v) + + +@yaml_object(yaml) +class TaintEffectEnum(str, enum.Enum): + NoSchedule: str = "NoSchedule" + PreferNoSchedule: str = "PreferNoSchedule" + NoExecute: str = "NoExecute" + + @classmethod + def to_yaml(cls, representer, node): + return representer.represent_str(node.value) + + +class Taint(Base): + key: str + value: str + effect: TaintEffectEnum + + +provider_enum_name_map: dict[ProviderEnum, str] = { + ProviderEnum.local: "local", + ProviderEnum.existing: "existing", + ProviderEnum.gcp: "google_cloud_platform", + ProviderEnum.aws: "amazon_web_services", + ProviderEnum.azure: "azure", + ProviderEnum.do: "digital_ocean", +}