Skip to content

Commit

Permalink
feat(sdk): support volume mount in tune API
Browse files Browse the repository at this point in the history
Signed-off-by: truc0 <22969604+truc0@users.noreply.github.com>
  • Loading branch information
truc0 committed Feb 7, 2025
1 parent 4d2a230 commit a0fbd20
Showing 1 changed file with 125 additions and 2 deletions.
127 changes: 125 additions & 2 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import multiprocessing
import time
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
Expand All @@ -30,6 +30,14 @@

logger = logging.getLogger(__name__)

TuneStoragePerTrialType = TypedDict(
"TuneStoragePerTrial",
{
"volume": Union[client.V1Volume, Dict[str, Any]],
"mount_path": Union[str, client.V1VolumeMount],
},
)


class KatibClient(object):
def __init__(
Expand Down Expand Up @@ -186,6 +194,7 @@ def tune(
env_per_trial: Optional[
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
] = None,
storage_per_trial: Optional[List[TuneStoragePerTrialType]] = None,
algorithm_name: str = "random",
algorithm_settings: Union[
dict, List[models.V1beta1AlgorithmSetting], None
Expand Down Expand Up @@ -276,6 +285,21 @@ class name in this argument.
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
or a kubernetes.client.models.V1EnvFromSource (documented here:
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
storage_per_trial: List of storage configurations for each trial container.
Each element in the list should be a dictionary with two keys:
- volume: Either a kubernetes.client.V1Volume object or a dictionary
containing volume configuration with required fields:
- name: Name of the volume
- type: One of "pvc", "secret", "config_map", or "empty_dir"
Additional fields based on volume type:
- For pvc: claim_name, read_only (optional)
- For secret: secret_name, items (optional), default_mode (optional),
optional (optional)
- For config_map: config_map_name, items (optional), default_mode
(optional), optional (optional)
- For empty_dir: medium (optional), size_limit (optional)
- mount_path: Either a kubernetes.client.V1VolumeMount object or a string
specifying the path where the volume should be mounted in the container
algorithm_name: Search algorithm for the HyperParameter tuning.
algorithm_settings: Settings for the search algorithm given.
For available fields, check this doc:
Expand Down Expand Up @@ -468,6 +492,101 @@ class name in this argument.
f"Incorrect value for env_per_trial: {env_per_trial}"
)

volumes: List[client.V1Volume] = []
volume_mounts: List[client.V1VolumeMount] = []
if storage_per_trial:
if isinstance(storage_per_trial, dict):
storage_per_trial = [storage_per_trial]
for storage in storage_per_trial:
print(f"storage: {storage}")
volume = None
if isinstance(storage["volume"], client.V1Volume):
volume = storage["volume"]
elif isinstance(storage["volume"], dict):
volume_name = storage["volume"].get("name")
volume_type = storage["volume"].get("type")

if not volume_name:
raise ValueError(
"storage_per_trial['volume'] does not have a 'name' key"
)
if not volume_type:
raise ValueError(
"storage_per_trial['volume'] does not have a 'type' key"
)

if volume_type == "pvc":
volume_claim_name = storage["volume"].get("claim_name")
if not volume_claim_name:
raise ValueError(
"storage_per_trial['volume'] should have a 'claim_name' key for type pvc"
)
volume = client.V1Volume(
name=volume_name,
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
claim_name=volume_claim_name,
read_only=storage["volume"].get("read_only", False),
),
)
elif volume_type == "secret":
volume = client.V1Volume(
name=volume_name,
secret=client.V1SecretVolumeSource(
secret_name=storage["volume"].get("secret_name"),
items=storage["volume"].get("items", None),
default_mode=storage["volume"].get(
"default_mode", None
),
optional=storage["volume"].get("optional", False),
),
)
elif volume_type == "config_map":
volume = client.V1Volume(
name=volume_name,
config_map=client.V1ConfigMapVolumeSource(
name=storage["volume"].get("config_map_name"),
items=storage["volume"].get("items", []),
default_mode=storage["volume"].get(
"default_mode", None
),
optional=storage["volume"].get("optional", False),
),
)
elif volume_type == "empty_dir":
volume = client.V1Volume(
name=volume_name,
empty_dir=client.V1EmptyDirVolumeSource(
medium=storage["volume"].get("medium", None),
size_limit=storage["volume"].get(
"size_limit", None
),
),
)
else:
raise ValueError(
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
)

else:
raise ValueError(
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
)

volumes.append(volume)

if isinstance(storage["mount_path"], client.V1VolumeMount):
volume_mounts.append(storage["mount_path"])
elif isinstance(storage["mount_path"], str):
volume_mounts.append(
client.V1VolumeMount(
name=volume_name, mount_path=storage["mount_path"]
)
)
else:
raise ValueError(
"storage_per_trial['mount_path'] must be a client.V1VolumeMount or a str"
)

# Create Trial specification.
trial_spec = client.V1Job(
api_version="batch/v1",
Expand All @@ -488,8 +607,12 @@ class name in this argument.
env=env if env else None,
env_from=env_from if env_from else None,
resources=resources_per_trial,
volume_mounts=(
volume_mounts if volume_mounts else None
),
)
],
volumes=volumes if volumes else None,
),
)
),
Expand Down Expand Up @@ -576,7 +699,7 @@ class name in this argument.
f"It must also start and end with an alphanumeric character."
)
elif hasattr(e, "status") and e.status == 409:
print(f"PVC '{name}' already exists in namespace " f"{namespace}.")
print(f"PVC '{name}' already exists in namespace {namespace}.")
else:
raise RuntimeError(f"failed to create PVC. Error: {e}")

Expand Down

0 comments on commit a0fbd20

Please sign in to comment.