@@ -153,8 +153,8 @@ def process_config(self, io, base=None, model=None, build=None,
153
153
ana : dict, optional
154
154
Analysis script configurationdictionary
155
155
rank : int, optional
156
- Rank of the GPU. If not specified, the model will be run on CPU if
157
- `world_size` is 0 and GPU is `world_size` is > 0.
156
+ Rank of the GPU. The model will be run on CPU if `world_size` is not
157
+ specified or 0 and on GPU is `world_size` is > 0.
158
158
159
159
Returns
160
160
-------
@@ -170,9 +170,15 @@ def process_config(self, io, base=None, model=None, build=None,
170
170
logger .setLevel (verbosity .upper ())
171
171
172
172
# Set GPUs visible to CUDA
173
- world_size = base .get ('world_size' , 0 )
174
- os .environ ['CUDA_VISIBLE_DEVICES' ] = ',' .join (
175
- [str (i ) for i in range (world_size )])
173
+ gpus = base .get ('gpus' , None )
174
+ if gpus is not None :
175
+ os .environ ['CUDA_VISIBLE_DEVICES' ] = ',' .join (
176
+ [str (i ) for i in gpus ])
177
+
178
+ elif not os .environ ['CUDA_VISIBLE_DEVICES' ]:
179
+ world_size = base .get ('world_size' , 0 )
180
+ os .environ ['CUDA_VISIBLE_DEVICES' ] = ',' .join (
181
+ [str (i ) for i in range (world_size )])
176
182
177
183
# If the seed is not set for the sampler, randomize it. This is done
178
184
# here to keep a record of the seeds provided to the samplers
@@ -224,7 +230,7 @@ def process_config(self, io, base=None, model=None, build=None,
224
230
# Return updated configuration
225
231
return base , io , model , build , post , ana
226
232
227
- def initialize_base (self , seed , dtype = 'float32' , world_size = 0 ,
233
+ def initialize_base (self , seed , dtype = 'float32' , world_size = None , gpus = None ,
228
234
log_dir = 'logs' , prefix_log = False , overwrite_log = False ,
229
235
parent_path = None , iterations = None , epochs = None ,
230
236
unwrap = False , rank = None , log_step = 1 , distributed = False ,
@@ -237,8 +243,10 @@ def initialize_base(self, seed, dtype='float32', world_size=0,
237
243
Random number generator seed
238
244
dtype : str, default 'float32'
239
245
Data type of the model parameters and input data
240
- world_size : int, default 0
246
+ world_size : int, optional
241
247
Number of GPUs to use in the underlying model
248
+ gpus : List[int], optional
249
+ List of indexes of GPUs to expose to the model
242
250
log_dir : str, default 'logs'
243
251
Path to the directory where the logs will be written to
244
252
prefix_log : bool, default False
@@ -279,6 +287,16 @@ def initialize_base(self, seed, dtype='float32', world_size=0,
279
287
numba_seed (seed )
280
288
torch .manual_seed (seed )
281
289
290
+ # Check on the number of GPUs to use
291
+ if gpus is not None :
292
+ assert world_size is None or len (gpus ) == world_size , (
293
+ f"The number of visible GPUs ({ len (gpus )} ) is not "
294
+ f"compatible with the world size ({ world_size } )." )
295
+ world_size = len (gpus )
296
+
297
+ elif world_size is None :
298
+ world_size = 0
299
+
282
300
# Set up the device the model will run on
283
301
if rank is None and world_size > 0 :
284
302
assert world_size < 2 , (
0 commit comments