Skip to content

Commit

Permalink
Merge pull request PolusAI#219 from vjaganat90/pyapi
Browse files Browse the repository at this point in the history
Python API : Full scattering support (with example)
  • Loading branch information
sameeul committed May 7, 2024
2 parents 7e8f0b4 + 67b8192 commit 12e4bca
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 36 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/echo_multi_scatter.wic
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ steps:
in:
message1: !* filt_message
message2: !* filt_message
message3: !ii scalar
message3: !ii scalar
58 changes: 58 additions & 0 deletions examples/scripts/scatter_pyapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path

from wic.api.pythonapi import Step, Workflow


def build_workflow() -> Workflow:
# scatter on all inputs
# step array_ind
array_ind = Step(clt_path='../../cwl_adapters/array_indices.cwl')
array_ind.input_array = ["hello world", "not", "what world?"]
array_ind.input_indices = [0, 1]
# step echo
echo = Step(clt_path='../../cwl_adapters/echo.cwl')
echo.message = array_ind.output_array
# set up inputs for scattering
scatter_inps = echo.inputs[0]
# assign the scatter and scatterMethod fields
echo.scatter = [scatter_inps]

# arrange steps
steps = [array_ind, echo]

# create workflow
filename = 'scatter_pyapi_py' # .yml
wkflw = Workflow(steps, filename)
return wkflw


def build_workflow_e3() -> Workflow:
# scatter on a subset of inputs
# step array_indices
array_ind = Step(clt_path='../../cwl_adapters/array_indices.cwl')
array_ind.input_array = ["hello world", "not", "what world?"]
array_ind.input_indices = [0, 2]
# step echo_3
echo_3 = Step(clt_path='../../cwl_adapters/echo_3.cwl')
echo_3.message1 = array_ind.output_array
echo_3.message2 = array_ind.output_array
echo_3.message3 = 'scalar'
# set up inputs for scattering
msg1 = echo_3.inputs[0]
msg2 = echo_3.inputs[1]
# assign the scatter and scatterMethod fields
echo_3.scatter = [msg1, msg2]
echo_3.scatterMethod = 'flat_crossproduct'

# arrange steps
steps = [array_ind, echo_3]

# create workflow
filename = 'scatter_pyapi_py' # .yml
wkflw = Workflow(steps, filename)
return wkflw


scatter_wic = build_workflow_e3()
scatter_wic.compile(True) # Do NOT .run() here
scatter_wic.run()
16 changes: 4 additions & 12 deletions src/wic/api/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@ class CWLTypesEnum(str, Enum):
DIRECTORY = "Directory"


CWL_TYPES_DICT: dict[str, object] = {
"null": None,
"boolean": bool,
"int": int,
"long": int,
"float": float,
"double": float,
"string": str,
"File": Path,
"Directory": Path,
"Any": Any,
}
class ScatterMethod(Enum):
dotproduct = "dotproduct"
flat_crossproduct = "flat_crossproduct"
nested_crossproduct = "nested_crossproduct"
75 changes: 52 additions & 23 deletions src/wic/api/pythonapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from cwl_utils.parser import load_document_by_uri, load_document_by_yaml
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator

from wic.api._types import CWL_TYPES_DICT
from wic import compiler, input_output, plugins
from wic import compiler, input_output, plugins, utils_cwl
from wic import run_local as run_local_module
from wic.cli import get_args
from wic.utils_graphs import get_graph_reps
from wic.wic_types import CompilerInfo, RoseTree, StepId, Tool, Tools, YamlTree

from ._types import ScatterMethod


global_config: Tools = {}

Expand Down Expand Up @@ -87,6 +88,7 @@ class ProcessInput(BaseModel): # pylint: disable=too-few-public-methods
linked: bool = False

def __init__(self, name: str, inp_type: Any) -> None:
inp_type = utils_cwl.canonicalize_type(inp_type)
if isinstance(inp_type, list) and "null" in inp_type:
required = False
else:
Expand All @@ -105,7 +107,7 @@ def set_inp_type(cls, inp: Any) -> Any:
"""Return inp_type."""
if isinstance(inp, list): # optional inps
inp = inp[1]
return CWL_TYPES_DICT[inp]
return inp

