Skip to content

Commit

Permalink
Connect Toil to MiniWDL write_* file cache
Browse files Browse the repository at this point in the history
  • Loading branch information
adamnovak committed Sep 27, 2024
1 parent e38eeac commit 99c68f3
Showing 1 changed file with 132 additions and 17 deletions.
149 changes: 132 additions & 17 deletions src/toil/wdl/wdltoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Callable,
Dict,
Generator,
IO,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -86,6 +87,21 @@
from toil.lib.threading import global_mutex
from toil.provisioners.clusterScaler import JobTooBigError

import hashlib
try:
from hashlib import file_digest
except ImportError:
# Polyfill file_digest from 3.11+
def file_digest(f: IO[bytes], alg_name: str) -> hashlib._Hash:
BUFFER_SIZE = 1024 * 1024
hasher = hashlib.new(alg_name)
buffer = f.read(BUFFER_SIZE)
while buffer:
hasher.update(buffer)
buffer = f.read(BUFFER_SIZE)
return hasher



logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -669,8 +685,8 @@ def set_shared_fs_path(file: Union[str, WDL.Value.File], path: str) -> None:
Accepts either a WDL-level File or the actual str value of one.
Returns a str that has to be assigned back to the WDL File's value, which
may be the same one.
Returns a str that has to be assigned back to the WDL File's value, if the
input was not a WDL File.
"""
if isinstance(file, WDL.Value.File):
file_value = file.value
Expand All @@ -680,6 +696,9 @@ def set_shared_fs_path(file: Union[str, WDL.Value.File], path: str) -> None:
# Make it be a str subclass we can set attributes on
file_value = AttrStr(file_value)
setattr(file_value, SHARED_PATH_ATTR, path)
if isinstance(file, WDL.Value.File):
# Commit mutation
file.value = file_value
return file_value

def view_shared_fs_paths(bindings: WDL.Env.Bindings[WDL.Value.Base]) -> WDL.Env.Bindings[WDL.Value.Base]:
Expand All @@ -695,16 +714,6 @@ def file_path_to_use(stored_file_string: str) -> str:
return map_over_files_in_bindings(bindings, file_path_to_use)


def get_miniwdl_input_digest(bindings: WDL.Env.Bindings[WDL.Value.Base]) -> str:
"""
Get a digest for looking up the task call with the given inputs.
Represents all files by their shared filesystem paths so that cache entries written by MiniWDL can be used by Toil.
"""

return WDL.Value.digest_env(view_shared_fs_paths(bindings))


DirectoryNamingStateDict = Dict[str, Tuple[Dict[str, str], Set[str]]]
def choose_human_readable_directory(root_dir: str, source_task_path: str, parent_id: str, state: DirectoryNamingStateDict) -> str:
"""
Expand Down Expand Up @@ -1123,6 +1132,109 @@ def _virtualize_filename(self, filename: str) -> str:
self._virtualized_to_devirtualized[result] = abs_filename
return result

