@@ -18,6 +18,8 @@ use subspace_farmer::cluster::plotter::plotter_service;
18
18
use subspace_farmer:: plotter:: cpu:: CpuPlotter ;
19
19
#[ cfg( feature = "cuda" ) ]
20
20
use subspace_farmer:: plotter:: gpu:: cuda:: CudaRecordsEncoder ;
21
+ #[ cfg( feature = "rocm" ) ]
22
+ use subspace_farmer:: plotter:: gpu:: rocm:: RocmRecordsEncoder ;
21
23
#[ cfg( feature = "_gpu" ) ]
22
24
use subspace_farmer:: plotter:: gpu:: GpuPlotter ;
23
25
use subspace_farmer:: plotter:: pool:: PoolPlotter ;
@@ -101,6 +103,24 @@ struct CudaPlottingOptions {
101
103
cuda_gpus : Option < String > ,
102
104
}
103
105
106
+ #[ cfg( feature = "rocm" ) ]
107
+ #[ derive( Debug , Parser ) ]
108
+ struct RocmPlottingOptions {
109
+ /// Defines how many sectors farmer will download concurrently during plotting with ROCm GPU,
110
+ /// allows to limit memory usage of the plotting process, defaults to number of ROCm GPUs found
111
+ /// + 1 to download future sector ahead of time.
112
+ ///
113
+ /// Increase will result in higher memory usage.
114
+ #[ arg( long) ]
115
+ rocm_sector_downloading_concurrency : Option < NonZeroUsize > ,
116
+ /// Specify exact GPUs to be used for plotting instead of using all GPUs (default behavior).
117
+ ///
118
+ /// GPUs are coma-separated: `--rocm-gpus 0,1,3`. Empty string can be specified to disable ROCm
119
+ /// GPU usage.
120
+ #[ arg( long) ]
121
+ rocm_gpus : Option < String > ,
122
+ }
123
+
104
124
/// Arguments for plotter
105
125
#[ derive( Debug , Parser ) ]
106
126
pub ( super ) struct PlotterArgs {
@@ -118,6 +138,10 @@ pub(super) struct PlotterArgs {
118
138
#[ cfg( feature = "cuda" ) ]
119
139
#[ clap( flatten) ]
120
140
cuda_plotting_options : CudaPlottingOptions ,
141
+ /// Plotting options only used by ROCm GPU plotter
142
+ #[ cfg( feature = "rocm" ) ]
143
+ #[ clap( flatten) ]
144
+ rocm_plotting_options : RocmPlottingOptions ,
121
145
/// Additional cluster components
122
146
#[ clap( raw = true ) ]
123
147
pub ( super ) additional_components : Vec < String > ,
@@ -137,6 +161,8 @@ where
137
161
cpu_plotting_options,
138
162
#[ cfg( feature = "cuda" ) ]
139
163
cuda_plotting_options,
164
+ #[ cfg ( feature = "rocm" ) ]
165
+ rocm_plotting_options,
140
166
additional_components: _,
141
167
} = plotter_args;
142
168
@@ -168,6 +194,21 @@ where
168
194
modern_plotters. push ( Box :: new ( cuda_plotter) ) ;
169
195
}
170
196
}
197
+ #[ cfg( feature = "rocm" ) ]
198
+ {
199
+ let maybe_rocm_plotter = init_rocm_plotter (
200
+ rocm_plotting_options,
201
+ piece_getter. clone ( ) ,
202
+ Arc :: clone ( & global_mutex) ,
203
+ kzg. clone ( ) ,
204
+ erasure_coding. clone ( ) ,
205
+ registry,
206
+ ) ?;
207
+
208
+ if let Some ( rocm_plotter) = maybe_rocm_plotter {
209
+ modern_plotters. push ( Box :: new ( rocm_plotter) ) ;
210
+ }
211
+ }
171
212
{
172
213
let cpu_sector_encoding_concurrency = cpu_plotting_options. cpu_sector_encoding_concurrency ;
173
214
let maybe_cpu_plotters = init_cpu_plotters :: < _ , PosTableLegacy , PosTable > (
@@ -401,3 +442,85 @@ where
401
442
. map_err ( |error| anyhow:: anyhow!( "Failed to initialize CUDA plotter: {error}" ) ) ?,
402
443
) )
403
444
}
445
+
446
+ #[ cfg( feature = "rocm" ) ]
447
+ fn init_rocm_plotter < PG > (
448
+ rocm_plotting_options : RocmPlottingOptions ,
449
+ piece_getter : PG ,
450
+ global_mutex : Arc < AsyncMutex < ( ) > > ,
451
+ kzg : Kzg ,
452
+ erasure_coding : ErasureCoding ,
453
+ registry : & mut Registry ,
454
+ ) -> anyhow:: Result < Option < GpuPlotter < PG , RocmRecordsEncoder > > >
455
+ where
456
+ PG : PieceGetter + Clone + Send + Sync + ' static ,
457
+ {
458
+ use std:: collections:: BTreeSet ;
459
+ use subspace_proof_of_space_gpu:: rocm:: rocm_devices;
460
+ use tracing:: { debug, warn} ;
461
+
462
+ let RocmPlottingOptions {
463
+ rocm_sector_downloading_concurrency,
464
+ rocm_gpus,
465
+ } = rocm_plotting_options;
466
+
467
+ let mut rocm_devices = rocm_devices ( ) ;
468
+ let mut used_rocm_devices = ( 0 ..rocm_devices. len ( ) ) . collect :: < Vec < _ > > ( ) ;
469
+
470
+ if let Some ( rocm_gpus) = rocm_gpus {
471
+ if rocm_gpus. is_empty ( ) {
472
+ info ! ( "ROCm GPU plotting was explicitly disabled" ) ;
473
+ return Ok ( None ) ;
474
+ }
475
+
476
+ let mut rocm_gpus_to_use = rocm_gpus
477
+ . split ( ',' )
478
+ . map ( |gpu_index| gpu_index. parse ( ) )
479
+ . collect :: < Result < BTreeSet < usize > , _ > > ( ) ?;
480
+
481
+ ( used_rocm_devices, rocm_devices) = rocm_devices
482
+ . into_iter ( )
483
+ . enumerate ( )
484
+ . filter ( |( index, _rocm_device) | rocm_gpus_to_use. remove ( index) )
485
+ . unzip ( ) ;
486
+
487
+ if !rocm_gpus_to_use. is_empty ( ) {
488
+ warn ! (
489
+ ?rocm_gpus_to_use,
490
+ "Some ROCm GPUs were not found on the system"
491
+ ) ;
492
+ }
493
+ }
494
+
495
+ if rocm_devices. is_empty ( ) {
496
+ debug ! ( "No ROCm GPU devices found" ) ;
497
+ return Ok ( None ) ;
498
+ }
499
+
500
+ info ! ( ?used_rocm_devices, "Using ROCm GPUs" ) ;
501
+
502
+ let rocm_downloading_semaphore = Arc :: new ( Semaphore :: new (
503
+ rocm_sector_downloading_concurrency
504
+ . map ( |rocm_sector_downloading_concurrency| rocm_sector_downloading_concurrency. get ( ) )
505
+ . unwrap_or ( rocm_devices. len ( ) + 1 ) ,
506
+ ) ) ;
507
+
508
+ Ok ( Some (
509
+ GpuPlotter :: new (
510
+ piece_getter,
511
+ rocm_downloading_semaphore,
512
+ rocm_devices
513
+ . into_iter ( )
514
+ . map ( |rocm_device| RocmRecordsEncoder :: new ( rocm_device, Arc :: clone ( & global_mutex) ) )
515
+ . collect :: < Result < _ , _ > > ( )
516
+ . map_err ( |error| {
517
+ anyhow:: anyhow!( "Failed to create ROCm records encoder: {error}" )
518
+ } ) ?,
519
+ global_mutex,
520
+ kzg,
521
+ erasure_coding,
522
+ Some ( registry) ,
523
+ )
524
+ . map_err ( |error| anyhow:: anyhow!( "Failed to initialize ROCm plotter: {error}" ) ) ?,
525
+ ) )
526
+ }
0 commit comments