Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add taint to user and worker nodes #2605

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5000f06
save progress
Adam-D-Lewis Jun 26, 2024
7ce8555
Merge branch 'develop' into node-taint
Adam-D-Lewis Aug 16, 2024
a661514
fix node taint check
Adam-D-Lewis Aug 16, 2024
fb55fab
Merge branch 'develop' into node-taint
Adam-D-Lewis Aug 19, 2024
7f1800d
fix node taints on gcp
Adam-D-Lewis Aug 19, 2024
40940f6
add latest changes
Adam-D-Lewis Aug 19, 2024
cdac5c6
merge develop
Adam-D-Lewis Aug 21, 2024
6382c7b
allow daemonsets to run on user node group
Adam-D-Lewis Aug 21, 2024
e9d9dd9
recreate node groups when taints change
Adam-D-Lewis Aug 21, 2024
c55cd5f
quick attempt to get scheduler running on tanted worker node group
Adam-D-Lewis Aug 21, 2024
57e6e09
Merge branch 'main' into node-taint
Adam-D-Lewis Oct 25, 2024
a1370c9
add default options to options_handler
Adam-D-Lewis Oct 25, 2024
0e7e11c
add comments
Adam-D-Lewis Oct 28, 2024
adb9d74
rename variable
Adam-D-Lewis Oct 31, 2024
7944071
add comment
Adam-D-Lewis Oct 31, 2024
fa81fb9
make work for all providers
Adam-D-Lewis Oct 31, 2024
da9fd82
move var back
Adam-D-Lewis Oct 31, 2024
6a1f81d
move var back
Adam-D-Lewis Oct 31, 2024
b4c08f3
move var back
Adam-D-Lewis Oct 31, 2024
9bae2a1
move var back
Adam-D-Lewis Oct 31, 2024
b3dbeda
add reference
Adam-D-Lewis Oct 31, 2024
97858d0
refactor
Adam-D-Lewis Nov 1, 2024
4ac7b9c
various fixes for aws and azure providers
Adam-D-Lewis Nov 1, 2024
480647b
Merge branch 'main' into node-taint
Adam-D-Lewis Nov 1, 2024
f6b9a4f
add taint conversion for AWS
Adam-D-Lewis Nov 4, 2024
e752a3a
add DEFAULT_.*_TAINT vars
Adam-D-Lewis Nov 4, 2024
59daa0c
clean up fixed TODOs
Adam-D-Lewis Nov 4, 2024
e05f143
more clean up
Adam-D-Lewis Nov 4, 2024
3a4ae6b
Merge branch 'main' into node-taint
Adam-D-Lewis Nov 4, 2024
f3cb2e9
fix test
Adam-D-Lewis Nov 4, 2024
b125e8c
fix test error
Adam-D-Lewis Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 62 additions & 31 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,33 @@ class ExistingInputVars(schema.Base):
kube_context: str


class DigitalOceanNodeGroup(schema.Base):
Copy link
Member Author

@Adam-D-Lewis Adam-D-Lewis Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate class, so I deleted it

class NodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
taints: Optional[List[schema.Taint]] = []

@field_validator("taints", mode="before")
def validate_taint_strings(cls, value: List[str | schema.Taint]):
TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)")
parsed_taints = []
for taint in value:
if not isinstance(taint, (str, schema.Taint)):
raise ValueError(
f"Unable to parse type: {type(taint)} as taint. Must be a string or Taint object."
)

if isinstance(taint, schema.Taint):
parsed_taint = taint
elif isinstance(taint, str):
match = TAINT_STR_REGEX.match(taint)
if not match:
raise ValueError(f"Invalid taint string: {taint}")
key, value, effect = match.groups()
parsed_taint = schema.Taint(key=key, value=value, effect=effect)
parsed_taints.append(parsed_taint)

return parsed_taints


class DigitalOceanInputVars(schema.Base):
Expand All @@ -55,7 +78,7 @@ class DigitalOceanInputVars(schema.Base):
region: str
tags: List[str]
kubernetes_version: str
node_groups: Dict[str, DigitalOceanNodeGroup]
node_groups: Dict[str, "DigitalOceanNodeGroup"]
kubeconfig_filename: str = get_kubeconfig_filename()


Expand All @@ -64,10 +87,26 @@ class GCPNodeGroupInputVars(schema.Base):
instance_type: str
min_size: int
max_size: int
node_taints: List[dict]
labels: Dict[str, str]
preemptible: bool
guest_accelerators: List["GCPGuestAccelerator"]

@field_validator("node_taints", mode="before")
def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [
dict(
key=taint.key,
value=taint.value,
effect={
schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE",
schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE",
schema.TaintEffectEnum.NoExecute: "NO_EXECUTE",
}[taint.effect],
)
for taint in value
]


class GCPPrivateClusterConfig(schema.Base):
enable_private_nodes: bool
Expand Down Expand Up @@ -261,16 +300,14 @@ class KeyValueDict(schema.Base):
value: str