class ToilWDLStdLibWorkflow(ToilWDLStdLibBase):
"""
Standard library implementation for workflow scope.
Handles deduplicating files generated by write_* calls at workflow scope
with copies already in the call cache, so that tasks that depend on them
can also be fulfilled from the cache.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

# Set up MiniWDL caching for files
miniwdl_logger = logging.getLogger("MiniWDL")
# TODO: Ship config from leader? It might not see the right environment.
miniwdl_config = WDL.runtime.config.Loader(miniwdl_logger)
self._miniwdl_cache = WDL.runtime.cache.new(miniwdl_config, miniwdl_logger)

# This needs to be hash-compatible with MiniWDL.
# MiniWDL hooks _virtualize_filename
# <https://github.com/chanzuckerberg/miniwdl/blob/475dd3f3784d1390e6a0e880d43316a620114de3/WDL/runtime/workflow.py#L699-L729>,
# but we probably don't want to hash all files virtualized at workflow
# scope, just dynamic ones.
#
# TODO: Test cache compatibility with MiniWDL when a file is virtualized
# from a string at workflow scope!
def _write(
self, serialize: Callable[[WDL.Value.Base, IO[bytes]], None]
) -> Callable[[WDL.Value.Base], WDL.Value.File]:

# Get the normal writer
writer = super()._write(serialize)

def wrapper(v: WDL.Value.Base) -> WDL.Value.File:
"""
Call the normal writer, and then deduplicate its result with the cache.
"""
# TODO: If we did this before the _virtualize_filename call in the
# base _write, we could let the cache bring the info between nodes
# and not need to use the job store.

virtualized_file = writer(v)

# TODO: If we did this before the _virtualize_filename call in the
# base _write we wouldn't need to immediately devirtualize. But we
# have internal caches to lean on.
devirtualized_filename = self._devirtualize_filename(virtualized_file.value)
# Hash the file to hex
hex_digest = file_digest(open(devirtualized_filename, "rb"), "sha256").hexdigest()
file_input_bindings = WDL.Env.Bindings(WDL.Env.Binding("file_sha256", WDL.Value.String(hex_digest)))
# Make an environment of "file_sha256" to that as a WDL string, and
# digest that, and make a write_ cache key. No need to transform to
# shared FS paths sonce no paths are in it.
log_bindings(logger.debug, "Digesting file bindings:", [file_input_bindings])
input_digest = WDL.Value.digest_env(file_input_bindings)
file_cache_key = "write_/" + input_digest
# Construct a description of the types we expect to get from the
# cache: just a File-type variable named "file"
expected_types = WDL.Env.Bindings(WDL.Env.Binding("file", WDL.Type.File()))
# Query the cache
file_output_bindings = self._miniwdl_cache.get(file_cache_key, file_input_bindings, expected_types)
if file_output_bindings:
# File with this hash is cached.
# Adjust virtualized_file to carry that path as its local-filesystem path.
set_shared_fs_path(virtualized_file, file_output_bindings.resolve("file").value)
elif self._miniwdl_cache._cfg["call_cache"].get_bool("put"):
# Save our novel file to the cache.

# Determine where we will save the file.
output_directory = os.path.join(self._miniwdl_cache._call_cache_dir, file_cache_key)
# This needs to exist before we can export to it
os.makedirs(output_directory, exist_ok=True)

# Export the file to the cache.
# write_* files will never really need to being siblings, so we
# don't need any real persistent state here.
# TODO: Will they secretly be siblings on a first run?
exported_path = self.devirtualize_to(
virtualized_file.value,
output_directory,
self._file_store,
self._execution_dir,
{},
{},
{},
enforce_existence=True,
export=True
)

# Save the cache entry pointing to it
self._miniwdl_cache.put(
file_cache_key,
WDL.Env.Bindings(WDL.Env.Binding("file", WDL.Value.File(exported_path)))
)

# Apply the shared filesystem path to the virtualized file
set_shared_fs_path(virtualized_file, exported_path)

return virtualized_file

return wrapper


class ToilWDLStdLibTaskCommand(ToilWDLStdLibBase):
"""
Standard library implementation to use inside a WDL task command evaluation.
Expand Down Expand Up @@ -1968,14 +2080,17 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# At this point we have what MiniWDL would call the "inputs" to the
# task (i.e. what you would put in a JSON file, without any defaulted
# or calculated inputs filled in). So start making cache keys.
input_digest = get_miniwdl_input_digest(bindings)
# But first we need to view the inputs as shared FS files.
transformed_bindings = view_shared_fs_paths(bindings)
log_bindings(logger.debug, "Digesting input bindings:", [transformed_bindings])
input_digest = WDL.Value.digest_env(transformed_bindings)
task_digest = self._task.digest
cache_key=f"{self._task.name}/{task_digest}/{input_digest}"
miniwdl_logger = logging.getLogger("MiniWDL")
# TODO: Ship config from leader? It might not see the right environment.
miniwdl_config = WDL.runtime.config.Loader(miniwdl_logger)
miniwdl_cache = WDL.runtime.cache.new(miniwdl_config, miniwdl_logger)
cached_result: Optional[WDLBindings] = miniwdl_cache.get(cache_key, bindings, self._task.effective_outputs)
cached_result: Optional[WDLBindings] = miniwdl_cache.get(cache_key, transformed_bindings, self._task.effective_outputs)
if cached_result is not None:
logger.info("Found task call in cache")
return self.postprocess(cached_result)
Expand Down Expand Up @@ -2740,7 +2855,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# Combine the bindings we get from previous jobs
incoming_bindings = combine_bindings(unwrap_all(self._prev_node_results))
# Set up the WDL standard library
standard_library = ToilWDLStdLibBase(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir"))
standard_library = ToilWDLStdLibWorkflow(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir"))
with monkeypatch_coerce(standard_library):
if isinstance(self._node, WDL.Tree.Decl):
# This is a variable assignment
Expand Down Expand Up @@ -2831,7 +2946,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# Combine the bindings we get from previous jobs
current_bindings = combine_bindings(unwrap_all(self._prev_node_results))
# Set up the WDL standard library
standard_library = ToilWDLStdLibBase(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir"))
standard_library = ToilWDLStdLibWorkflow(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir"))

with monkeypatch_coerce(standard_library):
for node in self._nodes:
Expand Down Expand Up @@ -3508,7 +3623,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]:
# For a task we only see the insode-the-task namespace.
bindings = combine_bindings(unwrap_all(self._prev_node_results))
# Set up the WDL standard library
standard_library = ToilWDLStdLibBase(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir"))
standard_library = ToilWDLStdLibWorkflow(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir"))

if self._workflow.inputs:
with monkeypatch_coerce(standard_library):
Expand Down

0 comments on commit 99c68f3

Please sign in to comment.