From 4899f503523ccda28dbccbab6da791f3d0863f4a Mon Sep 17 00:00:00 2001 From: NOBLES5E Date: Thu, 17 Jun 2021 04:43:13 -0700 Subject: [PATCH] feat: initial support for python op (#2) BREAKING CHANGE: `set_xxx_op`s are renamed to `append_xxx_op` now. --- .gitignore | 3 + Cargo.lock | 7 ++- bagua-core-c/Cargo.toml | 2 +- bagua-core-internal/Cargo.toml | 5 +- ...ecentralized_full_precision_synchronous.rs | 5 +- bagua-core-internal/src/comm_ops/mod.rs | 1 + .../src/comm_ops/python_ffi_op.rs | 27 ++++++++ bagua-core-internal/src/datatypes/mod.rs | 21 +++++-- bagua-core-internal/src/lib.rs | 36 ++++++----- bagua-core-py/Cargo.toml | 2 +- bagua-core-py/src/lib.rs | 63 ++++++++++++------- 11 files changed, 120 insertions(+), 52 deletions(-) create mode 100644 bagua-core-internal/src/comm_ops/python_ffi_op.rs diff --git a/.gitignore b/.gitignore index a8053c7..6de1857 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,6 @@ push.sh __pycache__/ *.egg-info/ /dist/ +/.eggs/ +/build/ +.data/ diff --git a/Cargo.lock b/Cargo.lock index a1cf50d..c135d36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,7 +75,7 @@ dependencies = [ [[package]] name = "bagua-core-c" -version = "0.1.0" +version = "0.1.2" dependencies = [ "anyhow", "bagua-core-internal", @@ -95,7 +95,7 @@ dependencies = [ [[package]] name = "bagua-core-internal" -version = "0.1.0" +version = "0.1.2" dependencies = [ "base64", "cc", @@ -110,6 +110,7 @@ dependencies = [ "once_cell", "oneshot", "parking_lot", + "pyo3", "scheduled-thread-pool", "serde", "serde_json", @@ -124,7 +125,7 @@ dependencies = [ [[package]] name = "bagua-core-py" -version = "0.1.0" +version = "0.1.2" dependencies = [ "anyhow", "bagua-core-internal", diff --git a/bagua-core-c/Cargo.toml b/bagua-core-c/Cargo.toml index 9bffab3..b5e9506 100644 --- a/bagua-core-c/Cargo.toml +++ b/bagua-core-c/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bagua-core-c" -version = "0.1.0" +version = "0.1.2" edition = "2018" [lib] diff --git a/bagua-core-internal/Cargo.toml b/bagua-core-internal/Cargo.toml index faa151f..3816d49 100644 --- a/bagua-core-internal/Cargo.toml +++ b/bagua-core-internal/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bagua-core-internal" -version = "0.1.0" +version = "0.1.2" authors = ["Xiangru Lian "] edition = "2018" publish = ["private"] @@ -27,6 +27,9 @@ scheduled-thread-pool = "0.2" serde_json = "1.0" ureq = "2.1" +[dependencies.pyo3] +version = "0.13.2" + [build-dependencies] shadow-rs = "0.5" cpp_build = "0.5" diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs index 28c877e..75d7548 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_synchronous.rs @@ -105,12 +105,13 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous { ); if step % comm_interval == 0 { + // TODO: move this to .then() python API instead of hard code this in op let post_backward_comm_op = BaguaScheduledCommOp { bucket: bucket.clone(), - op: Arc::new(DecentralizedFullPrecisionSynchronousPostStep { + ops: vec![Arc::new(DecentralizedFullPrecisionSynchronousPostStep { communicator: self.communicator.clone(), result_weight: peer_tensor, - }), + })], event_channel: Default::default(), }; diff --git a/bagua-core-internal/src/comm_ops/mod.rs b/bagua-core-internal/src/comm_ops/mod.rs index dc50449..cc02513 100644 --- a/bagua-core-internal/src/comm_ops/mod.rs +++ b/bagua-core-internal/src/comm_ops/mod.rs @@ -1,6 +1,7 @@ pub mod centralized_full_precision_synchronous; pub mod centralized_low_precision_synchronous; pub mod decentralized_full_precision_synchronous; +pub mod python_ffi_op; use crate::datatypes::BaguaBucket; use crate::BaguaCommOpChannels; diff --git a/bagua-core-internal/src/comm_ops/python_ffi_op.rs b/bagua-core-internal/src/comm_ops/python_ffi_op.rs new file mode 100644 index 0000000..ad42395 --- /dev/null +++ b/bagua-core-internal/src/comm_ops/python_ffi_op.rs @@ -0,0 +1,27 @@ +use crate::comm_ops::CommOpTrait; +use crate::communicators::BaguaCommunicator; +use crate::datatypes::{BaguaBucket, BaguaTensorRaw}; +use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL; +use crate::BaguaCommOpChannels; +use pyo3::Python; +use std::sync::Arc; + +#[derive(Debug)] +pub struct PythonFFIOp { + pub py_callable: pyo3::Py, +} + +impl CommOpTrait for PythonFFIOp { + fn execute_background_communication( + &self, + bucket: Arc, + _comm_op_channels: &BaguaCommOpChannels, + ) { + Python::with_gil(|python| { + let result = self.py_callable.call1(python, (bucket.name.as_str(),)); + if let Err(e) = result { + tracing::error!("python ffi op error: {:?}", e); + } + }); + } +} diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index 1e9a389..8af5b4e 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -3,6 +3,7 @@ use crate::comm_ops::centralized_low_precision_synchronous::CentralizedLowPrecis use crate::comm_ops::decentralized_full_precision_synchronous::{ DecentralizedFullPrecisionSynchronous, PeerSelectionMode, }; +use crate::comm_ops::python_ffi_op::PythonFFIOp; use crate::comm_ops::CommOpTrait; use crate::communicators::{BaguaCommunicator, BaguaSingleCommunicator}; use crate::resource_pool::{CudaMemory, CUDA_DEVICE_MEMORY_POOL}; @@ -586,7 +587,7 @@ pub struct BaguaBucketInner { pub tensors: Vec, pub dtype: BaguaTensorDtype, pub inplace: bool, - pub comm_op: Option>, + pub comm_ops: Vec>, pub align_bytes: usize, } @@ -734,12 +735,14 @@ impl<'b> Drop for BaguaCommunicationTensor<'b> { #[derive(Debug, Clone)] pub struct BaguaBucket { pub id: u64, + pub name: String, pub inner: Arc>, } impl BaguaBucket { pub fn new( tensors: &[&BaguaTensor], + name: &str, inplace: bool, align_bytes: usize, ) -> Result { @@ -812,10 +815,11 @@ impl BaguaBucket { let id = lazy_id::Id::lazy().get(); Ok(Self { id, + name: name.to_owned(), inner: Arc::new(Mutex::new(BaguaBucketInner { inplace, tensors: tensors.iter().map(|x| (**x).clone()).collect(), - comm_op: None, + comm_ops: vec![], dtype: tensors.first().unwrap().inner.read().raw.dtype.clone(), align_bytes, })), @@ -826,7 +830,7 @@ impl BaguaBucket { self.inner.lock().tensors.clone() } - pub fn set_decentralized_synchronous_op( + pub fn append_decentralized_synchronous_op( &mut self, communicator_internode: Option<&BaguaSingleCommunicator>, communicator_intranode: Option<&BaguaSingleCommunicator>, @@ -857,12 +861,17 @@ impl BaguaBucket { } }, }; - self.inner.lock().comm_op = Some(comm_op); + self.inner.lock().comm_ops.push(comm_op); + } + + pub fn append_python_op(&mut self, op: pyo3::Py) { + let comm_op: Arc = Arc::new(PythonFFIOp { py_callable: op }); + self.inner.lock().comm_ops.push(comm_op); } /// this function will use communicator_internode to communicate. /// if hierarchical = True, it will do hierarchical communicator, this requires intranode communicator on each node and inter node communicator on leader GPU. leader GPU will be the GPU whose communicator_intranode rank is 0 - pub fn set_centralized_synchronous_op( + pub fn append_centralized_synchronous_op( &mut self, communicator_internode: Option<&BaguaSingleCommunicator>, communicator_intranode: Option<&BaguaSingleCommunicator>, @@ -893,7 +902,7 @@ impl BaguaBucket { } }, }; - self.inner.lock().comm_op = Some(comm_op); + self.inner.lock().comm_ops.push(comm_op); } pub fn ready_for_comm(&self) -> bool { diff --git a/bagua-core-internal/src/lib.rs b/bagua-core-internal/src/lib.rs index eace251..82d964f 100644 --- a/bagua-core-internal/src/lib.rs +++ b/bagua-core-internal/src/lib.rs @@ -54,7 +54,7 @@ pub enum BaguaCoreError { #[derive(Debug)] pub struct BaguaScheduledCommOp { pub bucket: Arc, - pub op: Arc, + pub ops: Vec>, pub event_channel: BaguaEventChannel, } @@ -125,14 +125,17 @@ pub struct BaguaCommBackend { impl BaguaCommBackend { pub fn schedule_comm(&self, bucket: Arc) -> Result<(), BaguaCoreError> { let event_channel = BaguaEventChannel::default(); - self.channels.schedule_channel_sender.send(BaguaScheduledCommOp { - op: { - let guard = bucket.inner.lock(); - guard.comm_op.clone().expect("bucket must have communication operator set before scheduled for communication") - }, - bucket, - event_channel: event_channel.clone(), - }).map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?; + self.channels + .schedule_channel_sender + .send(BaguaScheduledCommOp { + ops: { + let guard = bucket.inner.lock(); + guard.comm_ops.clone() + }, + bucket, + event_channel: event_channel.clone(), + }) + .map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?; Ok(self .channels .not_waited_events_sender @@ -187,9 +190,12 @@ impl BaguaCommBackend { "worker received scheduled communication operation {:?}", comm_op ); - comm_op - .op - .execute_background_communication(comm_op.bucket.clone(), &channels_clone); + for op in &comm_op.ops { + op.execute_background_communication( + comm_op.bucket.clone(), + &channels_clone, + ); + } tracing::debug!("comm op executed: {:?}", comm_op); comm_op.event_channel.finish(); tracing::debug!("comm op marked finished: {:?}", comm_op); @@ -292,9 +298,9 @@ impl BaguaCommBackend { match comm_op { Ok(comm_op) => { tracing::debug!("received post step communication operation {:?}", comm_op); - comm_op - .op - .execute_background_communication(comm_op.bucket.clone(), &self.channels); + for op in &comm_op.ops { + op.execute_background_communication(comm_op.bucket.clone(), &self.channels); + } tracing::debug!("comm op executed: {:?}", comm_op); comm_op.event_channel.finish(); tracing::debug!("comm op marked finished: {:?}", comm_op); diff --git a/bagua-core-py/Cargo.toml b/bagua-core-py/Cargo.toml index 20f30cc..3721441 100644 --- a/bagua-core-py/Cargo.toml +++ b/bagua-core-py/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bagua-core-py" -version = "0.1.0" +version = "0.1.2" authors = ["Xiangru Lian "] edition = "2018" publish = ["private"] diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index 4f99254..de8a6d5 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -6,6 +6,8 @@ use bagua_core_internal::BaguaCommBackend; use numpy::{IntoPyArray, PyArray1}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; +use pyo3::PyNativeType; +use std::sync::Arc; #[pyclass(dict)] pub struct BaguaSingleCommunicatorPy { @@ -197,9 +199,8 @@ impl BaguaCommBackendPy { .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } - pub fn wait_pending_comm_ops(&self) -> PyResult { - self.inner - .wait_pending_comm_ops() + pub fn wait_pending_comm_ops(&self, py: Python) -> PyResult { + py.allow_threads(|| self.inner.wait_pending_comm_ops()) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } @@ -215,9 +216,8 @@ impl BaguaCommBackendPy { .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } - pub fn wait_pending_post_backward_comm_ops(&self) -> PyResult { - self.inner - .wait_pending_post_backward_comm_ops() + pub fn wait_pending_post_backward_comm_ops(&self, py: Python) -> PyResult { + py.allow_threads(|| self.inner.wait_pending_post_backward_comm_ops()) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } } @@ -232,6 +232,7 @@ impl BaguaBucketPy { #[new] #[args(align_bytes = "0")] pub fn new( + name: &str, tensors: Vec>, inplace: bool, align_bytes: usize, @@ -241,7 +242,7 @@ impl BaguaBucketPy { tensors_inner.push(&t.inner) } Ok(Self { - inner: BaguaBucket::new(tensors_inner.as_slice(), inplace, align_bytes) + inner: BaguaBucket::new(tensors_inner.as_slice(), name, inplace, align_bytes) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?, }) } @@ -254,50 +255,66 @@ impl BaguaBucketPy { .collect() } - #[args(hierarchical = "false", communication_interval = "1")] - pub fn set_decentralized_synchronous_op( + pub fn append_python_op(&mut self, op: &PyAny) -> PyResult<()> { + assert!(op.is_callable(), "python op should be a callable"); + self.inner.append_python_op(op.into_py(op.py())); + Ok(()) + } + + /// this function will use communicator_internode to communicate. + /// if hierarchical = True, it will do hierarchical communicator, this requires intranode communicator on each node and inter node communicator on leader GPU. leader GPU will be the GPU whose communicator_intranode rank is 0 + #[args(average = "true", hierarchical = "false", scattergather = "false")] + pub fn append_centralized_synchronous_op( &mut self, communicator_internode: Option<&BaguaSingleCommunicatorPy>, communicator_intranode: Option<&BaguaSingleCommunicatorPy>, hierarchical: bool, - peer_selection_mode: String, - communication_interval: usize, + average: bool, + scattergather: bool, compression: Option, ) -> PyResult<()> { - self.inner.set_decentralized_synchronous_op( + self.inner.append_centralized_synchronous_op( communicator_internode.map(|x| &x.inner), communicator_intranode.map(|x| &x.inner), hierarchical, - peer_selection_mode, - communication_interval, + average, + scattergather, compression, ); Ok(()) } - /// this function will use communicator_internode to communicate. - /// if hierarchical = True, it will do hierarchical communicator, this requires intranode communicator on each node and inter node communicator on leader GPU. leader GPU will be the GPU whose communicator_intranode rank is 0 - #[args(average = "true", hierarchical = "false", scattergather = "false")] - pub fn set_centralized_synchronous_op( + #[args(hierarchical = "false", communication_interval = "1")] + pub fn append_decentralized_synchronous_op( &mut self, communicator_internode: Option<&BaguaSingleCommunicatorPy>, communicator_intranode: Option<&BaguaSingleCommunicatorPy>, hierarchical: bool, - average: bool, - scattergather: bool, + peer_selection_mode: String, + communication_interval: usize, compression: Option, ) -> PyResult<()> { - self.inner.set_centralized_synchronous_op( + self.inner.append_decentralized_synchronous_op( communicator_internode.map(|x| &x.inner), communicator_intranode.map(|x| &x.inner), hierarchical, - average, - scattergather, + peer_selection_mode, + communication_interval, compression, ); Ok(()) } + pub fn print_ops(&self) -> PyResult<()> { + println!("{:?}", self.inner.inner.lock().comm_ops); + Ok(()) + } + + pub fn clear_ops(&mut self) -> PyResult<()> { + self.inner.inner.lock().comm_ops.clear(); + Ok(()) + } + pub fn ready_for_comm(&self) -> bool { self.inner.ready_for_comm() }