class DigitalOceanNodeGroup(schema.Base):
class DigitalOceanNodeGroup(NodeGroup):
"""Representation of a node group with Digital Ocean

- Kubernetes limits: https://docs.digitalocean.com/products/kubernetes/details/limits/
- Available instance types: https://slugs.do-api.dev/
"""

instance: str
min_nodes: Annotated[int, Field(ge=1)] = 1
max_nodes: Annotated[int, Field(ge=1)] = 1


DEFAULT_DO_NODE_GROUPS = {
Expand Down Expand Up @@ -349,19 +386,26 @@ class GCPGuestAccelerator(schema.Base):
count: Annotated[int, Field(ge=1)] = 1


class GCPNodeGroup(schema.Base):
instance: str
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
class GCPNodeGroup(NodeGroup):
preemptible: bool = False
labels: Dict[str, str] = {}
guest_accelerators: List[GCPGuestAccelerator] = []


DEFAULT_GCP_NODE_GROUPS = {
"general": GCPNodeGroup(instance="e2-standard-8", min_nodes=1, max_nodes=1),
"user": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"worker": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"user": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="user", effect="NoSchedule")],
),
"worker": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="worker", effect="NoSchedule")],
),
Adam-D-Lewis marked this conversation as resolved.
Show resolved Hide resolved
}


Expand Down Expand Up @@ -398,10 +442,8 @@ def _check_input(cls, data: Any) -> Any:
return data


class AzureNodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
class AzureNodeGroup(NodeGroup):
pass


DEFAULT_AZURE_NODE_GROUPS = {
Expand Down Expand Up @@ -469,10 +511,7 @@ def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]:
return value if value is None else azure_cloud.validate_tags(value)


class AWSNodeGroup(schema.Base):
instance: str
min_nodes: int = 0
max_nodes: int
class AWSNodeGroup(NodeGroup):
gpu: bool = False
single_subnet: bool = False
permissions_boundary: Optional[str] = None
Expand Down Expand Up @@ -584,17 +623,8 @@ class ExistingProvider(schema.Base):
schema.ProviderEnum.do: DigitalOceanProvider,
}

provider_enum_name_map: Dict[schema.ProviderEnum, str] = {
schema.ProviderEnum.local: "local",
schema.ProviderEnum.existing: "existing",
schema.ProviderEnum.gcp: "google_cloud_platform",
schema.ProviderEnum.aws: "amazon_web_services",
schema.ProviderEnum.azure: "azure",
schema.ProviderEnum.do: "digital_ocean",
}

provider_name_abbreviation_map: Dict[str, str] = {
value: key.value for key, value in provider_enum_name_map.items()
value: key.value for key, value in schema.provider_enum_name_map.items()
}

provider_enum_default_node_groups_map: Dict[schema.ProviderEnum, Any] = {
Expand Down Expand Up @@ -786,6 +816,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
instance_type=node_group.instance,
min_size=node_group.min_nodes,
max_size=node_group.max_nodes,
node_taints=node_group.taints,
preemptible=node_group.preemptible,
guest_accelerators=node_group.guest_accelerators,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ resource "aws_eks_node_group" "main" {
max_size = var.node_groups[count.index].max_size
}

# TODO: add node_taints (var.node_groups.node_taints) to the node group, check the node taints below are working
# https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/eks_node_group#node_taints
# dynamic "taint" {
# for_each = var.node_groups[count.index].node_taints
# content {
# key = taint.value.key
# value = taint.value.value
# effect = taint.value.effect
# }
# }

# Only set launch_template if its node_group counterpart parameter is not null
dynamic "launch_template" {
for_each = var.node_groups[count.index].launch_template != null ? [0] : []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ variable "node_groups" {
single_subnet = bool
launch_template = map(any)
ami_type = string
node_taints = list(any)
}))
}

