Skip to content

Commit

Permalink
Change supertype calculation when combining scatters (#5101)
Browse files Browse the repository at this point in the history
* Change supertype calculation when combining scatters

* Simplify if else branch
  • Loading branch information
stxue1 committed Sep 24, 2024
1 parent 8faca0f commit 027e89d
Showing 1 changed file with 17 additions and 27 deletions.
44 changes: 17 additions & 27 deletions src/toil/wdl/wdltoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,36 +505,26 @@ def log_bindings(log_function: Callable[..., None], message: str, all_bindings:
elif isinstance(bindings, Promise):
log_function("<Unfulfilled promise for bindings>")

def get_supertype(types: Sequence[Optional[WDL.Type.Base]]) -> WDL.Type.Base:
def get_supertype(types: Sequence[WDL.Type.Base]) -> WDL.Type.Base:
"""
Get the supertype that can hold values of all the given types.
"""

if None in types:
# Need to allow optional values
if len(types) == 1:
# Only None is here
return WDL.Type.Any(optional=True)
if len(types) == 2:
# None and something else
for item in types:
if item is not None:
# Return the type that's actually there, but make optional if not already.
return item.copy(optional=True)
raise RuntimeError("Expected non-None in types could not be found")
else:
# Multiple types, and some nulls, so we need an optional Any.
return WDL.Type.Any(optional=True)
else:
if len(types) == 1:
# Only one type. It isn't None.
the_type = types[0]
if the_type is None:
raise RuntimeError("The supertype cannot be None.")
return the_type
supertype = None
optional = False
for typ in types:
if isinstance(typ, WDL.Type.Any):
# ignore an Any type, as we represent a bottom type as Any. See https://miniwdl.readthedocs.io/en/latest/WDL.html#WDL.Type.Any
# and https://github.com/openwdl/wdl/blob/e43e042104b728df1f1ad6e6145945d2b32331a6/SPEC.md?plain=1#L1484
optional = optional or typ.optional
elif supertype is None:
supertype = typ
optional = optional or typ.optional
else:
# Multiple types (or none). Assume Any
return WDL.Type.Any()
# We have conflicting types
raise RuntimeError(f"Cannot generate a supertype from conflicting types: {types}")
if supertype is None:
return WDL.Type.Any(null=optional) # optional flag isn't used in Any
return supertype.copy(optional=optional)


def for_each_node(root: WDL.Tree.WorkflowNode) -> Iterator[WDL.Tree.WorkflowNode]:
Expand Down Expand Up @@ -3227,7 +3217,7 @@ def run(self, file_store: AbstractFileStore) -> WDLBindings:
# Problem: the WDL type types are not hashable, so we need to do bad N^2 deduplication
observed_types = []
for env in new_bindings:
binding_type = env.resolve(name).type if env.has_binding(name) else None
binding_type = env.resolve(name).type if env.has_binding(name) else WDL.Type.Any()
if binding_type not in observed_types:
observed_types.append(binding_type)
# Get the supertype of those types
Expand Down

0 comments on commit 027e89d

Please sign in to comment.