From 99c68f31d3472da2475137caef8923871942b988 Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Thu, 26 Sep 2024 22:16:10 -0400 Subject: [PATCH] Connect Toil to MiniWDL write_* file cache --- src/toil/wdl/wdltoil.py | 149 +++++++++++++++++++++++++++++++++++----- 1 file changed, 132 insertions(+), 17 deletions(-) diff --git a/src/toil/wdl/wdltoil.py b/src/toil/wdl/wdltoil.py index 672544754c..255a361e9d 100755 --- a/src/toil/wdl/wdltoil.py +++ b/src/toil/wdl/wdltoil.py @@ -36,6 +36,7 @@ Callable, Dict, Generator, + IO, Iterable, Iterator, List, @@ -86,6 +87,21 @@ from toil.lib.threading import global_mutex from toil.provisioners.clusterScaler import JobTooBigError +import hashlib +try: + from hashlib import file_digest +except ImportError: + # Polyfill file_digest from 3.11+ + def file_digest(f: IO[bytes], alg_name: str) -> hashlib._Hash: + BUFFER_SIZE = 1024 * 1024 + hasher = hashlib.new(alg_name) + buffer = f.read(BUFFER_SIZE) + while buffer: + hasher.update(buffer) + buffer = f.read(BUFFER_SIZE) + return hasher + + logger = logging.getLogger(__name__) @@ -669,8 +685,8 @@ def set_shared_fs_path(file: Union[str, WDL.Value.File], path: str) -> None: Accepts either a WDL-level File or the actual str value of one. - Returns a str that has to be assigned back to the WDL File's value, which - may be the same one. + Returns a str that has to be assigned back to the WDL File's value, if the + input was not a WDL File. """ if isinstance(file, WDL.Value.File): file_value = file.value @@ -680,6 +696,9 @@ def set_shared_fs_path(file: Union[str, WDL.Value.File], path: str) -> None: # Make it be a str subclass we can set attributes on file_value = AttrStr(file_value) setattr(file_value, SHARED_PATH_ATTR, path) + if isinstance(file, WDL.Value.File): + # Commit mutation + file.value = file_value return file_value def view_shared_fs_paths(bindings: WDL.Env.Bindings[WDL.Value.Base]) -> WDL.Env.Bindings[WDL.Value.Base]: @@ -695,16 +714,6 @@ def file_path_to_use(stored_file_string: str) -> str: return map_over_files_in_bindings(bindings, file_path_to_use) -def get_miniwdl_input_digest(bindings: WDL.Env.Bindings[WDL.Value.Base]) -> str: - """ - Get a digest for looking up the task call with the given inputs. - - Represents all files by their shared filesystem paths so that cache entries written by MiniWDL can be used by Toil. - """ - - return WDL.Value.digest_env(view_shared_fs_paths(bindings)) - - DirectoryNamingStateDict = Dict[str, Tuple[Dict[str, str], Set[str]]] def choose_human_readable_directory(root_dir: str, source_task_path: str, parent_id: str, state: DirectoryNamingStateDict) -> str: """ @@ -1123,6 +1132,109 @@ def _virtualize_filename(self, filename: str) -> str: self._virtualized_to_devirtualized[result] = abs_filename return result +class ToilWDLStdLibWorkflow(ToilWDLStdLibBase): + """ + Standard library implementation for workflow scope. + + Handles deduplicating files generated by write_* calls at workflow scope + with copies already in the call cache, so that tasks that depend on them + can also be fulfilled from the cache. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # Set up MiniWDL caching for files + miniwdl_logger = logging.getLogger("MiniWDL") + # TODO: Ship config from leader? It might not see the right environment. + miniwdl_config = WDL.runtime.config.Loader(miniwdl_logger) + self._miniwdl_cache = WDL.runtime.cache.new(miniwdl_config, miniwdl_logger) + + # This needs to be hash-compatible with MiniWDL. + # MiniWDL hooks _virtualize_filename + # , + # but we probably don't want to hash all files virtualized at workflow + # scope, just dynamic ones. + # + # TODO: Test cache compatibility with MiniWDL when a file is virtualized + # from a string at workflow scope! + def _write( + self, serialize: Callable[[WDL.Value.Base, IO[bytes]], None] + ) -> Callable[[WDL.Value.Base], WDL.Value.File]: + + # Get the normal writer + writer = super()._write(serialize) + + def wrapper(v: WDL.Value.Base) -> WDL.Value.File: + """ + Call the normal writer, and then deduplicate its result with the cache. + """ + # TODO: If we did this before the _virtualize_filename call in the + # base _write, we could let the cache bring the info between nodes + # and not need to use the job store. + + virtualized_file = writer(v) + + # TODO: If we did this before the _virtualize_filename call in the + # base _write we wouldn't need to immediately devirtualize. But we + # have internal caches to lean on. + devirtualized_filename = self._devirtualize_filename(virtualized_file.value) + # Hash the file to hex + hex_digest = file_digest(open(devirtualized_filename, "rb"), "sha256").hexdigest() + file_input_bindings = WDL.Env.Bindings(WDL.Env.Binding("file_sha256", WDL.Value.String(hex_digest))) + # Make an environment of "file_sha256" to that as a WDL string, and + # digest that, and make a write_ cache key. No need to transform to + # shared FS paths sonce no paths are in it. + log_bindings(logger.debug, "Digesting file bindings:", [file_input_bindings]) + input_digest = WDL.Value.digest_env(file_input_bindings) + file_cache_key = "write_/" + input_digest + # Construct a description of the types we expect to get from the + # cache: just a File-type variable named "file" + expected_types = WDL.Env.Bindings(WDL.Env.Binding("file", WDL.Type.File())) + # Query the cache + file_output_bindings = self._miniwdl_cache.get(file_cache_key, file_input_bindings, expected_types) + if file_output_bindings: + # File with this hash is cached. + # Adjust virtualized_file to carry that path as its local-filesystem path. + set_shared_fs_path(virtualized_file, file_output_bindings.resolve("file").value) + elif self._miniwdl_cache._cfg["call_cache"].get_bool("put"): + # Save our novel file to the cache. + + # Determine where we will save the file. + output_directory = os.path.join(self._miniwdl_cache._call_cache_dir, file_cache_key) + # This needs to exist before we can export to it + os.makedirs(output_directory, exist_ok=True) + + # Export the file to the cache. + # write_* files will never really need to being siblings, so we + # don't need any real persistent state here. + # TODO: Will they secretly be siblings on a first run? + exported_path = self.devirtualize_to( + virtualized_file.value, + output_directory, + self._file_store, + self._execution_dir, + {}, + {}, + {}, + enforce_existence=True, + export=True + ) + + # Save the cache entry pointing to it + self._miniwdl_cache.put( + file_cache_key, + WDL.Env.Bindings(WDL.Env.Binding("file", WDL.Value.File(exported_path))) + ) + + # Apply the shared filesystem path to the virtualized file + set_shared_fs_path(virtualized_file, exported_path) + + return virtualized_file + + return wrapper + + class ToilWDLStdLibTaskCommand(ToilWDLStdLibBase): """ Standard library implementation to use inside a WDL task command evaluation. @@ -1968,14 +2080,17 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]: # At this point we have what MiniWDL would call the "inputs" to the # task (i.e. what you would put in a JSON file, without any defaulted # or calculated inputs filled in). So start making cache keys. - input_digest = get_miniwdl_input_digest(bindings) + # But first we need to view the inputs as shared FS files. + transformed_bindings = view_shared_fs_paths(bindings) + log_bindings(logger.debug, "Digesting input bindings:", [transformed_bindings]) + input_digest = WDL.Value.digest_env(transformed_bindings) task_digest = self._task.digest cache_key=f"{self._task.name}/{task_digest}/{input_digest}" miniwdl_logger = logging.getLogger("MiniWDL") # TODO: Ship config from leader? It might not see the right environment. miniwdl_config = WDL.runtime.config.Loader(miniwdl_logger) miniwdl_cache = WDL.runtime.cache.new(miniwdl_config, miniwdl_logger) - cached_result: Optional[WDLBindings] = miniwdl_cache.get(cache_key, bindings, self._task.effective_outputs) + cached_result: Optional[WDLBindings] = miniwdl_cache.get(cache_key, transformed_bindings, self._task.effective_outputs) if cached_result is not None: logger.info("Found task call in cache") return self.postprocess(cached_result) @@ -2740,7 +2855,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]: # Combine the bindings we get from previous jobs incoming_bindings = combine_bindings(unwrap_all(self._prev_node_results)) # Set up the WDL standard library - standard_library = ToilWDLStdLibBase(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir")) + standard_library = ToilWDLStdLibWorkflow(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir")) with monkeypatch_coerce(standard_library): if isinstance(self._node, WDL.Tree.Decl): # This is a variable assignment @@ -2831,7 +2946,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]: # Combine the bindings we get from previous jobs current_bindings = combine_bindings(unwrap_all(self._prev_node_results)) # Set up the WDL standard library - standard_library = ToilWDLStdLibBase(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir")) + standard_library = ToilWDLStdLibWorkflow(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir")) with monkeypatch_coerce(standard_library): for node in self._nodes: @@ -3508,7 +3623,7 @@ def run(self, file_store: AbstractFileStore) -> Promised[WDLBindings]: # For a task we only see the insode-the-task namespace. bindings = combine_bindings(unwrap_all(self._prev_node_results)) # Set up the WDL standard library - standard_library = ToilWDLStdLibBase(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir")) + standard_library = ToilWDLStdLibWorkflow(file_store, self._task_path, execution_dir=self._wdl_options.get("execution_dir")) if self._workflow.inputs: with monkeypatch_coerce(standard_library):