def _set_value(
self, __value: Any, linked: bool = False
Expand All @@ -128,6 +130,7 @@ class ProcessOutput(BaseModel): # pylint: disable=too-few-public-methods
linked: bool = False

def __init__(self, name: str, out_type: Any) -> None:
out_type = utils_cwl.canonicalize_type(out_type)
if isinstance(out_type, list) and "null" in out_type:
required = False
else:
Expand All @@ -146,7 +149,7 @@ def set_out_type(cls, out: Any) -> Any:
"""Return out_type."""
if isinstance(out, list): # optional outs
out = out[1]
return CWL_TYPES_DICT[out]
return out

def _set_value(
self, __value: Any, linked: bool = False
Expand Down Expand Up @@ -186,12 +189,12 @@ def set_input_Step_Workflow(process_self: Any, __name: str, __value: Any) -> Any
try:
local_input = process_self.inputs[index]
# NOTE: Relax exact equality for Any type
if not local_input.inp_type == __value.out_type and not local_input.inp_type == Any and not __value.out_type == Any:
raise InvalidLinkError(
f"links must have the same input type. "
f"cannot link {local_input.name} to {__value.name}"
f"with types {local_input.inp_type} to {__value.out_type}"
)
# if not local_input.inp_type == __value.out_type and not local_input.inp_type == Any and not __value.out_type == Any:
# raise InvalidLinkError(
# f"links must have the same input type. "
# f"cannot link {local_input.name} to {__value.name}"
# f"with types {local_input.inp_type} to {__value.out_type}"
# )
if isinstance(process_other, Workflow):
tmp = __value.name # Use the formal parameter / variable name
local_input._set_value(f"{tmp}", linked=True)
Expand All @@ -211,12 +214,12 @@ def set_input_Step_Workflow(process_self: Any, __name: str, __value: Any) -> Any
try:
local_input = process_self.inputs[index]
# NOTE: Relax exact equality for Any type
if not local_input.inp_type == __value.out_type and not local_input.inp_type == Any and not __value.out_type == Any:
raise InvalidLinkError(
f"links must have the same input type. "
f"cannot link {local_input.name} to {__value.name} "
f"with types {local_input.inp_type} to {__value.out_type}"
)
# if not local_input.inp_type == __value.out_type and not local_input.inp_type == Any and not __value.out_type == Any:
# raise InvalidLinkError(
# f"links must have the same input type. "
# f"cannot link {local_input.name} to {__value.name} "
# f"with types {local_input.inp_type} to {__value.out_type}"
# )
if isinstance(process_other, Workflow):
tmp = __value.name # Use the formal parameter / variable name
local_input._set_value(f"{tmp}", linked=True)
Expand All @@ -229,14 +232,14 @@ def set_input_Step_Workflow(process_self: Any, __name: str, __value: Any) -> Any
raise exc

else:
obj = process_self.inputs[index]
# obj = process_self.inputs[index]
# NOTE: "TypeError: typing.Any cannot be used with isinstance()"
if not obj.inp_type == Any and not isinstance(__value, obj.inp_type):
raise TypeError(
f"invalid attribute type for {obj.name}: "
f"got {__value.__class__.__name__}, "
f"expected {obj.inp_type.__name__}"
)
# if not obj.inp_type == Any and not isinstance(__value, obj.inp_type) and not isinstance(__value, list):
# raise TypeError(
# f"invalid attribute type for {obj.name}: "
# f"got {__value.__class__.__name__}, "
# f"expected {obj.inp_type.__name__}"
# )
ii_dict = {'wic_inline_input': __value}
process_self.inputs[index]._set_value(ii_dict)

Expand All @@ -253,6 +256,10 @@ class Step(BaseModel): # pylint: disable=too-few-public-methods
outputs: list[ProcessOutput]
yaml: dict[str, Any]
cfg_yaml: dict = Field(default_factory=_default_dict)

# these are not part of 'clt data'
scatter: list[ProcessInput] = []
scatterMethod: str = ''
_input_names: list[str] = PrivateAttr(default_factory=list)
_output_names: list[str] = PrivateAttr(default_factory=list)

Expand Down Expand Up @@ -349,6 +356,18 @@ def __setattr__(self, __name: str, __value: Any) -> Any:
if __name in ["inputs", "outputs", "yaml", "cfg_yaml", "process_name", "_input_names", "_output_names",
"__private_attributes__", "__pydantic_private__"]:
return super().__setattr__(__name, __value)
if __name == "scatterMethod":
if hasattr(ScatterMethod, __value):
return super().__setattr__(__name, __value)
else:
raise ValueError(
f"Invalid value for scatterMethod the valid values are : \n {ScatterMethod.dotproduct.value} "
f"{ScatterMethod.flat_crossproduct.value} {ScatterMethod.nested_crossproduct.value}\n")
if __name == "scatter":
if not all([isinstance(x, ProcessInput) for x in __value]):
raise TypeError("all scatter inputs must be ProcessInput type")
return super().__setattr__(__name, __value)

if hasattr(self, "_input_names") and __name in self._input_names:
set_input_Step_Workflow(self, __name, __value)
else:
Expand Down Expand Up @@ -409,11 +428,21 @@ def _yml(self) -> dict:

out_list: list = [] # The out: tag is a list, not a dict
out_list = [{out.name: out.value} for out in self.outputs if out.value]
# list of inputs to be scattered on
scatter_list: list = []
scatter_list = [sc_inp.name for sc_inp in self.scatter]

d = {
"id": self.process_name,
"in": in_dict,
"out": out_list,
}
# scatter operates on input sink
if self.scatter:
d[self.process_name]["scatter"] = scatter_list
if '' == self.scatterMethod:
self.scatterMethod = ScatterMethod.dotproduct.value
d[self.process_name]["scatterMethod"] = self.scatterMethod
return d

def _save_cwl(self, path: Path) -> None:
Expand Down

0 comments on commit 12e4bca

Please sign in to comment.