Skip to content
This repository has been archived by the owner on Sep 15, 2021. It is now read-only.

Commit

Permalink
Format Rust code using rustfmt
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] authored Jul 21, 2021
1 parent 98319c9 commit a47e35e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::{BaguaCommunicator, BaguaHierarchicalCommunicator, NCCLGroupGuard};
use crate::datatypes::{BaguaBucket, BaguaTensor, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor};
use crate::datatypes::{
BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorRaw, RawBaguaTensor,
};
use crate::events::BaguaEventChannel;
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::{BaguaCommOpChannels, BaguaScheduledCommOp};
Expand Down Expand Up @@ -96,7 +98,5 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
);

*self.step.lock() += 1;

}
}

64 changes: 34 additions & 30 deletions bagua-core-internal/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1113,22 +1113,24 @@ impl BaguaBucket {
let communicator =
BaguaCommunicator::new(communicator_internode, communicator_intranode, hierarchical)
.expect("cannot create communicator");
let comm_op: Arc<dyn CommOpTrait + Send + Sync> = Arc::new(DecentralizedFullPrecisionSynchronous {
communicator,
peer_selection_mode: match peer_selection_mode.as_str() {
"all" => PeerSelectionMode::All,
"shift_one" => PeerSelectionMode::ShiftOne,
&_ => {
unimplemented!("unsupported peer_selection_mode for decentralized algorithm (should be `all` or `shift_one`)")
}
let comm_op: Arc<dyn CommOpTrait + Send + Sync> = Arc::new(
DecentralizedFullPrecisionSynchronous {
communicator,
peer_selection_mode: match peer_selection_mode.as_str() {
"all" => PeerSelectionMode::All,
"shift_one" => PeerSelectionMode::ShiftOne,
&_ => {
unimplemented!("unsupported peer_selection_mode for decentralized algorithm (should be `all` or `shift_one`)")
}
},
step: Default::default(),
peer_weight,
},
step: Default::default(),
peer_weight,
});
);

self.inner.lock().comm_ops.push(comm_op);
}

pub fn append_low_precision_decentralized_synchronous_op(
&mut self,
communicator_internode: Option<&BaguaSingleCommunicator>,
Expand All @@ -1143,30 +1145,32 @@ impl BaguaBucket {
let communicator =
BaguaCommunicator::new(communicator_internode, communicator_intranode, hierarchical)
.expect("cannot create communicator");
let comm_op: Arc<dyn CommOpTrait + Send + Sync> = Arc::new(DecentralizedLowPrecisionSynchronous {
communicator,
peer_selection_mode: match peer_selection_mode.as_str() {
"ring" => PeerSelectionMode::Ring,
&_ => {
unimplemented!("unsupported peer_selection_mode for low precision decentralized algorithm (should be `ring`)")
}
},
compression_method: TensorCompressionMethod::MinMaxUInt8(
MinMaxUInt8CompressionParameters {},
),
weight,
left_peer_weight,
right_peer_weight,
});

let comm_op: Arc<dyn CommOpTrait + Send + Sync> = Arc::new(
DecentralizedLowPrecisionSynchronous {
communicator,
peer_selection_mode: match peer_selection_mode.as_str() {
"ring" => PeerSelectionMode::Ring,
&_ => {
unimplemented!("unsupported peer_selection_mode for low precision decentralized algorithm (should be `ring`)")
}
},
compression_method: TensorCompressionMethod::MinMaxUInt8(
MinMaxUInt8CompressionParameters {},
),
weight,
left_peer_weight,
right_peer_weight,
},
);

self.inner.lock().comm_ops.push(comm_op);
}

pub fn append_python_op(&mut self, op: pyo3::Py<pyo3::PyAny>) {
let comm_op: Arc<dyn CommOpTrait + Send + Sync> = Arc::new(PythonFFIOp { py_callable: op });
self.inner.lock().comm_ops.push(comm_op);
}

/// this function will use communicator_internode to communicate.
/// if hierarchical = True, it will do hierarchical communicator, this requires intranode communicator on each node and inter node communicator on leader GPU. leader GPU will be the GPU whose communicator_intranode rank is 0
pub fn append_centralized_synchronous_op(
Expand Down
24 changes: 12 additions & 12 deletions bagua-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ impl BaguaCommBackendPy {
py.allow_threads(|| self.inner.start_upload_telemetry(skip))
.map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))
}

}

#[pyclass(dict)]
Expand Down Expand Up @@ -421,7 +420,7 @@ impl BaguaBucketPy {
);
Ok(())
}

#[args(hierarchical = "false", communication_interval = "1")]
pub fn append_low_precision_decentralized_synchronous_op(
&mut self,
Expand All @@ -434,16 +433,17 @@ impl BaguaBucketPy {
left_peer_weight: PyRef<BaguaTensorPy>,
right_peer_weight: PyRef<BaguaTensorPy>,
) -> PyResult<()> {
self.inner.append_low_precision_decentralized_synchronous_op(
communicator_internode.map(|x| &x.inner),
communicator_intranode.map(|x| &x.inner),
hierarchical,
peer_selection_mode,
compression,
(*weight).inner.clone(),
(*left_peer_weight).inner.clone(),
(*right_peer_weight).inner.clone(),
);
self.inner
.append_low_precision_decentralized_synchronous_op(
communicator_internode.map(|x| &x.inner),
communicator_intranode.map(|x| &x.inner),
hierarchical,
peer_selection_mode,
compression,
(*weight).inner.clone(),
(*left_peer_weight).inner.clone(),
(*right_peer_weight).inner.clone(),
);
Ok(())
}

Expand Down

0 comments on commit a47e35e

Please sign in to comment.