From 21362c4ff1dd47fcbb5ec59ea7979ba87b2e25f8 Mon Sep 17 00:00:00 2001 From: Nazar Mokrynskyi Date: Wed, 20 Nov 2024 02:55:14 +0200 Subject: [PATCH] Simplify GPU plotting logic --- .../commands/cluster/plotter.rs | 10 +- .../src/bin/subspace-farmer/commands/farm.rs | 10 +- .../subspace-farmer/src/plotter/gpu/cuda.rs | 91 ++++--------------- .../subspace-farmer/src/plotter/gpu/rocm.rs | 73 ++++----------- 4 files changed, 36 insertions(+), 148 deletions(-) diff --git a/crates/subspace-farmer/src/bin/subspace-farmer/commands/cluster/plotter.rs b/crates/subspace-farmer/src/bin/subspace-farmer/commands/cluster/plotter.rs index 26927d8262..49af09578d 100644 --- a/crates/subspace-farmer/src/bin/subspace-farmer/commands/cluster/plotter.rs +++ b/crates/subspace-farmer/src/bin/subspace-farmer/commands/cluster/plotter.rs @@ -398,10 +398,7 @@ where cuda_devices .into_iter() .map(|cuda_device| CudaRecordsEncoder::new(cuda_device, Arc::clone(&global_mutex))) - .collect::>() - .map_err(|error| { - anyhow::anyhow!("Failed to create CUDA records encoder: {error}") - })?, + .collect(), global_mutex, kzg, erasure_coding, @@ -480,10 +477,7 @@ where rocm_devices .into_iter() .map(|rocm_device| RocmRecordsEncoder::new(rocm_device, Arc::clone(&global_mutex))) - .collect::>() - .map_err(|error| { - anyhow::anyhow!("Failed to create ROCm records encoder: {error}") - })?, + .collect(), global_mutex, kzg, erasure_coding, diff --git a/crates/subspace-farmer/src/bin/subspace-farmer/commands/farm.rs b/crates/subspace-farmer/src/bin/subspace-farmer/commands/farm.rs index 1379f42025..a6736ed127 100644 --- a/crates/subspace-farmer/src/bin/subspace-farmer/commands/farm.rs +++ b/crates/subspace-farmer/src/bin/subspace-farmer/commands/farm.rs @@ -1072,10 +1072,7 @@ where cuda_devices .into_iter() .map(|cuda_device| CudaRecordsEncoder::new(cuda_device, Arc::clone(&global_mutex))) - .collect::>() - .map_err(|error| { - anyhow::anyhow!("Failed to create CUDA records encoder: {error}") - })?, + .collect(), global_mutex, kzg, erasure_coding, @@ -1154,10 +1151,7 @@ where rocm_devices .into_iter() .map(|rocm_device| RocmRecordsEncoder::new(rocm_device, Arc::clone(&global_mutex))) - .collect::>() - .map_err(|error| { - anyhow::anyhow!("Failed to create ROCm records encoder: {error}") - })?, + .collect(), global_mutex, kzg, erasure_coding, diff --git a/crates/subspace-farmer/src/plotter/gpu/cuda.rs b/crates/subspace-farmer/src/plotter/gpu/cuda.rs index 370439edf5..bbbedff731 100644 --- a/crates/subspace-farmer/src/plotter/gpu/cuda.rs +++ b/crates/subspace-farmer/src/plotter/gpu/cuda.rs @@ -2,9 +2,6 @@ use crate::plotter::gpu::GpuRecordsEncoder; use async_lock::Mutex as AsyncMutex; -use parking_lot::Mutex; -use rayon::{current_thread_index, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder}; -use std::process::exit; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use subspace_core_primitives::pieces::{PieceOffset, Record}; @@ -17,7 +14,6 @@ use subspace_proof_of_space_gpu::cuda::CudaDevice; #[derive(Debug)] pub struct CudaRecordsEncoder { cuda_device: CudaDevice, - thread_pool: ThreadPool, global_mutex: Arc>, } @@ -38,50 +34,23 @@ impl RecordsEncoder for CudaRecordsEncoder { .map_err(|error| anyhow::anyhow!("Failed to convert pieces in sector: {error}"))?; let mut sector_contents_map = SectorContentsMap::new(pieces_in_sector); - self.thread_pool.install(|| { - let iter = Mutex::new( - (PieceOffset::ZERO..) - .zip(records.iter_mut()) - .zip(sector_contents_map.iter_record_bitfields_mut()), - ); - let plotting_error = Mutex::new(None::); + for ((piece_offset, record), mut encoded_chunks_used) in (PieceOffset::ZERO..) + .zip(records.iter_mut()) + .zip(sector_contents_map.iter_record_bitfields_mut()) + { + // Take mutex briefly to make sure encoding is allowed right now + self.global_mutex.lock_blocking(); - rayon::scope(|scope| { - scope.spawn_broadcast(|_scope, _ctx| loop { - // Take mutex briefly to make sure encoding is allowed right now - self.global_mutex.lock_blocking(); + let pos_seed = sector_id.derive_evaluation_seed(piece_offset); - // This instead of `while` above because otherwise mutex will be held for the - // duration of the loop and will limit concurrency to 1 record - let Some(((piece_offset, record), mut encoded_chunks_used)) = - iter.lock().next() - else { - return; - }; - let pos_seed = sector_id.derive_evaluation_seed(piece_offset); + self.cuda_device + .generate_and_encode_pospace(&pos_seed, record, encoded_chunks_used.iter_mut()) + .map_err(anyhow::Error::msg)?; - if let Err(error) = self.cuda_device.generate_and_encode_pospace( - &pos_seed, - record, - encoded_chunks_used.iter_mut(), - ) { - plotting_error.lock().replace(error); - return; - } - - if abort_early.load(Ordering::Relaxed) { - return; - } - }); - }); - - let plotting_error = plotting_error.lock().take(); - if let Some(error) = plotting_error { - return Err(anyhow::Error::msg(error)); + if abort_early.load(Ordering::Relaxed) { + break; } - - Ok(()) - })?; + } Ok(sector_contents_map) } @@ -89,38 +58,10 @@ impl RecordsEncoder for CudaRecordsEncoder { impl CudaRecordsEncoder { /// Create new instance - pub fn new( - cuda_device: CudaDevice, - global_mutex: Arc>, - ) -> Result { - let id = cuda_device.id(); - let thread_name = move |thread_index| format!("cuda-{id}.{thread_index}"); - // TODO: remove this panic handler when rayon logs panic_info - // https://github.com/rayon-rs/rayon/issues/1208 - let panic_handler = move |panic_info| { - if let Some(index) = current_thread_index() { - eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info); - } else { - // We want to guarantee exit, rather than panicking in a panic handler. - eprintln!( - "rayon panic handler called on non-rayon thread: {:?}", - panic_info - ); - } - exit(1); - }; - - let thread_pool = ThreadPoolBuilder::new() - .thread_name(thread_name) - .panic_handler(panic_handler) - // Make sure there is overlap between records, so GPU is almost always busy - .num_threads(2) - .build()?; - - Ok(Self { + pub fn new(cuda_device: CudaDevice, global_mutex: Arc>) -> Self { + Self { cuda_device, - thread_pool, global_mutex, - }) + } } } diff --git a/crates/subspace-farmer/src/plotter/gpu/rocm.rs b/crates/subspace-farmer/src/plotter/gpu/rocm.rs index c86161f47a..2aef4c482a 100644 --- a/crates/subspace-farmer/src/plotter/gpu/rocm.rs +++ b/crates/subspace-farmer/src/plotter/gpu/rocm.rs @@ -2,8 +2,6 @@ use crate::plotter::gpu::GpuRecordsEncoder; use async_lock::Mutex as AsyncMutex; -use parking_lot::Mutex; -use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use subspace_core_primitives::pieces::{PieceOffset, Record}; @@ -16,7 +14,6 @@ use subspace_proof_of_space_gpu::rocm::RocmDevice; #[derive(Debug)] pub struct RocmRecordsEncoder { rocm_device: RocmDevice, - thread_pool: ThreadPool, global_mutex: Arc>, } @@ -37,50 +34,23 @@ impl RecordsEncoder for RocmRecordsEncoder { .map_err(|error| anyhow::anyhow!("Failed to convert pieces in sector: {error}"))?; let mut sector_contents_map = SectorContentsMap::new(pieces_in_sector); - self.thread_pool.install(|| { - let iter = Mutex::new( - (PieceOffset::ZERO..) - .zip(records.iter_mut()) - .zip(sector_contents_map.iter_record_bitfields_mut()), - ); - let plotting_error = Mutex::new(None::); + for ((piece_offset, record), mut encoded_chunks_used) in (PieceOffset::ZERO..) + .zip(records.iter_mut()) + .zip(sector_contents_map.iter_record_bitfields_mut()) + { + // Take mutex briefly to make sure encoding is allowed right now + self.global_mutex.lock_blocking(); - rayon::scope(|scope| { - scope.spawn_broadcast(|_scope, _ctx| loop { - // Take mutex briefly to make sure encoding is allowed right now - self.global_mutex.lock_blocking(); + let pos_seed = sector_id.derive_evaluation_seed(piece_offset); - // This instead of `while` above because otherwise mutex will be held for the - // duration of the loop and will limit concurrency to 1 record - let Some(((piece_offset, record), mut encoded_chunks_used)) = - iter.lock().next() - else { - return; - }; - let pos_seed = sector_id.derive_evaluation_seed(piece_offset); + self.rocm_device + .generate_and_encode_pospace(&pos_seed, record, encoded_chunks_used.iter_mut()) + .map_err(anyhow::Error::msg)?; - if let Err(error) = self.rocm_device.generate_and_encode_pospace( - &pos_seed, - record, - encoded_chunks_used.iter_mut(), - ) { - plotting_error.lock().replace(error); - return; - } - - if abort_early.load(Ordering::Relaxed) { - return; - } - }); - }); - - let plotting_error = plotting_error.lock().take(); - if let Some(error) = plotting_error { - return Err(anyhow::Error::msg(error)); + if abort_early.load(Ordering::Relaxed) { + break; } - - Ok(()) - })?; + } Ok(sector_contents_map) } @@ -88,21 +58,10 @@ impl RecordsEncoder for RocmRecordsEncoder { impl RocmRecordsEncoder { /// Create new instance - pub fn new( - rocm_device: RocmDevice, - global_mutex: Arc>, - ) -> Result { - let id = rocm_device.id(); - let thread_pool = ThreadPoolBuilder::new() - .thread_name(move |thread_index| format!("rocm-{id}.{thread_index}")) - // Make sure there is overlap between records, so GPU is almost always busy - .num_threads(2) - .build()?; - - Ok(Self { + pub fn new(rocm_device: RocmDevice, global_mutex: Arc>) -> Self { + Self { rocm_device, - thread_pool, global_mutex, - }) + } } }