Skip to content

Commit

Permalink
Merge pull request #5 from worldcoin/dzejkop/db-sync
Browse files Browse the repository at this point in the history
Dzejkop/db-sync
  • Loading branch information
Dzejkop authored Jan 31, 2024
2 parents 0002790 + ed6bb43 commit 333312f
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 62 deletions.
45 changes: 45 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ sqlx = { version = "0.7.3", features = [
"postgres",
"chrono",
] }
docker-db = { git = "https://github.com/Dzejkop/docker-db", rev = "ef9a4dfccd9cb9b4babeebfea0ba815e36afbafe" }
telemetry-batteries = { git = "https://github.com/worldcoin/telemetry-batteries.git", rev = "c6816624415ae194da5203a5161621a9e10ad3b0" }
tokio = { version = "1.35.1", features = ["macros"] }
tracing = "0.1.40"
Expand Down
4 changes: 4 additions & 0 deletions migrations/coordinator/001_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE masks (
id BIGINT PRIMARY KEY,
mask BYTEA NOT NULL
);
4 changes: 4 additions & 0 deletions migrations/participant/001_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE shares (
id BIGINT PRIMARY KEY,
share BYTEA NOT NULL
);
126 changes: 64 additions & 62 deletions src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl Coordinator {
JoinHandle<eyre::Result<()>>,
) {
// Collect batches of shares
let (processed_shares_tx, mut processed_shares_rx) = mpsc::channel(4);
// let (processed_shares_tx, mut processed_shares_rx) = mpsc::channel(4);

let streams_future =
future::try_join_all(self.participants.iter_mut().enumerate().map(
Expand All @@ -158,7 +158,7 @@ impl Coordinator {
let bytes_read = stream.read_buf(&mut buffer).await?;
if bytes_read == 0 {
let n_incomplete = (buffer.len()
+ std::mem::size_of::<[u16; 31]>() //TODO: make this a const
+ std::mem::size_of::<[u16; 31]>() //TODO: make this a const
- 1)
/ std::mem::size_of::<[u16; 31]>(); //TODO: make this a const
batch.truncate(batch.len() - n_incomplete);
Expand All @@ -170,67 +170,69 @@ impl Coordinator {
},
));

let batch_worker = tokio::task::spawn(async move {
loop {
// Collect futures of denominator and share batches
let streams_future = future::try_join_all(
self.participants.iter_mut().enumerate().map(
|(i, stream)| async move {
let mut batch = vec![[0_u16; 31]; BATCH_SIZE];
let mut buffer: &mut [u8] =
bytemuck::cast_slice_mut(batch.as_mut_slice());

// We can not use read_exact here as we might get EOF before the
// buffer is full But we should
// still try to fill the entire buffer.
// If nothing else, this guarantees that we read batches at a
// [u16;31] boundary.
while !buffer.is_empty() {
let bytes_read =
stream.read_buf(&mut buffer).await?;
if bytes_read == 0 {
let n_incomplete = (buffer.len()
+ std::mem::size_of::<[u16; 31]>() //TODO: make this a const
- 1)
/ std::mem::size_of::<[u16; 31]>(); //TODO: make this a const
batch.truncate(batch.len() - n_incomplete);
break;
}
}

Ok::<_, eyre::Report>(batch)
},
),
);

// Wait on all parts concurrently
let (denom, shares) =
tokio::join!(denominator_rx.recv(), streams_future);

let mut denom = denom.unwrap_or_default();
let mut shares = shares?;

// Find the shortest prefix
let batch_size = shares
.iter()
.map(Vec::len)
.fold(denom.len(), core::cmp::min);

denom.truncate(batch_size);
shares
.iter_mut()
.for_each(|batch| batch.truncate(batch_size));

// Send batches
processed_shares_tx.send((denom, shares)).await?;
if batch_size == 0 {
break;
}
}
Ok(())
});
// let batch_worker = tokio::task::spawn(async move {
// loop {
// // Collect futures of denominator and share batches
// let streams_future = future::try_join_all(
// self.participants.iter_mut().enumerate().map(
// |(i, stream)| async move {
// let mut batch = vec![[0_u16; 31]; BATCH_SIZE];
// let mut buffer: &mut [u8] =
// bytemuck::cast_slice_mut(batch.as_mut_slice());

// // We can not use read_exact here as we might get EOF before the
// // buffer is full But we should
// // still try to fill the entire buffer.
// // If nothing else, this guarantees that we read batches at a
// // [u16;31] boundary.
// while !buffer.is_empty() {
// let bytes_read =
// stream.read_buf(&mut buffer).await?;
// if bytes_read == 0 {
// let n_incomplete = (buffer.len()
// + std::mem::size_of::<[u16; 31]>() //TODO: make this a const
// - 1)
// / std::mem::size_of::<[u16; 31]>(); //TODO: make this a const
// batch.truncate(batch.len() - n_incomplete);
// break;
// }
// }

// Ok::<_, eyre::Report>(batch)
// },
// ),
// );

// // Wait on all parts concurrently
// let (denom, shares) =
// tokio::join!(denominator_rx.recv(), streams_future);

// let mut denom = denom.unwrap_or_default();
// let mut shares = shares?;

// // Find the shortest prefix
// let batch_size = shares
// .iter()
// .map(Vec::len)
// .fold(denom.len(), core::cmp::min);

// denom.truncate(batch_size);
// shares
// .iter_mut()
// .for_each(|batch| batch.truncate(batch_size));

// // Send batches
// processed_shares_tx.send((denom, shares)).await?;
// if batch_size == 0 {
// break;
// }
// }
// Ok(())
// });

// (processed_shares_rx, batch_worker)

(processed_shares_rx, batch_worker)
todo!()
}

pub async fn process_results(
Expand Down
3 changes: 3 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod coordinator;
pub mod participant;
pub mod impls;
127 changes: 127 additions & 0 deletions src/db/coordinator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use sqlx::migrate::Migrator;
use sqlx::{Postgres, QueryBuilder};

use crate::bits::Bits;

static MIGRATOR: Migrator = sqlx::migrate!("migrations/coordinator/");

pub struct CoordinatorDb {
pool: sqlx::Pool<Postgres>,
}

impl CoordinatorDb {
pub async fn new(url: &str) -> eyre::Result<Self> {
let pool = sqlx::Pool::connect(url).await?;

MIGRATOR.run(&pool).await?;

Ok(Self { pool })
}

pub async fn fetch_masks(&self, id: usize) -> eyre::Result<Vec<Bits>> {
let masks: Vec<(Bits,)> = sqlx::query_as(
r#"
SELECT mask
FROM masks
WHERE id >= $1
ORDER BY id ASC
"#,
)
.bind(id as i64)
.fetch_all(&self.pool)
.await?;

Ok(masks.into_iter().map(|(mask,)| mask).collect())
}

pub async fn insert_masks(
&self,
masks: &[(u64, Bits)],
) -> eyre::Result<()> {
let mut builder =
QueryBuilder::new("INSERT INTO masks (id, mask) VALUES ");

for (idx, (id, mask)) in masks.iter().enumerate() {
if idx > 0 {
builder.push(", ");
}
builder.push("(");
builder.push_bind(*id as i64);
builder.push(", ");
builder.push_bind(mask);
builder.push(")");
}

builder.push(" ON CONFLICT (id) DO NOTHING");

let query = builder.build();

query.execute(&self.pool).await?;

Ok(())
}
}

#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};

use super::*;

#[tokio::test]
async fn fetch_on_empty() -> eyre::Result<()> {
let db = docker_db::Postgres::spawn().await?;
let url = format!("postgres://postgres:postgres@{}", db.socket_addr());

let db = CoordinatorDb::new(&url).await?;

let masks = db.fetch_masks(0).await?;

assert!(masks.is_empty());

Ok(())
}

#[tokio::test]
async fn insert_and_fetch() -> eyre::Result<()> {
let db = docker_db::Postgres::spawn().await?;
let url = format!("postgres://postgres:postgres@{}", db.socket_addr());

let db = CoordinatorDb::new(&url).await?;

let mut rng = thread_rng();

let masks = vec![(0, rng.gen::<Bits>()), (1, rng.gen::<Bits>())];

db.insert_masks(&masks).await?;

let fetched_masks = db.fetch_masks(0).await?;
let masks_without_ids =
masks.iter().map(|(_, mask)| *mask).collect::<Vec<_>>();

assert_eq!(fetched_masks.len(), 2);
assert_eq!(fetched_masks, masks_without_ids);

Ok(())
}

#[tokio::test]
async fn partial_fetch() -> eyre::Result<()> {
let db = docker_db::Postgres::spawn().await?;
let url = format!("postgres://postgres:postgres@{}", db.socket_addr());

let db = CoordinatorDb::new(&url).await?;

let mut rng = thread_rng();

let masks = vec![(0, rng.gen::<Bits>()), (1, rng.gen::<Bits>())];

db.insert_masks(&masks).await?;

let fetched_masks = db.fetch_masks(1).await?;

assert_eq!(fetched_masks[0], masks[1].1);

Ok(())
}
}
Loading

0 comments on commit 333312f

Please sign in to comment.