Expand Down
1 change: 1 addition & 0 deletions src/_nebari/stages/infrastructure/template/azure/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ module "kubernetes" {
instance_type = config.instance
min_size = config.min_nodes
max_size = config.max_nodes
node_taints = config.node_taints
}
]
vnet_subnet_id = var.vnet_subnet_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ resource "azurerm_kubernetes_cluster" "main" {
min_count = var.node_groups[0].min_size
max_count = var.node_groups[0].max_size
max_pods = var.max_pods
# TODO: I don't think it's possible to add node_taints to the default node pool so we should throw an error somewhere if people try to do this
# see https://github.com/hashicorp/terraform-provider-azurerm/issues/9183 for more info

orchestrator_version = var.kubernetes_version
node_labels = {
Expand Down Expand Up @@ -81,4 +83,5 @@ resource "azurerm_kubernetes_cluster_node_pool" "node_group" {
orchestrator_version = var.kubernetes_version
tags = var.tags
vnet_subnet_id = var.vnet_subnet_id
node_taints = each.value.node_taints # TODO: check this is working
}
7 changes: 4 additions & 3 deletions src/_nebari/stages/infrastructure/template/azure/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ variable "kubernetes_version" {
variable "node_groups" {
description = "Azure node groups"
type = map(object({
instance = string
min_nodes = number
max_nodes = number
instance = string
min_nodes = number
max_nodes = number
node_taints = list(any)
}))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ resource "google_container_node_pool" "main" {

oauth_scopes = local.node_group_oauth_scopes

dynamic "taint" {
for_each = local.merged_node_groups[count.index].node_taints
content {
key = taint.value.key
value = taint.value.value
effect = taint.value.effect
}
}

metadata = {
disable-legacy-endpoints = "true"
}
Expand All @@ -108,9 +117,4 @@ resource "google_container_node_pool" "main" {
tags = var.tags
}

lifecycle {
ignore_changes = [
node_config[0].taint
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ variable "node_groups" {
min_size = 1
max_size = 1
labels = {}
node_taints = []
},
{
name = "user"
instance_type = "n1-standard-2"
min_size = 0
max_size = 2
labels = {}
node_taints = [] # TODO: Do this for other cloud providers
},
{
name = "worker"
instance_type = "n1-standard-2"
min_size = 0
max_size = 5
labels = {}
node_taints = []
}
]
}
Expand Down
51 changes: 51 additions & 0 deletions src/_nebari/stages/kubernetes_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,28 @@ def handle_units(cls, value: Optional[str]) -> float:
return byte_unit_conversion(value, "GiB")


class TolerationOperatorEnum(str, enum.Enum):
Equal = "Equal"
Exists = "Exists"

@classmethod
def to_yaml(cls, representer, node):
return representer.represent_str(node.value)


class Toleration(schema.Taint):
operator: TolerationOperatorEnum = TolerationOperatorEnum.Equal

@classmethod
def from_taint(
cls, taint: schema.Taint, operator: None | TolerationOperatorEnum = None
):
kwargs = {}
if operator:
kwargs["operator"] = operator
cls(**taint.model_dump(), **kwargs)


class JupyterhubInputVars(schema.Base):
jupyterhub_theme: Dict[str, Any] = Field(alias="jupyterhub-theme")
jupyterlab_image: ImageNameTag = Field(alias="jupyterlab-image")
Expand All @@ -478,6 +500,9 @@ class JupyterhubInputVars(schema.Base):
cloud_provider: str = Field(alias="cloud-provider")
jupyterlab_preferred_dir: Optional[str] = Field(alias="jupyterlab-preferred-dir")
shared_fs_type: SharedFsEnum
user_taint_tolerations: Optional[List[Toleration]] = Field(
alias="node-taint-tolerations"
)

@field_validator("jupyterhub_shared_storage", mode="before")
@classmethod
Expand All @@ -490,6 +515,9 @@ class DaskGatewayInputVars(schema.Base):
dask_gateway_profiles: Dict[str, Any] = Field(alias="dask-gateway-profiles")
cloud_provider: str = Field(alias="cloud-provider")
forwardauth_middleware_name: str = _forwardauth_middleware_name
worker_taint_tolerations: Optional[list[Toleration]] = Field(
alias="worker-taint-tolerations"
)


class MonitoringInputVars(schema.Base):
Expand Down Expand Up @@ -592,6 +620,27 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
):
jupyterhub_theme.update({"version": f"v{self.config.nebari_version}"})

def _node_taint_tolerations(node_group_name: str) -> List[Toleration]:
tolerations = []
provider = getattr(
self.config, schema.provider_enum_name_map[self.config.provider]
)
if not (
hasattr(provider, "node_groups")
and provider.node_groups.get(node_group_name, {})
and hasattr(provider.node_groups[node_group_name], "taints")
):
return tolerations
tolerations = [
Toleration.from_taint(taint)
for taint in getattr(
self.config, schema.provider_enum_name_map[self.config.provider]
)
.node_groups[node_group_name]
.taints
]
return tolerations

kubernetes_services_vars = KubernetesServicesInputVars(
name=self.config.project_name,
environment=self.config.namespace,
Expand Down Expand Up @@ -646,6 +695,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
jupyterlab_default_settings=self.config.jupyterlab.default_settings,
jupyterlab_gallery_settings=self.config.jupyterlab.gallery_settings,
jupyterlab_preferred_dir=self.config.jupyterlab.preferred_dir,
user_taint_tolerations=_node_taint_tolerations(node_group_name="user"),
shared_fs_type=(
# efs is equivalent to nfs in these modules
SharedFsEnum.nfs
Expand All @@ -660,6 +710,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
),
dask_gateway_profiles=self.config.profiles.model_dump()["dask_worker"],
cloud_provider=cloud_provider,
worker_taint_tolerations=_node_taint_tolerations(node_group_name="worker"),
)

monitoring_vars = MonitoringInputVars(
Expand Down
Loading
Loading