diff --git a/pathwaysutils/persistence/pathways_orbax_handler.py b/pathwaysutils/persistence/pathways_orbax_handler.py index 80bfcdb..ec4f95e 100644 --- a/pathwaysutils/persistence/pathways_orbax_handler.py +++ b/pathwaysutils/persistence/pathways_orbax_handler.py @@ -13,7 +13,7 @@ # limitations under the License. """TypeHandlers supporting Pathways backend.""" -from multiprocessing import Pool +import concurrent.futures import collections import concurrent.futures import datetime @@ -175,8 +175,13 @@ async def deserialize( grouped_shardings, ) ] - with Pool() as p: - grouped_arrays_and_futures = p.starmap(f, args_list) + with concurrent.futures.ThreadPoolExecutor() as e: + side_channel_call_futures = [e.submit(f, *args) for args in args_list] + concurrent.futures.wait( + side_channel_call_futures, + return_when=concurrent.futures.ALL_COMPLETED, + ) + grouped_arrays_and_futures = [side_channel_call_future.result() for side_channel_call_future in side_channel_call_futures] # grouped_arrays_and_futures = [ # f( # location=location,