Skip to content

Commit

Permalink
Implement script annotations for input artifacts
Browse files Browse the repository at this point in the history
Signed-off-by: Mikolaj Deja <mdeja2@bloomberg.net>
  • Loading branch information
Mikolaj Deja committed Aug 3, 2023
1 parent e0a565d commit ae8bd5d
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 67 deletions.
2 changes: 2 additions & 0 deletions src/hera/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from hera.workflows.archive import ArchiveStrategy, NoneArchiveStrategy, TarArchiveStrategy, ZipArchiveStrategy
from hera.workflows.artifact import (
Artifact,
ArtifactLoader,
ArtifactoryArtifact,
AzureArtifact,
GCSArtifact,
Expand Down Expand Up @@ -82,6 +83,7 @@
"AccessMode",
"ArchiveStrategy",
"Artifact",
"ArtifactLoader",
"ArtifactoryArtifact",
"AzureArtifact",
"AzureDiskVolumeVolume",
Expand Down
119 changes: 119 additions & 0 deletions src/hera/workflows/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
See https://argoproj.github.io/argo-workflows/walk-through/artifacts/ for a tutorial on Artifacts.
"""
from enum import Enum
from typing import List, Optional, Union, cast

from hera.shared._base_model import BaseModel
Expand All @@ -24,6 +25,13 @@
)


class ArtifactLoader(Enum):
"""Enum for artifact loader options."""

json = "json"
file = "file"


class Artifact(BaseModel):
"""Base artifact representation."""

Expand Down Expand Up @@ -63,6 +71,9 @@ class Artifact(BaseModel):
sub_path: Optional[str] = None
"""allows the specification of an artifact from a subpath within the main source."""

loader: Optional[ArtifactLoader] = None
"""used in Artifact annotations for determining how to load the data"""

def _build_archive(self) -> Optional[_ModelArchiveStrategy]:
if self.archive is None:
return None
Expand Down Expand Up @@ -97,6 +108,18 @@ def as_name(self, name: str) -> _ModelArtifact:
artifact.name = name
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return [
"mode",
"name",
"optional",
"path",
"recurseMode",
"subPath",
]


class ArtifactoryArtifact(_ModelArtifactoryArtifact, Artifact):
"""An artifact sourced from Artifactory."""
Expand All @@ -108,6 +131,11 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + ["url", "password_secret", "username_secret"]


class AzureArtifact(_ModelAzureArtifact, Artifact):
"""An artifact sourced from Microsoft Azure."""
Expand All @@ -123,6 +151,17 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + [
"endpoint",
"container",
"blob",
"account_key_secret",
"use_sdk_creds",
]


class GCSArtifact(_ModelGCSArtifact, Artifact):
"""An artifact sourced from Google Cloud Storage."""
Expand All @@ -136,6 +175,11 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + ["bucket", "key", "service_account_key_secret"]


class GitArtifact(_ModelGitArtifact, Artifact):
"""An artifact sourced from GitHub."""
Expand All @@ -157,6 +201,23 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + [
"branch",
"depth",
"disable_submodules",
"fetch",
"insecure_ignore_host_key",
"password_secret",
"repo",
"revision",
"single_branch",
"ssh_private_key_secret",
"username_secret",
]


class HDFSArtifact(Artifact):
"""A Hadoop File System artifact.
Expand Down Expand Up @@ -193,6 +254,22 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + [
"addresses",
"force",
"hdfs_path",
"hdfs_user",
"krb_c_cache_secret",
"krb_config_config_map",
"krb_keytab_secret",
"krb_realm",
"krb_service_principal_name",
"krb_username",
]


class HTTPArtifact(_ModelHTTPArtifact, Artifact):
"""An artifact sourced from an HTTP URL."""
Expand All @@ -206,6 +283,11 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + ["auth", "headers", "url"]


class OSSArtifact(_ModelOSSArtifact, Artifact):
"""An artifact sourced from OSS."""
Expand All @@ -224,6 +306,20 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + [
"access_key_secret",
"bucket",
"create_bucket_if_not_present",
"endpoint",
"key",
"lifecycle_rule",
"secret_key_secret",
"security_token",
]


class RawArtifact(_ModelRawArtifact, Artifact):
"""A raw bytes artifact representation."""
Expand All @@ -233,6 +329,11 @@ def _build_artifact(self) -> _ModelArtifact:
artifact.raw = _ModelRawArtifact(data=self.data)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + ["data"]


class S3Artifact(_ModelS3Artifact, Artifact):
"""An artifact sourced from AWS S3."""
Expand All @@ -254,8 +355,26 @@ def _build_artifact(self) -> _ModelArtifact:
)
return artifact

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input artifact annotations."""
return super().get_input_attributes() + [
"access_key_secret",
"bucket",
"create_bucket_if_not_present",
"encryption_options",
"endpoint",
"insecure",
"key",
"region",
"role_arn",
"secret_key_secret",
"use_sdk_creds",
]


__all__ = [
"Artifact",
*[c.__name__ for c in Artifact.__subclasses__()],
"ArtifactLoader",
]
40 changes: 20 additions & 20 deletions src/hera/workflows/container_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def next(self, other: ContainerNode) -> ContainerNode:
"""Sets the given container as a dependency of this container and returns the given container.
Examples:
>>> from hera.workflows import ContainerNode
>>> a, b = ContainerNode(name="a"), ContainerNode(name="b")
>>> a.next(b)
>>> b.dependencies
['a']
from hera.workflows import ContainerNode
# normally, you use the following within a `hera.workflows.ContainerSet` context.
a, b = ContainerNode(name="a"), ContainerNode(name="b")
a.next(b)
b.dependencies # prints ['a']
"""
assert issubclass(other.__class__, ContainerNode)
if other.dependencies is None:
Expand All @@ -65,11 +65,11 @@ def __rrshift__(self, other: List[ContainerNode]) -> ContainerNode:
Practically, the `__rrshift__` allows us to express statements such as `[a, b, c] >> d`, where `d` is `self.`
Examples:
>>> from hera.workflows import ContainerNode
>>> a, b, c = ContainerNode(name="a"), ContainerNode(name="b"), ContainerNode(name="c")
>>> [a, b] >> c
>>> c.dependencies
['a', 'b']
from hera.workflows import ContainerNode
# normally, you use the following within a `hera.workflows.ContainerSet` context.
a, b, c = ContainerNode(name="a"), ContainerNode(name="b"), ContainerNode(name="c")
[a, b] >> c
c.dependencies # prints ['a', 'b']
"""
assert isinstance(other, list), f"Unknown type {type(other)} specified using reverse right bitshift operator"
for o in other:
Expand All @@ -82,11 +82,11 @@ def __rshift__(
"""Sets the given container as a dependency of this container and returns the given container.
Examples:
>>> from hera.workflows import ContainerNode
>>> a, b = ContainerNode(name="a"), ContainerNode(name="b")
>>> a >> b
>>> b.dependencies
['a']
from hera.workflows import ContainerNode
# normally, you use the following within a `hera.workflows.ContainerSet` context.
a, b = ContainerNode(name="a"), ContainerNode(name="b")
a >> b
b.dependencies # prints ['a']
"""
if isinstance(other, ContainerNode):
return self.next(other)
Expand All @@ -101,15 +101,14 @@ def __rshift__(

def _build_container_node(self) -> _ModelContainerNode:
"""Builds the generated `ContainerNode`."""
image_pull_policy = self._build_image_pull_policy()
return _ModelContainerNode(
args=self.args,
command=self.command,
dependencies=self.dependencies,
env=self._build_env(),
env_from=self._build_env_from(),
image=self.image,
image_pull_policy=None if image_pull_policy is None else image_pull_policy.value,
image_pull_policy=self._build_image_pull_policy(),
lifecycle=self.lifecycle,
liveness_probe=self.liveness_probe,
name=self.name,
Expand Down Expand Up @@ -143,9 +142,10 @@ class ContainerSet(
The containers are run within the same pod.
Examples:
>>> with ContainerSet(...) as cs:
>>> ContainerNode(...)
>>> ContainerNode(...)
--------
>>> with ContainerSet(...) as cs:
>>> ContainerNode(...)
>>> ContainerNode(...)
"""

containers: List[Union[ContainerNode, _ModelContainerNode]] = []
Expand Down
5 changes: 2 additions & 3 deletions src/hera/workflows/cron_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,8 @@ def from_dict(cls, model_dict: Dict) -> ModelMapperMixin:
"""Create a CronWorkflow from a CronWorkflow contained in a dict.
Examples:
>>> my_cron_workflow = CronWorkflow(name="my-cron-wf")
>>> my_cron_workflow == CronWorkflow.from_dict(my_cron_workflow.to_dict())
True
my_cron_workflow = CronWorkflow(...)
my_cron_workflow == CronWorkflow.from_dict(my_cron_workflow.to_dict())
"""
return cls._from_dict(model_dict, _ModelCronWorkflow)

Expand Down
12 changes: 7 additions & 5 deletions src/hera/workflows/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ class DAG(
`hera.workflows.task.Task` objects instantiated will be added to the DAG's list of Tasks.
Examples:
>>> @script()
>>> def foo() -> None:
>>> print(42)
>>> with DAG(...) as dag:
>>> foo()
--------
>>> @script()
>>> def foo() -> None:
>>> print(42)
>>>
>>> with DAG(...) as dag:
>>> foo()
"""

fail_fast: Optional[bool] = None
Expand Down
5 changes: 5 additions & 0 deletions src/hera/workflows/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,10 @@ def as_output(self) -> _ModelParameter:
value_from=self.value_from,
)

@classmethod
def get_input_attributes(cls):
"""Return the attributes used for input parameter annotations."""
return ["enum", "description", "default", "name"]


__all__ = ["Parameter"]
Loading

0 comments on commit ae8bd5d

Please sign in to comment.