Skip to content

Commit 7a07866

Browse files
committed
fix: patching of session custom resources
1 parent 9cffdf9 commit 7a07866

File tree

4 files changed

+134
-27
lines changed

4 files changed

+134
-27
lines changed

components/renku_data_services/base_models/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from dataclasses import dataclass, field
88
from datetime import datetime
99
from enum import Enum, StrEnum
10-
from typing import ClassVar, Never, NewType, Optional, Protocol, Self, TypeVar, overload
10+
from typing import Annotated, ClassVar, Never, NewType, Optional, Protocol, Self, TypeVar, overload
1111

12+
from pydantic import PlainSerializer
1213
from sanic import Request
1314

1415
from renku_data_services.errors import errors
@@ -400,9 +401,11 @@ async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser:
400401
...
401402

402403

403-
ResetType = NewType("ResetType", object)
404+
__ResetType = NewType("__ResetType", object)
405+
ResetType = Annotated[__ResetType, PlainSerializer(lambda _: None, return_type=None)]
404406
"""This type represents that a value that may be None should be reset back to None or null.
405407
This type should have only one instance, defined in the same file as this type.
408+
The value will be serialized by pydantic as None.
406409
"""
407410

408411
RESET: ResetType = ResetType(object())

components/renku_data_services/notebooks/core_sessions.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import random
77
import string
8-
from collections.abc import AsyncIterator, Sequence
8+
from collections.abc import AsyncIterator, Mapping, Sequence
99
from datetime import timedelta
1010
from pathlib import PurePosixPath
1111
from typing import Protocol, TypeVar, cast
@@ -20,7 +20,7 @@
2020

2121
import renku_data_services.notebooks.image_check as ic
2222
from renku_data_services.app_config import logging
23-
from renku_data_services.base_models import AnonymousAPIUser, APIUser, AuthenticatedAPIUser
23+
from renku_data_services.base_models import RESET, AnonymousAPIUser, APIUser, AuthenticatedAPIUser, ResetType
2424
from renku_data_services.base_models.metrics import MetricsService
2525
from renku_data_services.connected_services.db import ConnectedServicesRepository
2626
from renku_data_services.crc.db import ClusterRepository, ResourcePoolRepository
@@ -54,6 +54,7 @@
5454
Authentication,
5555
AuthenticationType,
5656
Culling,
57+
CullingPatch,
5758
DataSource,
5859
ExtraContainer,
5960
ExtraVolume,
@@ -69,6 +70,7 @@
6970
Requests,
7071
RequestsStr,
7172
Resources,
73+
ResourcesPatch,
7274
SecretAsVolume,
7375
SecretAsVolumeItem,
7476
Session,
@@ -91,6 +93,7 @@
9193
)
9294
from renku_data_services.notebooks.utils import (
9395
node_affinity_from_resource_class,
96+
node_affinity_patch_from_resource_class,
9497
tolerations_from_resource_class,
9598
)
9699
from renku_data_services.project.db import ProjectRepository, ProjectSessionSecretRepository
@@ -462,6 +465,21 @@ async def request_session_secret_creation(
462465
)
463466

464467

