Skip to content

Commit 018b004

Browse files
Merge pull request #60 from francois-drielsma/develop
Add option to specify visible devices manually
2 parents 5e560be + 91a3420 commit 018b004

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

bin/run.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,17 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly,
107107
if not 'train' in cfg['base']:
108108
raise KeyError("--weight_prefix flag provided: must specify "
109109
"`train` in the `base` block.")
110-
cfg['base']['train']['weight_prefix']=weight_prefix
110+
cfg['base']['train']['weight_prefix'] = weight_prefix
111111

112112
if weight_path is not None:
113-
cfg['model']['weight_path']=weight_path
113+
cfg['model']['weight_path'] = weight_path
114114

115115
# Turn on PyTorch anomaly detection, if requested
116116
if detect_anomaly is not None:
117117
assert 'model' in cfg, (
118118
"There is no model to detect anomalies for, add `model` block.")
119119
cfg['model']['detect_anomaly'] = detect_anomaly
120120

121-
# If the -1 option for GPUs is selected, expose the process to all GPUs
122-
if os.environ.get('CUDA_VISIBLE_DEVICES') is not None \
123-
and cfg['base'].get('gpus', '') == '-1':
124-
cfg['base']['gpus'] = os.environ.get('CUDA_VISIBLE_DEVICES')
125-
126121
# Execute train/validation process
127122
run(cfg)
128123

spine/driver.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def process_config(self, io, base=None, model=None, build=None,
153153
ana : dict, optional
154154
Analysis script configurationdictionary
155155
rank : int, optional
156-
Rank of the GPU. If not specified, the model will be run on CPU if
157-
`world_size` is 0 and GPU is `world_size` is > 0.
156+
Rank of the GPU. The model will be run on CPU if `world_size` is not
157+
specified or 0 and on GPU is `world_size` is > 0.
158158
159159
Returns
160160
-------
@@ -170,9 +170,15 @@ def process_config(self, io, base=None, model=None, build=None,
170170
logger.setLevel(verbosity.upper())
171171

172172
# Set GPUs visible to CUDA
173-
world_size = base.get('world_size', 0)
174-
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
175-
[str(i) for i in range(world_size)])
173+
gpus = base.get('gpus', None)
174+
if gpus is not None:
175+
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
176+
[str(i) for i in gpus])
177+
178+
elif not os.environ['CUDA_VISIBLE_DEVICES']:
179+
world_size = base.get('world_size', 0)
180+
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
181+
[str(i) for i in range(world_size)])
176182

177183
# If the seed is not set for the sampler, randomize it. This is done
178184
# here to keep a record of the seeds provided to the samplers
@@ -224,7 +230,7 @@ def process_config(self, io, base=None, model=None, build=None,
224230
# Return updated configuration
225231
return base, io, model, build, post, ana
226232

227-
def initialize_base(self, seed, dtype='float32', world_size=0,
233+
def initialize_base(self, seed, dtype='float32', world_size=None, gpus=None,
228234
log_dir='logs', prefix_log=False, overwrite_log=False,
229235
parent_path=None, iterations=None, epochs=None,
230236
unwrap=False, rank=None, log_step=1, distributed=False,
@@ -237,8 +243,10 @@ def initialize_base(self, seed, dtype='float32', world_size=0,
237243
Random number generator seed
238244
dtype : str, default 'float32'
239245
Data type of the model parameters and input data
240-
world_size : int, default 0
246+
world_size : int, optional
241247
Number of GPUs to use in the underlying model
248+
gpus : List[int], optional
249+
List of indexes of GPUs to expose to the model
242250
log_dir : str, default 'logs'
243251
Path to the directory where the logs will be written to
244252
prefix_log : bool, default False
@@ -279,6 +287,16 @@ def initialize_base(self, seed, dtype='float32', world_size=0,
279287
numba_seed(seed)
280288
torch.manual_seed(seed)
281289

290+
# Check on the number of GPUs to use
291+
if gpus is not None:
292+
assert world_size is None or len(gpus) == world_size, (
293+
f"The number of visible GPUs ({len(gpus)}) is not "
294+
f"compatible with the world size ({world_size}).")
295+
world_size = len(gpus)
296+
297+
elif world_size is None:
298+
world_size = 0
299+
282300
# Set up the device the model will run on
283301
if rank is None and world_size > 0:
284302
assert world_size < 2, (

0 commit comments

Comments
 (0)