From 4795c970796872c86ec76ddebe902f85f99e1228 Mon Sep 17 00:00:00 2001 From: Sadi Kneipp Date: Wed, 27 Nov 2024 22:11:03 +0000 Subject: [PATCH] fix pool bug --- .../persistence/pathways_orbax_handler.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/pathwaysutils/persistence/pathways_orbax_handler.py b/pathwaysutils/persistence/pathways_orbax_handler.py index 86b5f29..fe32247 100644 --- a/pathwaysutils/persistence/pathways_orbax_handler.py +++ b/pathwaysutils/persistence/pathways_orbax_handler.py @@ -165,23 +165,16 @@ async def deserialize( ) grouped_arrays_and_futures = None - with Pool() as p: - grouped_arrays_and_futures = p.map( - f( - location=location, - name=name, - dtype=dtype, - shape=shape, - shardings=sharding, - ) - for location, name, dtype, shape, sharding in zip( + from functools import partial + args_list = [[location, name, dtype, shape, sharding] for location, name, dtype, shape, sharding in zip( locations, names, grouped_dtypes, grouped_global_shapes, grouped_shardings, - ) - ) + )] + with Pool() as p: + grouped_arrays_and_futures = p.apply(f, args_list) # grouped_arrays_and_futures = [ # f( # location=location,