468+
def resources_patch_from_resource_class(resource_class: ResourceClass) -> ResourcesPatch:
469+
"""Convert the resource class to a k8s resources spec."""
470+
gpu_name = GpuKind.NVIDIA.value + "/gpu"
471+
resources = resources_from_resource_class(resource_class)
472+
requests: Mapping[str, Requests | RequestsStr | ResetType] | ResetType
473+
limits: Mapping[str, Limits | LimitsStr | ResetType] | ResetType
474+
defaul_requests = {"memory": RESET, "cpu": RESET, gpu_name: RESET}
475+
default_limits = {"memory": RESET, "cpu": RESET, gpu_name: RESET}
476+
if resources.requests:
477+
requests = RESET if len(resources.requests.keys()) == 0 else {**defaul_requests, **resources.requests}
478+
if resources.limits:
479+
limits = RESET if len(resources.limits.keys()) == 0 else {**default_limits, **resources.limits}
480+
return ResourcesPatch(requests=requests, limits=limits)
481+
482+
465483
def resources_from_resource_class(resource_class: ResourceClass) -> Resources:
466484
"""Convert the resource class to a k8s resources spec."""
467485
requests: dict[str, Requests | RequestsStr] = {
@@ -528,6 +546,31 @@ def get_culling(
528546
)
529547

530548

549+
def get_culling_patch(
550+
user: AuthenticatedAPIUser | AnonymousAPIUser, resource_pool: ResourcePool, nb_config: NotebooksConfig
551+
) -> CullingPatch:
552+
"""Get the patch for the culling durations of a session."""
553+
culling = get_culling(user, resource_pool, nb_config)
554+
patch = CullingPatch(
555+
maxAge=RESET,
556+
maxFailedDuration=RESET,
557+
maxHibernatedDuration=RESET,
558+
maxIdleDuration=RESET,
559+
maxStartingDuration=RESET,
560+
)
561+
if culling.maxAge:
562+
patch.maxAge = culling.maxAge
563+
if culling.maxFailedDuration:
564+
patch.maxFailedDuration = culling.maxFailedDuration
565+
if culling.maxHibernatedDuration:
566+
patch.maxHibernatedDuration = culling.maxHibernatedDuration
567+
if culling.maxIdleDuration:
568+
patch.maxIdleDuration = culling.maxIdleDuration
569+
if culling.maxStartingDuration:
570+
patch.maxStartingDuration = culling.maxStartingDuration
571+
return patch
572+
573+
531574
async def __requires_image_pull_secret(nb_config: NotebooksConfig, image: str, internal_gitlab_user: APIUser) -> bool:
532575
"""Determines if an image requires a pull secret based on its visibility and their GitLab access token."""
533576

@@ -1030,29 +1073,39 @@ async def patch_session(
10301073
)
10311074
)
10321075
rp = await rp_repo.get_resource_pool_from_class(user, body.resource_class_id)
1076+
try:
1077+
old_rp = await rp_repo.get_resource_pool_from_class(user, session.resource_class_id())
1078+
except (errors.MissingResourceError, errors.UnauthorizedError, errors.ForbiddenError):
1079+
old_rp = None
10331080
rc = rp.get_resource_class(body.resource_class_id)
10341081
if not rc:
10351082
raise errors.MissingResourceError(
10361083
message=f"The resource class you requested with ID {body.resource_class_id} does not exist"
10371084
)
1038-
# TODO: reject session classes which change the cluster
1085+
if old_rp is not None and rp.cluster != old_rp.cluster:
1086+
raise errors.ValidationError(message="Changing resource pools with different clusters is not allowed.")
10391087
if not patch.metadata:
10401088
patch.metadata = AmaltheaSessionV1Alpha1MetadataPatch()
1041-
# Patch the resource class ID in the annotations
1089+
# Patch the resource pool and class ID in the annotations
1090+
patch.metadata.annotations = {"renku.io/resource_pool_id": str(rp.id)}
10421091
patch.metadata.annotations = {"renku.io/resource_class_id": str(body.resource_class_id)}
10431092
if not patch.spec.session:
10441093
patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch()
1045-
patch.spec.session.resources = resources_from_resource_class(rc)
1094+
patch.spec.session.resources = resources_patch_from_resource_class(rc)
10461095
# Tolerations
10471096
tolerations = tolerations_from_resource_class(rc, nb_config.sessions.tolerations_model)
10481097
patch.spec.tolerations = tolerations
10491098
# Affinities
1050-
patch.spec.affinity = node_affinity_from_resource_class(rc, nb_config.sessions.affinity_model)
1099+
patch.spec.affinity = node_affinity_patch_from_resource_class(rc, nb_config.sessions.affinity_model)
10511100
# Priority class (if a quota is being used)
1052-
patch.spec.priorityClassName = rc.quota
1053-
patch.spec.culling = get_culling(user, rp, nb_config)
1101+
if rc.quota is None:
1102+
patch.spec.priorityClassName = RESET
1103+
patch.spec.culling = get_culling_patch(user, rp, nb_config)
1104+
# Service account name
10541105
if rp.cluster is not None:
1055-
patch.spec.service_account_name = rp.cluster.service_account_name
1106+
patch.spec.service_account_name = (
1107+
rp.cluster.service_account_name if rp.cluster.service_account_name is not None else RESET
1108+
)
10561109

