From be741d84d0266e11217df1c1ac00ded9f47477cf Mon Sep 17 00:00:00 2001 From: Dzejkop Date: Wed, 31 Jan 2024 14:12:19 +0100 Subject: [PATCH 1/3] TEMP: Fix errors --- src/coordinator.rs | 126 +++++++++++++++++++++++---------------------- 1 file changed, 64 insertions(+), 62 deletions(-) diff --git a/src/coordinator.rs b/src/coordinator.rs index 351aa56..01ee038 100644 --- a/src/coordinator.rs +++ b/src/coordinator.rs @@ -140,7 +140,7 @@ impl Coordinator { JoinHandle>, ) { // Collect batches of shares - let (processed_shares_tx, mut processed_shares_rx) = mpsc::channel(4); + // let (processed_shares_tx, mut processed_shares_rx) = mpsc::channel(4); let streams_future = future::try_join_all(self.participants.iter_mut().enumerate().map( @@ -158,7 +158,7 @@ impl Coordinator { let bytes_read = stream.read_buf(&mut buffer).await?; if bytes_read == 0 { let n_incomplete = (buffer.len() - + std::mem::size_of::<[u16; 31]>() //TODO: make this a const + + std::mem::size_of::<[u16; 31]>() //TODO: make this a const - 1) / std::mem::size_of::<[u16; 31]>(); //TODO: make this a const batch.truncate(batch.len() - n_incomplete); @@ -170,67 +170,69 @@ impl Coordinator { }, )); - let batch_worker = tokio::task::spawn(async move { - loop { - // Collect futures of denominator and share batches - let streams_future = future::try_join_all( - self.participants.iter_mut().enumerate().map( - |(i, stream)| async move { - let mut batch = vec![[0_u16; 31]; BATCH_SIZE]; - let mut buffer: &mut [u8] = - bytemuck::cast_slice_mut(batch.as_mut_slice()); - - // We can not use read_exact here as we might get EOF before the - // buffer is full But we should - // still try to fill the entire buffer. - // If nothing else, this guarantees that we read batches at a - // [u16;31] boundary. - while !buffer.is_empty() { - let bytes_read = - stream.read_buf(&mut buffer).await?; - if bytes_read == 0 { - let n_incomplete = (buffer.len() - + std::mem::size_of::<[u16; 31]>() //TODO: make this a const - - 1) - / std::mem::size_of::<[u16; 31]>(); //TODO: make this a const - batch.truncate(batch.len() - n_incomplete); - break; - } - } - - Ok::<_, eyre::Report>(batch) - }, - ), - ); - - // Wait on all parts concurrently - let (denom, shares) = - tokio::join!(denominator_rx.recv(), streams_future); - - let mut denom = denom.unwrap_or_default(); - let mut shares = shares?; - - // Find the shortest prefix - let batch_size = shares - .iter() - .map(Vec::len) - .fold(denom.len(), core::cmp::min); - - denom.truncate(batch_size); - shares - .iter_mut() - .for_each(|batch| batch.truncate(batch_size)); - - // Send batches - processed_shares_tx.send((denom, shares)).await?; - if batch_size == 0 { - break; - } - } - Ok(()) - }); + // let batch_worker = tokio::task::spawn(async move { + // loop { + // // Collect futures of denominator and share batches + // let streams_future = future::try_join_all( + // self.participants.iter_mut().enumerate().map( + // |(i, stream)| async move { + // let mut batch = vec![[0_u16; 31]; BATCH_SIZE]; + // let mut buffer: &mut [u8] = + // bytemuck::cast_slice_mut(batch.as_mut_slice()); + + // // We can not use read_exact here as we might get EOF before the + // // buffer is full But we should + // // still try to fill the entire buffer. + // // If nothing else, this guarantees that we read batches at a + // // [u16;31] boundary. + // while !buffer.is_empty() { + // let bytes_read = + // stream.read_buf(&mut buffer).await?; + // if bytes_read == 0 { + // let n_incomplete = (buffer.len() + // + std::mem::size_of::<[u16; 31]>() //TODO: make this a const + // - 1) + // / std::mem::size_of::<[u16; 31]>(); //TODO: make this a const + // batch.truncate(batch.len() - n_incomplete); + // break; + // } + // } + + // Ok::<_, eyre::Report>(batch) + // }, + // ), + // ); + + // // Wait on all parts concurrently + // let (denom, shares) = + // tokio::join!(denominator_rx.recv(), streams_future); + + // let mut denom = denom.unwrap_or_default(); + // let mut shares = shares?; + + // // Find the shortest prefix + // let batch_size = shares + // .iter() + // .map(Vec::len) + // .fold(denom.len(), core::cmp::min); + + // denom.truncate(batch_size); + // shares + // .iter_mut() + // .for_each(|batch| batch.truncate(batch_size)); + + // // Send batches + // processed_shares_tx.send((denom, shares)).await?; + // if batch_size == 0 { + // break; + // } + // } + // Ok(()) + // }); + + // (processed_shares_rx, batch_worker) - (processed_shares_rx, batch_worker) + todo!() } pub async fn process_results( From 46608b56f94d29070c09560d9caed7928e21505d Mon Sep 17 00:00:00 2001 From: Dzejkop Date: Wed, 31 Jan 2024 14:51:30 +0100 Subject: [PATCH 2/3] Coordinator DB --- migrations/coordinator/001_init.sql | 4 ++++ migrations/participant/001_init.sql | 4 ++++ 2 files changed, 8 insertions(+) create mode 100644 migrations/coordinator/001_init.sql create mode 100644 migrations/participant/001_init.sql diff --git a/migrations/coordinator/001_init.sql b/migrations/coordinator/001_init.sql new file mode 100644 index 0000000..d1d74d1 --- /dev/null +++ b/migrations/coordinator/001_init.sql @@ -0,0 +1,4 @@ +CREATE TABLE masks ( + id BIGINT PRIMARY KEY, + mask BYTEA NOT NULL +); diff --git a/migrations/participant/001_init.sql b/migrations/participant/001_init.sql new file mode 100644 index 0000000..ecd7af4 --- /dev/null +++ b/migrations/participant/001_init.sql @@ -0,0 +1,4 @@ +CREATE TABLE shares ( + id BIGINT PRIMARY KEY, + share BYTEA NOT NULL +); From ed6bb43d60143d983b6f129be5c4c409b9abe4c2 Mon Sep 17 00:00:00 2001 From: Dzejkop Date: Wed, 31 Jan 2024 14:51:40 +0100 Subject: [PATCH 3/3] Implementation --- Cargo.lock | 45 +++++++++++++++ Cargo.toml | 1 + src/db.rs | 3 + src/db/coordinator.rs | 127 ++++++++++++++++++++++++++++++++++++++++++ src/db/impls.rs | 44 +++++++++++++++ src/db/participant.rs | 0 src/lib.rs | 1 + 7 files changed, 221 insertions(+) create mode 100644 src/db.rs create mode 100644 src/db/coordinator.rs create mode 100644 src/db/impls.rs create mode 100644 src/db/participant.rs diff --git a/Cargo.lock b/Cargo.lock index e0e7663..b969455 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -942,6 +942,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0688c2a7f92e427f44895cd63841bff7b29f8d7a1648b9e7e07a4a365b2e1257" +[[package]] +name = "docker-db" +version = "0.1.0" +source = "git+https://github.com/Dzejkop/docker-db?rev=ef9a4dfccd9cb9b4babeebfea0ba815e36afbafe#ef9a4dfccd9cb9b4babeebfea0ba815e36afbafe" +dependencies = [ + "test-case", + "thiserror", + "tokio", +] + [[package]] name = "dotenv" version = "0.15.0" @@ -1695,6 +1705,7 @@ dependencies = [ "clap", "config", "criterion", + "docker-db", "dotenv", "eyre", "float_eq", @@ -3079,6 +3090,39 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "test-case" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2550dd13afcd286853192af8601920d959b14c401fcece38071d53bf0768a8" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-core" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "test-case-macros" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", + "test-case-core", +] + [[package]] name = "thiserror" version = "1.0.56" @@ -3174,6 +3218,7 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", diff --git a/Cargo.toml b/Cargo.toml index 7ab0796..46ac0cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ sqlx = { version = "0.7.3", features = [ "postgres", "chrono", ] } +docker-db = { git = "https://github.com/Dzejkop/docker-db", rev = "ef9a4dfccd9cb9b4babeebfea0ba815e36afbafe" } telemetry-batteries = { git = "https://github.com/worldcoin/telemetry-batteries.git", rev = "c6816624415ae194da5203a5161621a9e10ad3b0" } tokio = { version = "1.35.1", features = ["macros"] } tracing = "0.1.40" diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..3ba0474 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,3 @@ +pub mod coordinator; +pub mod participant; +pub mod impls; diff --git a/src/db/coordinator.rs b/src/db/coordinator.rs new file mode 100644 index 0000000..d97bd1f --- /dev/null +++ b/src/db/coordinator.rs @@ -0,0 +1,127 @@ +use sqlx::migrate::Migrator; +use sqlx::{Postgres, QueryBuilder}; + +use crate::bits::Bits; + +static MIGRATOR: Migrator = sqlx::migrate!("migrations/coordinator/"); + +pub struct CoordinatorDb { + pool: sqlx::Pool, +} + +impl CoordinatorDb { + pub async fn new(url: &str) -> eyre::Result { + let pool = sqlx::Pool::connect(url).await?; + + MIGRATOR.run(&pool).await?; + + Ok(Self { pool }) + } + + pub async fn fetch_masks(&self, id: usize) -> eyre::Result> { + let masks: Vec<(Bits,)> = sqlx::query_as( + r#" + SELECT mask + FROM masks + WHERE id >= $1 + ORDER BY id ASC + "#, + ) + .bind(id as i64) + .fetch_all(&self.pool) + .await?; + + Ok(masks.into_iter().map(|(mask,)| mask).collect()) + } + + pub async fn insert_masks( + &self, + masks: &[(u64, Bits)], + ) -> eyre::Result<()> { + let mut builder = + QueryBuilder::new("INSERT INTO masks (id, mask) VALUES "); + + for (idx, (id, mask)) in masks.iter().enumerate() { + if idx > 0 { + builder.push(", "); + } + builder.push("("); + builder.push_bind(*id as i64); + builder.push(", "); + builder.push_bind(mask); + builder.push(")"); + } + + builder.push(" ON CONFLICT (id) DO NOTHING"); + + let query = builder.build(); + + query.execute(&self.pool).await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, Rng}; + + use super::*; + + #[tokio::test] + async fn fetch_on_empty() -> eyre::Result<()> { + let db = docker_db::Postgres::spawn().await?; + let url = format!("postgres://postgres:postgres@{}", db.socket_addr()); + + let db = CoordinatorDb::new(&url).await?; + + let masks = db.fetch_masks(0).await?; + + assert!(masks.is_empty()); + + Ok(()) + } + + #[tokio::test] + async fn insert_and_fetch() -> eyre::Result<()> { + let db = docker_db::Postgres::spawn().await?; + let url = format!("postgres://postgres:postgres@{}", db.socket_addr()); + + let db = CoordinatorDb::new(&url).await?; + + let mut rng = thread_rng(); + + let masks = vec![(0, rng.gen::()), (1, rng.gen::())]; + + db.insert_masks(&masks).await?; + + let fetched_masks = db.fetch_masks(0).await?; + let masks_without_ids = + masks.iter().map(|(_, mask)| *mask).collect::>(); + + assert_eq!(fetched_masks.len(), 2); + assert_eq!(fetched_masks, masks_without_ids); + + Ok(()) + } + + #[tokio::test] + async fn partial_fetch() -> eyre::Result<()> { + let db = docker_db::Postgres::spawn().await?; + let url = format!("postgres://postgres:postgres@{}", db.socket_addr()); + + let db = CoordinatorDb::new(&url).await?; + + let mut rng = thread_rng(); + + let masks = vec![(0, rng.gen::()), (1, rng.gen::())]; + + db.insert_masks(&masks).await?; + + let fetched_masks = db.fetch_masks(1).await?; + + assert_eq!(fetched_masks[0], masks[1].1); + + Ok(()) + } +} diff --git a/src/db/impls.rs b/src/db/impls.rs new file mode 100644 index 0000000..338ccea --- /dev/null +++ b/src/db/impls.rs @@ -0,0 +1,44 @@ +use crate::bits::{Bits, BITS}; + +const BYTES: usize = BITS / 8; + +impl sqlx::Type for Bits +where + DB: sqlx::Database, + [u8; BYTES]: sqlx::Type, +{ + fn type_info() -> DB::TypeInfo { + <[u8; BYTES] as sqlx::Type>::type_info() + } +} + +impl<'r, DB> sqlx::Decode<'r, DB> for Bits +where + DB: sqlx::Database, + [u8; BYTES]: sqlx::Decode<'r, DB>, +{ + fn decode( + value: >::ValueRef, + ) -> Result { + let bytes = <[u8; BYTES] as sqlx::Decode>::decode(value)?; + + Ok(bytemuck::pod_read_unaligned(&bytes)) + } +} + +impl<'q, DB> sqlx::Encode<'q, DB> for Bits +where + DB: sqlx::Database, + [u8; BYTES]: sqlx::Encode<'q, DB>, +{ + fn encode_by_ref( + &self, + buf: &mut >::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + // The size of the underlying data makes it unaligned + let bytes = bytemuck::bytes_of(self); + let bytes: [u8; BYTES] = bytes.try_into().expect("Wrong size"); + + <[u8; BYTES] as sqlx::Encode>::encode(bytes, buf) + } +} diff --git a/src/db/participant.rs b/src/db/participant.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/lib.rs b/src/lib.rs index 979567e..58bfd06 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,3 +5,4 @@ pub mod coordinator; pub mod distance; pub mod encoded_bits; pub mod template; +pub mod db;