Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiple bugs in Hera Input and Output classes #1193

Merged
merged 10 commits into from
Sep 17, 2024
136 changes: 53 additions & 83 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import sys
import warnings
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from pydantic.fields import FieldInfo

from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields
from hera.shared._type_util import get_workflow_annotation, is_annotated
from hera.shared._type_util import get_workflow_annotation
from hera.shared.serialization import MISSING, serialize
from hera.workflows._context import _context
from hera.workflows.artifact import Artifact
Expand Down Expand Up @@ -39,6 +41,23 @@
BaseModel = object # type: ignore


def _get_workflow_annotations(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]:
elliotgunton marked this conversation as resolved.
Show resolved Hide resolved
"""Constructs a workflow annotation for all Pydantic fields based on their annotations.

If a field has a workflow annotation, a copy will be returned, with name added if missing.
Otherwise, a Parameter annotation will be constructed.
"""
annotations = get_field_annotations(cls)
for field, field_info in get_fields(cls).items():
if annotation := get_workflow_annotation(annotations[field]):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or field
yield field, field_info, annotation_copy
else:
yield field, field_info, Parameter(name=field)


class InputMixin(BaseModel):
def __new__(cls, **kwargs):
if _context.declaring:
Expand All @@ -62,14 +81,9 @@ def __init__(self, /, **kwargs):
@classmethod
def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Parameter]:
parameters = []
annotations = get_field_annotations(cls)

for field, field_info in get_fields(cls).items():
if (param := get_workflow_annotation(annotations[field])) and isinstance(param, Parameter):
# Copy so as to not modify the Input fields themselves
param = param.copy()
if param.name is None:
param.name = field
for field, field_info, param in _get_workflow_annotations(cls):
if isinstance(param, Parameter):
elliotgunton marked this conversation as resolved.
Show resolved Hide resolved
if param.default is not None:
warnings.warn(
"Using the default field for Parameters in Annotations is deprecated since v5.16"
Expand All @@ -81,29 +95,15 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(field_info.default) # type: ignore
parameters.append(param)
elif not is_annotated(annotations[field]):
# Create a Parameter from basic type annotations
default = getattr(object_override, field) if object_override else field_info.default

# For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None
if default is None or default == PydanticUndefined:
default = MISSING

parameters.append(Parameter(name=field, default=default))

return parameters

@classmethod
def _get_artifacts(cls) -> List[Artifact]:
artifacts = []
annotations = get_field_annotations(cls)

for field in get_fields(cls):
if (artifact := get_workflow_annotation(annotations[field])) and isinstance(artifact, Artifact):
# Copy so as to not modify the Input fields themselves
artifact = artifact.copy()
if artifact.name is None:
artifact.name = field
for _, _, artifact in _get_workflow_annotations(cls):
if isinstance(artifact, Artifact):
if artifact.path is None:
artifact.path = artifact._get_default_inputs_path()
artifacts.append(artifact)
Expand All @@ -117,42 +117,33 @@ def _get_inputs(cls) -> List[Union[Artifact, Parameter]]:
def _get_as_templated_arguments(cls) -> Self:
"""Returns the Input with templated values to propagate through a DAG/Steps function."""
object_dict = {}
cls_fields = get_fields(cls)
annotations = get_field_annotations(cls)

for field in cls_fields:
if param_or_artifact := get_workflow_annotation(annotations[field]):
if isinstance(param_or_artifact, Parameter):
object_dict[field] = "{{inputs.parameters." + f"{param_or_artifact.name}" + "}}"
else:
object_dict[field] = "{{inputs.artifacts." + f"{param_or_artifact.name}" + "}}"
elif not is_annotated(annotations[field]):
object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}"
for field, _, annotation in _get_workflow_annotations(cls):
input_type = "parameters" if isinstance(annotation, Parameter) else "artifacts"
object_dict[field] = f"{{{{inputs.{input_type}.{annotation.name}}}}}"
elliotgunton marked this conversation as resolved.
Show resolved Hide resolved

return cls.construct(None, **object_dict)

def _get_as_arguments(self) -> ModelArguments:
params = []
artifacts = []
annotations = get_field_annotations(type(self))

if isinstance(self, V1BaseModel):
self_dict = self.dict()
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
self_dict = self.model_dump()

for field in get_fields(type(self)):
for field, _, annotation in _get_workflow_annotations(type(self)):
# The value may be a static value (of any time) if it has a default value, so we need to serialize it
# If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")``
templated_value = serialize(self_dict[field])
name = annotation.name
assert name is not None # guaranteed by _get_workflow_annotations

if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name:
if isinstance(param_or_artifact, Parameter):
params.append(ModelParameter(name=param_or_artifact.name, value=templated_value))
else:
artifacts.append(ModelArtifact(name=param_or_artifact.name, from_=templated_value))
elif not is_annotated(annotations[field]):
params.append(ModelParameter(name=field, value=templated_value))
if isinstance(annotation, Parameter):
params.append(ModelParameter(name=name, value=templated_value))
else:
artifacts.append(ModelArtifact(name=name, from_=templated_value))

return ModelArguments(parameters=params or None, artifacts=artifacts or None)

Expand All @@ -178,37 +169,22 @@ def __init__(self, /, **kwargs):
@classmethod
def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]:
outputs: List[Union[Artifact, Parameter]] = []
annotations = get_field_annotations(cls)

model_fields = get_fields(cls)

for field in model_fields:
for field, field_info, annotation in _get_workflow_annotations(cls):
if field in {"exit_code", "result"}:
continue
if param_or_artifact := get_workflow_annotation(annotations[field]):
if isinstance(param_or_artifact, Parameter):
if add_missing_path and (
param_or_artifact.value_from is None or param_or_artifact.value_from.path is None
):
param_or_artifact.value_from = ValueFrom(
path=f"/tmp/hera-outputs/parameters/{param_or_artifact.name}"
)
outputs.append(param_or_artifact)
else:
if add_missing_path and param_or_artifact.path is None:
param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}"
outputs.append(param_or_artifact)
elif not is_annotated(annotations[field]):
# Create a Parameter from basic type annotations
default = model_fields[field].default
if default is None or default == PydanticUndefined:
default = MISSING

value_from = None
if add_missing_path:
value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{field}")

outputs.append(Parameter(name=field, default=default, value_from=value_from))
if isinstance(annotation, Parameter):
if annotation.default is None:
default = field_info.default
if default is not None and default != PydanticUndefined:
annotation.default = serialize(default)

if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None):
annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}")
else:
if add_missing_path and annotation.path is None:
annotation.path = f"/tmp/hera-outputs/artifacts/{annotation.name}"
outputs.append(annotation)
return outputs

@classmethod
Expand All @@ -230,27 +206,21 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]:
This lets dags and steps hoist task/step outputs into its own outputs.
"""
outputs: List[Union[Artifact, Parameter]] = []
annotations = get_field_annotations(type(self))

if isinstance(self, V1BaseModel):
self_dict = self.dict()
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
self_dict = self.model_dump()

for field in get_fields(type(self)):
for field, _, annotation in _get_workflow_annotations(type(self)):
if field in {"exit_code", "result"}:
continue

templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"`

if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name:
if isinstance(param_or_artifact, Parameter):
outputs.append(
Parameter(name=param_or_artifact.name, value_from=ValueFrom(parameter=templated_value))
)
else:
outputs.append(Artifact(name=param_or_artifact.name, from_=templated_value))
elif not is_annotated(annotations[field]):
outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value)))
if isinstance(annotation, Parameter):
outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value)))
else:
outputs.append(Artifact(name=annotation.name, from_=templated_value))

return outputs
Loading