From a47e35e4b9145f2cafc18e809afa00d2d0e01569 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 21 Jul 2021 09:26:43 +0000 Subject: [PATCH] Format Rust code using rustfmt --- ...ecentralized_full_precision_synchronous.rs | 6 +- bagua-core-internal/src/datatypes/mod.rs | 64 ++++++++++--------- bagua-core-py/src/lib.rs | 24 +++---- 3 files changed, 49 insertions(+), 45 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs index 16fc8ea..97ec3b3 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs @@ -1,6 +1,8 @@ use crate::comm_ops::CommOpTrait; use crate::communicators::{BaguaCommunicator, BaguaHierarchicalCommunicator, NCCLGroupGuard}; -use crate::datatypes::{BaguaBucket, BaguaTensor, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor}; +use crate::datatypes::{ + BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorRaw, RawBaguaTensor, +}; use crate::events::BaguaEventChannel; use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL; use crate::{BaguaCommOpChannels, BaguaScheduledCommOp}; @@ -96,7 +98,5 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous { ); *self.step.lock() += 1; - } } - diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index 60865cb..350a8f7 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -1113,22 +1113,24 @@ impl BaguaBucket { let communicator = BaguaCommunicator::new(communicator_internode, communicator_intranode, hierarchical) .expect("cannot create communicator"); - let comm_op: Arc = Arc::new(DecentralizedFullPrecisionSynchronous { - communicator, - peer_selection_mode: match peer_selection_mode.as_str() { - "all" => PeerSelectionMode::All, - "shift_one" => PeerSelectionMode::ShiftOne, - &_ => { - unimplemented!("unsupported peer_selection_mode for decentralized algorithm (should be `all` or `shift_one`)") - } + let comm_op: Arc = Arc::new( + DecentralizedFullPrecisionSynchronous { + communicator, + peer_selection_mode: match peer_selection_mode.as_str() { + "all" => PeerSelectionMode::All, + "shift_one" => PeerSelectionMode::ShiftOne, + &_ => { + unimplemented!("unsupported peer_selection_mode for decentralized algorithm (should be `all` or `shift_one`)") + } + }, + step: Default::default(), + peer_weight, }, - step: Default::default(), - peer_weight, - }); + ); self.inner.lock().comm_ops.push(comm_op); } - + pub fn append_low_precision_decentralized_synchronous_op( &mut self, communicator_internode: Option<&BaguaSingleCommunicator>, @@ -1143,30 +1145,32 @@ impl BaguaBucket { let communicator = BaguaCommunicator::new(communicator_internode, communicator_intranode, hierarchical) .expect("cannot create communicator"); - let comm_op: Arc = Arc::new(DecentralizedLowPrecisionSynchronous { - communicator, - peer_selection_mode: match peer_selection_mode.as_str() { - "ring" => PeerSelectionMode::Ring, - &_ => { - unimplemented!("unsupported peer_selection_mode for low precision decentralized algorithm (should be `ring`)") - } - }, - compression_method: TensorCompressionMethod::MinMaxUInt8( - MinMaxUInt8CompressionParameters {}, - ), - weight, - left_peer_weight, - right_peer_weight, - }); - + let comm_op: Arc = Arc::new( + DecentralizedLowPrecisionSynchronous { + communicator, + peer_selection_mode: match peer_selection_mode.as_str() { + "ring" => PeerSelectionMode::Ring, + &_ => { + unimplemented!("unsupported peer_selection_mode for low precision decentralized algorithm (should be `ring`)") + } + }, + compression_method: TensorCompressionMethod::MinMaxUInt8( + MinMaxUInt8CompressionParameters {}, + ), + weight, + left_peer_weight, + right_peer_weight, + }, + ); + self.inner.lock().comm_ops.push(comm_op); } - + pub fn append_python_op(&mut self, op: pyo3::Py) { let comm_op: Arc = Arc::new(PythonFFIOp { py_callable: op }); self.inner.lock().comm_ops.push(comm_op); } - + /// this function will use communicator_internode to communicate. /// if hierarchical = True, it will do hierarchical communicator, this requires intranode communicator on each node and inter node communicator on leader GPU. leader GPU will be the GPU whose communicator_intranode rank is 0 pub fn append_centralized_synchronous_op( diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index c28bcd4..353539f 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -344,7 +344,6 @@ impl BaguaCommBackendPy { py.allow_threads(|| self.inner.start_upload_telemetry(skip)) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } - } #[pyclass(dict)] @@ -421,7 +420,7 @@ impl BaguaBucketPy { ); Ok(()) } - + #[args(hierarchical = "false", communication_interval = "1")] pub fn append_low_precision_decentralized_synchronous_op( &mut self, @@ -434,16 +433,17 @@ impl BaguaBucketPy { left_peer_weight: PyRef, right_peer_weight: PyRef, ) -> PyResult<()> { - self.inner.append_low_precision_decentralized_synchronous_op( - communicator_internode.map(|x| &x.inner), - communicator_intranode.map(|x| &x.inner), - hierarchical, - peer_selection_mode, - compression, - (*weight).inner.clone(), - (*left_peer_weight).inner.clone(), - (*right_peer_weight).inner.clone(), - ); + self.inner + .append_low_precision_decentralized_synchronous_op( + communicator_internode.map(|x| &x.inner), + communicator_intranode.map(|x| &x.inner), + hierarchical, + peer_selection_mode, + compression, + (*weight).inner.clone(), + (*left_peer_weight).inner.clone(), + (*right_peer_weight).inner.clone(), + ); Ok(()) }