diff --git a/src/hera/workflows/__init__.py b/src/hera/workflows/__init__.py index b133f5488..7ca9e54a6 100644 --- a/src/hera/workflows/__init__.py +++ b/src/hera/workflows/__init__.py @@ -9,6 +9,7 @@ from hera.workflows.archive import ArchiveStrategy, NoneArchiveStrategy, TarArchiveStrategy, ZipArchiveStrategy from hera.workflows.artifact import ( Artifact, + ArtifactLoader, ArtifactoryArtifact, AzureArtifact, GCSArtifact, @@ -82,6 +83,7 @@ "AccessMode", "ArchiveStrategy", "Artifact", + "ArtifactLoader", "ArtifactoryArtifact", "AzureArtifact", "AzureDiskVolumeVolume", diff --git a/src/hera/workflows/artifact.py b/src/hera/workflows/artifact.py index 2b861009c..c50d5fe9e 100644 --- a/src/hera/workflows/artifact.py +++ b/src/hera/workflows/artifact.py @@ -2,6 +2,7 @@ See https://argoproj.github.io/argo-workflows/walk-through/artifacts/ for a tutorial on Artifacts. """ +from enum import Enum from typing import List, Optional, Union, cast from hera.shared._base_model import BaseModel @@ -24,6 +25,13 @@ ) +class ArtifactLoader(Enum): + """Enum for artifact loader options.""" + + json = "json" + file = "file" + + class Artifact(BaseModel): """Base artifact representation.""" @@ -63,6 +71,9 @@ class Artifact(BaseModel): sub_path: Optional[str] = None """allows the specification of an artifact from a subpath within the main source.""" + loader: Optional[ArtifactLoader] = None + """used in Artifact annotations for determining how to load the data""" + def _build_archive(self) -> Optional[_ModelArchiveStrategy]: if self.archive is None: return None @@ -97,6 +108,18 @@ def as_name(self, name: str) -> _ModelArtifact: artifact.name = name return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return [ + "mode", + "name", + "optional", + "path", + "recurseMode", + "subPath", + ] + class ArtifactoryArtifact(_ModelArtifactoryArtifact, Artifact): """An artifact sourced from Artifactory.""" @@ -108,6 +131,11 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + ["url", "password_secret", "username_secret"] + class AzureArtifact(_ModelAzureArtifact, Artifact): """An artifact sourced from Microsoft Azure.""" @@ -123,6 +151,17 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + [ + "endpoint", + "container", + "blob", + "account_key_secret", + "use_sdk_creds", + ] + class GCSArtifact(_ModelGCSArtifact, Artifact): """An artifact sourced from Google Cloud Storage.""" @@ -136,6 +175,11 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + ["bucket", "key", "service_account_key_secret"] + class GitArtifact(_ModelGitArtifact, Artifact): """An artifact sourced from GitHub.""" @@ -157,6 +201,23 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + [ + "branch", + "depth", + "disable_submodules", + "fetch", + "insecure_ignore_host_key", + "password_secret", + "repo", + "revision", + "single_branch", + "ssh_private_key_secret", + "username_secret", + ] + class HDFSArtifact(Artifact): """A Hadoop File System artifact. @@ -193,6 +254,22 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + [ + "addresses", + "force", + "hdfs_path", + "hdfs_user", + "krb_c_cache_secret", + "krb_config_config_map", + "krb_keytab_secret", + "krb_realm", + "krb_service_principal_name", + "krb_username", + ] + class HTTPArtifact(_ModelHTTPArtifact, Artifact): """An artifact sourced from an HTTP URL.""" @@ -206,6 +283,11 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + ["auth", "headers", "url"] + class OSSArtifact(_ModelOSSArtifact, Artifact): """An artifact sourced from OSS.""" @@ -224,6 +306,20 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + [ + "access_key_secret", + "bucket", + "create_bucket_if_not_present", + "endpoint", + "key", + "lifecycle_rule", + "secret_key_secret", + "security_token", + ] + class RawArtifact(_ModelRawArtifact, Artifact): """A raw bytes artifact representation.""" @@ -233,6 +329,11 @@ def _build_artifact(self) -> _ModelArtifact: artifact.raw = _ModelRawArtifact(data=self.data) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + ["data"] + class S3Artifact(_ModelS3Artifact, Artifact): """An artifact sourced from AWS S3.""" @@ -254,8 +355,26 @@ def _build_artifact(self) -> _ModelArtifact: ) return artifact + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input artifact annotations.""" + return super().get_input_attributes() + [ + "access_key_secret", + "bucket", + "create_bucket_if_not_present", + "encryption_options", + "endpoint", + "insecure", + "key", + "region", + "role_arn", + "secret_key_secret", + "use_sdk_creds", + ] + __all__ = [ "Artifact", *[c.__name__ for c in Artifact.__subclasses__()], + "ArtifactLoader", ] diff --git a/src/hera/workflows/container_set.py b/src/hera/workflows/container_set.py index 8324a2486..ecb37038f 100644 --- a/src/hera/workflows/container_set.py +++ b/src/hera/workflows/container_set.py @@ -45,11 +45,11 @@ def next(self, other: ContainerNode) -> ContainerNode: """Sets the given container as a dependency of this container and returns the given container. Examples: - >>> from hera.workflows import ContainerNode - >>> a, b = ContainerNode(name="a"), ContainerNode(name="b") - >>> a.next(b) - >>> b.dependencies - ['a'] + from hera.workflows import ContainerNode + # normally, you use the following within a `hera.workflows.ContainerSet` context. + a, b = ContainerNode(name="a"), ContainerNode(name="b") + a.next(b) + b.dependencies # prints ['a'] """ assert issubclass(other.__class__, ContainerNode) if other.dependencies is None: @@ -65,11 +65,11 @@ def __rrshift__(self, other: List[ContainerNode]) -> ContainerNode: Practically, the `__rrshift__` allows us to express statements such as `[a, b, c] >> d`, where `d` is `self.` Examples: - >>> from hera.workflows import ContainerNode - >>> a, b, c = ContainerNode(name="a"), ContainerNode(name="b"), ContainerNode(name="c") - >>> [a, b] >> c - >>> c.dependencies - ['a', 'b'] + from hera.workflows import ContainerNode + # normally, you use the following within a `hera.workflows.ContainerSet` context. + a, b, c = ContainerNode(name="a"), ContainerNode(name="b"), ContainerNode(name="c") + [a, b] >> c + c.dependencies # prints ['a', 'b'] """ assert isinstance(other, list), f"Unknown type {type(other)} specified using reverse right bitshift operator" for o in other: @@ -82,11 +82,11 @@ def __rshift__( """Sets the given container as a dependency of this container and returns the given container. Examples: - >>> from hera.workflows import ContainerNode - >>> a, b = ContainerNode(name="a"), ContainerNode(name="b") - >>> a >> b - >>> b.dependencies - ['a'] + from hera.workflows import ContainerNode + # normally, you use the following within a `hera.workflows.ContainerSet` context. + a, b = ContainerNode(name="a"), ContainerNode(name="b") + a >> b + b.dependencies # prints ['a'] """ if isinstance(other, ContainerNode): return self.next(other) @@ -101,7 +101,6 @@ def __rshift__( def _build_container_node(self) -> _ModelContainerNode: """Builds the generated `ContainerNode`.""" - image_pull_policy = self._build_image_pull_policy() return _ModelContainerNode( args=self.args, command=self.command, @@ -109,7 +108,7 @@ def _build_container_node(self) -> _ModelContainerNode: env=self._build_env(), env_from=self._build_env_from(), image=self.image, - image_pull_policy=None if image_pull_policy is None else image_pull_policy.value, + image_pull_policy=self._build_image_pull_policy(), lifecycle=self.lifecycle, liveness_probe=self.liveness_probe, name=self.name, @@ -143,9 +142,10 @@ class ContainerSet( The containers are run within the same pod. Examples: - >>> with ContainerSet(...) as cs: - >>> ContainerNode(...) - >>> ContainerNode(...) + -------- + >>> with ContainerSet(...) as cs: + >>> ContainerNode(...) + >>> ContainerNode(...) """ containers: List[Union[ContainerNode, _ModelContainerNode]] = [] diff --git a/src/hera/workflows/cron_workflow.py b/src/hera/workflows/cron_workflow.py index c70c44112..62fa1c6ef 100644 --- a/src/hera/workflows/cron_workflow.py +++ b/src/hera/workflows/cron_workflow.py @@ -183,9 +183,8 @@ def from_dict(cls, model_dict: Dict) -> ModelMapperMixin: """Create a CronWorkflow from a CronWorkflow contained in a dict. Examples: - >>> my_cron_workflow = CronWorkflow(name="my-cron-wf") - >>> my_cron_workflow == CronWorkflow.from_dict(my_cron_workflow.to_dict()) - True + my_cron_workflow = CronWorkflow(...) + my_cron_workflow == CronWorkflow.from_dict(my_cron_workflow.to_dict()) """ return cls._from_dict(model_dict, _ModelCronWorkflow) diff --git a/src/hera/workflows/dag.py b/src/hera/workflows/dag.py index 4bbbfd25e..803022858 100644 --- a/src/hera/workflows/dag.py +++ b/src/hera/workflows/dag.py @@ -29,11 +29,13 @@ class DAG( `hera.workflows.task.Task` objects instantiated will be added to the DAG's list of Tasks. Examples: - >>> @script() - >>> def foo() -> None: - >>> print(42) - >>> with DAG(...) as dag: - >>> foo() + -------- + >>> @script() + >>> def foo() -> None: + >>> print(42) + >>> + >>> with DAG(...) as dag: + >>> foo() """ fail_fast: Optional[bool] = None diff --git a/src/hera/workflows/parameter.py b/src/hera/workflows/parameter.py index 9e74bb0a6..1996eee4d 100644 --- a/src/hera/workflows/parameter.py +++ b/src/hera/workflows/parameter.py @@ -101,5 +101,10 @@ def as_output(self) -> _ModelParameter: value_from=self.value_from, ) + @classmethod + def get_input_attributes(cls): + """Return the attributes used for input parameter annotations.""" + return ["enum", "description", "default", "name"] + __all__ = ["Parameter"] diff --git a/src/hera/workflows/runner.py b/src/hera/workflows/runner.py index b959700f3..5f2e99900 100644 --- a/src/hera/workflows/runner.py +++ b/src/hera/workflows/runner.py @@ -12,7 +12,8 @@ from typing_extensions import get_args, get_origin from hera.shared.serialization import serialize -from hera.workflows import Parameter +from hera.workflows import Artifact, Parameter +from hera.workflows.artifact import ArtifactLoader try: from typing import Annotated # type: ignore @@ -63,7 +64,7 @@ def _parse(value, key, f): The parsed value. """ - if _is_str_kwarg_of(key, f): + if _is_str_kwarg_of(key, f) or _is_artifact_loaded(key, f): return value try: return json.loads(value) @@ -84,15 +85,38 @@ def _is_str_kwarg_of(key: str, f: Callable): return False -def _map_keys(function: Callable, kwargs: dict): - """Change the kwargs's keys to use the python name instead of the parameter name which could be kebab case.""" - if os.environ.get("hera__script_annotations", None) is not None: +def _is_artifact_loaded(key, f): + """Check if param `key` of function `f` is actually an Artifact that has already been loaded.""" + param = inspect.signature(f).parameters[key] + return ( + get_origin(param.annotation) is Annotated + and isinstance(get_args(param.annotation)[1], Artifact) + and get_args(param.annotation)[1].loader == ArtifactLoader.json.value + ) + + +def _map_keys(function: Callable, kwargs: dict) -> dict: + """Change the kwargs's keys to use the python name instead of the parameter name which could be kebab case. + + For Parameters, update their name to not contain kebab-case in Python but allow it in YAML. + For Artifacts, parse the contained info to get the path or object as needed. + """ + if os.environ.get("hera__script_annotations", None) is None: return {key.replace("-", "_"): value for key, value in kwargs.items()} mapped_kwargs = {} for param_name, param in inspect.signature(function).parameters.items(): if get_origin(param.annotation) is Annotated and isinstance(get_args(param.annotation)[1], Parameter): mapped_kwargs[param_name] = kwargs[get_args(param.annotation)[1].name] + elif get_origin(param.annotation) is Annotated and isinstance(get_args(param.annotation)[1], Artifact): + if get_args(param.annotation)[1].loader == ArtifactLoader.json.value: + path = Path(get_args(param.annotation)[1].path) + mapped_kwargs[param_name] = json.load(path.open()) + elif get_args(param.annotation)[1].loader == ArtifactLoader.file.value: + path = Path(get_args(param.annotation)[1].path) + mapped_kwargs[param_name] = path.read_text() + elif get_args(param.annotation)[1].loader is None: + mapped_kwargs[param_name] = get_args(param.annotation)[1].path else: mapped_kwargs[param_name] = kwargs[param_name] return mapped_kwargs diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 799c7e70c..d25665ef8 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -14,6 +14,7 @@ Callable, List, Optional, + Tuple, Type, TypeVar, Union, @@ -36,6 +37,9 @@ VolumeMountMixin, ) from hera.workflows._unparse import roundtrip +from hera.workflows.artifact import ( + Artifact, +) from hera.workflows.models import ( EnvVar, Inputs as ModelInputs, @@ -144,21 +148,37 @@ def _constructor_validate(cls, values): def _build_inputs(self) -> Optional[ModelInputs]: inputs = super()._build_inputs() - func_parameters = _get_parameters_from_callable(self.source) if callable(self.source) else None - - if inputs is None and func_parameters is None: + func_parameters: List[Parameter] = [] + func_artifacts: List[Artifact] = [] + if callable(self.source): + if global_config.experimental_features["script_annotations"]: + func_parameters, func_artifacts = _get_parameters_and_artifacts_from_callable(self.source) + else: + func_parameters = _get_parameters_from_callable(self.source) + + if inputs is None and not func_parameters and not func_artifacts: return None - elif func_parameters is None: + elif not func_parameters and not func_artifacts: return inputs elif inputs is None: - inputs = ModelInputs(parameters=func_parameters) + inputs = ModelInputs( + parameters=func_parameters or None, artifacts=[a._build_artifact() for a in func_artifacts] or None + ) already_set_params = {p.name for p in inputs.parameters or []} already_set_artifacts = {p.name for p in inputs.artifacts or []} for param in func_parameters: if param.name not in already_set_params and param.name not in already_set_artifacts: - inputs.parameters = [param] if inputs.parameters is None else inputs.parameters + [param] + if inputs.parameters is None: + inputs.parameters = [] + inputs.parameters.append(param) + + for artifact in func_artifacts: + if artifact.name not in already_set_params and artifact.name not in already_set_artifacts: + if inputs.artifacts is None: + inputs.artifacts = [] + inputs.artifacts.append(artifact._build_artifact()) return inputs @@ -245,7 +265,7 @@ def _build_script(self) -> _ModelScriptTemplate: ) -def _get_parameters_from_callable(source: Callable) -> Optional[List[Parameter]]: +def _get_parameters_from_callable(source: Callable) -> List[Parameter]: # If there are any kwargs arguments associated with the function signature, # we store these as we can set them as default values for argo arguments parameters = [] @@ -259,25 +279,55 @@ def _get_parameters_from_callable(source: Callable) -> Optional[List[Parameter]] param = Parameter(name=p.name, default=default) parameters.append(param) - if not global_config.experimental_features["script_annotations"]: - continue + return parameters + - if get_origin(p.annotation) is not Annotated or not isinstance(get_args(p.annotation)[1], Parameter): - continue +def _get_parameters_and_artifacts_from_callable( + source: Callable, +) -> Tuple[List[Parameter], List[Artifact]]: + # If there are any kwargs arguments associated with the function signature, + # we store these as we can set them as default values for argo arguments + parameters = [] + artifacts = [] + + for p in inspect.signature(source).parameters.values(): + if get_origin(p.annotation) is Annotated and isinstance(get_args(p.annotation)[1], Artifact): + annotation = get_args(p.annotation)[1] + mytype = type(annotation) + kwargs = {} + for attr in mytype.get_input_attributes(): + if hasattr(annotation, attr) and getattr(annotation, attr) is not None: + kwargs[attr] = getattr(annotation, attr) + + artifact = mytype(**kwargs) + artifacts.append(artifact) + else: + if p.default != inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + default = p.default + else: + default = MISSING + + param = Parameter(name=p.name, default=default) + parameters.append(param) + + if get_origin(p.annotation) is not Annotated: + continue - annotation = get_args(p.annotation)[1] - for attr in ["enum", "description", "default", "name"]: - if not hasattr(annotation, attr) or getattr(annotation, attr) is None: + annotation = get_args(p.annotation)[1] + if not isinstance(annotation, Parameter): continue - if attr == "default" and param.default is not None: - raise ValueError( - "The default cannot be set via both the function parameter default and the annotation's default" - ) + for attr in Parameter.get_input_attributes(): + if not hasattr(annotation, attr) or getattr(annotation, attr) is None: + continue - setattr(param, attr, getattr(annotation, attr)) + if attr == "default" and param.default is not None: + raise ValueError( + "The default cannot be set via both the function parameter default and the annotation's default" + ) + setattr(param, attr, getattr(annotation, attr)) - return parameters or None + return parameters, artifacts FuncIns = ParamSpec("FuncIns") # For input types of given func to script decorator diff --git a/src/hera/workflows/task.py b/src/hera/workflows/task.py index 64d4934dd..a2c9f1aa5 100644 --- a/src/hera/workflows/task.py +++ b/src/hera/workflows/task.py @@ -49,7 +49,91 @@ class Task( ParameterMixin, ItemMixin, ): - """Task is used to run a given template within a DAG. Must be instantiated under a DAG context.""" + r"""Task is used to run a given template within a DAG. Must be instantiated under a DAG context. + + ## Dependencies + + Any `Tasks` without a dependency defined will start immediately. + + Dependencies between Tasks can be described using the convenience syntax `>>`, for example: + + ```py + A = Task(...) + B = Task(...) + A >> B + ``` + + describes the relationships: + + * "A has no dependencies (so starts immediately) + * "B depends on A". + + As a diagram: + + ``` + A + | + B + ``` + + `A >> B` is equivalent to `A.next(B)`. + + ## Lists of Tasks + + A list of Tasks used with the rshift syntax describes an "AND" dependency between the single Task on the left of + `>>` and the list Tasks to the right of `>>` (or vice versa). A list of Tasks on both sides of `>>` is not supported. + + For example: + + ``` + A = Task(...) + B = Task(...) + C = Task(...) + D = Task(...) + A >> [B, C] >> D + ``` + + describes the relationships: + + * "A has no dependencies + * "B AND C depend on A" + * "D depends on B AND C" + + As a diagram: + + ``` + A + / \\ + B C + \ / + D + ``` + + Dependencies can be described over multiple statements: + + ``` + A = Task(...) + B = Task(...) + C = Task(...) + D = Task(...) + A >> [C, D] + B >> [C, D] + ``` + + describes the relationships: + + * "A and B have no dependencies + * "C depends on A AND B" + * "D depends on A AND B" + + As a diagram: + + ``` + A B + | X | + C D + ``` + """ dependencies: Optional[List[str]] = None depends: Optional[str] = None diff --git a/src/hera/workflows/workflow.py b/src/hera/workflows/workflow.py index aa5eb4f55..6b1350e70 100644 --- a/src/hera/workflows/workflow.py +++ b/src/hera/workflows/workflow.py @@ -405,13 +405,6 @@ def wait(self, poll_interval: int = 5) -> TWorkflow: assert self.namespace is not None, "workflow namespace not defined" assert self.name is not None, "workflow name not defined" - # here we use the sleep interval to wait for the workflow post creation. This is to address a potential - # race conditions such as: - # 1. Argo server says "workflow was accepted" but the workflow is not yet created - # 2. Hera wants to verify the status of the workflow, but it's not yet defined because it's not created - # 3. Argo finally creates the workflow - # 4. Hera throws an `AssertionError` because the phase assertion fails - time.sleep(poll_interval) wf = self.workflows_service.get_workflow(self.name, namespace=self.namespace) assert wf.metadata.name is not None, f"workflow name not defined for workflow {self.name}" @@ -465,9 +458,8 @@ def from_dict(cls, model_dict: Dict) -> ModelMapperMixin: """Create a Workflow from a Workflow contained in a dict. Examples: - >>> my_workflow = Workflow(name="my-workflow") - >>> my_workflow == Workflow.from_dict(my_workflow.to_dict()) - True + my_workflow = Workflow(...) + my_workflow == Workflow.from_dict(my_workflow.to_dict()) """ return cls._from_dict(model_dict, _ModelWorkflow) diff --git a/src/hera/workflows/workflow_template.py b/src/hera/workflows/workflow_template.py index df9b7e6ba..0766bb4c0 100644 --- a/src/hera/workflows/workflow_template.py +++ b/src/hera/workflows/workflow_template.py @@ -116,9 +116,8 @@ def from_dict(cls, model_dict: Dict) -> ModelMapperMixin: """Create a WorkflowTemplate from a WorkflowTemplate contained in a dict. Examples: - >>> my_workflow_template = WorkflowTemplate(name="my-wft") - >>> my_workflow_template == WorkflowTemplate.from_dict(my_workflow_template.to_dict()) - True + my_workflow_template = WorkflowTemplate(...) + my_workflow_template == WorkflowTemplate.from_dict(my_workflow_template.to_dict()) """ return cls._from_dict(model_dict, _ModelWorkflowTemplate)