Skip to content

Commit

Permalink
Factor out common workflow annotation iteration
Browse files Browse the repository at this point in the history
Multiple methods in _io_mixins resolve the same problem of iterating
through all fields, constructing a workflow annotation based on the
user's annotations, defaulting the name if unset or falling back to a
Parameter if no annotation is found. Factor that logic out into an
iterator, simplifying the various methods and reducing the risk of
discrepancies and bugs in future.

Signed-off-by: Alice Purcell <alicederyn@gmail.com>
  • Loading branch information
alicederyn committed Sep 5, 2024
1 parent 70fd617 commit 519f83d
Showing 1 changed file with 45 additions and 85 deletions.
130 changes: 45 additions & 85 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sys
import warnings
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from pydantic.fields import FieldInfo

from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields
from hera.shared._type_util import get_workflow_annotation
from hera.shared.serialization import MISSING, serialize
Expand Down Expand Up @@ -39,6 +41,22 @@
BaseModel = object # type: ignore


def _get_workflow_annotations(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]:
"""Constructs a workflow annotation for all Pydantic fields based on their annotations.
If a field has a workflow annotation, a copy will be returned, with name added if missing.
Otherwise, a Parameter annotation will be constructed.
"""
annotations = get_field_annotations(cls)
for field, field_info in get_fields(cls).items():
if annotation := get_workflow_annotation(annotations[field]):
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or field
yield field, field_info, annotation_copy
else:
yield field, field_info, Parameter(name=field)


class InputMixin(BaseModel):
def __new__(cls, **kwargs):
if _context.declaring:
Expand All @@ -62,15 +80,10 @@ def __init__(self, /, **kwargs):
@classmethod
def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Parameter]:
parameters = []
annotations = get_field_annotations(cls)

for field, field_info in get_fields(cls).items():
param = get_workflow_annotation(annotations[field])
for field, field_info, param in _get_workflow_annotations(cls):
if isinstance(param, Parameter):
# Copy so as to not modify the Input fields themselves
param = param.copy()
if param.name is None:
param.name = field
if param.default is not None:
warnings.warn(
"Using the default field for Parameters in Annotations is deprecated since v5.16"
Expand All @@ -82,29 +95,15 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(field_info.default) # type: ignore
parameters.append(param)
elif param is None:
# Create a Parameter from basic type annotations
default = getattr(object_override, field) if object_override else field_info.default

# For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None
if default is None or default == PydanticUndefined:
default = MISSING

parameters.append(Parameter(name=field, default=default))

return parameters

@classmethod
def _get_artifacts(cls) -> List[Artifact]:
artifacts = []
annotations = get_field_annotations(cls)

for field in get_fields(cls):
if (artifact := get_workflow_annotation(annotations[field])) and isinstance(artifact, Artifact):
# Copy so as to not modify the Input fields themselves
artifact = artifact.copy()
if artifact.name is None:
artifact.name = field
for _, _, artifact in _get_workflow_annotations(cls):
if isinstance(artifact, Artifact):
if artifact.path is None:
artifact.path = artifact._get_default_inputs_path()
artifacts.append(artifact)
Expand All @@ -118,44 +117,31 @@ def _get_inputs(cls) -> List[Union[Artifact, Parameter]]:
def _get_as_templated_arguments(cls) -> Self:
"""Returns the Input with templated values to propagate through a DAG/Steps function."""
object_dict = {}
cls_fields = get_fields(cls)
annotations = get_field_annotations(cls)

for field in cls_fields:
if param_or_artifact := get_workflow_annotation(annotations[field]):
name = param_or_artifact.name or field
if isinstance(param_or_artifact, Parameter):
object_dict[field] = "{{inputs.parameters." + f"{name}" + "}}"
else:
object_dict[field] = "{{inputs.artifacts." + f"{name}" + "}}"
else:
object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}"
for field, _, annotation in _get_workflow_annotations(cls):
input_type = "parameters" if isinstance(annotation, Parameter) else "artifacts"
object_dict[field] = f"{{{{inputs.{input_type}.{annotation.name}}}}}"

