diff --git a/flytekit/extras/pod_templates/__init__.py b/flytekit/extras/pod_templates/__init__.py new file mode 100644 index 0000000000..846b4723bf --- /dev/null +++ b/flytekit/extras/pod_templates/__init__.py @@ -0,0 +1,3 @@ +from flytekit.extras.pod_templates.attach_shm import attach_shm + +__all__ = ["attach_shm"] diff --git a/flytekit/extras/pod_templates/attach_shm.py b/flytekit/extras/pod_templates/attach_shm.py new file mode 100644 index 0000000000..b8a961bd84 --- /dev/null +++ b/flytekit/extras/pod_templates/attach_shm.py @@ -0,0 +1,19 @@ +from flytekit.core.pod_template import PodTemplate + + +def attach_shm(name: str, size: str) -> PodTemplate: + from kubernetes.client.models import ( + V1Container, + V1EmptyDirVolumeSource, + V1PodSpec, + V1Volume, + V1VolumeMount, + ) + + return PodTemplate( + primary_container_name=name, + pod_spec=V1PodSpec( + containers=[V1Container(name=name, volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")])], + volumes=[V1Volume(name="dshm", empty_dir=V1EmptyDirVolumeSource(medium="", size_limit=size))], + ), + ) diff --git a/tests/flytekit/unit/extras/templates/__init__.py b/tests/flytekit/unit/extras/templates/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/templates/test_pod_templates.py b/tests/flytekit/unit/extras/templates/test_pod_templates.py new file mode 100644 index 0000000000..1d8c51ab75 --- /dev/null +++ b/tests/flytekit/unit/extras/templates/test_pod_templates.py @@ -0,0 +1,14 @@ +from flytekit.extras.pod_templates import attach_shm +from flytekit.core.task import task + +def test_attach_shm(): + + shm = attach_shm("SHM", "5Gi") + assert shm.name == "SHM" + assert shm.size == "5Gi" + + def my_task(): + pass + + # Verify pod template is attached to task + assert my_task.pod_template == shm