From 519f83d214226e542920eca5168b8e8111adaf2c Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Thu, 5 Sep 2024 22:45:43 +0100 Subject: [PATCH] Factor out common workflow annotation iteration Multiple methods in _io_mixins resolve the same problem of iterating through all fields, constructing a workflow annotation based on the user's annotations, defaulting the name if unset or falling back to a Parameter if no annotation is found. Factor that logic out into an iterator, simplifying the various methods and reducing the risk of discrepancies and bugs in future. Signed-off-by: Alice Purcell --- src/hera/workflows/io/_io_mixins.py | 130 ++++++++++------------------ 1 file changed, 45 insertions(+), 85 deletions(-) diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index 750413ab1..75213bad7 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -1,12 +1,14 @@ import sys import warnings -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self +from pydantic.fields import FieldInfo + from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields from hera.shared._type_util import get_workflow_annotation from hera.shared.serialization import MISSING, serialize @@ -39,6 +41,22 @@ BaseModel = object # type: ignore +def _get_workflow_annotations(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]: + """Constructs a workflow annotation for all Pydantic fields based on their annotations. + + If a field has a workflow annotation, a copy will be returned, with name added if missing. + Otherwise, a Parameter annotation will be constructed. + """ + annotations = get_field_annotations(cls) + for field, field_info in get_fields(cls).items(): + if annotation := get_workflow_annotation(annotations[field]): + annotation_copy = annotation.copy() + annotation_copy.name = annotation.name or field + yield field, field_info, annotation_copy + else: + yield field, field_info, Parameter(name=field) + + class InputMixin(BaseModel): def __new__(cls, **kwargs): if _context.declaring: @@ -62,15 +80,10 @@ def __init__(self, /, **kwargs): @classmethod def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Parameter]: parameters = [] - annotations = get_field_annotations(cls) - for field, field_info in get_fields(cls).items(): - param = get_workflow_annotation(annotations[field]) + for field, field_info, param in _get_workflow_annotations(cls): if isinstance(param, Parameter): # Copy so as to not modify the Input fields themselves - param = param.copy() - if param.name is None: - param.name = field if param.default is not None: warnings.warn( "Using the default field for Parameters in Annotations is deprecated since v5.16" @@ -82,29 +95,15 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet # Serialize the value (usually done in Parameter's validator) param.default = serialize(field_info.default) # type: ignore parameters.append(param) - elif param is None: - # Create a Parameter from basic type annotations - default = getattr(object_override, field) if object_override else field_info.default - - # For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None - if default is None or default == PydanticUndefined: - default = MISSING - - parameters.append(Parameter(name=field, default=default)) return parameters @classmethod def _get_artifacts(cls) -> List[Artifact]: artifacts = [] - annotations = get_field_annotations(cls) - for field in get_fields(cls): - if (artifact := get_workflow_annotation(annotations[field])) and isinstance(artifact, Artifact): - # Copy so as to not modify the Input fields themselves - artifact = artifact.copy() - if artifact.name is None: - artifact.name = field + for _, _, artifact in _get_workflow_annotations(cls): + if isinstance(artifact, Artifact): if artifact.path is None: artifact.path = artifact._get_default_inputs_path() artifacts.append(artifact) @@ -118,44 +117,31 @@ def _get_inputs(cls) -> List[Union[Artifact, Parameter]]: def _get_as_templated_arguments(cls) -> Self: """Returns the Input with templated values to propagate through a DAG/Steps function.""" object_dict = {} - cls_fields = get_fields(cls) - annotations = get_field_annotations(cls) - for field in cls_fields: - if param_or_artifact := get_workflow_annotation(annotations[field]): - name = param_or_artifact.name or field - if isinstance(param_or_artifact, Parameter): - object_dict[field] = "{{inputs.parameters." + f"{name}" + "}}" - else: - object_dict[field] = "{{inputs.artifacts." + f"{name}" + "}}" - else: - object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}" + for field, _, annotation in _get_workflow_annotations(cls): + input_type = "parameters" if isinstance(annotation, Parameter) else "artifacts" + object_dict[field] = f"{{{{inputs.{input_type}.{annotation.name}}}}}" return cls.construct(None, **object_dict) def _get_as_arguments(self) -> ModelArguments: params = [] artifacts = [] - annotations = get_field_annotations(type(self)) if isinstance(self, V1BaseModel): self_dict = self.dict() elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel): self_dict = self.model_dump() - for field in get_fields(type(self)): + for field, _, annotation in _get_workflow_annotations(type(self)): # The value may be a static value (of any time) if it has a default value, so we need to serialize it # If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")`` templated_value = serialize(self_dict[field]) - if param_or_artifact := get_workflow_annotation(annotations[field]): - name = param_or_artifact.name or field - if isinstance(param_or_artifact, Parameter): - params.append(ModelParameter(name=name, value=templated_value)) - else: - artifacts.append(ModelArtifact(name=name, from_=templated_value)) + if isinstance(annotation, Parameter): + params.append(ModelParameter(name=annotation.name, value=templated_value)) else: - params.append(ModelParameter(name=field, value=templated_value)) + artifacts.append(ModelArtifact(name=annotation.name, from_=templated_value)) return ModelArguments(parameters=params or None, artifacts=artifacts or None) @@ -181,43 +167,22 @@ def __init__(self, /, **kwargs): @classmethod def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]: outputs: List[Union[Artifact, Parameter]] = [] - annotations = get_field_annotations(cls) - - model_fields = get_fields(cls) - for field in model_fields: + for field, field_info, annotation in _get_workflow_annotations(cls): if field in {"exit_code", "result"}: continue - if param_or_artifact := get_workflow_annotation(annotations[field]): - param_or_artifact = param_or_artifact.copy() - param_or_artifact.name = param_or_artifact.name or field - if isinstance(param_or_artifact, Parameter): - if param_or_artifact.default is None: - default = model_fields[field].default - if default is not None and default != PydanticUndefined: - param_or_artifact.default = serialize(default) - if add_missing_path and ( - param_or_artifact.value_from is None or param_or_artifact.value_from.path is None - ): - param_or_artifact.value_from = ValueFrom( - path=f"/tmp/hera-outputs/parameters/{param_or_artifact.name}" - ) - outputs.append(param_or_artifact) - else: - if add_missing_path and param_or_artifact.path is None: - param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}" - outputs.append(param_or_artifact) + if isinstance(annotation, Parameter): + if annotation.default is None: + default = field_info.default + if default is not None and default != PydanticUndefined: + annotation.default = serialize(default) + + if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None): + annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}") else: - # Create a Parameter from basic type annotations - default = model_fields[field].default - if default is None or default == PydanticUndefined: - default = MISSING - - value_from = None - if add_missing_path: - value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{field}") - - outputs.append(Parameter(name=field, default=default, value_from=value_from)) + if add_missing_path and annotation.path is None: + annotation.path = f"/tmp/hera-outputs/artifacts/{annotation.name}" + outputs.append(annotation) return outputs @classmethod @@ -239,26 +204,21 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]: This lets dags and steps hoist task/step outputs into its own outputs. """ outputs: List[Union[Artifact, Parameter]] = [] - annotations = get_field_annotations(type(self)) if isinstance(self, V1BaseModel): self_dict = self.dict() elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel): self_dict = self.model_dump() - for field in get_fields(type(self)): + for field, _, annotation in _get_workflow_annotations(type(self)): if field in {"exit_code", "result"}: continue templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"` - if param_or_artifact := get_workflow_annotation(annotations[field]): - name = param_or_artifact.name or field - if isinstance(param_or_artifact, Parameter): - outputs.append(Parameter(name=name, value_from=ValueFrom(parameter=templated_value))) - else: - outputs.append(Artifact(name=name, from_=templated_value)) + if isinstance(annotation, Parameter): + outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value))) else: - outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value))) + outputs.append(Artifact(name=annotation.name, from_=templated_value)) return outputs