return cls.construct(None, **object_dict)

def _get_as_arguments(self) -> ModelArguments:
params = []
artifacts = []
annotations = get_field_annotations(type(self))

if isinstance(self, V1BaseModel):
self_dict = self.dict()
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
self_dict = self.model_dump()

for field in get_fields(type(self)):
for field, _, annotation in _get_workflow_annotations(type(self)):
# The value may be a static value (of any time) if it has a default value, so we need to serialize it
# If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")``
templated_value = serialize(self_dict[field])

if param_or_artifact := get_workflow_annotation(annotations[field]):
name = param_or_artifact.name or field
if isinstance(param_or_artifact, Parameter):
params.append(ModelParameter(name=name, value=templated_value))
else:
artifacts.append(ModelArtifact(name=name, from_=templated_value))
if isinstance(annotation, Parameter):
params.append(ModelParameter(name=annotation.name, value=templated_value))
else:
params.append(ModelParameter(name=field, value=templated_value))
artifacts.append(ModelArtifact(name=annotation.name, from_=templated_value))

return ModelArguments(parameters=params or None, artifacts=artifacts or None)

Expand All @@ -181,43 +167,22 @@ def __init__(self, /, **kwargs):
@classmethod
def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]:
outputs: List[Union[Artifact, Parameter]] = []
annotations = get_field_annotations(cls)

model_fields = get_fields(cls)

for field in model_fields:
for field, field_info, annotation in _get_workflow_annotations(cls):
if field in {"exit_code", "result"}:
continue
if param_or_artifact := get_workflow_annotation(annotations[field]):
param_or_artifact = param_or_artifact.copy()
param_or_artifact.name = param_or_artifact.name or field
if isinstance(param_or_artifact, Parameter):
if param_or_artifact.default is None:
default = model_fields[field].default
if default is not None and default != PydanticUndefined:
param_or_artifact.default = serialize(default)
if add_missing_path and (
param_or_artifact.value_from is None or param_or_artifact.value_from.path is None
):
param_or_artifact.value_from = ValueFrom(
path=f"/tmp/hera-outputs/parameters/{param_or_artifact.name}"
)
outputs.append(param_or_artifact)
else:
if add_missing_path and param_or_artifact.path is None:
param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}"
outputs.append(param_or_artifact)
if isinstance(annotation, Parameter):
if annotation.default is None:
default = field_info.default
if default is not None and default != PydanticUndefined:
annotation.default = serialize(default)

if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None):
annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}")
else:
# Create a Parameter from basic type annotations
default = model_fields[field].default
if default is None or default == PydanticUndefined:
default = MISSING

value_from = None
if add_missing_path:
value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{field}")

outputs.append(Parameter(name=field, default=default, value_from=value_from))
if add_missing_path and annotation.path is None:
annotation.path = f"/tmp/hera-outputs/artifacts/{annotation.name}"
outputs.append(annotation)
return outputs

@classmethod
Expand All @@ -239,26 +204,21 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]:
This lets dags and steps hoist task/step outputs into its own outputs.
"""
outputs: List[Union[Artifact, Parameter]] = []
annotations = get_field_annotations(type(self))

if isinstance(self, V1BaseModel):
self_dict = self.dict()
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
self_dict = self.model_dump()

for field in get_fields(type(self)):
for field, _, annotation in _get_workflow_annotations(type(self)):
if field in {"exit_code", "result"}:
continue

templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"`

if param_or_artifact := get_workflow_annotation(annotations[field]):
name = param_or_artifact.name or field
if isinstance(param_or_artifact, Parameter):
outputs.append(Parameter(name=name, value_from=ValueFrom(parameter=templated_value)))
else:
outputs.append(Artifact(name=name, from_=templated_value))
if isinstance(annotation, Parameter):
outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value)))
else:
outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value)))
outputs.append(Artifact(name=annotation.name, from_=templated_value))

return outputs

0 comments on commit 519f83d

Please sign in to comment.