10571110
# If the session is being hibernated we do not need to patch anything else that is
10581111
# not specifically called for in the request body, we can refresh things when the user resumes.
@@ -1126,6 +1179,8 @@ async def patch_session(
11261179
if image_pull_secret:
11271180
session_extras.concat(SessionExtraResources(secrets=[image_pull_secret]))
11281181
patch.spec.imagePullSecrets = [ImagePullSecret(name=image_pull_secret.name, adopt=image_pull_secret.adopt)]
1182+
else:
1183+
patch.spec.imagePullSecrets = RESET
11291184

11301185
# Construct session patch
11311186
patch.spec.extraContainers = _make_patch_spec_list(

components/renku_data_services/notebooks/crs.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic.types import HashableItemType
1515
from ulid import ULID
1616

17+
from renku_data_services.base_models.core import ResetType
1718
from renku_data_services.errors import errors
1819
from renku_data_services.notebooks import apispec
1920
from renku_data_services.notebooks.constants import AMALTHEA_SESSION_GVK, JUPYTER_SESSION_GVK
@@ -33,6 +34,8 @@
3334
MatchExpression,
3435
NodeAffinity,
3536
NodeSelectorTerm,
37+
PodAffinity,
38+
PodAntiAffinity,
3639
Preference,
3740
PreferredDuringSchedulingIgnoredDuringExecutionItem,
3841
ReconcileStrategy,
@@ -375,36 +378,61 @@ def base_url(self) -> str | None:
375378
return url
376379

377380

381+
class ResourcesPatch(BaseCRD):
382+
"""Resource requests and limits patch."""
383+
384+
limits: Mapping[str, LimitsStr | Limits | ResetType] | ResetType | None = None
385+
requests: Mapping[str, RequestsStr | Requests | ResetType] | ResetType | None = None
386+
387+
378388
class AmaltheaSessionV1Alpha1SpecSessionPatch(BaseCRD):
379389
"""Patch for the main session config."""
380390

381-
resources: Resources | None = None
382-
shmSize: int | str | None = None
383-
storage: Storage | None = None
391+
resources: ResourcesPatch | ResetType | None = None
392+
shmSize: int | str | ResetType | None = None
393+
storage: Storage | ResetType | None = None
384394
imagePullPolicy: ImagePullPolicy | None = None
385-
extraVolumeMounts: list[ExtraVolumeMount] | None = None
395+
extraVolumeMounts: list[ExtraVolumeMount] | ResetType | None = None
386396

387397

388398
class AmaltheaSessionV1Alpha1MetadataPatch(BaseCRD):
389399
"""Patch for the metadata of an amalthea session."""
390400

391-
annotations: dict[str, str] | None = None
401+
annotations: dict[str, str | ResetType] | ResetType | None = None
402+
403+
404+
class AffinityPatch(BaseCRD):
405+
"""Patch for the affinity of a session."""
406+
407+
nodeAffinity: NodeAffinity | ResetType | None = None
408+
podAffinity: PodAffinity | ResetType | None = None
409+
podAntiAffinity: PodAntiAffinity | ResetType | None = None
410+
411+
412+
class CullingPatch(Culling):
413+
"""Patch for the culling durations of a session."""
414+
415+
maxAge: timedelta | ResetType | None = None # type:ignore[assignment]
416+
maxFailedDuration: timedelta | ResetType | None = None # type:ignore[assignment]
417+
maxHibernatedDuration: timedelta | ResetType | None = None # type:ignore[assignment]
418+
maxIdleDuration: timedelta | ResetType | None = None # type:ignore[assignment]
419+
maxStartingDuration: timedelta | ResetType | None = None # type:ignore[assignment]
392420

393421

394422
class AmaltheaSessionV1Alpha1SpecPatch(BaseCRD):
395423
"""Patch for the spec of an amalthea session."""
396424

397-
extraContainers: list[ExtraContainer] | None = None
398-
extraVolumes: list[ExtraVolume] | None = None
425+
extraContainers: list[ExtraContainer] | ResetType | None = None
426+
extraVolumes: list[ExtraVolume] | ResetType | None = None
399427
hibernated: bool | None = None
400-
initContainers: list[InitContainer] | None = None
401-
imagePullSecrets: list[ImagePullSecret] | None = None
402-
priorityClassName: str | None = None
403-
tolerations: list[Toleration] | None = None
404-
affinity: Affinity | None = None
428+
initContainers: list[InitContainer] | ResetType | None = None
429+
imagePullSecrets: list[ImagePullSecret] | ResetType | None = None
430+
priorityClassName: str | ResetType | None = None
431+
tolerations: list[Toleration] | ResetType | None = None
432+
affinity: AffinityPatch | ResetType | None = None
405433
session: AmaltheaSessionV1Alpha1SpecSessionPatch | None = None
406-
culling: Culling | None = None
407-
service_account_name: str | None = None
434+
culling: CullingPatch | ResetType | None = None
435+
service_account_name: str | ResetType | None = None
408436

409437

410438
class AmaltheaSessionV1Alpha1Patch(BaseCRD):
@@ -414,8 +442,12 @@ class AmaltheaSessionV1Alpha1Patch(BaseCRD):
414442
spec: AmaltheaSessionV1Alpha1SpecPatch
415443

416444
def to_rfc7386(self) -> dict[str, Any]:
417-
"""Generate the patch to be applied to the session."""
418-
return self.model_dump(exclude_none=True)
445+
"""Generate the patch to be applied to the session.
446+
447+
Note that when the value for a key in the patch is anything other than a
448+
dictionary then the rfc7386 patch will replace, not merge.
449+
"""
450+
return self.model_dump(exclude_none=True, mode="json")
419451

420452

421453
def safe_parse_duration(val: Any) -> timedelta:

components/renku_data_services/notebooks/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Utilities for notebooks."""
22

33
import renku_data_services.crc.models as crc_models
4+
from renku_data_services.base_models.core import RESET
45
from renku_data_services.notebooks.crs import (
56
Affinity,
7+
AffinityPatch,
68
MatchExpression,
79
NodeAffinity,
810
NodeSelectorTerm,
@@ -128,6 +130,21 @@ def node_affinity_from_resource_class(
128130
return affinity
129131

130132

133+
def node_affinity_patch_from_resource_class(
134+
resource_class: crc_models.ResourceClass, default_affinity: Affinity
135+
) -> AffinityPatch:
136+
"""Create a patch for the session affinity."""
137+
affinity = node_affinity_from_resource_class(resource_class, default_affinity)
138+
patch = AffinityPatch(nodeAffinity=RESET, podAffinity=RESET, podAntiAffinity=RESET)
139+
if affinity.nodeAffinity:
140+
patch.nodeAffinity = affinity.nodeAffinity
141+
if affinity.podAffinity:
142+
patch.podAffinity = affinity.podAffinity
143+
if affinity.podAntiAffinity:
144+
patch.podAntiAffinity = affinity.podAntiAffinity
145+
return patch
146+
147+
131148
def tolerations_from_resource_class(
132149
resource_class: crc_models.ResourceClass, default_tolerations: list[Toleration]
133150
) -> list[Toleration]:

0 commit comments

Comments
 (0)